openai_tools/chat/
request.rs

1//! OpenAI Chat Completions API Request Module
2//!
3//! This module provides the functionality to build and send requests to the OpenAI Chat Completions API.
4//! It offers a builder pattern for constructing requests with various parameters and options,
5//! making it easy to interact with OpenAI's conversational AI models.
6//!
7//! # Key Features
8//!
9//! - **Builder Pattern**: Fluent API for constructing requests
10//! - **Structured Output**: Support for JSON schema-based responses
11//! - **Function Calling**: Tool integration for extended model capabilities
12//! - **Comprehensive Parameters**: Full support for all OpenAI API parameters
13//! - **Error Handling**: Robust error management and validation
14//!
15//! # Quick Start
16//!
17//! ```rust,no_run
18//! use openai_tools::chat::request::ChatCompletion;
19//! use openai_tools::common::message::Message;
20//! use openai_tools::common::role::Role;
21//!
22//! #[tokio::main]
23//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
24//!     // Initialize the chat completion client
25//!     let mut chat = ChatCompletion::new();
26//!     
27//!     // Create a simple conversation
28//!     let messages = vec![
29//!         Message::from_string(Role::User, "Hello! How are you?")
30//!     ];
31//!
32//!     // Send the request and get a response
33//!     let response = chat
34//!         .model_id("gpt-4o-mini")
35//!         .messages(messages)
36//!         .temperature(0.7)
37//!         .chat()
38//!         .await?;
39//!         
40//!     println!("AI Response: {}",
41//!              response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap());
42//!     Ok(())
43//! }
44//! ```
45//!
46//! # Advanced Usage
47//!
48//! ## Structured Output with JSON Schema
49//!
50//! ```rust,no_run
51//! use openai_tools::chat::request::ChatCompletion;
52//! use openai_tools::common::message::Message;
53//! use openai_tools::common::role::Role;
54//! use openai_tools::common::structured_output::Schema;
55//! use serde::{Deserialize, Serialize};
56//!
57//! #[derive(Serialize, Deserialize)]
58//! struct PersonInfo {
59//!     name: String,
60//!     age: u32,
61//!     occupation: String,
62//! }
63//!
64//! #[tokio::main]
65//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
66//!     let mut chat = ChatCompletion::new();
67//!     
68//!     // Define JSON schema for structured output
69//!     let mut schema = Schema::chat_json_schema("person_info");
70//!     schema.add_property("name", "string", "Person's full name");
71//!     schema.add_property("age", "number", "Person's age in years");
72//!     schema.add_property("occupation", "string", "Person's job or profession");
73//!     
74//!     let messages = vec![
75//!         Message::from_string(Role::User,
76//!             "Extract information about: John Smith, 30 years old, software engineer")
77//!     ];
78//!
79//!     let response = chat
80//!         .model_id("gpt-4o-mini")
81//!         .messages(messages)
82//!         .json_schema(schema)
83//!         .chat()
84//!         .await?;
85//!         
86//!     // Parse structured response
87//!     let person: PersonInfo = serde_json::from_str(
88//!         response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap()
89//!     )?;
90//!     
91//!     println!("Extracted: {} (age: {}, job: {})",
92//!              person.name, person.age, person.occupation);
93//!     Ok(())
94//! }
95//! ```
96//!
97//! ## Function Calling with Tools
98//!
99//! ```rust,no_run
100//! use openai_tools::chat::request::ChatCompletion;
101//! use openai_tools::common::message::Message;
102//! use openai_tools::common::role::Role;
103//! use openai_tools::common::tool::Tool;
104//! use openai_tools::common::parameters::ParameterProp;
105//!
106//! #[tokio::main]
107//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
108//!     let mut chat = ChatCompletion::new();
109//!     
110//!     // Define a weather checking tool
111//!     let weather_tool = Tool::function(
112//!         "get_weather",
113//!         "Get current weather information for a location",
114//!         vec![
115//!             ("location", ParameterProp::string("The city and country")),
116//!             ("unit", ParameterProp::string("Temperature unit (celsius/fahrenheit)")),
117//!         ],
118//!         false,
119//!     );
120//!     
121//!     let messages = vec![
122//!         Message::from_string(Role::User,
123//!             "What's the weather like in Tokyo today?")
124//!     ];
125//!
126//!     let response = chat
127//!         .model_id("gpt-4o-mini")
128//!         .messages(messages)
129//!         .tools(vec![weather_tool])
130//!         .temperature(0.1)
131//!         .chat()
132//!         .await?;
133//!         
134//!     // Handle tool calls
135//!     if let Some(tool_calls) = &response.choices[0].message.tool_calls {
136//!         for call in tool_calls {
137//!             println!("Tool called: {}", call.function.name);
138//!             if let Ok(args) = call.function.arguments_as_map() {
139//!                 println!("Arguments: {:?}", args);
140//!             }
141//!             // Execute the function and continue the conversation...
142//!         }
143//!     }
144//!     Ok(())
145//! }
146//! ```
147//!
148//! # Environment Setup
149//!
150//! Before using this module, ensure you have set up your OpenAI API key:
151//!
152//! ```bash
153//! export OPENAI_API_KEY="your-api-key-here"
154//! ```
155//!
156//! Or create a `.env` file in your project root:
157//!
158//! ```text
159//! OPENAI_API_KEY=your-api-key-here
160//! ```
161//!
162//!
163//! # Error Handling
164//!
165//! All methods return a `Result` type for proper error handling:
166//!
167//! ```rust,no_run
168//! use openai_tools::chat::request::ChatCompletion;
169//! use openai_tools::common::errors::OpenAIToolError;
170//!
171//! #[tokio::main]
172//! async fn main() {
173//!     let mut chat = ChatCompletion::new();
174//!     
175//!     match chat.model_id("gpt-4o-mini").chat().await {
176//!         Ok(response) => {
177//!             if let Some(content) = &response.choices[0].message.content {
178//!                 if let Some(text) = &content.text {
179//!                     println!("Success: {}", text);
180//!                 }
181//!             }
182//!         }
183//!         Err(OpenAIToolError::RequestError(e)) => {
184//!             eprintln!("Network error: {}", e);
185//!         }
186//!         Err(OpenAIToolError::SerdeJsonError(e)) => {
187//!             eprintln!("JSON parsing error: {}", e);
188//!         }
189//!         Err(e) => {
190//!             eprintln!("Other error: {}", e);
191//!         }
192//!     }
193//! }
194//! ```
195
196use crate::chat::response::Response;
197use crate::common::{
198    errors::{OpenAIToolError, Result},
199    message::Message,
200    structured_output::Schema,
201    tool::Tool,
202};
203use core::str;
204use dotenvy::dotenv;
205use serde::{Deserialize, Serialize};
206use std::collections::HashMap;
207use std::env;
208
209/// Response format structure for OpenAI API requests
210///
211/// This structure is used for structured output when JSON schema is specified.
212#[derive(Debug, Clone, Deserialize, Serialize)]
213struct Format {
214    #[serde(rename = "type")]
215    type_name: String,
216    json_schema: Schema,
217}
218
219impl Format {
220    /// Creates a new Format structure
221    ///
222    /// # Arguments
223    ///
224    /// * `type_name` - The type name for the response format
225    /// * `json_schema` - The JSON schema definition
226    ///
227    /// # Returns
228    ///
229    /// A new Format structure instance
230    pub fn new<T: AsRef<str>>(type_name: T, json_schema: Schema) -> Self {
231        Self { type_name: type_name.as_ref().to_string(), json_schema }
232    }
233}
234
235/// Request body structure for OpenAI Chat Completions API
236///
237/// This structure represents the parameters that will be sent in the request body
238/// to the OpenAI API. Each field corresponds to the API specification.
239#[derive(Debug, Clone, Deserialize, Serialize, Default)]
240struct Body {
241    model: String,
242    messages: Vec<Message>,
243    /// Whether to store the request and response at OpenAI
244    #[serde(skip_serializing_if = "Option::is_none")]
245    store: Option<bool>,
246    /// Frequency penalty parameter to reduce repetition (-2.0 to 2.0)
247    #[serde(skip_serializing_if = "Option::is_none")]
248    frequency_penalty: Option<f32>,
249    /// Logit bias to adjust the probability of specific tokens
250    #[serde(skip_serializing_if = "Option::is_none")]
251    logit_bias: Option<HashMap<String, i32>>,
252    /// Whether to include probability information for each token
253    #[serde(skip_serializing_if = "Option::is_none")]
254    logprobs: Option<bool>,
255    /// Number of top probabilities to return for each token (0-20)
256    #[serde(skip_serializing_if = "Option::is_none")]
257    top_logprobs: Option<u8>,
258    /// Maximum number of tokens to generate
259    #[serde(skip_serializing_if = "Option::is_none")]
260    max_completion_tokens: Option<u64>,
261    /// Number of responses to generate
262    #[serde(skip_serializing_if = "Option::is_none")]
263    n: Option<u32>,
264    /// Available modalities for the response (e.g., text, audio)
265    #[serde(skip_serializing_if = "Option::is_none")]
266    modalities: Option<Vec<String>>,
267    /// Presence penalty to encourage new topics (-2.0 to 2.0)
268    #[serde(skip_serializing_if = "Option::is_none")]
269    presence_penalty: Option<f32>,
270    /// Temperature parameter to control response randomness (0.0 to 2.0)
271    #[serde(skip_serializing_if = "Option::is_none")]
272    temperature: Option<f32>,
273    /// Response format specification (e.g., JSON schema)
274    #[serde(skip_serializing_if = "Option::is_none")]
275    response_format: Option<Format>,
276    /// Optional tools that can be used by the model
277    #[serde(skip_serializing_if = "Option::is_none")]
278    tools: Option<Vec<Tool>>,
279}
280
281/// OpenAI Chat Completions API client
282///
283/// This structure manages interactions with the OpenAI Chat Completions API.
284/// It handles API key management, request parameter configuration, and API calls.
285///
286/// # Example
287///
288/// ```rust
289/// use openai_tools::chat::request::ChatCompletion;
290/// use openai_tools::common::message::Message;
291/// use openai_tools::common::role::Role;
292///
293/// # #[tokio::main]
294/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
295/// let mut chat = ChatCompletion::new();
296/// let messages = vec![Message::from_string(Role::User, "Hello!")];
297///
298/// let response = chat
299///     .model_id("gpt-4o-mini")
300///     .messages(messages)
301///     .temperature(1.0)
302///     .chat()
303///     .await?;
304/// # Ok::<(), Box<dyn std::error::Error>>(())
305/// # }
306/// ```
307#[derive(Debug, Clone, Default, Deserialize, Serialize)]
308pub struct ChatCompletion {
309    api_key: String,
310    request_body: Body,
311}
312
313impl ChatCompletion {
314    /// Creates a new ChatCompletion instance
315    ///
316    /// Loads the API key from the `OPENAI_API_KEY` environment variable.
317    /// If a `.env` file exists, it will also be loaded.
318    ///
319    /// # Panics
320    ///
321    /// Panics if the `OPENAI_API_KEY` environment variable is not set.
322    ///
323    /// # Returns
324    ///
325    /// A new ChatCompletion instance
326    pub fn new() -> Self {
327        dotenv().ok();
328        let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set.");
329        Self { api_key, request_body: Body::default() }
330    }
331
332    /// Sets the model ID to use
333    ///
334    /// # Arguments
335    ///
336    /// * `model_id` - OpenAI model ID (e.g., `gpt-4o-mini`, `gpt-4o`)
337    ///
338    /// # Returns
339    ///
340    /// A mutable reference to self for method chaining
341    pub fn model_id<T: AsRef<str>>(&mut self, model_id: T) -> &mut Self {
342        self.request_body.model = model_id.as_ref().to_string();
343        self
344    }
345
346    /// Sets the chat message history
347    ///
348    /// # Arguments
349    ///
350    /// * `messages` - Vector of chat messages representing the conversation history
351    ///
352    /// # Returns
353    ///
354    /// A mutable reference to self for method chaining
355    pub fn messages(&mut self, messages: Vec<Message>) -> &mut Self {
356        self.request_body.messages = messages;
357        self
358    }
359
360    /// Sets whether to store the request and response at OpenAI
361    ///
362    /// # Arguments
363    ///
364    /// * `store` - `true` to store, `false` to not store
365    ///
366    /// # Returns
367    ///
368    /// A mutable reference to self for method chaining
369    pub fn store(&mut self, store: bool) -> &mut Self {
370        self.request_body.store = Option::from(store);
371        self
372    }
373
374    /// Sets the frequency penalty
375    ///
376    /// A parameter that penalizes based on word frequency to reduce repetition.
377    /// Positive values decrease repetition, negative values increase it.
378    ///
379    /// # Arguments
380    ///
381    /// * `frequency_penalty` - Frequency penalty value (range: -2.0 to 2.0)
382    ///
383    /// # Returns
384    ///
385    /// A mutable reference to self for method chaining
386    pub fn frequency_penalty(&mut self, frequency_penalty: f32) -> &mut Self {
387        self.request_body.frequency_penalty = Option::from(frequency_penalty);
388        self
389    }
390
391    /// Sets logit bias to adjust the probability of specific tokens
392    ///
393    /// # Arguments
394    ///
395    /// * `logit_bias` - A map of token IDs to adjustment values
396    ///
397    /// # Returns
398    ///
399    /// A mutable reference to self for method chaining
400    pub fn logit_bias<T: AsRef<str>>(&mut self, logit_bias: HashMap<T, i32>) -> &mut Self {
401        self.request_body.logit_bias =
402            Option::from(logit_bias.into_iter().map(|(k, v)| (k.as_ref().to_string(), v)).collect::<HashMap<String, i32>>());
403        self
404    }
405
406    /// Sets whether to include probability information for each token
407    ///
408    /// # Arguments
409    ///
410    /// * `logprobs` - `true` to include probability information
411    ///
412    /// # Returns
413    ///
414    /// A mutable reference to self for method chaining
415    pub fn logprobs(&mut self, logprobs: bool) -> &mut Self {
416        self.request_body.logprobs = Option::from(logprobs);
417        self
418    }
419
420    /// Sets the number of top probabilities to return for each token
421    ///
422    /// # Arguments
423    ///
424    /// * `top_logprobs` - Number of top probabilities (range: 0-20)
425    ///
426    /// # Returns
427    ///
428    /// A mutable reference to self for method chaining
429    pub fn top_logprobs(&mut self, top_logprobs: u8) -> &mut Self {
430        self.request_body.top_logprobs = Option::from(top_logprobs);
431        self
432    }
433
434    /// Sets the maximum number of tokens to generate
435    ///
436    /// # Arguments
437    ///
438    /// * `max_completion_tokens` - Maximum number of tokens
439    ///
440    /// # Returns
441    ///
442    /// A mutable reference to self for method chaining
443    pub fn max_completion_tokens(&mut self, max_completion_tokens: u64) -> &mut Self {
444        self.request_body.max_completion_tokens = Option::from(max_completion_tokens);
445        self
446    }
447
448    /// Sets the number of responses to generate
449    ///
450    /// # Arguments
451    ///
452    /// * `n` - Number of responses to generate
453    ///
454    /// # Returns
455    ///
456    /// A mutable reference to self for method chaining
457    pub fn n(&mut self, n: u32) -> &mut Self {
458        self.request_body.n = Option::from(n);
459        self
460    }
461
462    /// Sets the available modalities for the response
463    ///
464    /// # Arguments
465    ///
466    /// * `modalities` - List of modalities (e.g., `["text", "audio"]`)
467    ///
468    /// # Returns
469    ///
470    /// A mutable reference to self for method chaining
471    pub fn modalities<T: AsRef<str>>(&mut self, modalities: Vec<T>) -> &mut Self {
472        self.request_body.modalities = Option::from(modalities.into_iter().map(|m| m.as_ref().to_string()).collect::<Vec<String>>());
473        self
474    }
475
476    /// Sets the presence penalty
477    ///
478    /// A parameter that controls the tendency to include new content in the document.
479    /// Positive values encourage talking about new topics, negative values encourage
480    /// staying on existing topics.
481    ///
482    /// # Arguments
483    ///
484    /// * `presence_penalty` - Presence penalty value (range: -2.0 to 2.0)
485    ///
486    /// # Returns
487    ///
488    /// A mutable reference to self for method chaining
489    pub fn presence_penalty(&mut self, presence_penalty: f32) -> &mut Self {
490        self.request_body.presence_penalty = Option::from(presence_penalty);
491        self
492    }
493
494    /// Sets the temperature parameter to control response randomness
495    ///
496    /// Higher values (e.g., 1.0) produce more creative and diverse outputs,
497    /// while lower values (e.g., 0.2) produce more deterministic and consistent outputs.
498    ///
499    /// # Arguments
500    ///
501    /// * `temperature` - Temperature parameter (range: 0.0 to 2.0)
502    ///
503    /// # Returns
504    ///
505    /// A mutable reference to self for method chaining
506    pub fn temperature(&mut self, temperature: f32) -> &mut Self {
507        self.request_body.temperature = Option::from(temperature);
508        self
509    }
510
511    /// Sets structured output using JSON schema
512    ///
513    /// Enables receiving responses in a structured JSON format according to the
514    /// specified JSON schema.
515    ///
516    /// # Arguments
517    ///
518    /// * `json_schema` - JSON schema defining the response structure
519    ///
520    /// # Returns
521    ///
522    /// A mutable reference to self for method chaining
523    pub fn json_schema(&mut self, json_schema: Schema) -> &mut Self {
524        self.request_body.response_format = Option::from(Format::new(String::from("json_schema"), json_schema));
525        self
526    }
527
528    /// Sets the tools that can be called by the model
529    ///
530    /// Enables function calling by providing a list of tools that the model can choose to call.
531    /// When tools are provided, the model may generate tool calls instead of or in addition to
532    /// regular text responses.
533    ///
534    /// # Arguments
535    ///
536    /// * `tools` - Vector of tools available for the model to use
537    ///
538    /// # Returns
539    ///
540    /// A mutable reference to self for method chaining
541    pub fn tools(&mut self, tools: Vec<Tool>) -> &mut Self {
542        self.request_body.tools = Option::from(tools);
543        self
544    }
545
546    /// Gets the current message history
547    ///
548    /// # Returns
549    ///
550    /// A vector containing the message history
551    pub fn get_message_history(&self) -> Vec<Message> {
552        self.request_body.messages.clone()
553    }
554
555    /// Sends the chat completion request to OpenAI API
556    ///
557    /// This method validates the request parameters, constructs the HTTP request,
558    /// and sends it to the OpenAI Chat Completions endpoint.
559    ///
560    /// # Returns
561    ///
562    /// A `Result` containing the API response on success, or an error on failure.
563    ///
564    /// # Errors
565    ///
566    /// Returns an error if:
567    /// - API key is not set
568    /// - Model ID is not set
569    /// - Messages are empty
570    /// - Network request fails
571    /// - Response parsing fails
572    ///
573    /// # Example
574    ///
575    /// ```rust,no_run
576    /// use openai_tools::chat::request::ChatCompletion;
577    /// use openai_tools::common::message::Message;
578    /// use openai_tools::common::role::Role;
579    ///
580    /// # #[tokio::main]
581    /// # async fn main() -> Result<(), Box<dyn std::error::Error>>
582    /// # {
583    /// let mut chat = ChatCompletion::new();
584    /// let messages = vec![Message::from_string(Role::User, "Hello!")];
585    ///
586    /// let response = chat
587    ///     .model_id("gpt-4o-mini")
588    ///     .messages(messages)
589    ///     .temperature(1.0)
590    ///     .chat()
591    ///     .await?;
592    ///     
593    /// println!("{}", response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap());
594    /// # Ok::<(), Box<dyn std::error::Error>>(())
595    /// # }
596    /// ```
597    pub async fn chat(&mut self) -> Result<Response> {
598        // Check if the API key is set & body is built.
599        if self.api_key.is_empty() {
600            return Err(OpenAIToolError::Error("API key is not set.".into()));
601        }
602        if self.request_body.model.is_empty() {
603            return Err(OpenAIToolError::Error("Model ID is not set.".into()));
604        }
605        if self.request_body.messages.is_empty() {
606            return Err(OpenAIToolError::Error("Messages are not set.".into()));
607        }
608
609        let body = serde_json::to_string(&self.request_body)?;
610        let url = "https://api.openai.com/v1/chat/completions";
611
612        let client = request::Client::new();
613        let mut header = request::header::HeaderMap::new();
614        header.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
615        header.insert("Authorization", request::header::HeaderValue::from_str(&format!("Bearer {}", self.api_key)).unwrap());
616        header.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust/0.1.0"));
617
618        if cfg!(debug_assertions) {
619            // Replace API key with a placeholder in debug mode
620            let body_for_debug = serde_json::to_string_pretty(&self.request_body).unwrap().replace(&self.api_key, "*************");
621            tracing::info!("Request body: {}", body_for_debug);
622        }
623
624        let response = client.post(url).headers(header).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
625        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
626
627        if cfg!(debug_assertions) {
628            tracing::info!("Response content: {}", content);
629        }
630
631        serde_json::from_str::<Response>(&content).map_err(OpenAIToolError::SerdeJsonError)
632    }
633}