openai_tools/chat/request.rs
1//! OpenAI Chat Completions API Request Module
2//!
3//! This module provides the functionality to build and send requests to the OpenAI Chat Completions API.
4//! It offers a builder pattern for constructing requests with various parameters and options,
5//! making it easy to interact with OpenAI's conversational AI models.
6//!
7//! # Key Features
8//!
9//! - **Builder Pattern**: Fluent API for constructing requests
10//! - **Structured Output**: Support for JSON schema-based responses
11//! - **Function Calling**: Tool integration for extended model capabilities
12//! - **Comprehensive Parameters**: Full support for all OpenAI API parameters
13//! - **Error Handling**: Robust error management and validation
14//!
15//! # Quick Start
16//!
17//! ```rust,no_run
18//! use openai_tools::chat::request::ChatCompletion;
19//! use openai_tools::common::message::Message;
20//! use openai_tools::common::role::Role;
21//!
22//! #[tokio::main]
23//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
24//! // Initialize the chat completion client
25//! let mut chat = ChatCompletion::new();
26//!
27//! // Create a simple conversation
28//! let messages = vec![
29//! Message::from_string(Role::User, "Hello! How are you?")
30//! ];
31//!
32//! // Send the request and get a response
33//! let response = chat
34//! .model_id("gpt-4o-mini")
35//! .messages(messages)
36//! .temperature(0.7)
37//! .chat()
38//! .await?;
39//!
40//! println!("AI Response: {}",
41//! response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap());
42//! Ok(())
43//! }
44//! ```
45//!
46//! # Advanced Usage
47//!
48//! ## Structured Output with JSON Schema
49//!
50//! ```rust,no_run
51//! use openai_tools::chat::request::ChatCompletion;
52//! use openai_tools::common::message::Message;
53//! use openai_tools::common::role::Role;
54//! use openai_tools::common::structured_output::Schema;
55//! use serde::{Deserialize, Serialize};
56//!
57//! #[derive(Serialize, Deserialize)]
58//! struct PersonInfo {
59//! name: String,
60//! age: u32,
61//! occupation: String,
62//! }
63//!
64//! #[tokio::main]
65//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
66//! let mut chat = ChatCompletion::new();
67//!
68//! // Define JSON schema for structured output
69//! let mut schema = Schema::chat_json_schema("person_info");
70//! schema.add_property("name", "string", "Person's full name");
71//! schema.add_property("age", "number", "Person's age in years");
72//! schema.add_property("occupation", "string", "Person's job or profession");
73//!
74//! let messages = vec![
75//! Message::from_string(Role::User,
76//! "Extract information about: John Smith, 30 years old, software engineer")
77//! ];
78//!
79//! let response = chat
80//! .model_id("gpt-4o-mini")
81//! .messages(messages)
82//! .json_schema(schema)
83//! .chat()
84//! .await?;
85//!
86//! // Parse structured response
87//! let person: PersonInfo = serde_json::from_str(
88//! response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap()
89//! )?;
90//!
91//! println!("Extracted: {} (age: {}, job: {})",
92//! person.name, person.age, person.occupation);
93//! Ok(())
94//! }
95//! ```
96//!
97//! ## Function Calling with Tools
98//!
99//! ```rust,no_run
100//! use openai_tools::chat::request::ChatCompletion;
101//! use openai_tools::common::message::Message;
102//! use openai_tools::common::role::Role;
103//! use openai_tools::common::tool::Tool;
104//! use openai_tools::common::parameters::ParameterProp;
105//!
106//! #[tokio::main]
107//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
108//! let mut chat = ChatCompletion::new();
109//!
110//! // Define a weather checking tool
111//! let weather_tool = Tool::function(
112//! "get_weather",
113//! "Get current weather information for a location",
114//! vec![
115//! ("location", ParameterProp::string("The city and country")),
116//! ("unit", ParameterProp::string("Temperature unit (celsius/fahrenheit)")),
117//! ],
118//! false,
119//! );
120//!
121//! let messages = vec![
122//! Message::from_string(Role::User,
123//! "What's the weather like in Tokyo today?")
124//! ];
125//!
126//! let response = chat
127//! .model_id("gpt-4o-mini")
128//! .messages(messages)
129//! .tools(vec![weather_tool])
130//! .temperature(0.1)
131//! .chat()
132//! .await?;
133//!
134//! // Handle tool calls
135//! if let Some(tool_calls) = &response.choices[0].message.tool_calls {
136//! for call in tool_calls {
137//! println!("Tool called: {}", call.function.name);
138//! if let Ok(args) = call.function.arguments_as_map() {
139//! println!("Arguments: {:?}", args);
140//! }
141//! // Execute the function and continue the conversation...
142//! }
143//! }
144//! Ok(())
145//! }
146//! ```
147//!
148//! # Environment Setup
149//!
150//! Before using this module, ensure you have set up your OpenAI API key:
151//!
152//! ```bash
153//! export OPENAI_API_KEY="your-api-key-here"
154//! ```
155//!
156//! Or create a `.env` file in your project root:
157//!
158//! ```text
159//! OPENAI_API_KEY=your-api-key-here
160//! ```
161//!
162//!
163//! # Error Handling
164//!
165//! All methods return a `Result` type for proper error handling:
166//!
167//! ```rust,no_run
168//! use openai_tools::chat::request::ChatCompletion;
169//! use openai_tools::common::errors::OpenAIToolError;
170//!
171//! #[tokio::main]
172//! async fn main() {
173//! let mut chat = ChatCompletion::new();
174//!
175//! match chat.model_id("gpt-4o-mini").chat().await {
176//! Ok(response) => {
177//! if let Some(content) = &response.choices[0].message.content {
178//! if let Some(text) = &content.text {
179//! println!("Success: {}", text);
180//! }
181//! }
182//! }
183//! Err(OpenAIToolError::RequestError(e)) => {
184//! eprintln!("Network error: {}", e);
185//! }
186//! Err(OpenAIToolError::SerdeJsonError(e)) => {
187//! eprintln!("JSON parsing error: {}", e);
188//! }
189//! Err(e) => {
190//! eprintln!("Other error: {}", e);
191//! }
192//! }
193//! }
194//! ```
195
196use crate::chat::response::Response;
197use crate::common::{
198 errors::{OpenAIToolError, Result},
199 message::Message,
200 structured_output::Schema,
201 tool::Tool,
202};
203use core::str;
204use dotenvy::dotenv;
205use serde::{Deserialize, Serialize};
206use std::collections::HashMap;
207use std::env;
208
209/// Response format structure for OpenAI API requests
210///
211/// This structure is used for structured output when JSON schema is specified.
212#[derive(Debug, Clone, Deserialize, Serialize)]
213struct Format {
214 #[serde(rename = "type")]
215 type_name: String,
216 json_schema: Schema,
217}
218
219impl Format {
220 /// Creates a new Format structure
221 ///
222 /// # Arguments
223 ///
224 /// * `type_name` - The type name for the response format
225 /// * `json_schema` - The JSON schema definition
226 ///
227 /// # Returns
228 ///
229 /// A new Format structure instance
230 pub fn new<T: AsRef<str>>(type_name: T, json_schema: Schema) -> Self {
231 Self { type_name: type_name.as_ref().to_string(), json_schema }
232 }
233}
234
235/// Request body structure for OpenAI Chat Completions API
236///
237/// This structure represents the parameters that will be sent in the request body
238/// to the OpenAI API. Each field corresponds to the API specification.
239#[derive(Debug, Clone, Deserialize, Serialize, Default)]
240struct Body {
241 model: String,
242 messages: Vec<Message>,
243 /// Whether to store the request and response at OpenAI
244 #[serde(skip_serializing_if = "Option::is_none")]
245 store: Option<bool>,
246 /// Frequency penalty parameter to reduce repetition (-2.0 to 2.0)
247 #[serde(skip_serializing_if = "Option::is_none")]
248 frequency_penalty: Option<f32>,
249 /// Logit bias to adjust the probability of specific tokens
250 #[serde(skip_serializing_if = "Option::is_none")]
251 logit_bias: Option<HashMap<String, i32>>,
252 /// Whether to include probability information for each token
253 #[serde(skip_serializing_if = "Option::is_none")]
254 logprobs: Option<bool>,
255 /// Number of top probabilities to return for each token (0-20)
256 #[serde(skip_serializing_if = "Option::is_none")]
257 top_logprobs: Option<u8>,
258 /// Maximum number of tokens to generate
259 #[serde(skip_serializing_if = "Option::is_none")]
260 max_completion_tokens: Option<u64>,
261 /// Number of responses to generate
262 #[serde(skip_serializing_if = "Option::is_none")]
263 n: Option<u32>,
264 /// Available modalities for the response (e.g., text, audio)
265 #[serde(skip_serializing_if = "Option::is_none")]
266 modalities: Option<Vec<String>>,
267 /// Presence penalty to encourage new topics (-2.0 to 2.0)
268 #[serde(skip_serializing_if = "Option::is_none")]
269 presence_penalty: Option<f32>,
270 /// Temperature parameter to control response randomness (0.0 to 2.0)
271 #[serde(skip_serializing_if = "Option::is_none")]
272 temperature: Option<f32>,
273 /// Response format specification (e.g., JSON schema)
274 #[serde(skip_serializing_if = "Option::is_none")]
275 response_format: Option<Format>,
276 /// Optional tools that can be used by the model
277 #[serde(skip_serializing_if = "Option::is_none")]
278 tools: Option<Vec<Tool>>,
279}
280
281/// OpenAI Chat Completions API client
282///
283/// This structure manages interactions with the OpenAI Chat Completions API.
284/// It handles API key management, request parameter configuration, and API calls.
285///
286/// # Example
287///
288/// ```rust
289/// use openai_tools::chat::request::ChatCompletion;
290/// use openai_tools::common::message::Message;
291/// use openai_tools::common::role::Role;
292///
293/// # #[tokio::main]
294/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
295/// let mut chat = ChatCompletion::new();
296/// let messages = vec![Message::from_string(Role::User, "Hello!")];
297///
298/// let response = chat
299/// .model_id("gpt-4o-mini")
300/// .messages(messages)
301/// .temperature(1.0)
302/// .chat()
303/// .await?;
304/// # Ok::<(), Box<dyn std::error::Error>>(())
305/// # }
306/// ```
307#[derive(Debug, Clone, Default, Deserialize, Serialize)]
308pub struct ChatCompletion {
309 api_key: String,
310 request_body: Body,
311}
312
313impl ChatCompletion {
314 /// Creates a new ChatCompletion instance
315 ///
316 /// Loads the API key from the `OPENAI_API_KEY` environment variable.
317 /// If a `.env` file exists, it will also be loaded.
318 ///
319 /// # Panics
320 ///
321 /// Panics if the `OPENAI_API_KEY` environment variable is not set.
322 ///
323 /// # Returns
324 ///
325 /// A new ChatCompletion instance
326 pub fn new() -> Self {
327 dotenv().ok();
328 let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set.");
329 Self { api_key, request_body: Body::default() }
330 }
331
332 /// Sets the model ID to use
333 ///
334 /// # Arguments
335 ///
336 /// * `model_id` - OpenAI model ID (e.g., `gpt-4o-mini`, `gpt-4o`)
337 ///
338 /// # Returns
339 ///
340 /// A mutable reference to self for method chaining
341 pub fn model_id<T: AsRef<str>>(&mut self, model_id: T) -> &mut Self {
342 self.request_body.model = model_id.as_ref().to_string();
343 self
344 }
345
346 /// Sets the chat message history
347 ///
348 /// # Arguments
349 ///
350 /// * `messages` - Vector of chat messages representing the conversation history
351 ///
352 /// # Returns
353 ///
354 /// A mutable reference to self for method chaining
355 pub fn messages(&mut self, messages: Vec<Message>) -> &mut Self {
356 self.request_body.messages = messages;
357 self
358 }
359
360 /// Sets whether to store the request and response at OpenAI
361 ///
362 /// # Arguments
363 ///
364 /// * `store` - `true` to store, `false` to not store
365 ///
366 /// # Returns
367 ///
368 /// A mutable reference to self for method chaining
369 pub fn store(&mut self, store: bool) -> &mut Self {
370 self.request_body.store = Option::from(store);
371 self
372 }
373
374 /// Sets the frequency penalty
375 ///
376 /// A parameter that penalizes based on word frequency to reduce repetition.
377 /// Positive values decrease repetition, negative values increase it.
378 ///
379 /// # Arguments
380 ///
381 /// * `frequency_penalty` - Frequency penalty value (range: -2.0 to 2.0)
382 ///
383 /// # Returns
384 ///
385 /// A mutable reference to self for method chaining
386 pub fn frequency_penalty(&mut self, frequency_penalty: f32) -> &mut Self {
387 self.request_body.frequency_penalty = Option::from(frequency_penalty);
388 self
389 }
390
391 /// Sets logit bias to adjust the probability of specific tokens
392 ///
393 /// # Arguments
394 ///
395 /// * `logit_bias` - A map of token IDs to adjustment values
396 ///
397 /// # Returns
398 ///
399 /// A mutable reference to self for method chaining
400 pub fn logit_bias<T: AsRef<str>>(&mut self, logit_bias: HashMap<T, i32>) -> &mut Self {
401 self.request_body.logit_bias =
402 Option::from(logit_bias.into_iter().map(|(k, v)| (k.as_ref().to_string(), v)).collect::<HashMap<String, i32>>());
403 self
404 }
405
406 /// Sets whether to include probability information for each token
407 ///
408 /// # Arguments
409 ///
410 /// * `logprobs` - `true` to include probability information
411 ///
412 /// # Returns
413 ///
414 /// A mutable reference to self for method chaining
415 pub fn logprobs(&mut self, logprobs: bool) -> &mut Self {
416 self.request_body.logprobs = Option::from(logprobs);
417 self
418 }
419
420 /// Sets the number of top probabilities to return for each token
421 ///
422 /// # Arguments
423 ///
424 /// * `top_logprobs` - Number of top probabilities (range: 0-20)
425 ///
426 /// # Returns
427 ///
428 /// A mutable reference to self for method chaining
429 pub fn top_logprobs(&mut self, top_logprobs: u8) -> &mut Self {
430 self.request_body.top_logprobs = Option::from(top_logprobs);
431 self
432 }
433
434 /// Sets the maximum number of tokens to generate
435 ///
436 /// # Arguments
437 ///
438 /// * `max_completion_tokens` - Maximum number of tokens
439 ///
440 /// # Returns
441 ///
442 /// A mutable reference to self for method chaining
443 pub fn max_completion_tokens(&mut self, max_completion_tokens: u64) -> &mut Self {
444 self.request_body.max_completion_tokens = Option::from(max_completion_tokens);
445 self
446 }
447
448 /// Sets the number of responses to generate
449 ///
450 /// # Arguments
451 ///
452 /// * `n` - Number of responses to generate
453 ///
454 /// # Returns
455 ///
456 /// A mutable reference to self for method chaining
457 pub fn n(&mut self, n: u32) -> &mut Self {
458 self.request_body.n = Option::from(n);
459 self
460 }
461
462 /// Sets the available modalities for the response
463 ///
464 /// # Arguments
465 ///
466 /// * `modalities` - List of modalities (e.g., `["text", "audio"]`)
467 ///
468 /// # Returns
469 ///
470 /// A mutable reference to self for method chaining
471 pub fn modalities<T: AsRef<str>>(&mut self, modalities: Vec<T>) -> &mut Self {
472 self.request_body.modalities = Option::from(modalities.into_iter().map(|m| m.as_ref().to_string()).collect::<Vec<String>>());
473 self
474 }
475
476 /// Sets the presence penalty
477 ///
478 /// A parameter that controls the tendency to include new content in the document.
479 /// Positive values encourage talking about new topics, negative values encourage
480 /// staying on existing topics.
481 ///
482 /// # Arguments
483 ///
484 /// * `presence_penalty` - Presence penalty value (range: -2.0 to 2.0)
485 ///
486 /// # Returns
487 ///
488 /// A mutable reference to self for method chaining
489 pub fn presence_penalty(&mut self, presence_penalty: f32) -> &mut Self {
490 self.request_body.presence_penalty = Option::from(presence_penalty);
491 self
492 }
493
494 /// Sets the temperature parameter to control response randomness
495 ///
496 /// Higher values (e.g., 1.0) produce more creative and diverse outputs,
497 /// while lower values (e.g., 0.2) produce more deterministic and consistent outputs.
498 ///
499 /// # Arguments
500 ///
501 /// * `temperature` - Temperature parameter (range: 0.0 to 2.0)
502 ///
503 /// # Returns
504 ///
505 /// A mutable reference to self for method chaining
506 pub fn temperature(&mut self, temperature: f32) -> &mut Self {
507 self.request_body.temperature = Option::from(temperature);
508 self
509 }
510
511 /// Sets structured output using JSON schema
512 ///
513 /// Enables receiving responses in a structured JSON format according to the
514 /// specified JSON schema.
515 ///
516 /// # Arguments
517 ///
518 /// * `json_schema` - JSON schema defining the response structure
519 ///
520 /// # Returns
521 ///
522 /// A mutable reference to self for method chaining
523 pub fn json_schema(&mut self, json_schema: Schema) -> &mut Self {
524 self.request_body.response_format = Option::from(Format::new(String::from("json_schema"), json_schema));
525 self
526 }
527
528 /// Sets the tools that can be called by the model
529 ///
530 /// Enables function calling by providing a list of tools that the model can choose to call.
531 /// When tools are provided, the model may generate tool calls instead of or in addition to
532 /// regular text responses.
533 ///
534 /// # Arguments
535 ///
536 /// * `tools` - Vector of tools available for the model to use
537 ///
538 /// # Returns
539 ///
540 /// A mutable reference to self for method chaining
541 pub fn tools(&mut self, tools: Vec<Tool>) -> &mut Self {
542 self.request_body.tools = Option::from(tools);
543 self
544 }
545
546 /// Gets the current message history
547 ///
548 /// # Returns
549 ///
550 /// A vector containing the message history
551 pub fn get_message_history(&self) -> Vec<Message> {
552 self.request_body.messages.clone()
553 }
554
555 /// Sends the chat completion request to OpenAI API
556 ///
557 /// This method validates the request parameters, constructs the HTTP request,
558 /// and sends it to the OpenAI Chat Completions endpoint.
559 ///
560 /// # Returns
561 ///
562 /// A `Result` containing the API response on success, or an error on failure.
563 ///
564 /// # Errors
565 ///
566 /// Returns an error if:
567 /// - API key is not set
568 /// - Model ID is not set
569 /// - Messages are empty
570 /// - Network request fails
571 /// - Response parsing fails
572 ///
573 /// # Example
574 ///
575 /// ```rust,no_run
576 /// use openai_tools::chat::request::ChatCompletion;
577 /// use openai_tools::common::message::Message;
578 /// use openai_tools::common::role::Role;
579 ///
580 /// # #[tokio::main]
581 /// # async fn main() -> Result<(), Box<dyn std::error::Error>>
582 /// # {
583 /// let mut chat = ChatCompletion::new();
584 /// let messages = vec![Message::from_string(Role::User, "Hello!")];
585 ///
586 /// let response = chat
587 /// .model_id("gpt-4o-mini")
588 /// .messages(messages)
589 /// .temperature(1.0)
590 /// .chat()
591 /// .await?;
592 ///
593 /// println!("{}", response.choices[0].message.content.as_ref().unwrap().text.as_ref().unwrap());
594 /// # Ok::<(), Box<dyn std::error::Error>>(())
595 /// # }
596 /// ```
597 pub async fn chat(&mut self) -> Result<Response> {
598 // Check if the API key is set & body is built.
599 if self.api_key.is_empty() {
600 return Err(OpenAIToolError::Error("API key is not set.".into()));
601 }
602 if self.request_body.model.is_empty() {
603 return Err(OpenAIToolError::Error("Model ID is not set.".into()));
604 }
605 if self.request_body.messages.is_empty() {
606 return Err(OpenAIToolError::Error("Messages are not set.".into()));
607 }
608
609 let body = serde_json::to_string(&self.request_body)?;
610 let url = "https://api.openai.com/v1/chat/completions";
611
612 let client = request::Client::new();
613 let mut header = request::header::HeaderMap::new();
614 header.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
615 header.insert("Authorization", request::header::HeaderValue::from_str(&format!("Bearer {}", self.api_key)).unwrap());
616 header.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust/0.1.0"));
617
618 if cfg!(debug_assertions) {
619 // Replace API key with a placeholder in debug mode
620 let body_for_debug = serde_json::to_string_pretty(&self.request_body).unwrap().replace(&self.api_key, "*************");
621 tracing::info!("Request body: {}", body_for_debug);
622 }
623
624 let response = client.post(url).headers(header).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
625 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
626
627 if cfg!(debug_assertions) {
628 tracing::info!("Response content: {}", content);
629 }
630
631 serde_json::from_str::<Response>(&content).map_err(OpenAIToolError::SerdeJsonError)
632 }
633}