openai_tools/chat/mod.rs
1//! # Chat Module
2//!
3//! This module provides functionality for interacting with the OpenAI Chat Completions API.
4//! It includes tools for building requests, sending them to OpenAI's chat completion endpoint,
5//! and processing the responses.
6//!
7//! ## Key Features
8//!
9//! - Chat completion request building and sending
10//! - Structured output support with JSON schema
11//! - Response parsing and processing
12//! - Support for various OpenAI models and parameters
13//!
14//! ## Usage Examples
15//!
16//! ### Basic Chat Completion
17//!
18//! ```rust,no_run
19//! use openai_tools::chat::request::ChatCompletion;
20//! use openai_tools::common::message::Message;
21//! use openai_tools::common::role::Role;
22//!
23//! #[tokio::main]
24//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
25//! let mut chat = ChatCompletion::new();
26//! let messages = vec![Message::from_string(Role::User, "Hello!")];
27//!
28//! let response = chat
29//! .model_id("gpt-4o-mini")
30//! .messages(messages)
31//! .temperature(1.0)
32//! .chat()
33//! .await?;
34//!
35//! println!("{}", response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap());
36//! Ok(())
37//! }
38//! ```
39//!
40//! ### Using JSON Schema for Structured Output
41//!
42//! ```rust,no_run
43//! use openai_tools::chat::request::ChatCompletion;
44//! use openai_tools::common::message::Message;
45//! use openai_tools::common::role::Role;
46//! use openai_tools::common::structured_output::Schema;
47//! use serde::{Deserialize, Serialize};
48//!
49//! #[derive(Debug, Serialize, Deserialize)]
50//! struct WeatherInfo {
51//! location: String,
52//! date: String,
53//! weather: String,
54//! temperature: String,
55//! }
56//!
57//! #[tokio::main]
58//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
59//! let mut chat = ChatCompletion::new();
60//! let messages = vec![Message::from_string(
61//! Role::User,
62//! "What's the weather like tomorrow in Tokyo?"
63//! )];
64//!
65//! // Create JSON schema for structured output
66//! let mut json_schema = Schema::chat_json_schema("weather");
67//! json_schema.add_property("location", "string", "The location for weather check");
68//! json_schema.add_property("date", "string", "The date for weather forecast");
69//! json_schema.add_property("weather", "string", "Weather condition description");
70//! json_schema.add_property("temperature", "string", "Temperature information");
71//!
72//! let response = chat
73//! .model_id("gpt-4o-mini")
74//! .messages(messages)
75//! .temperature(0.7)
76//! .json_schema(json_schema)
77//! .chat()
78//! .await?;
79//!
80//! // Parse structured response
81//! let weather: WeatherInfo = serde_json::from_str(
82//! response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap()
83//! )?;
84//! println!("Weather in {}: {} on {}, Temperature: {}",
85//! weather.location, weather.weather, weather.date, weather.temperature);
86//! Ok(())
87//! }
88//! ```
89//!
90//! ### Using Function Calling with Tools
91//!
92//! ```rust,no_run
93//! use openai_tools::chat::request::ChatCompletion;
94//! use openai_tools::common::message::Message;
95//! use openai_tools::common::role::Role;
96//! use openai_tools::common::tool::Tool;
97//! use openai_tools::common::parameters::ParameterProperty;
98//!
99//! #[tokio::main]
100//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
101//! let mut chat = ChatCompletion::new();
102//! let messages = vec![Message::from_string(
103//! Role::User,
104//! "Please calculate 25 + 17 using the calculator tool"
105//! )];
106//!
107//! // Define a calculator function tool
108//! let calculator_tool = Tool::function(
109//! "calculator",
110//! "A calculator that can perform basic arithmetic operations",
111//! vec![
112//! ("operation", ParameterProperty::from_string("The operation to perform (add, subtract, multiply, divide)")),
113//! ("a", ParameterProperty::from_number("The first number")),
114//! ("b", ParameterProperty::from_number("The second number")),
115//! ],
116//! false, // strict mode
117//! );
118//!
119//! let response = chat
120//! .model_id("gpt-4o-mini")
121//! .messages(messages)
122//! .temperature(0.1)
123//! .tools(vec![calculator_tool])
124//! .chat()
125//! .await?;
126//!
127//! // Handle function calls in the response
128//! if let Some(tool_calls) = &response.choices[0].message.tool_calls {
129//! // Add the assistant's message with tool calls to conversation history
130//! chat.add_message(response.choices[0].message.clone());
131//!
132//! for tool_call in tool_calls {
133//! println!("Function called: {}", tool_call.function.name);
134//! if let Ok(args) = tool_call.function.arguments_as_map() {
135//! println!("Arguments: {:?}", args);
136//! }
137//!
138//! // Execute the function (in this example, we simulate the calculation)
139//! let result = "42"; // This would be the actual calculation result
140//!
141//! // Add the tool call response to continue the conversation
142//! chat.add_message(Message::from_tool_call_response(result, &tool_call.id));
143//! }
144//!
145//! // Get the final response after tool execution
146//! let final_response = chat.chat().await?;
147//! if let Some(content) = &final_response.choices[0].message.content {
148//! if let Some(text) = &content.text {
149//! println!("Final answer: {}", text);
150//! }
151//! }
152//! } else if let Some(content) = &response.choices[0].message.content {
153//! if let Some(text) = &content.text {
154//! println!("{}", text);
155//! }
156//! }
157//! Ok(())
158//! }
159//! ```
160
161pub mod request;
162pub mod response;
163
164#[cfg(test)]
165mod tests {
166 use crate::chat::request::ChatCompletion;
167 use crate::common::{
168 errors::OpenAIToolError, message::Message, parameters::ParameterProperty, role::Role, structured_output::Schema, tool::Tool,
169 };
170 use serde::{Deserialize, Serialize};
171 use serde_json;
172 use std::sync::Once;
173 use tracing_subscriber::EnvFilter;
174
175 static INIT: Once = Once::new();
176
177 fn init_tracing() {
178 INIT.call_once(|| {
179 // `RUST_LOG` 環境変数があればそれを使い、なければ "info"
180 let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
181 // try_init()を使用してsubscriberが既に設定されている場合はスキップ
182 let _ = tracing_subscriber::fmt()
183 .with_env_filter(filter)
184 .with_test_writer() // `cargo test` / nextest 用
185 .try_init();
186 });
187 }
188 #[tokio::test]
189 #[test_log::test]
190 async fn test_chat_completion() {
191 init_tracing();
192 let mut chat = ChatCompletion::new();
193 let messages = vec![Message::from_string(Role::User, "Hi there!")];
194
195 chat.model_id("gpt-4o-mini").messages(messages).temperature(1.0);
196
197 let mut counter = 3;
198 loop {
199 match chat.chat().await {
200 Ok(response) => {
201 tracing::info!("{:?}", &response.choices[0].message.content.clone().expect("Response content should not be empty"));
202 assert!(true);
203 break;
204 }
205 Err(e) => match e {
206 OpenAIToolError::RequestError(e) => {
207 tracing::warn!("Request error: {} (retrying... {})", e, counter);
208 counter -= 1;
209 if counter == 0 {
210 assert!(false, "Chat completion failed (retry limit reached)");
211 }
212 continue;
213 }
214 _ => {
215 tracing::error!("Error: {}", e);
216 assert!(false, "Chat completion failed");
217 }
218 },
219 };
220 }
221 }
222
223 #[tokio::test]
224 #[test_log::test]
225 async fn test_chat_completion_2() {
226 init_tracing();
227 let mut chat = ChatCompletion::new();
228 let messages = vec![Message::from_string(Role::User, "トンネルを抜けると?")];
229
230 chat.model_id("gpt-4o-mini").messages(messages).temperature(1.5);
231
232 let mut counter = 3;
233 loop {
234 match chat.chat().await {
235 Ok(response) => {
236 println!("{:?}", &response.choices[0].message.content.clone().expect("Response content should not be empty"));
237 assert!(true);
238 break;
239 }
240 Err(e) => match e {
241 OpenAIToolError::RequestError(e) => {
242 tracing::warn!("Request error: {} (retrying... {})", e, counter);
243 counter -= 1;
244 if counter == 0 {
245 assert!(false, "Chat completion failed (retry limit reached)");
246 }
247 continue;
248 }
249 _ => {
250 tracing::error!("Error: {}", e);
251 assert!(false, "Chat completion failed");
252 }
253 },
254 };
255 }
256 }
257
258 #[derive(Debug, Serialize, Deserialize)]
259 struct Weather {
260 #[serde(default = "String::new")]
261 location: String,
262 #[serde(default = "String::new")]
263 date: String,
264 #[serde(default = "String::new")]
265 weather: String,
266 #[serde(default = "String::new")]
267 error: String,
268 }
269
270 #[tokio::test]
271 #[test_log::test]
272 async fn test_chat_completion_with_json_schema() {
273 init_tracing();
274 let mut openai = ChatCompletion::new();
275 let messages = vec![Message::from_string(Role::User, "Hi there! How's the weather tomorrow in Tokyo? If you can't answer, report error.")];
276
277 let mut json_schema = Schema::chat_json_schema("weather");
278 json_schema.add_property("location", "string", "The location to check the weather for.");
279 json_schema.add_property("date", "string", "The date to check the weather for.");
280 json_schema.add_property("weather", "string", "The weather for the location and date.");
281 json_schema.add_property("error", "string", "Error message. If there is no error, leave this field empty.");
282 openai.model_id("gpt-4o-mini").messages(messages).temperature(1.0).json_schema(json_schema);
283
284 let mut counter = 3;
285 loop {
286 match openai.chat().await {
287 Ok(response) => {
288 println!("{:#?}", response);
289 match serde_json::from_str::<Weather>(
290 &response.choices[0]
291 .message
292 .content
293 .clone()
294 .expect("Response content should not be empty")
295 .text
296 .expect("Response content should not be empty"),
297 ) {
298 Ok(weather) => {
299 println!("{:#?}", weather);
300 assert!(true);
301 }
302 Err(e) => {
303 println!("{:#?}", e);
304 assert!(false);
305 }
306 }
307 break;
308 }
309 Err(e) => match e {
310 OpenAIToolError::RequestError(e) => {
311 tracing::warn!("Request error: {} (retrying... {})", e, counter);
312 counter -= 1;
313 if counter == 0 {
314 assert!(false, "Chat completion failed (retry limit reached)");
315 }
316 continue;
317 }
318 _ => {
319 tracing::error!("Error: {}", e);
320 assert!(false, "Chat completion failed");
321 }
322 },
323 };
324 }
325 }
326
327 #[derive(Deserialize)]
328 struct Summary {
329 pub is_survey: bool,
330 pub research_question: String,
331 pub contributions: String,
332 pub dataset: String,
333 pub proposed_method: String,
334 pub experiment_results: String,
335 pub comparison_with_related_works: String,
336 pub future_works: String,
337 }
338 #[tokio::test]
339 #[test_log::test]
340 async fn test_summarize() {
341 init_tracing();
342 let mut openai = ChatCompletion::new();
343 let instruction = std::fs::read_to_string("src/test_rsc/sample_instruction.txt").unwrap();
344
345 let messages = vec![Message::from_string(Role::User, instruction.clone())];
346
347 let mut json_schema = Schema::chat_json_schema("summary");
348 json_schema.add_property("is_survey", "boolean", "この論文がサーベイ論文かどうかをtrue/falseで判定.");
349 json_schema.add_property(
350 "research_question",
351 "string",
352 "この論文のリサーチクエスチョンの説明.この論文の背景や既存研究との関連も含めて記述する.",
353 );
354 json_schema.add_property("contributions", "string", "この論文のコントリビューションをリスト形式で記述する.");
355 json_schema.add_property("dataset", "string", "この論文で使用されているデータセットをリストアップする.");
356 json_schema.add_property("proposed_method", "string", "提案手法の詳細な説明.");
357 json_schema.add_property("experiment_results", "string", "実験の結果の詳細な説明.");
358 json_schema.add_property(
359 "comparison_with_related_works",
360 "string",
361 "関連研究と比較した場合のこの論文の新規性についての説明.可能な限り既存研究を参照しながら記述すること.",
362 );
363 json_schema.add_property("future_works", "string", "未解決の課題および将来の研究の方向性について記述.");
364
365 openai.model_id(String::from("gpt-4o-mini")).messages(messages).temperature(1.0).json_schema(json_schema);
366
367 let mut counter = 3;
368 loop {
369 match openai.chat().await {
370 Ok(response) => {
371 println!("{:#?}", response);
372 match serde_json::from_str::<Summary>(
373 &response.choices[0]
374 .message
375 .content
376 .clone()
377 .expect("Response content should not be empty")
378 .text
379 .expect("Response content should not be empty"),
380 ) {
381 Ok(summary) => {
382 tracing::info!("Summary.is_survey: {}", summary.is_survey);
383 tracing::info!("Summary.research_question: {}", summary.research_question);
384 tracing::info!("Summary.contributions: {}", summary.contributions);
385 tracing::info!("Summary.dataset: {}", summary.dataset);
386 tracing::info!("Summary.proposed_method: {}", summary.proposed_method);
387 tracing::info!("Summary.experiment_results: {}", summary.experiment_results);
388 tracing::info!("Summary.comparison_with_related_works: {}", summary.comparison_with_related_works);
389 tracing::info!("Summary.future_works: {}", summary.future_works);
390 assert!(true);
391 }
392 Err(e) => {
393 tracing::error!("Error: {}", e);
394 assert!(false);
395 }
396 }
397 break;
398 }
399 Err(e) => match e {
400 OpenAIToolError::RequestError(e) => {
401 tracing::warn!("Request error: {} (retrying... {})", e, counter);
402 counter -= 1;
403 if counter == 0 {
404 assert!(false, "Chat completion failed (retry limit reached)");
405 }
406 continue;
407 }
408 _ => {
409 tracing::error!("Error: {}", e);
410 assert!(false, "Chat completion failed");
411 }
412 },
413 };
414 }
415 }
416
417 #[tokio::test]
418 #[test_log::test]
419 async fn test_chat_completion_with_function_calling() {
420 init_tracing();
421 let mut chat = ChatCompletion::new();
422 let messages = vec![Message::from_string(Role::User, "Please calculate 25 + 17 using the calculator tool.")];
423
424 // Define a calculator function tool
425 let calculator_tool = Tool::function(
426 "calculator",
427 "A calculator that can perform basic arithmetic operations",
428 vec![
429 ("operation", ParameterProperty::from_string("The operation to perform (add, subtract, multiply, divide)")),
430 ("a", ParameterProperty::from_number("The first number")),
431 ("b", ParameterProperty::from_number("The second number")),
432 ],
433 false, // strict mode
434 );
435
436 chat.model_id("gpt-4o-mini").messages(messages).temperature(0.1).tools(vec![calculator_tool]);
437 // First call
438 let mut counter = 3;
439 loop {
440 match chat.chat().await {
441 Ok(response) => {
442 tracing::info!("First Response: {:#?}", response);
443
444 let message = response.choices[0].message.clone();
445 chat.add_message(message.clone());
446
447 // Check if the response contains tool calls
448 if let Some(tool_calls) = &message.tool_calls {
449 assert!(!tool_calls.is_empty(), "Tool calls should not be empty");
450
451 for tool_call in tool_calls {
452 tracing::info!("Function called: {}", tool_call.function.name);
453 tracing::info!("Arguments: {:?}", tool_call.function.arguments);
454
455 // Verify that the calculator function was called
456 assert_eq!(tool_call.function.name, "calculator");
457
458 // Parse the arguments to verify they contain the expected operation
459 let args = tool_call.function.arguments_as_map().unwrap();
460 assert!(args.get("operation").is_some());
461 assert!(args.get("a").is_some());
462 assert!(args.get("b").is_some());
463
464 tracing::info!("Function call validation passed");
465
466 chat.add_message(Message::from_tool_call_response("42", &tool_call.id));
467 }
468 assert!(true);
469 } else {
470 // If no tool calls, check if the content mentions function calling
471 tracing::info!(
472 "No tool calls found. Content: {}",
473 &response.choices[0]
474 .message
475 .content
476 .clone()
477 .expect("Response content should not be empty")
478 .text
479 .expect("Response content should not be empty")
480 );
481 // This might happen if the model decides not to use the tool
482 // We'll still consider this a valid response for testing purposes
483 assert!(false, "Expected tool calls but none found in response");
484 }
485 break;
486 }
487 Err(e) => match e {
488 OpenAIToolError::RequestError(e) => {
489 tracing::warn!("Request error: {} (retrying... {})", e, counter);
490 counter -= 1;
491 if counter == 0 {
492 assert!(false, "Function calling test failed (retry limit reached)");
493 }
494 continue;
495 }
496 _ => {
497 tracing::error!("Error: {}", e);
498 assert!(false, "Function calling test failed");
499 }
500 },
501 };
502 }
503
504 // Second call to ensure the tool is still available
505 let messages = chat.get_message_history();
506 let mut chat = ChatCompletion::new();
507 chat.model_id("gpt-4o-mini").messages(messages).temperature(1.0);
508
509 let mut counter = 3;
510 loop {
511 match chat.chat().await {
512 Ok(response) => {
513 tracing::info!("Second Response: {:#?}", response);
514 assert!(!response.choices.is_empty(), "Response should contain at least one choice");
515 let content = response.choices[0]
516 .message
517 .content
518 .clone()
519 .expect("Response content should not be empty")
520 .text
521 .expect("Response content should not be empty");
522 tracing::info!("Content: {}", content);
523 // Check if the content contains the expected result
524 assert!(content.contains("42"), "Expected content to contain '42', found: {}", content);
525 break;
526 }
527 Err(e) => match e {
528 OpenAIToolError::RequestError(e) => {
529 tracing::warn!("Request error: {} (retrying... {})", e, counter);
530 counter -= 1;
531 if counter == 0 {
532 assert!(false, "Function calling test failed (retry limit reached)");
533 }
534 continue;
535 }
536 _ => {
537 tracing::error!("Error: {}", e);
538 assert!(false, "Function calling test failed");
539 }
540 },
541 };
542 }
543 }
544
545 // #[tokio::test]
546 // async fn test_chat_completion_with_long_arguments() {
547 // init_tracing();
548 // let mut openai = ChatCompletion::new();
549 // let text = std::fs::read_to_string("src/test_rsc/long_text.txt").unwrap();
550 // let messages = vec![Message::from_string(Role::User, text)];
551
552 // let token_count = messages
553 // .iter()
554 // .map(|m| m.get_input_token_count())
555 // .sum::<usize>();
556 // tracing::info!("Token count: {}", token_count);
557
558 // openai
559 // .model_id(String::from("gpt-4o-mini"))
560 // .messages(messages)
561 // .temperature(1.0);
562
563 // let mut counter = 3;
564 // loop {
565 // match openai.chat().await {
566 // Ok(response) => {
567 // println!("{:#?}", response);
568 // assert!(true);
569 // break;
570 // }
571 // Err(e) => match e {
572 // OpenAIToolError::RequestError(e) => {
573 // tracing::warn!("Request error: {} (retrying... {})", e, counter);
574 // counter -= 1;
575 // if counter == 0 {
576 // assert!(false, "Chat completion failed (retry limit reached)");
577 // }
578 // continue;
579 // }
580 // _ => {
581 // tracing::error!("Error: {}", e);
582 // assert!(false, "Chat completion failed");
583 // }
584 // },
585 // };
586 // }
587 // }
588}