openai_tools/chat/
request.rs

1//! OpenAI Chat Completions API Request Module
2//!
3//! This module provides the functionality to build and send requests to the OpenAI Chat Completions API.
4//! It offers a builder pattern for constructing requests with various parameters and options,
5//! making it easy to interact with OpenAI's conversational AI models.
6//!
7//! # Key Features
8//!
9//! - **Builder Pattern**: Fluent API for constructing requests
10//! - **Structured Output**: Support for JSON schema-based responses
11//! - **Function Calling**: Tool integration for extended model capabilities
12//! - **Comprehensive Parameters**: Full support for all OpenAI API parameters
13//! - **Error Handling**: Robust error management and validation
14//!
15//! # Quick Start
16//!
17//! ```rust,no_run
18//! use openai_tools::chat::request::ChatCompletion;
19//! use openai_tools::common::message::Message;
20//! use openai_tools::common::role::Role;
21//!
22//! #[tokio::main]
23//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
24//!     // Initialize the chat completion client
25//!     let mut chat = ChatCompletion::new();
26//!     
27//!     // Create a simple conversation
28//!     let messages = vec![
29//!         Message::from_string(Role::User, "Hello! How are you?")
30//!     ];
31//!
32//!     // Send the request and get a response
33//!     let response = chat
34//!         .model_id("gpt-4o-mini")
35//!         .messages(messages)
36//!         .temperature(0.7)
37//!         .chat()
38//!         .await?;
39//!         
40//!     println!("AI Response: {}",
41//!              response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap());
42//!     Ok(())
43//! }
44//! ```
45//!
46//! # Advanced Usage
47//!
48//! ## Structured Output with JSON Schema
49//!
50//! ```rust,no_run
51//! use openai_tools::chat::request::ChatCompletion;
52//! use openai_tools::common::message::Message;
53//! use openai_tools::common::role::Role;
54//! use openai_tools::common::structured_output::Schema;
55//! use serde::{Deserialize, Serialize};
56//!
57//! #[derive(Serialize, Deserialize)]
58//! struct PersonInfo {
59//!     name: String,
60//!     age: u32,
61//!     occupation: String,
62//! }
63//!
64//! #[tokio::main]
65//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
66//!     let mut chat = ChatCompletion::new();
67//!     
68//!     // Define JSON schema for structured output
69//!     let mut schema = Schema::chat_json_schema("person_info");
70//!     schema.add_property("name", "string", "Person's full name");
71//!     schema.add_property("age", "number", "Person's age in years");
72//!     schema.add_property("occupation", "string", "Person's job or profession");
73//!     
74//!     let messages = vec![
75//!         Message::from_string(Role::User,
76//!             "Extract information about: John Smith, 30 years old, software engineer")
77//!     ];
78//!
79//!     let response = chat
80//!         .model_id("gpt-4o-mini")
81//!         .messages(messages)
82//!         .json_schema(schema)
83//!         .chat()
84//!         .await?;
85//!         
86//!     // Parse structured response
87//!     let person: PersonInfo = serde_json::from_str(
88//!         response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap()
89//!     )?;
90//!     
91//!     println!("Extracted: {} (age: {}, job: {})",
92//!              person.name, person.age, person.occupation);
93//!     Ok(())
94//! }
95//! ```
96//!
97//! ## Function Calling with Tools
98//!
99//! ```rust,no_run
100//! use openai_tools::chat::request::ChatCompletion;
101//! use openai_tools::common::message::Message;
102//! use openai_tools::common::role::Role;
103//! use openai_tools::common::tool::Tool;
104//! use openai_tools::common::parameters::ParameterProperty;
105//!
106//! #[tokio::main]
107//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
108//!     let mut chat = ChatCompletion::new();
109//!     
110//!     // Define a weather checking tool
111//!     let weather_tool = Tool::function(
112//!         "get_weather",
113//!         "Get current weather information for a location",
114//!         vec![
115//!             ("location", ParameterProperty::from_string("The city and country")),
116//!             ("unit", ParameterProperty::from_string("Temperature unit (celsius/fahrenheit)")),
117//!         ],
118//!         false,
119//!     );
120//!     
121//!     let messages = vec![
122//!         Message::from_string(Role::User,
123//!             "What's the weather like in Tokyo today?")
124//!     ];
125//!
126//!     let response = chat
127//!         .model_id("gpt-4o-mini")
128//!         .messages(messages)
129//!         .tools(vec![weather_tool])
130//!         .temperature(0.1)
131//!         .chat()
132//!         .await?;
133//!         
134//!     // Handle tool calls
135//!     if let Some(tool_calls) = &response.choices[0].message.tool_calls {
136//!         for call in tool_calls {
137//!             println!("Tool called: {}", call.function.name);
138//!             if let Ok(args) = call.function.arguments_as_map() {
139//!                 println!("Arguments: {:?}", args);
140//!             }
141//!             // Execute the function and continue the conversation...
142//!         }
143//!     }
144//!     Ok(())
145//! }
146//! ```
147//!
148//! # Environment Setup
149//!
150//! Before using this module, ensure you have set up your OpenAI API key:
151//!
152//! ```bash
153//! export OPENAI_API_KEY="your-api-key-here"
154//! ```
155//!
156//! Or create a `.env` file in your project root:
157//!
158//! ```text
159//! OPENAI_API_KEY=your-api-key-here
160//! ```
161//!
162//!
163//! # Error Handling
164//!
165//! All methods return a `Result` type for proper error handling:
166//!
167//! ```rust,no_run
168//! use openai_tools::chat::request::ChatCompletion;
169//! use openai_tools::common::errors::OpenAIToolError;
170//!
171//! #[tokio::main]
172//! async fn main() {
173//!     let mut chat = ChatCompletion::new();
174//!     
175//!     match chat.model_id("gpt-4o-mini").chat().await {
176//!         Ok(response) => {
177//!             if let Some(content) = &response.choices[0].message.content {
178//!                 if let Some(text) = &content.text {
179//!                     println!("Success: {}", text);
180//!                 }
181//!             }
182//!         }
183//!         Err(OpenAIToolError::RequestError(e)) => {
184//!             eprintln!("Network error: {}", e);
185//!         }
186//!         Err(OpenAIToolError::SerdeJsonError(e)) => {
187//!             eprintln!("JSON parsing error: {}", e);
188//!         }
189//!         Err(e) => {
190//!             eprintln!("Other error: {}", e);
191//!         }
192//!     }
193//! }
194//! ```
195
196use crate::chat::response::Response;
197use crate::common::{
198    errors::{OpenAIToolError, Result},
199    message::Message,
200    structured_output::Schema,
201    tool::Tool,
202};
203use core::str;
204use dotenvy::dotenv;
205use serde::{Deserialize, Serialize};
206use std::collections::HashMap;
207use std::env;
208
209/// Response format structure for OpenAI API requests
210///
211/// This structure is used for structured output when JSON schema is specified.
212#[derive(Debug, Clone, Deserialize, Serialize)]
213struct Format {
214    #[serde(rename = "type")]
215    type_name: String,
216    json_schema: Schema,
217}
218
219impl Format {
220    /// Creates a new Format structure
221    ///
222    /// # Arguments
223    ///
224    /// * `type_name` - The type name for the response format
225    /// * `json_schema` - The JSON schema definition
226    ///
227    /// # Returns
228    ///
229    /// A new Format structure instance
230    pub fn new<T: AsRef<str>>(type_name: T, json_schema: Schema) -> Self {
231        Self { type_name: type_name.as_ref().to_string(), json_schema }
232    }
233}
234
235/// Request body structure for OpenAI Chat Completions API
236///
237/// This structure represents the parameters that will be sent in the request body
238/// to the OpenAI API. Each field corresponds to the API specification.
239#[derive(Debug, Clone, Deserialize, Serialize, Default)]
240struct Body {
241    model: String,
242    messages: Vec<Message>,
243    /// Whether to store the request and response at OpenAI
244    #[serde(skip_serializing_if = "Option::is_none")]
245    store: Option<bool>,
246    /// Frequency penalty parameter to reduce repetition (-2.0 to 2.0)
247    #[serde(skip_serializing_if = "Option::is_none")]
248    frequency_penalty: Option<f32>,
249    /// Logit bias to adjust the probability of specific tokens
250    #[serde(skip_serializing_if = "Option::is_none")]
251    logit_bias: Option<HashMap<String, i32>>,
252    /// Whether to include probability information for each token
253    #[serde(skip_serializing_if = "Option::is_none")]
254    logprobs: Option<bool>,
255    /// Number of top probabilities to return for each token (0-20)
256    #[serde(skip_serializing_if = "Option::is_none")]
257    top_logprobs: Option<u8>,
258    /// Maximum number of tokens to generate
259    #[serde(skip_serializing_if = "Option::is_none")]
260    max_completion_tokens: Option<u64>,
261    /// Number of responses to generate
262    #[serde(skip_serializing_if = "Option::is_none")]
263    n: Option<u32>,
264    /// Available modalities for the response (e.g., text, audio)
265    #[serde(skip_serializing_if = "Option::is_none")]
266    modalities: Option<Vec<String>>,
267    /// Presence penalty to encourage new topics (-2.0 to 2.0)
268    #[serde(skip_serializing_if = "Option::is_none")]
269    presence_penalty: Option<f32>,
270    /// Temperature parameter to control response randomness (0.0 to 2.0)
271    #[serde(skip_serializing_if = "Option::is_none")]
272    temperature: Option<f32>,
273    /// Response format specification (e.g., JSON schema)
274    #[serde(skip_serializing_if = "Option::is_none")]
275    response_format: Option<Format>,
276    /// Optional tools that can be used by the model
277    #[serde(skip_serializing_if = "Option::is_none")]
278    tools: Option<Vec<Tool>>,
279}
280
281/// OpenAI Chat Completions API client
282///
283/// This structure manages interactions with the OpenAI Chat Completions API.
284/// It handles API key management, request parameter configuration, and API calls.
285///
286/// # Example
287///
288/// ```rust
289/// use openai_tools::chat::request::ChatCompletion;
290/// use openai_tools::common::message::Message;
291/// use openai_tools::common::role::Role;
292///
293/// # #[tokio::main]
294/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
295/// let mut chat = ChatCompletion::new();
296/// let messages = vec![Message::from_string(Role::User, "Hello!")];
297///
298/// let response = chat
299///     .model_id("gpt-4o-mini")
300///     .messages(messages)
301///     .temperature(1.0)
302///     .chat()
303///     .await?;
304/// # Ok::<(), Box<dyn std::error::Error>>(())
305/// # }
306/// ```
307#[derive(Debug, Clone, Default, Deserialize, Serialize)]
308pub struct ChatCompletion {
309    api_key: String,
310    request_body: Body,
311}
312
313impl ChatCompletion {
314    /// Creates a new ChatCompletion instance
315    ///
316    /// Loads the API key from the `OPENAI_API_KEY` environment variable.
317    /// If a `.env` file exists, it will also be loaded.
318    ///
319    /// # Panics
320    ///
321    /// Panics if the `OPENAI_API_KEY` environment variable is not set.
322    ///
323    /// # Returns
324    ///
325    /// A new ChatCompletion instance
326    pub fn new() -> Self {
327        dotenv().ok();
328        let api_key =
329            env::var("OPENAI_API_KEY").map_err(|e| OpenAIToolError::Error(format!("OPENAI_API_KEY not set in environment: {}", e))).unwrap();
330        Self { api_key, request_body: Body::default() }
331    }
332
333    /// Sets the model ID to use
334    ///
335    /// # Arguments
336    ///
337    /// * `model_id` - OpenAI model ID (e.g., `gpt-4o-mini`, `gpt-4o`)
338    ///
339    /// # Returns
340    ///
341    /// A mutable reference to self for method chaining
342    pub fn model_id<T: AsRef<str>>(&mut self, model_id: T) -> &mut Self {
343        self.request_body.model = model_id.as_ref().to_string();
344        self
345    }
346
347    /// Sets the chat message history
348    ///
349    /// # Arguments
350    ///
351    /// * `messages` - Vector of chat messages representing the conversation history
352    ///
353    /// # Returns
354    ///
355    /// A mutable reference to self for method chaining
356    pub fn messages(&mut self, messages: Vec<Message>) -> &mut Self {
357        self.request_body.messages = messages;
358        self
359    }
360
361    /// Adds a single message to the conversation history
362    ///
363    /// This method appends a new message to the existing conversation history.
364    /// It's useful for building conversations incrementally.
365    ///
366    /// # Arguments
367    ///
368    /// * `message` - The message to add to the conversation
369    ///
370    /// # Returns
371    ///
372    /// A mutable reference to self for method chaining
373    ///
374    /// # Examples
375    ///
376    /// ```rust,no_run
377    /// use openai_tools::chat::request::ChatCompletion;
378    /// use openai_tools::common::message::Message;
379    /// use openai_tools::common::role::Role;
380    ///
381    /// let mut chat = ChatCompletion::new();
382    /// chat.add_message(Message::from_string(Role::User, "Hello!"))
383    ///     .add_message(Message::from_string(Role::Assistant, "Hi there!"))
384    ///     .add_message(Message::from_string(Role::User, "How are you?"));
385    /// ```
386    pub fn add_message(&mut self, message: Message) -> &mut Self {
387        self.request_body.messages.push(message);
388        self
389    }
390    /// Sets whether to store the request and response at OpenAI
391    ///
392    /// # Arguments
393    ///
394    /// * `store` - `true` to store, `false` to not store
395    ///
396    /// # Returns
397    ///
398    /// A mutable reference to self for method chaining
399    pub fn store(&mut self, store: bool) -> &mut Self {
400        self.request_body.store = Option::from(store);
401        self
402    }
403
404    /// Sets the frequency penalty
405    ///
406    /// A parameter that penalizes based on word frequency to reduce repetition.
407    /// Positive values decrease repetition, negative values increase it.
408    ///
409    /// # Arguments
410    ///
411    /// * `frequency_penalty` - Frequency penalty value (range: -2.0 to 2.0)
412    ///
413    /// # Returns
414    ///
415    /// A mutable reference to self for method chaining
416    pub fn frequency_penalty(&mut self, frequency_penalty: f32) -> &mut Self {
417        self.request_body.frequency_penalty = Option::from(frequency_penalty);
418        self
419    }
420
421    /// Sets logit bias to adjust the probability of specific tokens
422    ///
423    /// # Arguments
424    ///
425    /// * `logit_bias` - A map of token IDs to adjustment values
426    ///
427    /// # Returns
428    ///
429    /// A mutable reference to self for method chaining
430    pub fn logit_bias<T: AsRef<str>>(&mut self, logit_bias: HashMap<T, i32>) -> &mut Self {
431        self.request_body.logit_bias =
432            Option::from(logit_bias.into_iter().map(|(k, v)| (k.as_ref().to_string(), v)).collect::<HashMap<String, i32>>());
433        self
434    }
435
436    /// Sets whether to include probability information for each token
437    ///
438    /// # Arguments
439    ///
440    /// * `logprobs` - `true` to include probability information
441    ///
442    /// # Returns
443    ///
444    /// A mutable reference to self for method chaining
445    pub fn logprobs(&mut self, logprobs: bool) -> &mut Self {
446        self.request_body.logprobs = Option::from(logprobs);
447        self
448    }
449
450    /// Sets the number of top probabilities to return for each token
451    ///
452    /// # Arguments
453    ///
454    /// * `top_logprobs` - Number of top probabilities (range: 0-20)
455    ///
456    /// # Returns
457    ///
458    /// A mutable reference to self for method chaining
459    pub fn top_logprobs(&mut self, top_logprobs: u8) -> &mut Self {
460        self.request_body.top_logprobs = Option::from(top_logprobs);
461        self
462    }
463
464    /// Sets the maximum number of tokens to generate
465    ///
466    /// # Arguments
467    ///
468    /// * `max_completion_tokens` - Maximum number of tokens
469    ///
470    /// # Returns
471    ///
472    /// A mutable reference to self for method chaining
473    pub fn max_completion_tokens(&mut self, max_completion_tokens: u64) -> &mut Self {
474        self.request_body.max_completion_tokens = Option::from(max_completion_tokens);
475        self
476    }
477
478    /// Sets the number of responses to generate
479    ///
480    /// # Arguments
481    ///
482    /// * `n` - Number of responses to generate
483    ///
484    /// # Returns
485    ///
486    /// A mutable reference to self for method chaining
487    pub fn n(&mut self, n: u32) -> &mut Self {
488        self.request_body.n = Option::from(n);
489        self
490    }
491
492    /// Sets the available modalities for the response
493    ///
494    /// # Arguments
495    ///
496    /// * `modalities` - List of modalities (e.g., `["text", "audio"]`)
497    ///
498    /// # Returns
499    ///
500    /// A mutable reference to self for method chaining
501    pub fn modalities<T: AsRef<str>>(&mut self, modalities: Vec<T>) -> &mut Self {
502        self.request_body.modalities = Option::from(modalities.into_iter().map(|m| m.as_ref().to_string()).collect::<Vec<String>>());
503        self
504    }
505
506    /// Sets the presence penalty
507    ///
508    /// A parameter that controls the tendency to include new content in the document.
509    /// Positive values encourage talking about new topics, negative values encourage
510    /// staying on existing topics.
511    ///
512    /// # Arguments
513    ///
514    /// * `presence_penalty` - Presence penalty value (range: -2.0 to 2.0)
515    ///
516    /// # Returns
517    ///
518    /// A mutable reference to self for method chaining
519    pub fn presence_penalty(&mut self, presence_penalty: f32) -> &mut Self {
520        self.request_body.presence_penalty = Option::from(presence_penalty);
521        self
522    }
523
524    /// Sets the temperature parameter to control response randomness
525    ///
526    /// Higher values (e.g., 1.0) produce more creative and diverse outputs,
527    /// while lower values (e.g., 0.2) produce more deterministic and consistent outputs.
528    ///
529    /// # Arguments
530    ///
531    /// * `temperature` - Temperature parameter (range: 0.0 to 2.0)
532    ///
533    /// # Returns
534    ///
535    /// A mutable reference to self for method chaining
536    pub fn temperature(&mut self, temperature: f32) -> &mut Self {
537        self.request_body.temperature = Option::from(temperature);
538        self
539    }
540
541    /// Sets structured output using JSON schema
542    ///
543    /// Enables receiving responses in a structured JSON format according to the
544    /// specified JSON schema.
545    ///
546    /// # Arguments
547    ///
548    /// * `json_schema` - JSON schema defining the response structure
549    ///
550    /// # Returns
551    ///
552    /// A mutable reference to self for method chaining
553    pub fn json_schema(&mut self, json_schema: Schema) -> &mut Self {
554        self.request_body.response_format = Option::from(Format::new(String::from("json_schema"), json_schema));
555        self
556    }
557
558    /// Sets the tools that can be called by the model
559    ///
560    /// Enables function calling by providing a list of tools that the model can choose to call.
561    /// When tools are provided, the model may generate tool calls instead of or in addition to
562    /// regular text responses.
563    ///
564    /// # Arguments
565    ///
566    /// * `tools` - Vector of tools available for the model to use
567    ///
568    /// # Returns
569    ///
570    /// A mutable reference to self for method chaining
571    pub fn tools(&mut self, tools: Vec<Tool>) -> &mut Self {
572        self.request_body.tools = Option::from(tools);
573        self
574    }
575
576    /// Gets the current message history
577    ///
578    /// # Returns
579    ///
580    /// A vector containing the message history
581    pub fn get_message_history(&self) -> Vec<Message> {
582        self.request_body.messages.clone()
583    }
584
585    /// Sends the chat completion request to OpenAI API
586    ///
587    /// This method validates the request parameters, constructs the HTTP request,
588    /// and sends it to the OpenAI Chat Completions endpoint.
589    ///
590    /// # Returns
591    ///
592    /// A `Result` containing the API response on success, or an error on failure.
593    ///
594    /// # Errors
595    ///
596    /// Returns an error if:
597    /// - API key is not set
598    /// - Model ID is not set
599    /// - Messages are empty
600    /// - Network request fails
601    /// - Response parsing fails
602    ///
603    /// # Example
604    ///
605    /// ```rust,no_run
606    /// use openai_tools::chat::request::ChatCompletion;
607    /// use openai_tools::common::message::Message;
608    /// use openai_tools::common::role::Role;
609    ///
610    /// # #[tokio::main]
611    /// # async fn main() -> Result<(), Box<dyn std::error::Error>>
612    /// # {
613    /// let mut chat = ChatCompletion::new();
614    /// let messages = vec![Message::from_string(Role::User, "Hello!")];
615    ///
616    /// let response = chat
617    ///     .model_id("gpt-4o-mini")
618    ///     .messages(messages)
619    ///     .temperature(1.0)
620    ///     .chat()
621    ///     .await?;
622    ///     
623    /// println!("{}", response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap());
624    /// # Ok::<(), Box<dyn std::error::Error>>(())
625    /// # }
626    /// ```
627    pub async fn chat(&mut self) -> Result<Response> {
628        // Check if the API key is set & body is built.
629        if self.api_key.is_empty() {
630            return Err(OpenAIToolError::Error("API key is not set.".into()));
631        }
632        if self.request_body.model.is_empty() {
633            return Err(OpenAIToolError::Error("Model ID is not set.".into()));
634        }
635        if self.request_body.messages.is_empty() {
636            return Err(OpenAIToolError::Error("Messages are not set.".into()));
637        }
638
639        let body = serde_json::to_string(&self.request_body)?;
640        let url = "https://api.openai.com/v1/chat/completions";
641
642        let client = request::Client::new();
643        let mut header = request::header::HeaderMap::new();
644        header.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
645        header.insert("Authorization", request::header::HeaderValue::from_str(&format!("Bearer {}", self.api_key)).unwrap());
646        header.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
647
648        if cfg!(debug_assertions) {
649            // Replace API key with a placeholder in debug mode
650            let body_for_debug = serde_json::to_string_pretty(&self.request_body).unwrap().replace(&self.api_key, "*************");
651            tracing::info!("Request body: {}", body_for_debug);
652        }
653
654        let response = client.post(url).headers(header).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
655        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
656
657        if cfg!(debug_assertions) {
658            tracing::info!("Response content: {}", content);
659        }
660
661        serde_json::from_str::<Response>(&content).map_err(OpenAIToolError::SerdeJsonError)
662    }
663}