chat_gpt_lib_rs/api_resources/
chat.rs

1//! This module provides functionality for creating chat-based completions using the
2//! [OpenAI Chat Completions API](https://platform.openai.com/docs/api-reference/chat).
3//!
4//! The Chat API is designed for conversational interactions, where each request includes a list
5//! of messages with a role (system, user, or assistant). The model responds based on the context
6//! established by these messages, allowing for more interactive and context-aware responses
7//! compared to plain completions.
8//!
9//! # Overview
10//!
11//! The core usage involves calling [`create_chat_completion`] with a [`CreateChatCompletionRequest`],
12//! which includes a sequence of [`ChatMessage`] items. Each `ChatMessage` has a `role` and `content`.
13//! The API then returns a [`CreateChatCompletionResponse`] containing one or more
14//! [`ChatCompletionChoice`] objects (depending on the `n` parameter).
15//!
16//! ```rust,no_run
17//! use chat_gpt_lib_rs::api_resources::chat::{create_chat_completion, CreateChatCompletionRequest, ChatMessage, ChatRole};
18//! use chat_gpt_lib_rs::api_resources::models::Model;
19//! use chat_gpt_lib_rs::error::OpenAIError;
20//! use chat_gpt_lib_rs::OpenAIClient;
21//!
22//! #[tokio::main]
23//! async fn main() -> Result<(), OpenAIError> {
24//!     let client = OpenAIClient::new(None)?; // Reads API key from OPENAI_API_KEY
25//!
26//!     let request = CreateChatCompletionRequest {
27//!         model: Model::O1Mini,
28//!         messages: vec![
29//!             ChatMessage {
30//!                 role: ChatRole::System,
31//!                 content: "You are a helpful assistant.".to_string(),
32//!                 name: None,
33//!             },
34//!             ChatMessage {
35//!                 role: ChatRole::User,
36//!                 content: "Write a tagline for an ice cream shop.".to_string(),
37//!                 name: None,
38//!             },
39//!         ],
40//!         max_tokens: Some(50),
41//!         temperature: Some(0.7),
42//!         ..Default::default()
43//!     };
44//!
45//!     let response = create_chat_completion(&client, &request).await?;
46//!
47//!     for choice in &response.choices {
48//!         println!("Assistant: {}", choice.message.content);
49//!     }
50//!
51//!     Ok(())
52//! }
53//! ```
54
55use serde::{Deserialize, Serialize};
56use std::collections::HashMap;
57
58use crate::api::{post_json, post_json_stream};
59use crate::config::OpenAIClient;
60use crate::error::OpenAIError;
61
62use crate::api_resources::models::Model;
63
64/// The role of a message in the chat sequence.
65///
66/// Typically one of `system`, `user`, `assistant`. OpenAI may add or adjust roles in the future.
67#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
68#[serde(rename_all = "lowercase")]
69pub enum ChatRole {
70    /// For system-level instructions (e.g. "You are a helpful assistant.")
71    System,
72    /// For user-supplied messages
73    User,
74    /// For assistant messages (responses from the model)
75    Assistant,
76    /// For tools
77    Tool,
78    /// For function
79    Function,
80    /// Experimental or extended role types, if they become available
81    #[serde(other)]
82    Other,
83}
84
85/// A single message in a chat conversation.
86///
87/// Each message has:
88/// - A [`ChatRole`], indicating who is sending the message (system, user, assistant).
89/// - The message `content`.
90/// - An optional `name` for the user or system, if applicable.
91#[derive(Debug, Serialize, Deserialize, Clone)]
92pub struct ChatMessage {
93    /// The role of the sender (system, user, or assistant).
94    pub role: ChatRole,
95    /// The content of the message.
96    pub content: String,
97    /// The (optional) name of the user or system. This can be used to identify
98    /// the speaker when multiple users or participants exist in a conversation.
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub name: Option<String>,
101}
102
103/// A request struct for creating chat completions with the OpenAI Chat Completions API.
104///
105/// # Fields
106/// - `model`: The ID of the model to use (e.g., "gpt-3.5-turbo").
107/// - `messages`: A list of [`ChatMessage`] items providing the conversation history.
108/// - `stream`: Whether or not to stream responses via server-sent events.
109/// - `max_tokens`, `temperature`, `top_p`, etc.: Parameters controlling the generation.
110/// - `n`: Number of chat completion choices to generate.
111/// - `logit_bias`, `user`: Additional advanced parameters.
112#[derive(Debug, Serialize, Default, Clone)]
113pub struct CreateChatCompletionRequest {
114    /// **Required**. The model used for this chat request.
115    /// Examples: "Model::O1Mini", "Model::Other("gpt-4".to_string)".
116    pub model: Model,
117
118    /// **Required**. The messages that make up the conversation so far.
119    pub messages: Vec<ChatMessage>,
120
121    /// Controls the creativity of the output. 0 is the most deterministic, 2 is the most creative.
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub temperature: Option<f64>,
124
125    /// The nucleus sampling parameter. Like `temperature`, but a value like 0.1 means only
126    /// the top 10% probability mass is considered.
127    #[serde(skip_serializing_if = "Option::is_none")]
128    pub top_p: Option<f64>,
129
130    /// How many chat completion choices to generate for each input message. Defaults to 1.
131    #[serde(skip_serializing_if = "Option::is_none")]
132    pub n: Option<u32>,
133
134    /// If set, partial message deltas are sent as data-only server-sent events (SSE) as they become available.
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub stream: Option<bool>,
137
138    /// The maximum number of tokens allowed for the generated answer. Defaults to the max tokens allowed by the model minus the prompt.
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub max_tokens: Option<u32>,
141
142    /// A map between token (encoded as a string) and an associated bias from -100 to 100
143    /// that adjusts the likelihood of the token appearing.
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub logit_bias: Option<HashMap<String, i32>>,
146
147    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
148    #[serde(skip_serializing_if = "Option::is_none")]
149    pub user: Option<String>,
150}
151
152/// The response returned by the OpenAI Chat Completions API.
153///
154/// Includes one or more chat-based completion choices and any usage statistics.
155#[derive(Debug, Deserialize)]
156pub struct CreateChatCompletionResponse {
157    /// An identifier for this chat completion (e.g., "chatcmpl-xxxxxx").
158    pub id: String,
159    /// The object type, usually "chat.completion".
160    pub object: String,
161    /// The creation time in epoch seconds.
162    pub created: u64,
163    /// The base model used for this request.
164    pub model: String,
165    /// A list of generated chat completion choices.
166    pub choices: Vec<ChatCompletionChoice>,
167    /// Token usage data (optional field).
168    #[serde(default)]
169    pub usage: Option<ChatCompletionUsage>,
170}
171
172/// A single chat completion choice within a [`CreateChatCompletionResponse`].
173#[derive(Debug, Deserialize)]
174pub struct ChatCompletionChoice {
175    /// The index of this choice (useful if `n` > 1).
176    pub index: u32,
177    /// The chat message object containing the role and content.
178    pub message: ChatMessage,
179    /// Why the chat completion ended (e.g., "stop", "length").
180    #[serde(skip_serializing_if = "Option::is_none")]
181    pub finish_reason: Option<String>,
182}
183
184/// Token usage data, if requested or included by default.
185#[derive(Debug, Deserialize)]
186pub struct ChatCompletionUsage {
187    /// Number of tokens used in the prompt so far.
188    pub prompt_tokens: u32,
189    /// Number of tokens used in the generated answer.
190    pub completion_tokens: u32,
191    /// Total number of tokens consumed by this request.
192    pub total_tokens: u32,
193}
194
195/// --- Streaming Types ---
196///
197/// The streaming endpoint returns partial updates (chunks) with a slightly different
198/// JSON structure. We define separate types to deserialize these chunks.
199/// Represents the delta (partial update) in a streaming chat completion.
200#[derive(Debug, Deserialize)]
201pub struct ChatCompletionDelta {
202    /// May be present in the first chunk, indicating the role (typically "assistant").
203    pub role: Option<String>,
204    /// Partial content for the message.
205    pub content: Option<String>,
206}
207
208/// A single choice within a streaming chat completion chunk.
209#[derive(Debug, Deserialize)]
210pub struct ChatCompletionChunkChoice {
211    /// The index of this choice within the chunk.
212    pub index: u32,
213    /// The delta containing the partial message update.
214    pub delta: ChatCompletionDelta,
215    /// Optional log probabilities for this choice.
216    #[serde(skip_serializing_if = "Option::is_none")]
217    pub logprobs: Option<serde_json::Value>,
218    /// Optional finish reason indicating why generation ended (if applicable).
219    #[serde(skip_serializing_if = "Option::is_none")]
220    pub finish_reason: Option<String>,
221}
222
223/// A streaming chat completion chunk returned by the API.
224#[derive(Debug, Deserialize)]
225pub struct CreateChatCompletionChunk {
226    /// The unique identifier for this chat completion chunk.
227    pub id: String,
228    /// The type of the returned object (e.g., "chat.completion.chunk").
229    pub object: String,
230    /// The creation time (in epoch seconds) for this chunk.
231    pub created: u64,
232    /// The model used to generate the completion.
233    pub model: String,
234    /// A list of choices contained in this chunk.
235    pub choices: Vec<ChatCompletionChunkChoice>,
236}
237
238/// Creates a chat-based completion using the [OpenAI Chat Completions API](https://platform.openai.com/docs/api-reference/chat).
239///
240/// # Parameters
241/// * `client` - The [`OpenAIClient`](crate::config::OpenAIClient) to use for the request.
242/// * `request` - A [`CreateChatCompletionRequest`] specifying the messages, model, and other parameters.
243///
244/// # Returns
245/// A [`CreateChatCompletionResponse`] containing one or more [`ChatCompletionChoice`] items.
246///
247/// # Errors
248/// - [`OpenAIError::HTTPError`]: if the request fails at the network layer.
249/// - [`OpenAIError::DeserializeError`]: if the response fails to parse.
250/// - [`OpenAIError::APIError`]: if OpenAI returns an error (e.g., invalid request).
251pub async fn create_chat_completion(
252    client: &OpenAIClient,
253    request: &CreateChatCompletionRequest,
254) -> Result<CreateChatCompletionResponse, OpenAIError> {
255    // According to the OpenAI docs, the endpoint for chat completions is:
256    // POST /v1/chat/completions
257    let endpoint = "chat/completions";
258    post_json(client, endpoint, request).await
259}
260
261/// Creates a streaming chat-based completion using the OpenAI Chat Completions API.
262/// When `stream` is set to `Some(true)`, partial updates (chunks) are returned.
263/// Each item in the stream is a partial update represented by [`CreateChatCompletionChunk`].
264pub async fn create_chat_completion_stream(
265    client: &OpenAIClient,
266    request: &CreateChatCompletionRequest,
267) -> Result<
268    impl tokio_stream::Stream<Item = Result<CreateChatCompletionChunk, OpenAIError>>,
269    OpenAIError,
270> {
271    let endpoint = "chat/completions";
272    post_json_stream(client, endpoint, request).await
273}
274
275#[cfg(test)]
276mod tests {
277    /// # Tests for the `chat` module
278    ///
279    /// We use [`wiremock`](https://crates.io/crates/wiremock) to mock responses from the
280    /// `/v1/chat/completions` endpoint. These tests ensure that:
281    /// 1. A successful JSON body is deserialized into [`CreateChatCompletionResponse`].
282    /// 2. Non-2xx responses with an OpenAI-style error body map to [`OpenAIError::APIError`].
283    /// 3. Malformed or mismatched JSON produces an [`OpenAIError::DeserializeError`].
284    ///
285    use super::*;
286    use crate::config::OpenAIClient;
287    use crate::error::OpenAIError;
288    use serde_json::json;
289    use wiremock::matchers::{method, path};
290    use wiremock::{Mock, MockServer, ResponseTemplate};
291
292    #[tokio::test]
293    async fn test_create_chat_completion_success() {
294        // Start a local mock server
295        let mock_server = MockServer::start().await;
296
297        // Mock successful response JSON
298        let success_body = json!({
299            "id": "chatcmpl-12345",
300            "object": "chat.completion",
301            "created": 1234567890,
302            "model": "o1-mini",
303            "choices": [{
304                "index": 0,
305                "message": {
306                    "role": "assistant",
307                    "content": "Here is a witty ice cream tagline!",
308                },
309                "finish_reason": "stop"
310            }],
311            "usage": {
312                "prompt_tokens": 10,
313                "completion_tokens": 5,
314                "total_tokens": 15
315            }
316        });
317
318        Mock::given(method("POST"))
319            .and(path("/chat/completions"))
320            .respond_with(ResponseTemplate::new(200).set_body_json(success_body))
321            .mount(&mock_server)
322            .await;
323
324        let client = OpenAIClient::builder()
325            .with_api_key("test-key")
326            .with_base_url(&mock_server.uri()) // override base URL to mock server
327            .build()
328            .unwrap();
329
330        // Build a minimal request
331        let req = CreateChatCompletionRequest {
332            model: Model::Other("o1-mini".to_string()),
333            messages: vec![ChatMessage {
334                role: ChatRole::User,
335                content: "Write me an ice cream tagline.".to_string(),
336                name: None,
337            }],
338            max_tokens: Some(50),
339            ..Default::default()
340        };
341
342        // Call the function under test
343        let result = create_chat_completion(&client, &req).await;
344        assert!(result.is_ok(), "Expected success, got: {:?}", result);
345
346        let resp = result.unwrap();
347        assert_eq!(resp.id, "chatcmpl-12345");
348        assert_eq!(resp.object, "chat.completion");
349        assert_eq!(resp.model, "o1-mini");
350        assert_eq!(resp.choices.len(), 1);
351
352        let first_choice = &resp.choices[0];
353        assert_eq!(first_choice.message.role, ChatRole::Assistant);
354        assert_eq!(
355            first_choice.message.content,
356            "Here is a witty ice cream tagline!"
357        );
358        assert_eq!(resp.usage.as_ref().unwrap().total_tokens, 15);
359    }
360
361    #[tokio::test]
362    async fn test_create_chat_completion_api_error() {
363        // Mock a 400 error with OpenAI-style error body
364        let mock_server = MockServer::start().await;
365        let error_body = json!({
366            "error": {
367                "message": "Invalid model ID",
368                "type": "invalid_request_error",
369                "code": "model_not_found"
370            }
371        });
372
373        Mock::given(method("POST"))
374            .and(path("/chat/completions"))
375            .respond_with(ResponseTemplate::new(400).set_body_json(error_body))
376            .mount(&mock_server)
377            .await;
378
379        let client = OpenAIClient::builder()
380            .with_api_key("test-key")
381            .with_base_url(&mock_server.uri())
382            .build()
383            .unwrap();
384
385        let req = CreateChatCompletionRequest {
386            model: Model::Other("non_existent_model".to_string()),
387            messages: vec![],
388            ..Default::default()
389        };
390
391        let result = create_chat_completion(&client, &req).await;
392        match result {
393            Err(OpenAIError::APIError { message, .. }) => {
394                assert!(
395                    message.contains("Invalid model ID"),
396                    "Expected an API error with 'Invalid model ID', got: {}",
397                    message
398                );
399            }
400            other => panic!("Expected APIError, got: {:?}", other),
401        }
402    }
403
404    #[tokio::test]
405    async fn test_create_chat_completion_deserialize_error() {
406        // Mock a 200 response with malformed or mismatched JSON
407        let mock_server = MockServer::start().await;
408        let malformed_json = r#"{
409            "id": "chatcmpl-12345",
410            "object": "chat.completion",
411            "created": "not_a_number",   // string instead of number
412            "model": "o1-mini",
413            "choices": "should_be_an_array"
414        }"#;
415
416        Mock::given(method("POST"))
417            .and(path("/chat/completions"))
418            .respond_with(
419                ResponseTemplate::new(200).set_body_raw(malformed_json, "application/json"),
420            )
421            .mount(&mock_server)
422            .await;
423
424        let client = OpenAIClient::builder()
425            .with_api_key("test-key")
426            .with_base_url(&mock_server.uri())
427            .build()
428            .unwrap();
429
430        let req = CreateChatCompletionRequest {
431            model: Model::Other("o1-mini".to_string()),
432            messages: vec![],
433            ..Default::default()
434        };
435
436        let result = create_chat_completion(&client, &req).await;
437
438        // Expect a deserialization error
439        match result {
440            Err(OpenAIError::DeserializeError(_)) => {} // success
441            other => panic!("Expected DeserializeError, got: {:?}", other),
442        }
443    }
444}