openai_tools/chat/mod.rs
1//! # Chat Module
2//!
3//! This module provides functionality for interacting with the OpenAI Chat Completions API.
4//! It includes tools for building requests, sending them to OpenAI's chat completion endpoint,
5//! and processing the responses.
6//!
7//! ## Key Features
8//!
9//! - Chat completion request building and sending
10//! - Structured output support with JSON schema
11//! - Response parsing and processing
12//! - Support for various OpenAI models and parameters
13//!
14//! ## Usage Examples
15//!
16//! ### Basic Chat Completion
17//!
18//! ```rust,no_run
19//! use openai_tools::chat::request::ChatCompletion;
20//! use openai_tools::common::message::Message;
21//! use openai_tools::common::role::Role;
22//!
23//! #[tokio::main]
24//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
25//! let mut chat = ChatCompletion::new();
26//! let messages = vec![Message::from_string(Role::User, "Hello!")];
27//!
28//! let response = chat
29//! .model_id("gpt-4o-mini")
30//! .messages(messages)
31//! .temperature(1.0)
32//! .chat()
33//! .await?;
34//!
35//! println!("{}", response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap());
36//! Ok(())
37//! }
38//! ```
39//!
40//! ### Using JSON Schema for Structured Output
41//!
42//! ```rust,no_run
43//! use openai_tools::chat::request::ChatCompletion;
44//! use openai_tools::common::message::Message;
45//! use openai_tools::common::role::Role;
46//! use openai_tools::common::structured_output::Schema;
47//! use serde::{Deserialize, Serialize};
48//!
49//! #[derive(Debug, Serialize, Deserialize)]
50//! struct WeatherInfo {
51//! location: String,
52//! date: String,
53//! weather: String,
54//! temperature: String,
55//! }
56//!
57//! #[tokio::main]
58//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
59//! let mut chat = ChatCompletion::new();
60//! let messages = vec![Message::from_string(
61//! Role::User,
62//! "What's the weather like tomorrow in Tokyo?"
63//! )];
64//!
65//! // Create JSON schema for structured output
66//! let mut json_schema = Schema::chat_json_schema("weather");
67//! json_schema.add_property("location", "string", "The location for weather check");
68//! json_schema.add_property("date", "string", "The date for weather forecast");
69//! json_schema.add_property("weather", "string", "Weather condition description");
70//! json_schema.add_property("temperature", "string", "Temperature information");
71//!
72//! let response = chat
73//! .model_id("gpt-4o-mini")
74//! .messages(messages)
75//! .temperature(0.7)
76//! .json_schema(json_schema)
77//! .chat()
78//! .await?;
79//!
80//! // Parse structured response
81//! let weather: WeatherInfo = serde_json::from_str(
82//! response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap()
83//! )?;
84//! println!("Weather in {}: {} on {}, Temperature: {}",
85//! weather.location, weather.weather, weather.date, weather.temperature);
86//! Ok(())
87//! }
88//! ```
89//!
90//! ### Using Function Calling with Tools
91//!
92//! ```rust,no_run
93//! use openai_tools::chat::request::ChatCompletion;
94//! use openai_tools::common::message::Message;
95//! use openai_tools::common::role::Role;
96//! use openai_tools::common::tool::Tool;
97//! use openai_tools::common::parameters::ParameterProp;
98//!
99//! #[tokio::main]
100//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
101//! let mut chat = ChatCompletion::new();
102//! let messages = vec![Message::from_string(
103//! Role::User,
104//! "Please calculate 15 * 23 using the calculator tool"
105//! )];
106//!
107//! // Define a calculator function tool
108//! let calculator_tool = Tool::function(
109//! "calculator",
110//! "A calculator that can perform basic arithmetic operations",
111//! vec![
112//! ("operation", ParameterProp::string("The operation to perform (add, subtract, multiply, divide)")),
113//! ("a", ParameterProp::number("The first number")),
114//! ("b", ParameterProp::number("The second number")),
115//! ],
116//! false, // strict mode
117//! );
118//!
119//! let response = chat
120//! .model_id("gpt-4o-mini")
121//! .messages(messages)
122//! .temperature(0.1)
123//! .tools(vec![calculator_tool])
124//! .chat()
125//! .await?;
126//!
127//! // Handle function calls in the response
128//! if let Some(tool_calls) = &response.choices[0].message.tool_calls {
129//! for tool_call in tool_calls {
130//! println!("Function called: {}", tool_call.function.name);
131//! if let Ok(args) = tool_call.function.arguments_as_map() {
132//! println!("Arguments: {:?}", args);
133//! }
134//! // In a real application, you would execute the function here
135//! // and send the result back to continue the conversation
136//! }
137//! } else if let Some(content) = &response.choices[0].message.content {
138//! if let Some(text) = &content.text {
139//! println!("{}", text);
140//! }
141//! }
142//! Ok(())
143//! }
144//! ```
145
146pub mod request;
147pub mod response;
148
149#[cfg(test)]
150mod tests {
151 use crate::chat::request::ChatCompletion;
152 use crate::common::{errors::OpenAIToolError, message::Message, parameters::ParameterProp, role::Role, structured_output::Schema, tool::Tool};
153 use serde::{Deserialize, Serialize};
154 use serde_json;
155 use std::sync::Once;
156 use tracing_subscriber::EnvFilter;
157
158 static INIT: Once = Once::new();
159
160 fn init_tracing() {
161 INIT.call_once(|| {
162 // `RUST_LOG` 環境変数があればそれを使い、なければ "info"
163 let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
164 tracing_subscriber::fmt()
165 .with_env_filter(filter)
166 .with_test_writer() // `cargo test` / nextest 用
167 .init();
168 });
169 }
170 #[tokio::test]
171 async fn test_chat_completion() {
172 init_tracing();
173 let mut chat = ChatCompletion::new();
174 let messages = vec![Message::from_string(Role::User, "Hi there!")];
175
176 chat.model_id("gpt-4o-mini").messages(messages).temperature(1.0);
177
178 let mut counter = 3;
179 loop {
180 match chat.chat().await {
181 Ok(response) => {
182 tracing::info!("{:?}", &response.choices[0].message.content.clone().expect("Response content should not be empty"));
183 assert!(true);
184 break;
185 }
186 Err(e) => match e {
187 OpenAIToolError::RequestError(e) => {
188 tracing::warn!("Request error: {} (retrying... {})", e, counter);
189 counter -= 1;
190 if counter == 0 {
191 assert!(false, "Chat completion failed (retry limit reached)");
192 }
193 continue;
194 }
195 _ => {
196 tracing::error!("Error: {}", e);
197 assert!(false, "Chat completion failed");
198 }
199 },
200 };
201 }
202 }
203
204 #[tokio::test]
205 async fn test_chat_completion_2() {
206 init_tracing();
207 let mut chat = ChatCompletion::new();
208 let messages = vec![Message::from_string(Role::User, "トンネルを抜けると?")];
209
210 chat.model_id("gpt-4o-mini").messages(messages).temperature(1.5);
211
212 let mut counter = 3;
213 loop {
214 match chat.chat().await {
215 Ok(response) => {
216 println!("{:?}", &response.choices[0].message.content.clone().expect("Response content should not be empty"));
217 assert!(true);
218 break;
219 }
220 Err(e) => match e {
221 OpenAIToolError::RequestError(e) => {
222 tracing::warn!("Request error: {} (retrying... {})", e, counter);
223 counter -= 1;
224 if counter == 0 {
225 assert!(false, "Chat completion failed (retry limit reached)");
226 }
227 continue;
228 }
229 _ => {
230 tracing::error!("Error: {}", e);
231 assert!(false, "Chat completion failed");
232 }
233 },
234 };
235 }
236 }
237
238 #[derive(Debug, Serialize, Deserialize)]
239 struct Weather {
240 #[serde(default = "String::new")]
241 location: String,
242 #[serde(default = "String::new")]
243 date: String,
244 #[serde(default = "String::new")]
245 weather: String,
246 #[serde(default = "String::new")]
247 error: String,
248 }
249
250 #[tokio::test]
251 async fn test_chat_completion_with_json_schema() {
252 init_tracing();
253 let mut openai = ChatCompletion::new();
254 let messages = vec![Message::from_string(Role::User, "Hi there! How's the weather tomorrow in Tokyo? If you can't answer, report error.")];
255
256 let mut json_schema = Schema::chat_json_schema("weather");
257 json_schema.add_property("location", "string", "The location to check the weather for.");
258 json_schema.add_property("date", "string", "The date to check the weather for.");
259 json_schema.add_property("weather", "string", "The weather for the location and date.");
260 json_schema.add_property("error", "string", "Error message. If there is no error, leave this field empty.");
261 openai.model_id("gpt-4o-mini").messages(messages).temperature(1.0).json_schema(json_schema);
262
263 let mut counter = 3;
264 loop {
265 match openai.chat().await {
266 Ok(response) => {
267 println!("{:#?}", response);
268 match serde_json::from_str::<Weather>(
269 &response.choices[0]
270 .message
271 .content
272 .clone()
273 .expect("Response content should not be empty")
274 .text
275 .expect("Response content should not be empty"),
276 ) {
277 Ok(weather) => {
278 println!("{:#?}", weather);
279 assert!(true);
280 }
281 Err(e) => {
282 println!("{:#?}", e);
283 assert!(false);
284 }
285 }
286 break;
287 }
288 Err(e) => match e {
289 OpenAIToolError::RequestError(e) => {
290 tracing::warn!("Request error: {} (retrying... {})", e, counter);
291 counter -= 1;
292 if counter == 0 {
293 assert!(false, "Chat completion failed (retry limit reached)");
294 }
295 continue;
296 }
297 _ => {
298 tracing::error!("Error: {}", e);
299 assert!(false, "Chat completion failed");
300 }
301 },
302 };
303 }
304 }
305
306 #[derive(Deserialize)]
307 struct Summary {
308 pub is_survey: bool,
309 pub research_question: String,
310 pub contributions: String,
311 pub dataset: String,
312 pub proposed_method: String,
313 pub experiment_results: String,
314 pub comparison_with_related_works: String,
315 pub future_works: String,
316 }
317 #[tokio::test]
318 async fn test_summarize() {
319 init_tracing();
320 let mut openai = ChatCompletion::new();
321 let instruction = std::fs::read_to_string("src/test_rsc/sample_instruction.txt").unwrap();
322
323 let messages = vec![Message::from_string(Role::User, instruction.clone())];
324
325 let mut json_schema = Schema::chat_json_schema("summary");
326 json_schema.add_property("is_survey", "boolean", "この論文がサーベイ論文かどうかをtrue/falseで判定.");
327 json_schema.add_property(
328 "research_question",
329 "string",
330 "この論文のリサーチクエスチョンの説明.この論文の背景や既存研究との関連も含めて記述する.",
331 );
332 json_schema.add_property("contributions", "string", "この論文のコントリビューションをリスト形式で記述する.");
333 json_schema.add_property("dataset", "string", "この論文で使用されているデータセットをリストアップする.");
334 json_schema.add_property("proposed_method", "string", "提案手法の詳細な説明.");
335 json_schema.add_property("experiment_results", "string", "実験の結果の詳細な説明.");
336 json_schema.add_property(
337 "comparison_with_related_works",
338 "string",
339 "関連研究と比較した場合のこの論文の新規性についての説明.可能な限り既存研究を参照しながら記述すること.",
340 );
341 json_schema.add_property("future_works", "string", "未解決の課題および将来の研究の方向性について記述.");
342
343 openai.model_id(String::from("gpt-4o-mini")).messages(messages).temperature(1.0).json_schema(json_schema);
344
345 let mut counter = 3;
346 loop {
347 match openai.chat().await {
348 Ok(response) => {
349 println!("{:#?}", response);
350 match serde_json::from_str::<Summary>(
351 &response.choices[0]
352 .message
353 .content
354 .clone()
355 .expect("Response content should not be empty")
356 .text
357 .expect("Response content should not be empty"),
358 ) {
359 Ok(summary) => {
360 tracing::info!("Summary.is_survey: {}", summary.is_survey);
361 tracing::info!("Summary.research_question: {}", summary.research_question);
362 tracing::info!("Summary.contributions: {}", summary.contributions);
363 tracing::info!("Summary.dataset: {}", summary.dataset);
364 tracing::info!("Summary.proposed_method: {}", summary.proposed_method);
365 tracing::info!("Summary.experiment_results: {}", summary.experiment_results);
366 tracing::info!("Summary.comparison_with_related_works: {}", summary.comparison_with_related_works);
367 tracing::info!("Summary.future_works: {}", summary.future_works);
368 assert!(true);
369 }
370 Err(e) => {
371 tracing::error!("Error: {}", e);
372 assert!(false);
373 }
374 }
375 break;
376 }
377 Err(e) => match e {
378 OpenAIToolError::RequestError(e) => {
379 tracing::warn!("Request error: {} (retrying... {})", e, counter);
380 counter -= 1;
381 if counter == 0 {
382 assert!(false, "Chat completion failed (retry limit reached)");
383 }
384 continue;
385 }
386 _ => {
387 tracing::error!("Error: {}", e);
388 assert!(false, "Chat completion failed");
389 }
390 },
391 };
392 }
393 }
394
395 #[tokio::test]
396 async fn test_chat_completion_with_function_calling() {
397 init_tracing();
398 let mut chat = ChatCompletion::new();
399 let messages = vec![Message::from_string(Role::User, "Please calculate 25 + 17 using the calculator tool.")];
400
401 // Define a calculator function tool
402 let calculator_tool = Tool::function(
403 "calculator",
404 "A calculator that can perform basic arithmetic operations",
405 vec![
406 ("operation", ParameterProp::string("The operation to perform (add, subtract, multiply, divide)")),
407 ("a", ParameterProp::number("The first number")),
408 ("b", ParameterProp::number("The second number")),
409 ],
410 false, // strict mode
411 );
412
413 chat.model_id("gpt-4o-mini").messages(messages).temperature(0.1).tools(vec![calculator_tool]);
414
415 let mut counter = 3;
416 loop {
417 match chat.chat().await {
418 Ok(response) => {
419 tracing::info!("Response: {:#?}", response);
420
421 // Check if the response contains tool calls
422 if let Some(tool_calls) = &response.choices[0].message.tool_calls {
423 assert!(!tool_calls.is_empty(), "Tool calls should not be empty");
424
425 for tool_call in tool_calls {
426 tracing::info!("Function called: {}", tool_call.function.name);
427 tracing::info!("Arguments: {:?}", tool_call.function.arguments);
428
429 // Verify that the calculator function was called
430 assert_eq!(tool_call.function.name, "calculator");
431
432 // Parse the arguments to verify they contain the expected operation
433 let args = tool_call.function.arguments_as_map().unwrap();
434 assert!(args.get("operation").is_some());
435 assert!(args.get("a").is_some());
436 assert!(args.get("b").is_some());
437
438 tracing::info!("Function call validation passed");
439 }
440 assert!(true);
441 } else {
442 // If no tool calls, check if the content mentions function calling
443 tracing::info!(
444 "No tool calls found. Content: {}",
445 &response.choices[0]
446 .message
447 .content
448 .clone()
449 .expect("Response content should not be empty")
450 .text
451 .expect("Response content should not be empty")
452 );
453 // This might happen if the model decides not to use the tool
454 // We'll still consider this a valid response for testing purposes
455 assert!(false, "Expected tool calls but none found in response");
456 }
457 break;
458 }
459 Err(e) => match e {
460 OpenAIToolError::RequestError(e) => {
461 tracing::warn!("Request error: {} (retrying... {})", e, counter);
462 counter -= 1;
463 if counter == 0 {
464 assert!(false, "Function calling test failed (retry limit reached)");
465 }
466 continue;
467 }
468 _ => {
469 tracing::error!("Error: {}", e);
470 assert!(false, "Function calling test failed");
471 }
472 },
473 };
474 }
475 }
476
477 // #[tokio::test]
478 // async fn test_chat_completion_with_long_arguments() {
479 // init_tracing();
480 // let mut openai = ChatCompletion::new();
481 // let text = std::fs::read_to_string("src/test_rsc/long_text.txt").unwrap();
482 // let messages = vec![Message::from_string(Role::User, text)];
483
484 // let token_count = messages
485 // .iter()
486 // .map(|m| m.get_input_token_count())
487 // .sum::<usize>();
488 // tracing::info!("Token count: {}", token_count);
489
490 // openai
491 // .model_id(String::from("gpt-4o-mini"))
492 // .messages(messages)
493 // .temperature(1.0);
494
495 // let mut counter = 3;
496 // loop {
497 // match openai.chat().await {
498 // Ok(response) => {
499 // println!("{:#?}", response);
500 // assert!(true);
501 // break;
502 // }
503 // Err(e) => match e {
504 // OpenAIToolError::RequestError(e) => {
505 // tracing::warn!("Request error: {} (retrying... {})", e, counter);
506 // counter -= 1;
507 // if counter == 0 {
508 // assert!(false, "Chat completion failed (retry limit reached)");
509 // }
510 // continue;
511 // }
512 // _ => {
513 // tracing::error!("Error: {}", e);
514 // assert!(false, "Chat completion failed");
515 // }
516 // },
517 // };
518 // }
519 // }
520}