chat_gpt_lib_rs/api_resources/chat.rs
1//! This module provides functionality for creating chat-based completions using the
2//! [OpenAI Chat Completions API](https://platform.openai.com/docs/api-reference/chat).
3//!
4//! The Chat API is designed for conversational interactions, where each request includes a list
5//! of messages with a role (system, user, or assistant). The model responds based on the context
6//! established by these messages, allowing for more interactive and context-aware responses
7//! compared to plain completions.
8//!
9//! # Overview
10//!
11//! The core usage involves calling [`create_chat_completion`] with a [`CreateChatCompletionRequest`],
12//! which includes a sequence of [`ChatMessage`] items. Each `ChatMessage` has a `role` and `content`.
13//! The API then returns a [`CreateChatCompletionResponse`] containing one or more
14//! [`ChatCompletionChoice`] objects (depending on the `n` parameter).
15//!
16//! ```rust,no_run
17//! use chat_gpt_lib_rs::api_resources::chat::{create_chat_completion, CreateChatCompletionRequest, ChatMessage, ChatRole};
18//! use chat_gpt_lib_rs::api_resources::models::Model;
19//! use chat_gpt_lib_rs::error::OpenAIError;
20//! use chat_gpt_lib_rs::OpenAIClient;
21//!
22//! #[tokio::main]
23//! async fn main() -> Result<(), OpenAIError> {
24//! let client = OpenAIClient::new(None)?; // Reads API key from OPENAI_API_KEY
25//!
26//! let request = CreateChatCompletionRequest {
27//! model: Model::O1Mini,
28//! messages: vec![
29//! ChatMessage {
30//! role: ChatRole::System,
31//! content: "You are a helpful assistant.".to_string(),
32//! name: None,
33//! },
34//! ChatMessage {
35//! role: ChatRole::User,
36//! content: "Write a tagline for an ice cream shop.".to_string(),
37//! name: None,
38//! },
39//! ],
40//! max_tokens: Some(50),
41//! temperature: Some(0.7),
42//! ..Default::default()
43//! };
44//!
45//! let response = create_chat_completion(&client, &request).await?;
46//!
47//! for choice in &response.choices {
48//! println!("Assistant: {}", choice.message.content);
49//! }
50//!
51//! Ok(())
52//! }
53//! ```
54
55use serde::{Deserialize, Serialize};
56use std::collections::HashMap;
57
58use crate::api::{post_json, post_json_stream};
59use crate::config::OpenAIClient;
60use crate::error::OpenAIError;
61
62use crate::api_resources::models::Model;
63
64/// The role of a message in the chat sequence.
65///
66/// Typically one of `system`, `user`, `assistant`. OpenAI may add or adjust roles in the future.
67#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
68#[serde(rename_all = "lowercase")]
69pub enum ChatRole {
70 /// For system-level instructions (e.g. "You are a helpful assistant.")
71 System,
72 /// For user-supplied messages
73 User,
74 /// For assistant messages (responses from the model)
75 Assistant,
76 /// For tools
77 Tool,
78 /// For function
79 Function,
80 /// Experimental or extended role types, if they become available
81 #[serde(other)]
82 Other,
83}
84
85/// A single message in a chat conversation.
86///
87/// Each message has:
88/// - A [`ChatRole`], indicating who is sending the message (system, user, assistant).
89/// - The message `content`.
90/// - An optional `name` for the user or system, if applicable.
91#[derive(Debug, Serialize, Deserialize, Clone)]
92pub struct ChatMessage {
93 /// The role of the sender (system, user, or assistant).
94 pub role: ChatRole,
95 /// The content of the message.
96 pub content: String,
97 /// The (optional) name of the user or system. This can be used to identify
98 /// the speaker when multiple users or participants exist in a conversation.
99 #[serde(skip_serializing_if = "Option::is_none")]
100 pub name: Option<String>,
101}
102
103/// A request struct for creating chat completions with the OpenAI Chat Completions API.
104///
105/// # Fields
106/// - `model`: The ID of the model to use (e.g., "gpt-3.5-turbo").
107/// - `messages`: A list of [`ChatMessage`] items providing the conversation history.
108/// - `stream`: Whether or not to stream responses via server-sent events.
109/// - `max_tokens`, `temperature`, `top_p`, etc.: Parameters controlling the generation.
110/// - `n`: Number of chat completion choices to generate.
111/// - `logit_bias`, `user`: Additional advanced parameters.
112#[derive(Debug, Serialize, Default, Clone)]
113pub struct CreateChatCompletionRequest {
114 /// **Required**. The model used for this chat request.
115 /// Examples: "Model::O1Mini", "Model::Other("gpt-4".to_string)".
116 pub model: Model,
117
118 /// **Required**. The messages that make up the conversation so far.
119 pub messages: Vec<ChatMessage>,
120
121 /// Controls the creativity of the output. 0 is the most deterministic, 2 is the most creative.
122 #[serde(skip_serializing_if = "Option::is_none")]
123 pub temperature: Option<f64>,
124
125 /// The nucleus sampling parameter. Like `temperature`, but a value like 0.1 means only
126 /// the top 10% probability mass is considered.
127 #[serde(skip_serializing_if = "Option::is_none")]
128 pub top_p: Option<f64>,
129
130 /// How many chat completion choices to generate for each input message. Defaults to 1.
131 #[serde(skip_serializing_if = "Option::is_none")]
132 pub n: Option<u32>,
133
134 /// If set, partial message deltas are sent as data-only server-sent events (SSE) as they become available.
135 #[serde(skip_serializing_if = "Option::is_none")]
136 pub stream: Option<bool>,
137
138 /// The maximum number of tokens allowed for the generated answer. Defaults to the max tokens allowed by the model minus the prompt.
139 #[serde(skip_serializing_if = "Option::is_none")]
140 pub max_tokens: Option<u32>,
141
142 /// A map between token (encoded as a string) and an associated bias from -100 to 100
143 /// that adjusts the likelihood of the token appearing.
144 #[serde(skip_serializing_if = "Option::is_none")]
145 pub logit_bias: Option<HashMap<String, i32>>,
146
147 /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
148 #[serde(skip_serializing_if = "Option::is_none")]
149 pub user: Option<String>,
150}
151
152/// The response returned by the OpenAI Chat Completions API.
153///
154/// Includes one or more chat-based completion choices and any usage statistics.
155#[derive(Debug, Deserialize)]
156pub struct CreateChatCompletionResponse {
157 /// An identifier for this chat completion (e.g., "chatcmpl-xxxxxx").
158 pub id: String,
159 /// The object type, usually "chat.completion".
160 pub object: String,
161 /// The creation time in epoch seconds.
162 pub created: u64,
163 /// The base model used for this request.
164 pub model: String,
165 /// A list of generated chat completion choices.
166 pub choices: Vec<ChatCompletionChoice>,
167 /// Token usage data (optional field).
168 #[serde(default)]
169 pub usage: Option<ChatCompletionUsage>,
170}
171
172/// A single chat completion choice within a [`CreateChatCompletionResponse`].
173#[derive(Debug, Deserialize)]
174pub struct ChatCompletionChoice {
175 /// The index of this choice (useful if `n` > 1).
176 pub index: u32,
177 /// The chat message object containing the role and content.
178 pub message: ChatMessage,
179 /// Why the chat completion ended (e.g., "stop", "length").
180 #[serde(skip_serializing_if = "Option::is_none")]
181 pub finish_reason: Option<String>,
182}
183
184/// Token usage data, if requested or included by default.
185#[derive(Debug, Deserialize)]
186pub struct ChatCompletionUsage {
187 /// Number of tokens used in the prompt so far.
188 pub prompt_tokens: u32,
189 /// Number of tokens used in the generated answer.
190 pub completion_tokens: u32,
191 /// Total number of tokens consumed by this request.
192 pub total_tokens: u32,
193}
194
195/// --- Streaming Types ---
196///
197/// The streaming endpoint returns partial updates (chunks) with a slightly different
198/// JSON structure. We define separate types to deserialize these chunks.
199/// Represents the delta (partial update) in a streaming chat completion.
200#[derive(Debug, Deserialize)]
201pub struct ChatCompletionDelta {
202 /// May be present in the first chunk, indicating the role (typically "assistant").
203 pub role: Option<String>,
204 /// Partial content for the message.
205 pub content: Option<String>,
206}
207
208/// A single choice within a streaming chat completion chunk.
209#[derive(Debug, Deserialize)]
210pub struct ChatCompletionChunkChoice {
211 /// The index of this choice within the chunk.
212 pub index: u32,
213 /// The delta containing the partial message update.
214 pub delta: ChatCompletionDelta,
215 /// Optional log probabilities for this choice.
216 #[serde(skip_serializing_if = "Option::is_none")]
217 pub logprobs: Option<serde_json::Value>,
218 /// Optional finish reason indicating why generation ended (if applicable).
219 #[serde(skip_serializing_if = "Option::is_none")]
220 pub finish_reason: Option<String>,
221}
222
223/// A streaming chat completion chunk returned by the API.
224#[derive(Debug, Deserialize)]
225pub struct CreateChatCompletionChunk {
226 /// The unique identifier for this chat completion chunk.
227 pub id: String,
228 /// The type of the returned object (e.g., "chat.completion.chunk").
229 pub object: String,
230 /// The creation time (in epoch seconds) for this chunk.
231 pub created: u64,
232 /// The model used to generate the completion.
233 pub model: String,
234 /// A list of choices contained in this chunk.
235 pub choices: Vec<ChatCompletionChunkChoice>,
236}
237
238/// Creates a chat-based completion using the [OpenAI Chat Completions API](https://platform.openai.com/docs/api-reference/chat).
239///
240/// # Parameters
241/// * `client` - The [`OpenAIClient`](crate::config::OpenAIClient) to use for the request.
242/// * `request` - A [`CreateChatCompletionRequest`] specifying the messages, model, and other parameters.
243///
244/// # Returns
245/// A [`CreateChatCompletionResponse`] containing one or more [`ChatCompletionChoice`] items.
246///
247/// # Errors
248/// - [`OpenAIError::HTTPError`]: if the request fails at the network layer.
249/// - [`OpenAIError::DeserializeError`]: if the response fails to parse.
250/// - [`OpenAIError::APIError`]: if OpenAI returns an error (e.g., invalid request).
251pub async fn create_chat_completion(
252 client: &OpenAIClient,
253 request: &CreateChatCompletionRequest,
254) -> Result<CreateChatCompletionResponse, OpenAIError> {
255 // According to the OpenAI docs, the endpoint for chat completions is:
256 // POST /v1/chat/completions
257 let endpoint = "chat/completions";
258 post_json(client, endpoint, request).await
259}
260
261/// Creates a streaming chat-based completion using the OpenAI Chat Completions API.
262/// When `stream` is set to `Some(true)`, partial updates (chunks) are returned.
263/// Each item in the stream is a partial update represented by [`CreateChatCompletionChunk`].
264pub async fn create_chat_completion_stream(
265 client: &OpenAIClient,
266 request: &CreateChatCompletionRequest,
267) -> Result<
268 impl tokio_stream::Stream<Item = Result<CreateChatCompletionChunk, OpenAIError>>,
269 OpenAIError,
270> {
271 let endpoint = "chat/completions";
272 post_json_stream(client, endpoint, request).await
273}
274
275#[cfg(test)]
276mod tests {
277 /// # Tests for the `chat` module
278 ///
279 /// We use [`wiremock`](https://crates.io/crates/wiremock) to mock responses from the
280 /// `/v1/chat/completions` endpoint. These tests ensure that:
281 /// 1. A successful JSON body is deserialized into [`CreateChatCompletionResponse`].
282 /// 2. Non-2xx responses with an OpenAI-style error body map to [`OpenAIError::APIError`].
283 /// 3. Malformed or mismatched JSON produces an [`OpenAIError::DeserializeError`].
284 ///
285 use super::*;
286 use crate::config::OpenAIClient;
287 use crate::error::OpenAIError;
288 use serde_json::json;
289 use wiremock::matchers::{method, path};
290 use wiremock::{Mock, MockServer, ResponseTemplate};
291
292 #[tokio::test]
293 async fn test_create_chat_completion_success() {
294 // Start a local mock server
295 let mock_server = MockServer::start().await;
296
297 // Mock successful response JSON
298 let success_body = json!({
299 "id": "chatcmpl-12345",
300 "object": "chat.completion",
301 "created": 1234567890,
302 "model": "o1-mini",
303 "choices": [{
304 "index": 0,
305 "message": {
306 "role": "assistant",
307 "content": "Here is a witty ice cream tagline!",
308 },
309 "finish_reason": "stop"
310 }],
311 "usage": {
312 "prompt_tokens": 10,
313 "completion_tokens": 5,
314 "total_tokens": 15
315 }
316 });
317
318 Mock::given(method("POST"))
319 .and(path("/chat/completions"))
320 .respond_with(ResponseTemplate::new(200).set_body_json(success_body))
321 .mount(&mock_server)
322 .await;
323
324 let client = OpenAIClient::builder()
325 .with_api_key("test-key")
326 .with_base_url(&mock_server.uri()) // override base URL to mock server
327 .build()
328 .unwrap();
329
330 // Build a minimal request
331 let req = CreateChatCompletionRequest {
332 model: Model::Other("o1-mini".to_string()),
333 messages: vec![ChatMessage {
334 role: ChatRole::User,
335 content: "Write me an ice cream tagline.".to_string(),
336 name: None,
337 }],
338 max_tokens: Some(50),
339 ..Default::default()
340 };
341
342 // Call the function under test
343 let result = create_chat_completion(&client, &req).await;
344 assert!(result.is_ok(), "Expected success, got: {:?}", result);
345
346 let resp = result.unwrap();
347 assert_eq!(resp.id, "chatcmpl-12345");
348 assert_eq!(resp.object, "chat.completion");
349 assert_eq!(resp.model, "o1-mini");
350 assert_eq!(resp.choices.len(), 1);
351
352 let first_choice = &resp.choices[0];
353 assert_eq!(first_choice.message.role, ChatRole::Assistant);
354 assert_eq!(
355 first_choice.message.content,
356 "Here is a witty ice cream tagline!"
357 );
358 assert_eq!(resp.usage.as_ref().unwrap().total_tokens, 15);
359 }
360
361 #[tokio::test]
362 async fn test_create_chat_completion_api_error() {
363 // Mock a 400 error with OpenAI-style error body
364 let mock_server = MockServer::start().await;
365 let error_body = json!({
366 "error": {
367 "message": "Invalid model ID",
368 "type": "invalid_request_error",
369 "code": "model_not_found"
370 }
371 });
372
373 Mock::given(method("POST"))
374 .and(path("/chat/completions"))
375 .respond_with(ResponseTemplate::new(400).set_body_json(error_body))
376 .mount(&mock_server)
377 .await;
378
379 let client = OpenAIClient::builder()
380 .with_api_key("test-key")
381 .with_base_url(&mock_server.uri())
382 .build()
383 .unwrap();
384
385 let req = CreateChatCompletionRequest {
386 model: Model::Other("non_existent_model".to_string()),
387 messages: vec![],
388 ..Default::default()
389 };
390
391 let result = create_chat_completion(&client, &req).await;
392 match result {
393 Err(OpenAIError::APIError { message, .. }) => {
394 assert!(
395 message.contains("Invalid model ID"),
396 "Expected an API error with 'Invalid model ID', got: {}",
397 message
398 );
399 }
400 other => panic!("Expected APIError, got: {:?}", other),
401 }
402 }
403
404 #[tokio::test]
405 async fn test_create_chat_completion_deserialize_error() {
406 // Mock a 200 response with malformed or mismatched JSON
407 let mock_server = MockServer::start().await;
408 let malformed_json = r#"{
409 "id": "chatcmpl-12345",
410 "object": "chat.completion",
411 "created": "not_a_number", // string instead of number
412 "model": "o1-mini",
413 "choices": "should_be_an_array"
414 }"#;
415
416 Mock::given(method("POST"))
417 .and(path("/chat/completions"))
418 .respond_with(
419 ResponseTemplate::new(200).set_body_raw(malformed_json, "application/json"),
420 )
421 .mount(&mock_server)
422 .await;
423
424 let client = OpenAIClient::builder()
425 .with_api_key("test-key")
426 .with_base_url(&mock_server.uri())
427 .build()
428 .unwrap();
429
430 let req = CreateChatCompletionRequest {
431 model: Model::Other("o1-mini".to_string()),
432 messages: vec![],
433 ..Default::default()
434 };
435
436 let result = create_chat_completion(&client, &req).await;
437
438 // Expect a deserialization error
439 match result {
440 Err(OpenAIError::DeserializeError(_)) => {} // success
441 other => panic!("Expected DeserializeError, got: {:?}", other),
442 }
443 }
444}