openai_tools/chat/
mod.rs

1//! # Chat Module
2//!
3//! This module provides functionality for interacting with the OpenAI Chat Completions API.
4//! It includes tools for building requests, sending them to OpenAI's chat completion endpoint,
5//! and processing the responses.
6//!
7//! ## Key Features
8//!
9//! - Chat completion request building and sending
10//! - Structured output support with JSON schema
11//! - Response parsing and processing
12//! - Support for various OpenAI models and parameters
13//!
14//! ## Usage Examples
15//!
16//! ### Basic Chat Completion
17//!
18//! ```rust,no_run
19//! use openai_tools::chat::request::ChatCompletion;
20//! use openai_tools::common::message::Message;
21//! use openai_tools::common::role::Role;
22//!
23//! #[tokio::main]
24//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
25//!     let mut chat = ChatCompletion::new();
26//!     let messages = vec![Message::from_string(Role::User, "Hello!")];
27//!     
28//!     let response = chat
29//!         .model_id("gpt-4o-mini")
30//!         .messages(messages)
31//!         .temperature(1.0)
32//!         .chat()
33//!         .await?;
34//!         
35//!     println!("{}", response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap());
36//!     Ok(())
37//! }
38//! ```
39//!
40//! ### Using JSON Schema for Structured Output
41//!
42//! ```rust,no_run
43//! use openai_tools::chat::request::ChatCompletion;
44//! use openai_tools::common::message::Message;
45//! use openai_tools::common::role::Role;
46//! use openai_tools::common::structured_output::Schema;
47//! use serde::{Deserialize, Serialize};
48//!
49//! #[derive(Debug, Serialize, Deserialize)]
50//! struct WeatherInfo {
51//!     location: String,
52//!     date: String,
53//!     weather: String,
54//!     temperature: String,
55//! }
56//!
57//! #[tokio::main]
58//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
59//!     let mut chat = ChatCompletion::new();
60//!     let messages = vec![Message::from_string(
61//!         Role::User,
62//!         "What's the weather like tomorrow in Tokyo?"
63//!     )];
64//!     
65//!     // Create JSON schema for structured output
66//!     let mut json_schema = Schema::chat_json_schema("weather");
67//!     json_schema.add_property("location", "string", "The location for weather check");
68//!     json_schema.add_property("date", "string", "The date for weather forecast");
69//!     json_schema.add_property("weather", "string", "Weather condition description");
70//!     json_schema.add_property("temperature", "string", "Temperature information");
71//!     
72//!     let response = chat
73//!         .model_id("gpt-4o-mini")
74//!         .messages(messages)
75//!         .temperature(0.7)
76//!         .json_schema(json_schema)
77//!         .chat()
78//!         .await?;
79//!         
80//!     // Parse structured response
81//!     let weather: WeatherInfo = serde_json::from_str(
82//!         response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap()
83//!     )?;
84//!     println!("Weather in {}: {} on {}, Temperature: {}",
85//!              weather.location, weather.weather, weather.date, weather.temperature);
86//!     Ok(())
87//! }
88//! ```
89//!
90//! ### Using Function Calling with Tools
91//!
92//! ```rust,no_run
93//! use openai_tools::chat::request::ChatCompletion;
94//! use openai_tools::common::message::Message;
95//! use openai_tools::common::role::Role;
96//! use openai_tools::common::tool::Tool;
97//! use openai_tools::common::parameters::ParameterProperty;
98//!
99//! #[tokio::main]
100//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
101//!     let mut chat = ChatCompletion::new();
102//!     let messages = vec![Message::from_string(
103//!         Role::User,
104//!         "Please calculate 25 + 17 using the calculator tool"
105//!     )];
106//!     
107//!     // Define a calculator function tool
108//!     let calculator_tool = Tool::function(
109//!         "calculator",
110//!         "A calculator that can perform basic arithmetic operations",
111//!         vec![
112//!             ("operation", ParameterProperty::from_string("The operation to perform (add, subtract, multiply, divide)")),
113//!             ("a", ParameterProperty::from_number("The first number")),
114//!             ("b", ParameterProperty::from_number("The second number")),
115//!         ],
116//!         false, // strict mode
117//!     );
118//!     
119//!     let response = chat
120//!         .model_id("gpt-4o-mini")
121//!         .messages(messages)
122//!         .temperature(0.1)
123//!         .tools(vec![calculator_tool])
124//!         .chat()
125//!         .await?;
126//!         
127//!     // Handle function calls in the response
128//!     if let Some(tool_calls) = &response.choices[0].message.tool_calls {
129//!         // Add the assistant's message with tool calls to conversation history
130//!         chat.add_message(response.choices[0].message.clone());
131//!         
132//!         for tool_call in tool_calls {
133//!             println!("Function called: {}", tool_call.function.name);
134//!             if let Ok(args) = tool_call.function.arguments_as_map() {
135//!                 println!("Arguments: {:?}", args);
136//!             }
137//!             
138//!             // Execute the function (in this example, we simulate the calculation)
139//!             let result = "42"; // This would be the actual calculation result
140//!             
141//!             // Add the tool call response to continue the conversation
142//!             chat.add_message(Message::from_tool_call_response(result, &tool_call.id));
143//!         }
144//!         
145//!         // Get the final response after tool execution
146//!         let final_response = chat.chat().await?;
147//!         if let Some(content) = &final_response.choices[0].message.content {
148//!             if let Some(text) = &content.text {
149//!                 println!("Final answer: {}", text);
150//!             }
151//!         }
152//!     } else if let Some(content) = &response.choices[0].message.content {
153//!         if let Some(text) = &content.text {
154//!             println!("{}", text);
155//!         }
156//!     }
157//!     Ok(())
158//! }
159//! ```
160
161pub mod request;
162pub mod response;
163
164#[cfg(test)]
165mod tests {
166    use crate::chat::request::ChatCompletion;
167    use crate::common::{
168        errors::OpenAIToolError, message::Message, parameters::ParameterProperty, role::Role, structured_output::Schema, tool::Tool,
169    };
170    use serde::{Deserialize, Serialize};
171    use serde_json;
172    use std::sync::Once;
173    use tracing_subscriber::EnvFilter;
174
175    static INIT: Once = Once::new();
176
177    fn init_tracing() {
178        INIT.call_once(|| {
179            // `RUST_LOG` 環境変数があればそれを使い、なければ "info"
180            let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
181            // try_init()を使用してsubscriberが既に設定されている場合はスキップ
182            let _ = tracing_subscriber::fmt()
183                .with_env_filter(filter)
184                .with_test_writer() // `cargo test` / nextest 用
185                .try_init();
186        });
187    }
188    #[tokio::test]
189    #[test_log::test]
190    async fn test_chat_completion() {
191        init_tracing();
192        let mut chat = ChatCompletion::new();
193        let messages = vec![Message::from_string(Role::User, "Hi there!")];
194
195        chat.model_id("gpt-4o-mini").messages(messages).temperature(1.0);
196
197        let mut counter = 3;
198        loop {
199            match chat.chat().await {
200                Ok(response) => {
201                    tracing::info!("{:?}", &response.choices[0].message.content.clone().expect("Response content should not be empty"));
202                    assert!(true);
203                    break;
204                }
205                Err(e) => match e {
206                    OpenAIToolError::RequestError(e) => {
207                        tracing::warn!("Request error: {} (retrying... {})", e, counter);
208                        counter -= 1;
209                        if counter == 0 {
210                            assert!(false, "Chat completion failed (retry limit reached)");
211                        }
212                        continue;
213                    }
214                    _ => {
215                        tracing::error!("Error: {}", e);
216                        assert!(false, "Chat completion failed");
217                    }
218                },
219            };
220        }
221    }
222
223    #[tokio::test]
224    #[test_log::test]
225    async fn test_chat_completion_2() {
226        init_tracing();
227        let mut chat = ChatCompletion::new();
228        let messages = vec![Message::from_string(Role::User, "トンネルを抜けると?")];
229
230        chat.model_id("gpt-4o-mini").messages(messages).temperature(1.5);
231
232        let mut counter = 3;
233        loop {
234            match chat.chat().await {
235                Ok(response) => {
236                    println!("{:?}", &response.choices[0].message.content.clone().expect("Response content should not be empty"));
237                    assert!(true);
238                    break;
239                }
240                Err(e) => match e {
241                    OpenAIToolError::RequestError(e) => {
242                        tracing::warn!("Request error: {} (retrying... {})", e, counter);
243                        counter -= 1;
244                        if counter == 0 {
245                            assert!(false, "Chat completion failed (retry limit reached)");
246                        }
247                        continue;
248                    }
249                    _ => {
250                        tracing::error!("Error: {}", e);
251                        assert!(false, "Chat completion failed");
252                    }
253                },
254            };
255        }
256    }
257
258    #[derive(Debug, Serialize, Deserialize)]
259    struct Weather {
260        #[serde(default = "String::new")]
261        location: String,
262        #[serde(default = "String::new")]
263        date: String,
264        #[serde(default = "String::new")]
265        weather: String,
266        #[serde(default = "String::new")]
267        error: String,
268    }
269
270    #[tokio::test]
271    #[test_log::test]
272    async fn test_chat_completion_with_json_schema() {
273        init_tracing();
274        let mut openai = ChatCompletion::new();
275        let messages = vec![Message::from_string(Role::User, "Hi there! How's the weather tomorrow in Tokyo? If you can't answer, report error.")];
276
277        let mut json_schema = Schema::chat_json_schema("weather");
278        json_schema.add_property("location", "string", "The location to check the weather for.");
279        json_schema.add_property("date", "string", "The date to check the weather for.");
280        json_schema.add_property("weather", "string", "The weather for the location and date.");
281        json_schema.add_property("error", "string", "Error message. If there is no error, leave this field empty.");
282        openai.model_id("gpt-4o-mini").messages(messages).temperature(1.0).json_schema(json_schema);
283
284        let mut counter = 3;
285        loop {
286            match openai.chat().await {
287                Ok(response) => {
288                    println!("{:#?}", response);
289                    match serde_json::from_str::<Weather>(
290                        &response.choices[0]
291                            .message
292                            .content
293                            .clone()
294                            .expect("Response content should not be empty")
295                            .text
296                            .expect("Response content should not be empty"),
297                    ) {
298                        Ok(weather) => {
299                            println!("{:#?}", weather);
300                            assert!(true);
301                        }
302                        Err(e) => {
303                            println!("{:#?}", e);
304                            assert!(false);
305                        }
306                    }
307                    break;
308                }
309                Err(e) => match e {
310                    OpenAIToolError::RequestError(e) => {
311                        tracing::warn!("Request error: {} (retrying... {})", e, counter);
312                        counter -= 1;
313                        if counter == 0 {
314                            assert!(false, "Chat completion failed (retry limit reached)");
315                        }
316                        continue;
317                    }
318                    _ => {
319                        tracing::error!("Error: {}", e);
320                        assert!(false, "Chat completion failed");
321                    }
322                },
323            };
324        }
325    }
326
327    #[derive(Deserialize)]
328    struct Summary {
329        pub is_survey: bool,
330        pub research_question: String,
331        pub contributions: String,
332        pub dataset: String,
333        pub proposed_method: String,
334        pub experiment_results: String,
335        pub comparison_with_related_works: String,
336        pub future_works: String,
337    }
338    #[tokio::test]
339    #[test_log::test]
340    async fn test_summarize() {
341        init_tracing();
342        let mut openai = ChatCompletion::new();
343        let instruction = std::fs::read_to_string("src/test_rsc/sample_instruction.txt").unwrap();
344
345        let messages = vec![Message::from_string(Role::User, instruction.clone())];
346
347        let mut json_schema = Schema::chat_json_schema("summary");
348        json_schema.add_property("is_survey", "boolean", "この論文がサーベイ論文かどうかをtrue/falseで判定.");
349        json_schema.add_property(
350            "research_question",
351            "string",
352            "この論文のリサーチクエスチョンの説明.この論文の背景や既存研究との関連も含めて記述する.",
353        );
354        json_schema.add_property("contributions", "string", "この論文のコントリビューションをリスト形式で記述する.");
355        json_schema.add_property("dataset", "string", "この論文で使用されているデータセットをリストアップする.");
356        json_schema.add_property("proposed_method", "string", "提案手法の詳細な説明.");
357        json_schema.add_property("experiment_results", "string", "実験の結果の詳細な説明.");
358        json_schema.add_property(
359            "comparison_with_related_works",
360            "string",
361            "関連研究と比較した場合のこの論文の新規性についての説明.可能な限り既存研究を参照しながら記述すること.",
362        );
363        json_schema.add_property("future_works", "string", "未解決の課題および将来の研究の方向性について記述.");
364
365        openai.model_id(String::from("gpt-4o-mini")).messages(messages).temperature(1.0).json_schema(json_schema);
366
367        let mut counter = 3;
368        loop {
369            match openai.chat().await {
370                Ok(response) => {
371                    println!("{:#?}", response);
372                    match serde_json::from_str::<Summary>(
373                        &response.choices[0]
374                            .message
375                            .content
376                            .clone()
377                            .expect("Response content should not be empty")
378                            .text
379                            .expect("Response content should not be empty"),
380                    ) {
381                        Ok(summary) => {
382                            tracing::info!("Summary.is_survey: {}", summary.is_survey);
383                            tracing::info!("Summary.research_question: {}", summary.research_question);
384                            tracing::info!("Summary.contributions: {}", summary.contributions);
385                            tracing::info!("Summary.dataset: {}", summary.dataset);
386                            tracing::info!("Summary.proposed_method: {}", summary.proposed_method);
387                            tracing::info!("Summary.experiment_results: {}", summary.experiment_results);
388                            tracing::info!("Summary.comparison_with_related_works: {}", summary.comparison_with_related_works);
389                            tracing::info!("Summary.future_works: {}", summary.future_works);
390                            assert!(true);
391                        }
392                        Err(e) => {
393                            tracing::error!("Error: {}", e);
394                            assert!(false);
395                        }
396                    }
397                    break;
398                }
399                Err(e) => match e {
400                    OpenAIToolError::RequestError(e) => {
401                        tracing::warn!("Request error: {} (retrying... {})", e, counter);
402                        counter -= 1;
403                        if counter == 0 {
404                            assert!(false, "Chat completion failed (retry limit reached)");
405                        }
406                        continue;
407                    }
408                    _ => {
409                        tracing::error!("Error: {}", e);
410                        assert!(false, "Chat completion failed");
411                    }
412                },
413            };
414        }
415    }
416
417    #[tokio::test]
418    #[test_log::test]
419    async fn test_chat_completion_with_function_calling() {
420        init_tracing();
421        let mut chat = ChatCompletion::new();
422        let messages = vec![Message::from_string(Role::User, "Please calculate 25 + 17 using the calculator tool.")];
423
424        // Define a calculator function tool
425        let calculator_tool = Tool::function(
426            "calculator",
427            "A calculator that can perform basic arithmetic operations",
428            vec![
429                ("operation", ParameterProperty::from_string("The operation to perform (add, subtract, multiply, divide)")),
430                ("a", ParameterProperty::from_number("The first number")),
431                ("b", ParameterProperty::from_number("The second number")),
432            ],
433            false, // strict mode
434        );
435
436        chat.model_id("gpt-4o-mini").messages(messages).temperature(0.1).tools(vec![calculator_tool]);
437        // First call
438        let mut counter = 3;
439        loop {
440            match chat.chat().await {
441                Ok(response) => {
442                    tracing::info!("First Response: {:#?}", response);
443
444                    let message = response.choices[0].message.clone();
445                    chat.add_message(message.clone());
446
447                    // Check if the response contains tool calls
448                    if let Some(tool_calls) = &message.tool_calls {
449                        assert!(!tool_calls.is_empty(), "Tool calls should not be empty");
450
451                        for tool_call in tool_calls {
452                            tracing::info!("Function called: {}", tool_call.function.name);
453                            tracing::info!("Arguments: {:?}", tool_call.function.arguments);
454
455                            // Verify that the calculator function was called
456                            assert_eq!(tool_call.function.name, "calculator");
457
458                            // Parse the arguments to verify they contain the expected operation
459                            let args = tool_call.function.arguments_as_map().unwrap();
460                            assert!(args.get("operation").is_some());
461                            assert!(args.get("a").is_some());
462                            assert!(args.get("b").is_some());
463
464                            tracing::info!("Function call validation passed");
465
466                            chat.add_message(Message::from_tool_call_response("42", &tool_call.id));
467                        }
468                        assert!(true);
469                    } else {
470                        // If no tool calls, check if the content mentions function calling
471                        tracing::info!(
472                            "No tool calls found. Content: {}",
473                            &response.choices[0]
474                                .message
475                                .content
476                                .clone()
477                                .expect("Response content should not be empty")
478                                .text
479                                .expect("Response content should not be empty")
480                        );
481                        // This might happen if the model decides not to use the tool
482                        // We'll still consider this a valid response for testing purposes
483                        assert!(false, "Expected tool calls but none found in response");
484                    }
485                    break;
486                }
487                Err(e) => match e {
488                    OpenAIToolError::RequestError(e) => {
489                        tracing::warn!("Request error: {} (retrying... {})", e, counter);
490                        counter -= 1;
491                        if counter == 0 {
492                            assert!(false, "Function calling test failed (retry limit reached)");
493                        }
494                        continue;
495                    }
496                    _ => {
497                        tracing::error!("Error: {}", e);
498                        assert!(false, "Function calling test failed");
499                    }
500                },
501            };
502        }
503
504        // Second call to ensure the tool is still available
505        let messages = chat.get_message_history();
506        let mut chat = ChatCompletion::new();
507        chat.model_id("gpt-4o-mini").messages(messages).temperature(1.0);
508
509        let mut counter = 3;
510        loop {
511            match chat.chat().await {
512                Ok(response) => {
513                    tracing::info!("Second Response: {:#?}", response);
514                    assert!(!response.choices.is_empty(), "Response should contain at least one choice");
515                    let content = response.choices[0]
516                        .message
517                        .content
518                        .clone()
519                        .expect("Response content should not be empty")
520                        .text
521                        .expect("Response content should not be empty");
522                    tracing::info!("Content: {}", content);
523                    // Check if the content contains the expected result
524                    assert!(content.contains("42"), "Expected content to contain '42', found: {}", content);
525                    break;
526                }
527                Err(e) => match e {
528                    OpenAIToolError::RequestError(e) => {
529                        tracing::warn!("Request error: {} (retrying... {})", e, counter);
530                        counter -= 1;
531                        if counter == 0 {
532                            assert!(false, "Function calling test failed (retry limit reached)");
533                        }
534                        continue;
535                    }
536                    _ => {
537                        tracing::error!("Error: {}", e);
538                        assert!(false, "Function calling test failed");
539                    }
540                },
541            };
542        }
543    }
544
545    // #[tokio::test]
546    // async fn test_chat_completion_with_long_arguments() {
547    //     init_tracing();
548    //     let mut openai = ChatCompletion::new();
549    //     let text = std::fs::read_to_string("src/test_rsc/long_text.txt").unwrap();
550    //     let messages = vec![Message::from_string(Role::User, text)];
551
552    //     let token_count = messages
553    //         .iter()
554    //         .map(|m| m.get_input_token_count())
555    //         .sum::<usize>();
556    //     tracing::info!("Token count: {}", token_count);
557
558    //     openai
559    //         .model_id(String::from("gpt-4o-mini"))
560    //         .messages(messages)
561    //         .temperature(1.0);
562
563    //     let mut counter = 3;
564    //     loop {
565    //         match openai.chat().await {
566    //             Ok(response) => {
567    //                 println!("{:#?}", response);
568    //                 assert!(true);
569    //                 break;
570    //             }
571    //             Err(e) => match e {
572    //                 OpenAIToolError::RequestError(e) => {
573    //                     tracing::warn!("Request error: {} (retrying... {})", e, counter);
574    //                     counter -= 1;
575    //                     if counter == 0 {
576    //                         assert!(false, "Chat completion failed (retry limit reached)");
577    //                     }
578    //                     continue;
579    //                 }
580    //                 _ => {
581    //                     tracing::error!("Error: {}", e);
582    //                     assert!(false, "Chat completion failed");
583    //                 }
584    //             },
585    //         };
586    //     }
587    // }
588}