openai_tools/chat/
mod.rs

1//! # Chat Module
2//!
3//! This module provides functionality for interacting with the OpenAI Chat Completions API.
4//! It includes tools for building requests, sending them to OpenAI's chat completion endpoint,
5//! and processing the responses.
6//!
7//! ## Key Features
8//!
9//! - Chat completion request building and sending
10//! - Structured output support with JSON schema
11//! - Response parsing and processing
12//! - Support for various OpenAI models and parameters
13//!
14//! ## Usage Examples
15//!
16//! ### Basic Chat Completion
17//!
18//! ```rust,no_run
19//! use openai_tools::chat::request::ChatCompletion;
20//! use openai_tools::common::message::Message;
21//! use openai_tools::common::role::Role;
22//!
23//! #[tokio::main]
24//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
25//!     let mut chat = ChatCompletion::new();
26//!     let messages = vec![Message::from_string(Role::User, "Hello!")];
27//!     
28//!     let response = chat
29//!         .model_id("gpt-4o-mini")
30//!         .messages(messages)
31//!         .temperature(1.0)
32//!         .chat()
33//!         .await?;
34//!         
35//!     println!("{}", response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap());
36//!     Ok(())
37//! }
38//! ```
39//!
40//! ### Using JSON Schema for Structured Output
41//!
42//! ```rust,no_run
43//! use openai_tools::chat::request::ChatCompletion;
44//! use openai_tools::common::message::Message;
45//! use openai_tools::common::role::Role;
46//! use openai_tools::common::structured_output::Schema;
47//! use serde::{Deserialize, Serialize};
48//!
49//! #[derive(Debug, Serialize, Deserialize)]
50//! struct WeatherInfo {
51//!     location: String,
52//!     date: String,
53//!     weather: String,
54//!     temperature: String,
55//! }
56//!
57//! #[tokio::main]
58//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
59//!     let mut chat = ChatCompletion::new();
60//!     let messages = vec![Message::from_string(
61//!         Role::User,
62//!         "What's the weather like tomorrow in Tokyo?"
63//!     )];
64//!     
65//!     // Create JSON schema for structured output
66//!     let mut json_schema = Schema::chat_json_schema("weather");
67//!     json_schema.add_property("location", "string", "The location for weather check");
68//!     json_schema.add_property("date", "string", "The date for weather forecast");
69//!     json_schema.add_property("weather", "string", "Weather condition description");
70//!     json_schema.add_property("temperature", "string", "Temperature information");
71//!     
72//!     let response = chat
73//!         .model_id("gpt-4o-mini")
74//!         .messages(messages)
75//!         .temperature(0.7)
76//!         .json_schema(json_schema)
77//!         .chat()
78//!         .await?;
79//!         
80//!     // Parse structured response
81//!     let weather: WeatherInfo = serde_json::from_str(
82//!         response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap()
83//!     )?;
84//!     println!("Weather in {}: {} on {}, Temperature: {}",
85//!              weather.location, weather.weather, weather.date, weather.temperature);
86//!     Ok(())
87//! }
88//! ```
89//!
90//! ### Using Function Calling with Tools
91//!
92//! ```rust,no_run
93//! use openai_tools::chat::request::ChatCompletion;
94//! use openai_tools::common::message::Message;
95//! use openai_tools::common::role::Role;
96//! use openai_tools::common::tool::Tool;
97//! use openai_tools::common::parameters::ParameterProp;
98//!
99//! #[tokio::main]
100//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
101//!     let mut chat = ChatCompletion::new();
102//!     let messages = vec![Message::from_string(
103//!         Role::User,
104//!         "Please calculate 15 * 23 using the calculator tool"
105//!     )];
106//!     
107//!     // Define a calculator function tool
108//!     let calculator_tool = Tool::function(
109//!         "calculator",
110//!         "A calculator that can perform basic arithmetic operations",
111//!         vec![
112//!             ("operation", ParameterProp::string("The operation to perform (add, subtract, multiply, divide)")),
113//!             ("a", ParameterProp::number("The first number")),
114//!             ("b", ParameterProp::number("The second number")),
115//!         ],
116//!         false, // strict mode
117//!     );
118//!     
119//!     let response = chat
120//!         .model_id("gpt-4o-mini")
121//!         .messages(messages)
122//!         .temperature(0.1)
123//!         .tools(vec![calculator_tool])
124//!         .chat()
125//!         .await?;
126//!         
127//!     // Handle function calls in the response
128//!     if let Some(tool_calls) = &response.choices[0].message.tool_calls {
129//!         for tool_call in tool_calls {
130//!             println!("Function called: {}", tool_call.function.name);
131//!             if let Ok(args) = tool_call.function.arguments_as_map() {
132//!                 println!("Arguments: {:?}", args);
133//!             }
134//!             // In a real application, you would execute the function here
135//!             // and send the result back to continue the conversation
136//!         }
137//!     } else if let Some(content) = &response.choices[0].message.content {
138//!         if let Some(text) = &content.text {
139//!             println!("{}", text);
140//!         }
141//!     }
142//!     Ok(())
143//! }
144//! ```
145
146pub mod request;
147pub mod response;
148
149#[cfg(test)]
150mod tests {
151    use crate::chat::request::ChatCompletion;
152    use crate::common::{errors::OpenAIToolError, message::Message, parameters::ParameterProp, role::Role, structured_output::Schema, tool::Tool};
153    use serde::{Deserialize, Serialize};
154    use serde_json;
155    use std::sync::Once;
156    use tracing_subscriber::EnvFilter;
157
158    static INIT: Once = Once::new();
159
160    fn init_tracing() {
161        INIT.call_once(|| {
162            // `RUST_LOG` 環境変数があればそれを使い、なければ "info"
163            let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
164            tracing_subscriber::fmt()
165                .with_env_filter(filter)
166                .with_test_writer() // `cargo test` / nextest 用
167                .init();
168        });
169    }
170    #[tokio::test]
171    async fn test_chat_completion() {
172        init_tracing();
173        let mut chat = ChatCompletion::new();
174        let messages = vec![Message::from_string(Role::User, "Hi there!")];
175
176        chat.model_id("gpt-4o-mini").messages(messages).temperature(1.0);
177
178        let mut counter = 3;
179        loop {
180            match chat.chat().await {
181                Ok(response) => {
182                    tracing::info!("{:?}", &response.choices[0].message.content.clone().expect("Response content should not be empty"));
183                    assert!(true);
184                    break;
185                }
186                Err(e) => match e {
187                    OpenAIToolError::RequestError(e) => {
188                        tracing::warn!("Request error: {} (retrying... {})", e, counter);
189                        counter -= 1;
190                        if counter == 0 {
191                            assert!(false, "Chat completion failed (retry limit reached)");
192                        }
193                        continue;
194                    }
195                    _ => {
196                        tracing::error!("Error: {}", e);
197                        assert!(false, "Chat completion failed");
198                    }
199                },
200            };
201        }
202    }
203
204    #[tokio::test]
205    async fn test_chat_completion_2() {
206        init_tracing();
207        let mut chat = ChatCompletion::new();
208        let messages = vec![Message::from_string(Role::User, "トンネルを抜けると?")];
209
210        chat.model_id("gpt-4o-mini").messages(messages).temperature(1.5);
211
212        let mut counter = 3;
213        loop {
214            match chat.chat().await {
215                Ok(response) => {
216                    println!("{:?}", &response.choices[0].message.content.clone().expect("Response content should not be empty"));
217                    assert!(true);
218                    break;
219                }
220                Err(e) => match e {
221                    OpenAIToolError::RequestError(e) => {
222                        tracing::warn!("Request error: {} (retrying... {})", e, counter);
223                        counter -= 1;
224                        if counter == 0 {
225                            assert!(false, "Chat completion failed (retry limit reached)");
226                        }
227                        continue;
228                    }
229                    _ => {
230                        tracing::error!("Error: {}", e);
231                        assert!(false, "Chat completion failed");
232                    }
233                },
234            };
235        }
236    }
237
238    #[derive(Debug, Serialize, Deserialize)]
239    struct Weather {
240        #[serde(default = "String::new")]
241        location: String,
242        #[serde(default = "String::new")]
243        date: String,
244        #[serde(default = "String::new")]
245        weather: String,
246        #[serde(default = "String::new")]
247        error: String,
248    }
249
250    #[tokio::test]
251    async fn test_chat_completion_with_json_schema() {
252        init_tracing();
253        let mut openai = ChatCompletion::new();
254        let messages = vec![Message::from_string(Role::User, "Hi there! How's the weather tomorrow in Tokyo? If you can't answer, report error.")];
255
256        let mut json_schema = Schema::chat_json_schema("weather");
257        json_schema.add_property("location", "string", "The location to check the weather for.");
258        json_schema.add_property("date", "string", "The date to check the weather for.");
259        json_schema.add_property("weather", "string", "The weather for the location and date.");
260        json_schema.add_property("error", "string", "Error message. If there is no error, leave this field empty.");
261        openai.model_id("gpt-4o-mini").messages(messages).temperature(1.0).json_schema(json_schema);
262
263        let mut counter = 3;
264        loop {
265            match openai.chat().await {
266                Ok(response) => {
267                    println!("{:#?}", response);
268                    match serde_json::from_str::<Weather>(
269                        &response.choices[0]
270                            .message
271                            .content
272                            .clone()
273                            .expect("Response content should not be empty")
274                            .text
275                            .expect("Response content should not be empty"),
276                    ) {
277                        Ok(weather) => {
278                            println!("{:#?}", weather);
279                            assert!(true);
280                        }
281                        Err(e) => {
282                            println!("{:#?}", e);
283                            assert!(false);
284                        }
285                    }
286                    break;
287                }
288                Err(e) => match e {
289                    OpenAIToolError::RequestError(e) => {
290                        tracing::warn!("Request error: {} (retrying... {})", e, counter);
291                        counter -= 1;
292                        if counter == 0 {
293                            assert!(false, "Chat completion failed (retry limit reached)");
294                        }
295                        continue;
296                    }
297                    _ => {
298                        tracing::error!("Error: {}", e);
299                        assert!(false, "Chat completion failed");
300                    }
301                },
302            };
303        }
304    }
305
306    #[derive(Deserialize)]
307    struct Summary {
308        pub is_survey: bool,
309        pub research_question: String,
310        pub contributions: String,
311        pub dataset: String,
312        pub proposed_method: String,
313        pub experiment_results: String,
314        pub comparison_with_related_works: String,
315        pub future_works: String,
316    }
317    #[tokio::test]
318    async fn test_summarize() {
319        init_tracing();
320        let mut openai = ChatCompletion::new();
321        let instruction = std::fs::read_to_string("src/test_rsc/sample_instruction.txt").unwrap();
322
323        let messages = vec![Message::from_string(Role::User, instruction.clone())];
324
325        let mut json_schema = Schema::chat_json_schema("summary");
326        json_schema.add_property("is_survey", "boolean", "この論文がサーベイ論文かどうかをtrue/falseで判定.");
327        json_schema.add_property(
328            "research_question",
329            "string",
330            "この論文のリサーチクエスチョンの説明.この論文の背景や既存研究との関連も含めて記述する.",
331        );
332        json_schema.add_property("contributions", "string", "この論文のコントリビューションをリスト形式で記述する.");
333        json_schema.add_property("dataset", "string", "この論文で使用されているデータセットをリストアップする.");
334        json_schema.add_property("proposed_method", "string", "提案手法の詳細な説明.");
335        json_schema.add_property("experiment_results", "string", "実験の結果の詳細な説明.");
336        json_schema.add_property(
337            "comparison_with_related_works",
338            "string",
339            "関連研究と比較した場合のこの論文の新規性についての説明.可能な限り既存研究を参照しながら記述すること.",
340        );
341        json_schema.add_property("future_works", "string", "未解決の課題および将来の研究の方向性について記述.");
342
343        openai.model_id(String::from("gpt-4o-mini")).messages(messages).temperature(1.0).json_schema(json_schema);
344
345        let mut counter = 3;
346        loop {
347            match openai.chat().await {
348                Ok(response) => {
349                    println!("{:#?}", response);
350                    match serde_json::from_str::<Summary>(
351                        &response.choices[0]
352                            .message
353                            .content
354                            .clone()
355                            .expect("Response content should not be empty")
356                            .text
357                            .expect("Response content should not be empty"),
358                    ) {
359                        Ok(summary) => {
360                            tracing::info!("Summary.is_survey: {}", summary.is_survey);
361                            tracing::info!("Summary.research_question: {}", summary.research_question);
362                            tracing::info!("Summary.contributions: {}", summary.contributions);
363                            tracing::info!("Summary.dataset: {}", summary.dataset);
364                            tracing::info!("Summary.proposed_method: {}", summary.proposed_method);
365                            tracing::info!("Summary.experiment_results: {}", summary.experiment_results);
366                            tracing::info!("Summary.comparison_with_related_works: {}", summary.comparison_with_related_works);
367                            tracing::info!("Summary.future_works: {}", summary.future_works);
368                            assert!(true);
369                        }
370                        Err(e) => {
371                            tracing::error!("Error: {}", e);
372                            assert!(false);
373                        }
374                    }
375                    break;
376                }
377                Err(e) => match e {
378                    OpenAIToolError::RequestError(e) => {
379                        tracing::warn!("Request error: {} (retrying... {})", e, counter);
380                        counter -= 1;
381                        if counter == 0 {
382                            assert!(false, "Chat completion failed (retry limit reached)");
383                        }
384                        continue;
385                    }
386                    _ => {
387                        tracing::error!("Error: {}", e);
388                        assert!(false, "Chat completion failed");
389                    }
390                },
391            };
392        }
393    }
394
395    #[tokio::test]
396    async fn test_chat_completion_with_function_calling() {
397        init_tracing();
398        let mut chat = ChatCompletion::new();
399        let messages = vec![Message::from_string(Role::User, "Please calculate 25 + 17 using the calculator tool.")];
400
401        // Define a calculator function tool
402        let calculator_tool = Tool::function(
403            "calculator",
404            "A calculator that can perform basic arithmetic operations",
405            vec![
406                ("operation", ParameterProp::string("The operation to perform (add, subtract, multiply, divide)")),
407                ("a", ParameterProp::number("The first number")),
408                ("b", ParameterProp::number("The second number")),
409            ],
410            false, // strict mode
411        );
412
413        chat.model_id("gpt-4o-mini").messages(messages).temperature(0.1).tools(vec![calculator_tool]);
414
415        let mut counter = 3;
416        loop {
417            match chat.chat().await {
418                Ok(response) => {
419                    tracing::info!("Response: {:#?}", response);
420
421                    // Check if the response contains tool calls
422                    if let Some(tool_calls) = &response.choices[0].message.tool_calls {
423                        assert!(!tool_calls.is_empty(), "Tool calls should not be empty");
424
425                        for tool_call in tool_calls {
426                            tracing::info!("Function called: {}", tool_call.function.name);
427                            tracing::info!("Arguments: {:?}", tool_call.function.arguments);
428
429                            // Verify that the calculator function was called
430                            assert_eq!(tool_call.function.name, "calculator");
431
432                            // Parse the arguments to verify they contain the expected operation
433                            let args = tool_call.function.arguments_as_map().unwrap();
434                            assert!(args.get("operation").is_some());
435                            assert!(args.get("a").is_some());
436                            assert!(args.get("b").is_some());
437
438                            tracing::info!("Function call validation passed");
439                        }
440                        assert!(true);
441                    } else {
442                        // If no tool calls, check if the content mentions function calling
443                        tracing::info!(
444                            "No tool calls found. Content: {}",
445                            &response.choices[0]
446                                .message
447                                .content
448                                .clone()
449                                .expect("Response content should not be empty")
450                                .text
451                                .expect("Response content should not be empty")
452                        );
453                        // This might happen if the model decides not to use the tool
454                        // We'll still consider this a valid response for testing purposes
455                        assert!(false, "Expected tool calls but none found in response");
456                    }
457                    break;
458                }
459                Err(e) => match e {
460                    OpenAIToolError::RequestError(e) => {
461                        tracing::warn!("Request error: {} (retrying... {})", e, counter);
462                        counter -= 1;
463                        if counter == 0 {
464                            assert!(false, "Function calling test failed (retry limit reached)");
465                        }
466                        continue;
467                    }
468                    _ => {
469                        tracing::error!("Error: {}", e);
470                        assert!(false, "Function calling test failed");
471                    }
472                },
473            };
474        }
475    }
476
477    // #[tokio::test]
478    // async fn test_chat_completion_with_long_arguments() {
479    //     init_tracing();
480    //     let mut openai = ChatCompletion::new();
481    //     let text = std::fs::read_to_string("src/test_rsc/long_text.txt").unwrap();
482    //     let messages = vec![Message::from_string(Role::User, text)];
483
484    //     let token_count = messages
485    //         .iter()
486    //         .map(|m| m.get_input_token_count())
487    //         .sum::<usize>();
488    //     tracing::info!("Token count: {}", token_count);
489
490    //     openai
491    //         .model_id(String::from("gpt-4o-mini"))
492    //         .messages(messages)
493    //         .temperature(1.0);
494
495    //     let mut counter = 3;
496    //     loop {
497    //         match openai.chat().await {
498    //             Ok(response) => {
499    //                 println!("{:#?}", response);
500    //                 assert!(true);
501    //                 break;
502    //             }
503    //             Err(e) => match e {
504    //                 OpenAIToolError::RequestError(e) => {
505    //                     tracing::warn!("Request error: {} (retrying... {})", e, counter);
506    //                     counter -= 1;
507    //                     if counter == 0 {
508    //                         assert!(false, "Chat completion failed (retry limit reached)");
509    //                     }
510    //                     continue;
511    //                 }
512    //                 _ => {
513    //                     tracing::error!("Error: {}", e);
514    //                     assert!(false, "Chat completion failed");
515    //                 }
516    //             },
517    //         };
518    //     }
519    // }
520}