openai_tools/chat/
request.rs

1//! OpenAI Chat Completions API Request Module
2//!
3//! This module provides the functionality to build and send requests to the OpenAI Chat Completions API.
4//! It offers a builder pattern for constructing requests with various parameters and options,
5//! making it easy to interact with OpenAI's conversational AI models.
6//!
7//! # Key Features
8//!
9//! - **Builder Pattern**: Fluent API for constructing requests
10//! - **Structured Output**: Support for JSON schema-based responses
11//! - **Function Calling**: Tool integration for extended model capabilities
12//! - **Comprehensive Parameters**: Full support for all OpenAI API parameters
13//! - **Error Handling**: Robust error management and validation
14//!
15//! # Quick Start
16//!
17//! ```rust,no_run
18//! use openai_tools::chat::request::ChatCompletion;
19//! use openai_tools::common::message::Message;
20//! use openai_tools::common::role::Role;
21//!
22//! #[tokio::main]
23//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
24//!     // Initialize the chat completion client
25//!     let mut chat = ChatCompletion::new();
26//!     
27//!     // Create a simple conversation
28//!     let messages = vec![
29//!         Message::from_string(Role::User, "Hello! How are you?")
30//!     ];
31//!
32//!     // Send the request and get a response
33//!     let response = chat
34//!         .model_id("gpt-4o-mini")
35//!         .messages(messages)
36//!         .temperature(0.7)
37//!         .chat()
38//!         .await?;
39//!         
40//!     println!("AI Response: {}",
41//!              response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap());
42//!     Ok(())
43//! }
44//! ```
45//!
46//! # Advanced Usage
47//!
48//! ## Structured Output with JSON Schema
49//!
50//! ```rust,no_run
51//! use openai_tools::chat::request::ChatCompletion;
52//! use openai_tools::common::message::Message;
53//! use openai_tools::common::role::Role;
54//! use openai_tools::common::structured_output::Schema;
55//! use serde::{Deserialize, Serialize};
56//!
57//! #[derive(Serialize, Deserialize)]
58//! struct PersonInfo {
59//!     name: String,
60//!     age: u32,
61//!     occupation: String,
62//! }
63//!
64//! #[tokio::main]
65//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
66//!     let mut chat = ChatCompletion::new();
67//!     
68//!     // Define JSON schema for structured output
69//!     let mut schema = Schema::chat_json_schema("person_info");
70//!     schema.add_property("name", "string", "Person's full name");
71//!     schema.add_property("age", "number", "Person's age in years");
72//!     schema.add_property("occupation", "string", "Person's job or profession");
73//!     
74//!     let messages = vec![
75//!         Message::from_string(Role::User,
76//!             "Extract information about: John Smith, 30 years old, software engineer")
77//!     ];
78//!
79//!     let response = chat
80//!         .model_id("gpt-4o-mini")
81//!         .messages(messages)
82//!         .json_schema(schema)
83//!         .chat()
84//!         .await?;
85//!         
86//!     // Parse structured response
87//!     let person: PersonInfo = serde_json::from_str(
88//!         response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap()
89//!     )?;
90//!     
91//!     println!("Extracted: {} (age: {}, job: {})",
92//!              person.name, person.age, person.occupation);
93//!     Ok(())
94//! }
95//! ```
96//!
97//! ## Function Calling with Tools
98//!
99//! ```rust,no_run
100//! use openai_tools::chat::request::ChatCompletion;
101//! use openai_tools::common::message::Message;
102//! use openai_tools::common::role::Role;
103//! use openai_tools::common::tool::Tool;
104//! use openai_tools::common::parameters::ParameterProperty;
105//!
106//! #[tokio::main]
107//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
108//!     let mut chat = ChatCompletion::new();
109//!     
110//!     // Define a weather checking tool
111//!     let weather_tool = Tool::function(
112//!         "get_weather",
113//!         "Get current weather information for a location",
114//!         vec![
115//!             ("location", ParameterProperty::from_string("The city and country")),
116//!             ("unit", ParameterProperty::from_string("Temperature unit (celsius/fahrenheit)")),
117//!         ],
118//!         false,
119//!     );
120//!     
121//!     let messages = vec![
122//!         Message::from_string(Role::User,
123//!             "What's the weather like in Tokyo today?")
124//!     ];
125//!
126//!     let response = chat
127//!         .model_id("gpt-4o-mini")
128//!         .messages(messages)
129//!         .tools(vec![weather_tool])
130//!         .temperature(0.1)
131//!         .chat()
132//!         .await?;
133//!         
134//!     // Handle tool calls
135//!     if let Some(tool_calls) = &response.choices[0].message.tool_calls {
136//!         for call in tool_calls {
137//!             println!("Tool called: {}", call.function.name);
138//!             if let Ok(args) = call.function.arguments_as_map() {
139//!                 println!("Arguments: {:?}", args);
140//!             }
141//!             // Execute the function and continue the conversation...
142//!         }
143//!     }
144//!     Ok(())
145//! }
146//! ```
147//!
148//! # Environment Setup
149//!
150//! Before using this module, ensure you have set up your OpenAI API key:
151//!
152//! ```bash
153//! export OPENAI_API_KEY="your-api-key-here"
154//! ```
155//!
156//! Or create a `.env` file in your project root:
157//!
158//! ```text
159//! OPENAI_API_KEY=your-api-key-here
160//! ```
161//!
162//!
163//! # Error Handling
164//!
165//! All methods return a `Result` type for proper error handling:
166//!
167//! ```rust,no_run
168//! use openai_tools::chat::request::ChatCompletion;
169//! use openai_tools::common::errors::OpenAIToolError;
170//!
171//! #[tokio::main]
172//! async fn main() {
173//!     let mut chat = ChatCompletion::new();
174//!     
175//!     match chat.model_id("gpt-4o-mini").chat().await {
176//!         Ok(response) => {
177//!             if let Some(content) = &response.choices[0].message.content {
178//!                 if let Some(text) = &content.text {
179//!                     println!("Success: {}", text);
180//!                 }
181//!             }
182//!         }
183//!         Err(OpenAIToolError::RequestError(e)) => {
184//!             eprintln!("Network error: {}", e);
185//!         }
186//!         Err(OpenAIToolError::SerdeJsonError(e)) => {
187//!             eprintln!("JSON parsing error: {}", e);
188//!         }
189//!         Err(e) => {
190//!             eprintln!("Other error: {}", e);
191//!         }
192//!     }
193//! }
194//! ```
195
196use crate::chat::response::Response;
197use crate::common::{
198    errors::{OpenAIToolError, Result},
199    message::Message,
200    structured_output::Schema,
201    tool::Tool,
202};
203use core::str;
204use dotenvy::dotenv;
205use serde::{Deserialize, Serialize};
206use std::collections::HashMap;
207use std::env;
208
209/// Response format structure for OpenAI API requests
210///
211/// This structure is used for structured output when JSON schema is specified.
212#[derive(Debug, Clone, Deserialize, Serialize)]
213struct Format {
214    #[serde(rename = "type")]
215    type_name: String,
216    json_schema: Schema,
217}
218
219impl Format {
220    /// Creates a new Format structure
221    ///
222    /// # Arguments
223    ///
224    /// * `type_name` - The type name for the response format
225    /// * `json_schema` - The JSON schema definition
226    ///
227    /// # Returns
228    ///
229    /// A new Format structure instance
230    pub fn new<T: AsRef<str>>(type_name: T, json_schema: Schema) -> Self {
231        Self { type_name: type_name.as_ref().to_string(), json_schema }
232    }
233}
234
235/// Request body structure for OpenAI Chat Completions API
236///
237/// This structure represents the parameters that will be sent in the request body
238/// to the OpenAI API. Each field corresponds to the API specification.
239#[derive(Debug, Clone, Deserialize, Serialize, Default)]
240struct Body {
241    model: String,
242    messages: Vec<Message>,
243    /// Whether to store the request and response at OpenAI
244    #[serde(skip_serializing_if = "Option::is_none")]
245    store: Option<bool>,
246    /// Frequency penalty parameter to reduce repetition (-2.0 to 2.0)
247    #[serde(skip_serializing_if = "Option::is_none")]
248    frequency_penalty: Option<f32>,
249    /// Logit bias to adjust the probability of specific tokens
250    #[serde(skip_serializing_if = "Option::is_none")]
251    logit_bias: Option<HashMap<String, i32>>,
252    /// Whether to include probability information for each token
253    #[serde(skip_serializing_if = "Option::is_none")]
254    logprobs: Option<bool>,
255    /// Number of top probabilities to return for each token (0-20)
256    #[serde(skip_serializing_if = "Option::is_none")]
257    top_logprobs: Option<u8>,
258    /// Maximum number of tokens to generate
259    #[serde(skip_serializing_if = "Option::is_none")]
260    max_completion_tokens: Option<u64>,
261    /// Number of responses to generate
262    #[serde(skip_serializing_if = "Option::is_none")]
263    n: Option<u32>,
264    /// Available modalities for the response (e.g., text, audio)
265    #[serde(skip_serializing_if = "Option::is_none")]
266    modalities: Option<Vec<String>>,
267    /// Presence penalty to encourage new topics (-2.0 to 2.0)
268    #[serde(skip_serializing_if = "Option::is_none")]
269    presence_penalty: Option<f32>,
270    /// Temperature parameter to control response randomness (0.0 to 2.0)
271    #[serde(skip_serializing_if = "Option::is_none")]
272    temperature: Option<f32>,
273    /// Response format specification (e.g., JSON schema)
274    #[serde(skip_serializing_if = "Option::is_none")]
275    response_format: Option<Format>,
276    /// Optional tools that can be used by the model
277    #[serde(skip_serializing_if = "Option::is_none")]
278    tools: Option<Vec<Tool>>,
279}
280
281/// OpenAI Chat Completions API client
282///
283/// This structure manages interactions with the OpenAI Chat Completions API.
284/// It handles API key management, request parameter configuration, and API calls.
285///
286/// # Example
287///
288/// ```rust
289/// use openai_tools::chat::request::ChatCompletion;
290/// use openai_tools::common::message::Message;
291/// use openai_tools::common::role::Role;
292///
293/// # #[tokio::main]
294/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
295/// let mut chat = ChatCompletion::new();
296/// let messages = vec![Message::from_string(Role::User, "Hello!")];
297///
298/// let response = chat
299///     .model_id("gpt-4o-mini")
300///     .messages(messages)
301///     .temperature(1.0)
302///     .chat()
303///     .await?;
304/// # Ok::<(), Box<dyn std::error::Error>>(())
305/// # }
306/// ```
307#[derive(Debug, Clone, Default, Deserialize, Serialize)]
308pub struct ChatCompletion {
309    api_key: String,
310    request_body: Body,
311}
312
313impl ChatCompletion {
314    /// Creates a new ChatCompletion instance
315    ///
316    /// Loads the API key from the `OPENAI_API_KEY` environment variable.
317    /// If a `.env` file exists, it will also be loaded.
318    ///
319    /// # Panics
320    ///
321    /// Panics if the `OPENAI_API_KEY` environment variable is not set.
322    ///
323    /// # Returns
324    ///
325    /// A new ChatCompletion instance
326    pub fn new() -> Self {
327        dotenv().ok();
328        let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set.");
329        Self { api_key, request_body: Body::default() }
330    }
331
332    /// Sets the model ID to use
333    ///
334    /// # Arguments
335    ///
336    /// * `model_id` - OpenAI model ID (e.g., `gpt-4o-mini`, `gpt-4o`)
337    ///
338    /// # Returns
339    ///
340    /// A mutable reference to self for method chaining
341    pub fn model_id<T: AsRef<str>>(&mut self, model_id: T) -> &mut Self {
342        self.request_body.model = model_id.as_ref().to_string();
343        self
344    }
345
346    /// Sets the chat message history
347    ///
348    /// # Arguments
349    ///
350    /// * `messages` - Vector of chat messages representing the conversation history
351    ///
352    /// # Returns
353    ///
354    /// A mutable reference to self for method chaining
355    pub fn messages(&mut self, messages: Vec<Message>) -> &mut Self {
356        self.request_body.messages = messages;
357        self
358    }
359
360    /// Adds a single message to the conversation history
361    ///
362    /// This method appends a new message to the existing conversation history.
363    /// It's useful for building conversations incrementally.
364    ///
365    /// # Arguments
366    ///
367    /// * `message` - The message to add to the conversation
368    ///
369    /// # Returns
370    ///
371    /// A mutable reference to self for method chaining
372    ///
373    /// # Examples
374    ///
375    /// ```rust,no_run
376    /// use openai_tools::chat::request::ChatCompletion;
377    /// use openai_tools::common::message::Message;
378    /// use openai_tools::common::role::Role;
379    ///
380    /// let mut chat = ChatCompletion::new();
381    /// chat.add_message(Message::from_string(Role::User, "Hello!"))
382    ///     .add_message(Message::from_string(Role::Assistant, "Hi there!"))
383    ///     .add_message(Message::from_string(Role::User, "How are you?"));
384    /// ```
385    pub fn add_message(&mut self, message: Message) -> &mut Self {
386        self.request_body.messages.push(message);
387        self
388    }
389    /// Sets whether to store the request and response at OpenAI
390    ///
391    /// # Arguments
392    ///
393    /// * `store` - `true` to store, `false` to not store
394    ///
395    /// # Returns
396    ///
397    /// A mutable reference to self for method chaining
398    pub fn store(&mut self, store: bool) -> &mut Self {
399        self.request_body.store = Option::from(store);
400        self
401    }
402
403    /// Sets the frequency penalty
404    ///
405    /// A parameter that penalizes based on word frequency to reduce repetition.
406    /// Positive values decrease repetition, negative values increase it.
407    ///
408    /// # Arguments
409    ///
410    /// * `frequency_penalty` - Frequency penalty value (range: -2.0 to 2.0)
411    ///
412    /// # Returns
413    ///
414    /// A mutable reference to self for method chaining
415    pub fn frequency_penalty(&mut self, frequency_penalty: f32) -> &mut Self {
416        self.request_body.frequency_penalty = Option::from(frequency_penalty);
417        self
418    }
419
420    /// Sets logit bias to adjust the probability of specific tokens
421    ///
422    /// # Arguments
423    ///
424    /// * `logit_bias` - A map of token IDs to adjustment values
425    ///
426    /// # Returns
427    ///
428    /// A mutable reference to self for method chaining
429    pub fn logit_bias<T: AsRef<str>>(&mut self, logit_bias: HashMap<T, i32>) -> &mut Self {
430        self.request_body.logit_bias =
431            Option::from(logit_bias.into_iter().map(|(k, v)| (k.as_ref().to_string(), v)).collect::<HashMap<String, i32>>());
432        self
433    }
434
435    /// Sets whether to include probability information for each token
436    ///
437    /// # Arguments
438    ///
439    /// * `logprobs` - `true` to include probability information
440    ///
441    /// # Returns
442    ///
443    /// A mutable reference to self for method chaining
444    pub fn logprobs(&mut self, logprobs: bool) -> &mut Self {
445        self.request_body.logprobs = Option::from(logprobs);
446        self
447    }
448
449    /// Sets the number of top probabilities to return for each token
450    ///
451    /// # Arguments
452    ///
453    /// * `top_logprobs` - Number of top probabilities (range: 0-20)
454    ///
455    /// # Returns
456    ///
457    /// A mutable reference to self for method chaining
458    pub fn top_logprobs(&mut self, top_logprobs: u8) -> &mut Self {
459        self.request_body.top_logprobs = Option::from(top_logprobs);
460        self
461    }
462
463    /// Sets the maximum number of tokens to generate
464    ///
465    /// # Arguments
466    ///
467    /// * `max_completion_tokens` - Maximum number of tokens
468    ///
469    /// # Returns
470    ///
471    /// A mutable reference to self for method chaining
472    pub fn max_completion_tokens(&mut self, max_completion_tokens: u64) -> &mut Self {
473        self.request_body.max_completion_tokens = Option::from(max_completion_tokens);
474        self
475    }
476
477    /// Sets the number of responses to generate
478    ///
479    /// # Arguments
480    ///
481    /// * `n` - Number of responses to generate
482    ///
483    /// # Returns
484    ///
485    /// A mutable reference to self for method chaining
486    pub fn n(&mut self, n: u32) -> &mut Self {
487        self.request_body.n = Option::from(n);
488        self
489    }
490
491    /// Sets the available modalities for the response
492    ///
493    /// # Arguments
494    ///
495    /// * `modalities` - List of modalities (e.g., `["text", "audio"]`)
496    ///
497    /// # Returns
498    ///
499    /// A mutable reference to self for method chaining
500    pub fn modalities<T: AsRef<str>>(&mut self, modalities: Vec<T>) -> &mut Self {
501        self.request_body.modalities = Option::from(modalities.into_iter().map(|m| m.as_ref().to_string()).collect::<Vec<String>>());
502        self
503    }
504
505    /// Sets the presence penalty
506    ///
507    /// A parameter that controls the tendency to include new content in the document.
508    /// Positive values encourage talking about new topics, negative values encourage
509    /// staying on existing topics.
510    ///
511    /// # Arguments
512    ///
513    /// * `presence_penalty` - Presence penalty value (range: -2.0 to 2.0)
514    ///
515    /// # Returns
516    ///
517    /// A mutable reference to self for method chaining
518    pub fn presence_penalty(&mut self, presence_penalty: f32) -> &mut Self {
519        self.request_body.presence_penalty = Option::from(presence_penalty);
520        self
521    }
522
523    /// Sets the temperature parameter to control response randomness
524    ///
525    /// Higher values (e.g., 1.0) produce more creative and diverse outputs,
526    /// while lower values (e.g., 0.2) produce more deterministic and consistent outputs.
527    ///
528    /// # Arguments
529    ///
530    /// * `temperature` - Temperature parameter (range: 0.0 to 2.0)
531    ///
532    /// # Returns
533    ///
534    /// A mutable reference to self for method chaining
535    pub fn temperature(&mut self, temperature: f32) -> &mut Self {
536        self.request_body.temperature = Option::from(temperature);
537        self
538    }
539
540    /// Sets structured output using JSON schema
541    ///
542    /// Enables receiving responses in a structured JSON format according to the
543    /// specified JSON schema.
544    ///
545    /// # Arguments
546    ///
547    /// * `json_schema` - JSON schema defining the response structure
548    ///
549    /// # Returns
550    ///
551    /// A mutable reference to self for method chaining
552    pub fn json_schema(&mut self, json_schema: Schema) -> &mut Self {
553        self.request_body.response_format = Option::from(Format::new(String::from("json_schema"), json_schema));
554        self
555    }
556
557    /// Sets the tools that can be called by the model
558    ///
559    /// Enables function calling by providing a list of tools that the model can choose to call.
560    /// When tools are provided, the model may generate tool calls instead of or in addition to
561    /// regular text responses.
562    ///
563    /// # Arguments
564    ///
565    /// * `tools` - Vector of tools available for the model to use
566    ///
567    /// # Returns
568    ///
569    /// A mutable reference to self for method chaining
570    pub fn tools(&mut self, tools: Vec<Tool>) -> &mut Self {
571        self.request_body.tools = Option::from(tools);
572        self
573    }
574
575    /// Gets the current message history
576    ///
577    /// # Returns
578    ///
579    /// A vector containing the message history
580    pub fn get_message_history(&self) -> Vec<Message> {
581        self.request_body.messages.clone()
582    }
583
584    /// Sends the chat completion request to OpenAI API
585    ///
586    /// This method validates the request parameters, constructs the HTTP request,
587    /// and sends it to the OpenAI Chat Completions endpoint.
588    ///
589    /// # Returns
590    ///
591    /// A `Result` containing the API response on success, or an error on failure.
592    ///
593    /// # Errors
594    ///
595    /// Returns an error if:
596    /// - API key is not set
597    /// - Model ID is not set
598    /// - Messages are empty
599    /// - Network request fails
600    /// - Response parsing fails
601    ///
602    /// # Example
603    ///
604    /// ```rust,no_run
605    /// use openai_tools::chat::request::ChatCompletion;
606    /// use openai_tools::common::message::Message;
607    /// use openai_tools::common::role::Role;
608    ///
609    /// # #[tokio::main]
610    /// # async fn main() -> Result<(), Box<dyn std::error::Error>>
611    /// # {
612    /// let mut chat = ChatCompletion::new();
613    /// let messages = vec![Message::from_string(Role::User, "Hello!")];
614    ///
615    /// let response = chat
616    ///     .model_id("gpt-4o-mini")
617    ///     .messages(messages)
618    ///     .temperature(1.0)
619    ///     .chat()
620    ///     .await?;
621    ///     
622    /// println!("{}", response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap());
623    /// # Ok::<(), Box<dyn std::error::Error>>(())
624    /// # }
625    /// ```
626    pub async fn chat(&mut self) -> Result<Response> {
627        // Check if the API key is set & body is built.
628        if self.api_key.is_empty() {
629            return Err(OpenAIToolError::Error("API key is not set.".into()));
630        }
631        if self.request_body.model.is_empty() {
632            return Err(OpenAIToolError::Error("Model ID is not set.".into()));
633        }
634        if self.request_body.messages.is_empty() {
635            return Err(OpenAIToolError::Error("Messages are not set.".into()));
636        }
637
638        let body = serde_json::to_string(&self.request_body)?;
639        let url = "https://api.openai.com/v1/chat/completions";
640
641        let client = request::Client::new();
642        let mut header = request::header::HeaderMap::new();
643        header.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
644        header.insert("Authorization", request::header::HeaderValue::from_str(&format!("Bearer {}", self.api_key)).unwrap());
645        header.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust/0.1.0"));
646
647        if cfg!(debug_assertions) {
648            // Replace API key with a placeholder in debug mode
649            let body_for_debug = serde_json::to_string_pretty(&self.request_body).unwrap().replace(&self.api_key, "*************");
650            tracing::info!("Request body: {}", body_for_debug);
651        }
652
653        let response = client.post(url).headers(header).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
654        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
655
656        if cfg!(debug_assertions) {
657            tracing::info!("Response content: {}", content);
658        }
659
660        serde_json::from_str::<Response>(&content).map_err(OpenAIToolError::SerdeJsonError)
661    }
662}