openai_tools/chat/request.rs
1//! OpenAI Chat Completions API Request Module
2//!
3//! This module provides the functionality to build and send requests to the OpenAI Chat Completions API.
4//! It offers a builder pattern for constructing requests with various parameters and options,
5//! making it easy to interact with OpenAI's conversational AI models.
6//!
7//! # Key Features
8//!
9//! - **Builder Pattern**: Fluent API for constructing requests
10//! - **Structured Output**: Support for JSON schema-based responses
11//! - **Function Calling**: Tool integration for extended model capabilities
12//! - **Comprehensive Parameters**: Full support for all OpenAI API parameters
13//! - **Error Handling**: Robust error management and validation
14//!
15//! # Quick Start
16//!
17//! ```rust,no_run
18//! use openai_tools::chat::request::ChatCompletion;
19//! use openai_tools::common::message::Message;
20//! use openai_tools::common::role::Role;
21//!
22//! #[tokio::main]
23//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
24//! // Initialize the chat completion client
25//! let mut chat = ChatCompletion::new();
26//!
27//! // Create a simple conversation
28//! let messages = vec![
29//! Message::from_string(Role::User, "Hello! How are you?")
30//! ];
31//!
32//! // Send the request and get a response
33//! let response = chat
34//! .model_id("gpt-4o-mini")
35//! .messages(messages)
36//! .temperature(0.7)
37//! .chat()
38//! .await?;
39//!
40//! println!("AI Response: {}",
41//! response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap());
42//! Ok(())
43//! }
44//! ```
45//!
46//! # Advanced Usage
47//!
48//! ## Structured Output with JSON Schema
49//!
50//! ```rust,no_run
51//! use openai_tools::chat::request::ChatCompletion;
52//! use openai_tools::common::message::Message;
53//! use openai_tools::common::role::Role;
54//! use openai_tools::common::structured_output::Schema;
55//! use serde::{Deserialize, Serialize};
56//!
57//! #[derive(Serialize, Deserialize)]
58//! struct PersonInfo {
59//! name: String,
60//! age: u32,
61//! occupation: String,
62//! }
63//!
64//! #[tokio::main]
65//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
66//! let mut chat = ChatCompletion::new();
67//!
68//! // Define JSON schema for structured output
69//! let mut schema = Schema::chat_json_schema("person_info");
70//! schema.add_property("name", "string", "Person's full name");
71//! schema.add_property("age", "number", "Person's age in years");
72//! schema.add_property("occupation", "string", "Person's job or profession");
73//!
74//! let messages = vec![
75//! Message::from_string(Role::User,
76//! "Extract information about: John Smith, 30 years old, software engineer")
77//! ];
78//!
79//! let response = chat
80//! .model_id("gpt-4o-mini")
81//! .messages(messages)
82//! .json_schema(schema)
83//! .chat()
84//! .await?;
85//!
86//! // Parse structured response
87//! let person: PersonInfo = serde_json::from_str(
88//! response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap()
89//! )?;
90//!
91//! println!("Extracted: {} (age: {}, job: {})",
92//! person.name, person.age, person.occupation);
93//! Ok(())
94//! }
95//! ```
96//!
97//! ## Function Calling with Tools
98//!
99//! ```rust,no_run
100//! use openai_tools::chat::request::ChatCompletion;
101//! use openai_tools::common::message::Message;
102//! use openai_tools::common::role::Role;
103//! use openai_tools::common::tool::Tool;
104//! use openai_tools::common::parameters::ParameterProperty;
105//!
106//! #[tokio::main]
107//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
108//! let mut chat = ChatCompletion::new();
109//!
110//! // Define a weather checking tool
111//! let weather_tool = Tool::function(
112//! "get_weather",
113//! "Get current weather information for a location",
114//! vec![
115//! ("location", ParameterProperty::from_string("The city and country")),
116//! ("unit", ParameterProperty::from_string("Temperature unit (celsius/fahrenheit)")),
117//! ],
118//! false,
119//! );
120//!
121//! let messages = vec![
122//! Message::from_string(Role::User,
123//! "What's the weather like in Tokyo today?")
124//! ];
125//!
126//! let response = chat
127//! .model_id("gpt-4o-mini")
128//! .messages(messages)
129//! .tools(vec![weather_tool])
130//! .temperature(0.1)
131//! .chat()
132//! .await?;
133//!
134//! // Handle tool calls
135//! if let Some(tool_calls) = &response.choices[0].message.tool_calls {
136//! for call in tool_calls {
137//! println!("Tool called: {}", call.function.name);
138//! if let Ok(args) = call.function.arguments_as_map() {
139//! println!("Arguments: {:?}", args);
140//! }
141//! // Execute the function and continue the conversation...
142//! }
143//! }
144//! Ok(())
145//! }
146//! ```
147//!
148//! # Environment Setup
149//!
150//! Before using this module, ensure you have set up your OpenAI API key:
151//!
152//! ```bash
153//! export OPENAI_API_KEY="your-api-key-here"
154//! ```
155//!
156//! Or create a `.env` file in your project root:
157//!
158//! ```text
159//! OPENAI_API_KEY=your-api-key-here
160//! ```
161//!
162//!
163//! # Error Handling
164//!
165//! All methods return a `Result` type for proper error handling:
166//!
167//! ```rust,no_run
168//! use openai_tools::chat::request::ChatCompletion;
169//! use openai_tools::common::errors::OpenAIToolError;
170//!
171//! #[tokio::main]
172//! async fn main() {
173//! let mut chat = ChatCompletion::new();
174//!
175//! match chat.model_id("gpt-4o-mini").chat().await {
176//! Ok(response) => {
177//! if let Some(content) = &response.choices[0].message.content {
178//! if let Some(text) = &content.text {
179//! println!("Success: {}", text);
180//! }
181//! }
182//! }
183//! Err(OpenAIToolError::RequestError(e)) => {
184//! eprintln!("Network error: {}", e);
185//! }
186//! Err(OpenAIToolError::SerdeJsonError(e)) => {
187//! eprintln!("JSON parsing error: {}", e);
188//! }
189//! Err(e) => {
190//! eprintln!("Other error: {}", e);
191//! }
192//! }
193//! }
194//! ```
195
196use crate::chat::response::Response;
197use crate::common::{
198 errors::{OpenAIToolError, Result},
199 message::Message,
200 structured_output::Schema,
201 tool::Tool,
202};
203use core::str;
204use dotenvy::dotenv;
205use serde::{Deserialize, Serialize};
206use std::collections::HashMap;
207use std::env;
208
209/// Response format structure for OpenAI API requests
210///
211/// This structure is used for structured output when JSON schema is specified.
212#[derive(Debug, Clone, Deserialize, Serialize)]
213struct Format {
214 #[serde(rename = "type")]
215 type_name: String,
216 json_schema: Schema,
217}
218
219impl Format {
220 /// Creates a new Format structure
221 ///
222 /// # Arguments
223 ///
224 /// * `type_name` - The type name for the response format
225 /// * `json_schema` - The JSON schema definition
226 ///
227 /// # Returns
228 ///
229 /// A new Format structure instance
230 pub fn new<T: AsRef<str>>(type_name: T, json_schema: Schema) -> Self {
231 Self { type_name: type_name.as_ref().to_string(), json_schema }
232 }
233}
234
235/// Request body structure for OpenAI Chat Completions API
236///
237/// This structure represents the parameters that will be sent in the request body
238/// to the OpenAI API. Each field corresponds to the API specification.
239#[derive(Debug, Clone, Deserialize, Serialize, Default)]
240struct Body {
241 model: String,
242 messages: Vec<Message>,
243 /// Whether to store the request and response at OpenAI
244 #[serde(skip_serializing_if = "Option::is_none")]
245 store: Option<bool>,
246 /// Frequency penalty parameter to reduce repetition (-2.0 to 2.0)
247 #[serde(skip_serializing_if = "Option::is_none")]
248 frequency_penalty: Option<f32>,
249 /// Logit bias to adjust the probability of specific tokens
250 #[serde(skip_serializing_if = "Option::is_none")]
251 logit_bias: Option<HashMap<String, i32>>,
252 /// Whether to include probability information for each token
253 #[serde(skip_serializing_if = "Option::is_none")]
254 logprobs: Option<bool>,
255 /// Number of top probabilities to return for each token (0-20)
256 #[serde(skip_serializing_if = "Option::is_none")]
257 top_logprobs: Option<u8>,
258 /// Maximum number of tokens to generate
259 #[serde(skip_serializing_if = "Option::is_none")]
260 max_completion_tokens: Option<u64>,
261 /// Number of responses to generate
262 #[serde(skip_serializing_if = "Option::is_none")]
263 n: Option<u32>,
264 /// Available modalities for the response (e.g., text, audio)
265 #[serde(skip_serializing_if = "Option::is_none")]
266 modalities: Option<Vec<String>>,
267 /// Presence penalty to encourage new topics (-2.0 to 2.0)
268 #[serde(skip_serializing_if = "Option::is_none")]
269 presence_penalty: Option<f32>,
270 /// Temperature parameter to control response randomness (0.0 to 2.0)
271 #[serde(skip_serializing_if = "Option::is_none")]
272 temperature: Option<f32>,
273 /// Response format specification (e.g., JSON schema)
274 #[serde(skip_serializing_if = "Option::is_none")]
275 response_format: Option<Format>,
276 /// Optional tools that can be used by the model
277 #[serde(skip_serializing_if = "Option::is_none")]
278 tools: Option<Vec<Tool>>,
279}
280
281/// OpenAI Chat Completions API client
282///
283/// This structure manages interactions with the OpenAI Chat Completions API.
284/// It handles API key management, request parameter configuration, and API calls.
285///
286/// # Example
287///
288/// ```rust
289/// use openai_tools::chat::request::ChatCompletion;
290/// use openai_tools::common::message::Message;
291/// use openai_tools::common::role::Role;
292///
293/// # #[tokio::main]
294/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
295/// let mut chat = ChatCompletion::new();
296/// let messages = vec![Message::from_string(Role::User, "Hello!")];
297///
298/// let response = chat
299/// .model_id("gpt-4o-mini")
300/// .messages(messages)
301/// .temperature(1.0)
302/// .chat()
303/// .await?;
304/// # Ok::<(), Box<dyn std::error::Error>>(())
305/// # }
306/// ```
307#[derive(Debug, Clone, Default, Deserialize, Serialize)]
308pub struct ChatCompletion {
309 api_key: String,
310 request_body: Body,
311}
312
313impl ChatCompletion {
314 /// Creates a new ChatCompletion instance
315 ///
316 /// Loads the API key from the `OPENAI_API_KEY` environment variable.
317 /// If a `.env` file exists, it will also be loaded.
318 ///
319 /// # Panics
320 ///
321 /// Panics if the `OPENAI_API_KEY` environment variable is not set.
322 ///
323 /// # Returns
324 ///
325 /// A new ChatCompletion instance
326 pub fn new() -> Self {
327 dotenv().ok();
328 let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set.");
329 Self { api_key, request_body: Body::default() }
330 }
331
332 /// Sets the model ID to use
333 ///
334 /// # Arguments
335 ///
336 /// * `model_id` - OpenAI model ID (e.g., `gpt-4o-mini`, `gpt-4o`)
337 ///
338 /// # Returns
339 ///
340 /// A mutable reference to self for method chaining
341 pub fn model_id<T: AsRef<str>>(&mut self, model_id: T) -> &mut Self {
342 self.request_body.model = model_id.as_ref().to_string();
343 self
344 }
345
346 /// Sets the chat message history
347 ///
348 /// # Arguments
349 ///
350 /// * `messages` - Vector of chat messages representing the conversation history
351 ///
352 /// # Returns
353 ///
354 /// A mutable reference to self for method chaining
355 pub fn messages(&mut self, messages: Vec<Message>) -> &mut Self {
356 self.request_body.messages = messages;
357 self
358 }
359
360 /// Adds a single message to the conversation history
361 ///
362 /// This method appends a new message to the existing conversation history.
363 /// It's useful for building conversations incrementally.
364 ///
365 /// # Arguments
366 ///
367 /// * `message` - The message to add to the conversation
368 ///
369 /// # Returns
370 ///
371 /// A mutable reference to self for method chaining
372 ///
373 /// # Examples
374 ///
375 /// ```rust,no_run
376 /// use openai_tools::chat::request::ChatCompletion;
377 /// use openai_tools::common::message::Message;
378 /// use openai_tools::common::role::Role;
379 ///
380 /// let mut chat = ChatCompletion::new();
381 /// chat.add_message(Message::from_string(Role::User, "Hello!"))
382 /// .add_message(Message::from_string(Role::Assistant, "Hi there!"))
383 /// .add_message(Message::from_string(Role::User, "How are you?"));
384 /// ```
385 pub fn add_message(&mut self, message: Message) -> &mut Self {
386 self.request_body.messages.push(message);
387 self
388 }
389 /// Sets whether to store the request and response at OpenAI
390 ///
391 /// # Arguments
392 ///
393 /// * `store` - `true` to store, `false` to not store
394 ///
395 /// # Returns
396 ///
397 /// A mutable reference to self for method chaining
398 pub fn store(&mut self, store: bool) -> &mut Self {
399 self.request_body.store = Option::from(store);
400 self
401 }
402
403 /// Sets the frequency penalty
404 ///
405 /// A parameter that penalizes based on word frequency to reduce repetition.
406 /// Positive values decrease repetition, negative values increase it.
407 ///
408 /// # Arguments
409 ///
410 /// * `frequency_penalty` - Frequency penalty value (range: -2.0 to 2.0)
411 ///
412 /// # Returns
413 ///
414 /// A mutable reference to self for method chaining
415 pub fn frequency_penalty(&mut self, frequency_penalty: f32) -> &mut Self {
416 self.request_body.frequency_penalty = Option::from(frequency_penalty);
417 self
418 }
419
420 /// Sets logit bias to adjust the probability of specific tokens
421 ///
422 /// # Arguments
423 ///
424 /// * `logit_bias` - A map of token IDs to adjustment values
425 ///
426 /// # Returns
427 ///
428 /// A mutable reference to self for method chaining
429 pub fn logit_bias<T: AsRef<str>>(&mut self, logit_bias: HashMap<T, i32>) -> &mut Self {
430 self.request_body.logit_bias =
431 Option::from(logit_bias.into_iter().map(|(k, v)| (k.as_ref().to_string(), v)).collect::<HashMap<String, i32>>());
432 self
433 }
434
435 /// Sets whether to include probability information for each token
436 ///
437 /// # Arguments
438 ///
439 /// * `logprobs` - `true` to include probability information
440 ///
441 /// # Returns
442 ///
443 /// A mutable reference to self for method chaining
444 pub fn logprobs(&mut self, logprobs: bool) -> &mut Self {
445 self.request_body.logprobs = Option::from(logprobs);
446 self
447 }
448
449 /// Sets the number of top probabilities to return for each token
450 ///
451 /// # Arguments
452 ///
453 /// * `top_logprobs` - Number of top probabilities (range: 0-20)
454 ///
455 /// # Returns
456 ///
457 /// A mutable reference to self for method chaining
458 pub fn top_logprobs(&mut self, top_logprobs: u8) -> &mut Self {
459 self.request_body.top_logprobs = Option::from(top_logprobs);
460 self
461 }
462
463 /// Sets the maximum number of tokens to generate
464 ///
465 /// # Arguments
466 ///
467 /// * `max_completion_tokens` - Maximum number of tokens
468 ///
469 /// # Returns
470 ///
471 /// A mutable reference to self for method chaining
472 pub fn max_completion_tokens(&mut self, max_completion_tokens: u64) -> &mut Self {
473 self.request_body.max_completion_tokens = Option::from(max_completion_tokens);
474 self
475 }
476
477 /// Sets the number of responses to generate
478 ///
479 /// # Arguments
480 ///
481 /// * `n` - Number of responses to generate
482 ///
483 /// # Returns
484 ///
485 /// A mutable reference to self for method chaining
486 pub fn n(&mut self, n: u32) -> &mut Self {
487 self.request_body.n = Option::from(n);
488 self
489 }
490
491 /// Sets the available modalities for the response
492 ///
493 /// # Arguments
494 ///
495 /// * `modalities` - List of modalities (e.g., `["text", "audio"]`)
496 ///
497 /// # Returns
498 ///
499 /// A mutable reference to self for method chaining
500 pub fn modalities<T: AsRef<str>>(&mut self, modalities: Vec<T>) -> &mut Self {
501 self.request_body.modalities = Option::from(modalities.into_iter().map(|m| m.as_ref().to_string()).collect::<Vec<String>>());
502 self
503 }
504
505 /// Sets the presence penalty
506 ///
507 /// A parameter that controls the tendency to include new content in the document.
508 /// Positive values encourage talking about new topics, negative values encourage
509 /// staying on existing topics.
510 ///
511 /// # Arguments
512 ///
513 /// * `presence_penalty` - Presence penalty value (range: -2.0 to 2.0)
514 ///
515 /// # Returns
516 ///
517 /// A mutable reference to self for method chaining
518 pub fn presence_penalty(&mut self, presence_penalty: f32) -> &mut Self {
519 self.request_body.presence_penalty = Option::from(presence_penalty);
520 self
521 }
522
523 /// Sets the temperature parameter to control response randomness
524 ///
525 /// Higher values (e.g., 1.0) produce more creative and diverse outputs,
526 /// while lower values (e.g., 0.2) produce more deterministic and consistent outputs.
527 ///
528 /// # Arguments
529 ///
530 /// * `temperature` - Temperature parameter (range: 0.0 to 2.0)
531 ///
532 /// # Returns
533 ///
534 /// A mutable reference to self for method chaining
535 pub fn temperature(&mut self, temperature: f32) -> &mut Self {
536 self.request_body.temperature = Option::from(temperature);
537 self
538 }
539
540 /// Sets structured output using JSON schema
541 ///
542 /// Enables receiving responses in a structured JSON format according to the
543 /// specified JSON schema.
544 ///
545 /// # Arguments
546 ///
547 /// * `json_schema` - JSON schema defining the response structure
548 ///
549 /// # Returns
550 ///
551 /// A mutable reference to self for method chaining
552 pub fn json_schema(&mut self, json_schema: Schema) -> &mut Self {
553 self.request_body.response_format = Option::from(Format::new(String::from("json_schema"), json_schema));
554 self
555 }
556
557 /// Sets the tools that can be called by the model
558 ///
559 /// Enables function calling by providing a list of tools that the model can choose to call.
560 /// When tools are provided, the model may generate tool calls instead of or in addition to
561 /// regular text responses.
562 ///
563 /// # Arguments
564 ///
565 /// * `tools` - Vector of tools available for the model to use
566 ///
567 /// # Returns
568 ///
569 /// A mutable reference to self for method chaining
570 pub fn tools(&mut self, tools: Vec<Tool>) -> &mut Self {
571 self.request_body.tools = Option::from(tools);
572 self
573 }
574
575 /// Gets the current message history
576 ///
577 /// # Returns
578 ///
579 /// A vector containing the message history
580 pub fn get_message_history(&self) -> Vec<Message> {
581 self.request_body.messages.clone()
582 }
583
584 /// Sends the chat completion request to OpenAI API
585 ///
586 /// This method validates the request parameters, constructs the HTTP request,
587 /// and sends it to the OpenAI Chat Completions endpoint.
588 ///
589 /// # Returns
590 ///
591 /// A `Result` containing the API response on success, or an error on failure.
592 ///
593 /// # Errors
594 ///
595 /// Returns an error if:
596 /// - API key is not set
597 /// - Model ID is not set
598 /// - Messages are empty
599 /// - Network request fails
600 /// - Response parsing fails
601 ///
602 /// # Example
603 ///
604 /// ```rust,no_run
605 /// use openai_tools::chat::request::ChatCompletion;
606 /// use openai_tools::common::message::Message;
607 /// use openai_tools::common::role::Role;
608 ///
609 /// # #[tokio::main]
610 /// # async fn main() -> Result<(), Box<dyn std::error::Error>>
611 /// # {
612 /// let mut chat = ChatCompletion::new();
613 /// let messages = vec![Message::from_string(Role::User, "Hello!")];
614 ///
615 /// let response = chat
616 /// .model_id("gpt-4o-mini")
617 /// .messages(messages)
618 /// .temperature(1.0)
619 /// .chat()
620 /// .await?;
621 ///
622 /// println!("{}", response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap());
623 /// # Ok::<(), Box<dyn std::error::Error>>(())
624 /// # }
625 /// ```
626 pub async fn chat(&mut self) -> Result<Response> {
627 // Check if the API key is set & body is built.
628 if self.api_key.is_empty() {
629 return Err(OpenAIToolError::Error("API key is not set.".into()));
630 }
631 if self.request_body.model.is_empty() {
632 return Err(OpenAIToolError::Error("Model ID is not set.".into()));
633 }
634 if self.request_body.messages.is_empty() {
635 return Err(OpenAIToolError::Error("Messages are not set.".into()));
636 }
637
638 let body = serde_json::to_string(&self.request_body)?;
639 let url = "https://api.openai.com/v1/chat/completions";
640
641 let client = request::Client::new();
642 let mut header = request::header::HeaderMap::new();
643 header.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
644 header.insert("Authorization", request::header::HeaderValue::from_str(&format!("Bearer {}", self.api_key)).unwrap());
645 header.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust/0.1.0"));
646
647 if cfg!(debug_assertions) {
648 // Replace API key with a placeholder in debug mode
649 let body_for_debug = serde_json::to_string_pretty(&self.request_body).unwrap().replace(&self.api_key, "*************");
650 tracing::info!("Request body: {}", body_for_debug);
651 }
652
653 let response = client.post(url).headers(header).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
654 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
655
656 if cfg!(debug_assertions) {
657 tracing::info!("Response content: {}", content);
658 }
659
660 serde_json::from_str::<Response>(&content).map_err(OpenAIToolError::SerdeJsonError)
661 }
662}