chat_gpt_lib_rs/api_resources/
completions.rs

1//! This module provides functionality for creating text completions using the
2//! [OpenAI Completions API](https://platform.openai.com/docs/api-reference/completions).
3//!
4//! **Note**: This struct (`CreateCompletionRequest`) has been expanded to capture additional
5//! fields from the OpenAI specification, including `best_of`, `seed`, `suffix`, etc. Some
6//! fields support multiple data types (e.g., `prompt`, `stop`) using `#[serde(untagged)]` enums
7//! for flexible deserialization and serialization.
8//!
9//! # Overview
10//!
11//! The Completions API can generate or manipulate text based on a given prompt. You specify a model
12//! (e.g., `"gpt-3.5-turbo-instruct"`), a prompt, and various parameters like `max_tokens` and
13//! `temperature`.  
14//! **Important**: This request object allows for advanced configurations such as
15//! `best_of`, `seed`, and `logit_bias`. Use them carefully, especially if they can consume many
16//! tokens or produce unexpected outputs.
17//!
18//! Typical usage involves calling [`create_completion`] with a [`CreateCompletionRequest`]:
19//!
20//! ```rust,no_run
21//! use chat_gpt_lib_rs::api_resources::completions::{create_completion, CreateCompletionRequest, PromptInput};
22//! use chat_gpt_lib_rs::OpenAIClient;
23//! use chat_gpt_lib_rs::error::OpenAIError;
24//!
25//! #[tokio::main]
26//! async fn main() -> Result<(), OpenAIError> {
27//!     let client = OpenAIClient::new(None)?; // Reads API key from OPENAI_API_KEY
28//!
29//!     let request = CreateCompletionRequest {
30//!         model: "gpt-3.5-turbo-instruct".into(),
31//!         // `PromptInput::String` variant if we just have a single prompt text
32//!         prompt: Some(PromptInput::String("Tell me a joke about cats".to_string())),
33//!         max_tokens: Some(50),
34//!         temperature: Some(1.0),
35//!         ..Default::default()
36//!     };
37//!
38//!     let response = create_completion(&client, &request).await?;
39//!     if let Some(choice) = response.choices.get(0) {
40//!         println!("Completion: {}", choice.text);
41//!     }
42//!     Ok(())
43//! }
44//! ```
45
46use std::collections::HashMap;
47
48use serde::{Deserialize, Serialize};
49
50use crate::api::{post_json, post_json_stream};
51use crate::config::OpenAIClient;
52use crate::error::OpenAIError;
53
54use super::models::Model;
55
56/// Represents the diverse ways a prompt can be supplied:
57///
58/// - A single string (`"Hello, world!"`)
59/// - An array of strings
60/// - An array of integers (token IDs)
61/// - An array of arrays of integers (multiple sequences of token IDs)
62///
63/// This enumeration corresponds to the JSON schema's `oneOf` for `prompt`.  
64/// By using `#[serde(untagged)]`, Serde will automatically handle whichever
65/// variant is provided.
66#[derive(Debug, Serialize, Deserialize, Clone)]
67#[serde(untagged)]
68pub enum PromptInput {
69    /// A single string prompt
70    String(String),
71    /// Multiple string prompts
72    Strings(Vec<String>),
73    /// A single sequence of token IDs
74    Ints(Vec<i64>),
75    /// Multiple sequences of token IDs
76    MultiInts(Vec<Vec<i64>>),
77}
78
79/// Represents the different ways `stop` can be supplied:
80///
81/// - A single string (e.g. `"\n"`)
82/// - An array of up to 4 strings (e.g. `[".END", "Goodbye"]`)
83#[derive(Debug, Serialize, Deserialize, Clone)]
84#[serde(untagged)]
85pub enum StopSequence {
86    /// A single stopping string
87    Single(String),
88    /// Multiple stopping strings
89    Multiple(Vec<String>),
90}
91
92/// Placeholder for potential streaming options, per the spec reference:
93/// `#/components/schemas/ChatCompletionStreamOptions`.
94///
95/// If you plan to implement streaming logic, define fields here accordingly.
96#[derive(Debug, Serialize, Deserialize, Clone, Default)]
97pub struct ChatCompletionStreamOptions {
98    // For now, this is an empty placeholder.
99    // Extend or remove based on your streaming logic requirements.
100}
101
102/// A request struct for creating text completions with the OpenAI API.
103///
104/// This struct fully reflects the extended specification from OpenAI,
105/// including fields such as `best_of`, `seed`, and `suffix`.
106#[derive(Debug, Serialize, Default, Clone)]
107pub struct CreateCompletionRequest {
108    /// **Required.** ID of the model to use. For example: `"gpt-3.5-turbo-instruct"`, `"davinci-002"`,
109    /// or `"text-davinci-003"`.
110    pub model: Model,
111
112    /// **Required.** The prompt(s) to generate completions for.
113    /// Defaults to `<|endoftext|>` if not provided.
114    ///
115    /// Can be a single string, an array of strings, an array of integers (token IDs),
116    /// or an array of arrays of integers (multiple token sequences).
117    #[serde(skip_serializing_if = "Option::is_none")]
118    #[serde(default = "default_prompt")]
119    pub prompt: Option<PromptInput>,
120
121    /// The maximum number of [tokens](https://platform.openai.com/tokenizer) to generate
122    /// in the completion. Defaults to 16.
123    ///
124    /// The combined length of prompt + `max_tokens` cannot exceed the model's context length.
125    #[serde(skip_serializing_if = "Option::is_none")]
126    #[serde(default = "default_max_tokens")]
127    pub max_tokens: Option<u32>,
128
129    /// What sampling temperature to use, between `0` and `2`. Higher values like `0.8` will make the
130    /// output more random, while lower values like `0.2` will make it more focused and deterministic.
131    ///
132    /// We generally recommend altering this or `top_p` but not both.
133    #[serde(skip_serializing_if = "Option::is_none")]
134    #[serde(default = "default_temperature")]
135    pub temperature: Option<f64>,
136
137    /// An alternative to sampling with temperature, called nucleus sampling, where the model
138    /// considers the results of the tokens with `top_p` probability mass. So `0.1` means only
139    /// the tokens comprising the top 10% probability mass are considered.
140    #[serde(skip_serializing_if = "Option::is_none")]
141    #[serde(default = "default_top_p")]
142    pub top_p: Option<f64>,
143
144    /// How many completions to generate for each prompt. Defaults to 1.
145    ///
146    /// **Note**: Because this parameter generates many completions, it can quickly consume your
147    /// token quota. Use carefully and ensure you have reasonable settings for `max_tokens` and `stop`.
148    #[serde(skip_serializing_if = "Option::is_none")]
149    #[serde(default = "default_n")]
150    pub n: Option<u32>,
151
152    /// Generates `best_of` completions server-side and returns the "best" (the one with the
153    /// highest log probability per token). Must be greater than `n`. Defaults to 1.
154    ///
155    /// **Note**: This parameter can quickly consume your token quota if `best_of` is large.
156    #[serde(skip_serializing_if = "Option::is_none")]
157    #[serde(default = "default_best_of")]
158    pub best_of: Option<u32>,
159
160    /// Whether to stream back partial progress. Defaults to `false`.
161    ///
162    /// If set to `true`, tokens will be sent as data-only server-sent events (SSE) as they
163    /// become available, with the stream terminated by a `data: [DONE]` message.
164    #[serde(skip_serializing_if = "Option::is_none")]
165    #[serde(default)]
166    pub stream: Option<bool>,
167
168    /// Additional options that could be used in streaming scenarios.
169    /// This is a placeholder for any extended streaming logic.
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub stream_options: Option<ChatCompletionStreamOptions>,
172
173    /// Include the log probabilities on the `logprobs` most likely tokens, along with the chosen tokens.
174    /// A value of `5` returns the 5 most likely tokens. Defaults to `null`.
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub logprobs: Option<u32>,
177
178    /// Echo back the prompt in addition to the completion. Defaults to `false`.
179    #[serde(skip_serializing_if = "Option::is_none")]
180    #[serde(default)]
181    pub echo: Option<bool>,
182
183    /// Up to 4 sequences where the API will stop generating further tokens. The returned text will
184    /// not contain the stop sequence. Defaults to `null`.
185    #[serde(skip_serializing_if = "Option::is_none")]
186    pub stop: Option<StopSequence>,
187
188    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in
189    /// the text so far, increasing the model's likelihood to talk about new topics. Defaults to 0.
190    #[serde(skip_serializing_if = "Option::is_none")]
191    #[serde(default)]
192    pub presence_penalty: Option<f64>,
193
194    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency
195    /// in the text so far, decreasing the model's likelihood to repeat the same line verbatim. Defaults to 0.
196    #[serde(skip_serializing_if = "Option::is_none")]
197    #[serde(default)]
198    pub frequency_penalty: Option<f64>,
199
200    /// Modify the likelihood of specified tokens appearing in the completion.
201    /// Maps token IDs to a bias value from -100 to 100. Defaults to `null`.
202    #[serde(skip_serializing_if = "Option::is_none")]
203    pub logit_bias: Option<HashMap<String, i32>>,
204
205    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
206    /// This is optional, but recommended. Example: `"user-1234"`.
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub user: Option<String>,
209
210    /// If specified, the system will make a best effort to sample deterministically.
211    /// Repeated requests with the same `seed` and parameters should return the same result (best-effort).
212    ///
213    /// Determinism is not guaranteed, and you should refer to the `system_fingerprint` in the response
214    /// to monitor backend changes.
215    #[serde(skip_serializing_if = "Option::is_none")]
216    pub seed: Option<i64>,
217
218    /// The suffix that comes after a completion of inserted text. This parameter is only supported
219    /// for `gpt-3.5-turbo-instruct`. Defaults to `null`.
220    #[serde(skip_serializing_if = "Option::is_none")]
221    pub suffix: Option<String>,
222}
223
224/// Default prompt is `<|endoftext|>`, per the specification.
225#[allow(dead_code)] // This way, Serde can still invoke them at runtime, but the compiler won’t complain.
226fn default_prompt() -> Option<PromptInput> {
227    Some(PromptInput::String("<|endoftext|>".to_string()))
228}
229
230/// Default max_tokens is `16`.
231#[allow(dead_code)] // This way, Serde can still invoke them at runtime, but the compiler won’t complain.
232fn default_max_tokens() -> Option<u32> {
233    Some(16)
234}
235
236/// Default temperature is `1.0`.
237#[allow(dead_code)] // This way, Serde can still invoke them at runtime, but the compiler won’t complain.
238fn default_temperature() -> Option<f64> {
239    Some(1.0)
240}
241
242/// Default top_p is `1.0`.
243#[allow(dead_code)] // This way, Serde can still invoke them at runtime, but the compiler won’t complain.
244fn default_top_p() -> Option<f64> {
245    Some(1.0)
246}
247
248/// Default `n` is `1`.
249#[allow(dead_code)] // This way, Serde can still invoke them at runtime, but the compiler won’t complain.
250fn default_n() -> Option<u32> {
251    Some(1)
252}
253
254/// Default `best_of` is `1`.
255#[allow(dead_code)] // This way, Serde can still invoke them at runtime, but the compiler won’t complain.
256fn default_best_of() -> Option<u32> {
257    Some(1)
258}
259
260/// The response returned by the OpenAI Completions API.
261///
262/// Contains generated `choices` plus optional `usage` metrics.
263#[derive(Debug, Deserialize)]
264pub struct CreateCompletionResponse {
265    /// An identifier fo this completion (e.g. `"cmpl-xxxxxxxx"`).
266    pub id: String,
267    /// The object type, usually `"text_completion"`.
268    pub object: String,
269    /// The creation time in epoch seconds.
270    pub created: u64,
271    /// The model used for this request.
272    pub model: Model,
273    /// A list of generated completions.
274    pub choices: Vec<CompletionChoice>,
275    /// Token usage data (optional field).
276    #[serde(default)]
277    pub usage: Option<CompletionUsage>,
278}
279
280/// A single generated completion choice within a [`CreateCompletionResponse`].
281#[derive(Debug, Deserialize)]
282pub struct CompletionChoice {
283    /// The generated text.
284    pub text: String,
285    /// Which completion index this choice corresponds to (useful if `n` > 1).
286    pub index: u32,
287    /// The reason why the completion ended (e.g., "stop", "length").
288    #[serde(skip_serializing_if = "Option::is_none")]
289    pub finish_reason: Option<String>,
290    /// The log probabilities, if `logprobs` was requested.
291    #[serde(skip_serializing_if = "Option::is_none")]
292    pub logprobs: Option<serde_json::Value>,
293}
294
295/// Token usage data, if requested or included by default.
296#[derive(Debug, Deserialize)]
297pub struct CompletionUsage {
298    /// Number of tokens in the prompt.
299    pub prompt_tokens: u32,
300    /// Number of tokens in the generated completion.
301    pub completion_tokens: u32,
302    /// Total number of tokens consumed by this request.
303    pub total_tokens: u32,
304}
305
306/// Creates a text completion using the [OpenAI Completions API](https://platform.openai.com/docs/api-reference/completions).
307///
308/// # Parameters
309///
310/// * `client` - The [`OpenAIClient`](crate::config::OpenAIClient) to use for the request.
311/// * `request` - A [`CreateCompletionRequest`] specifying the prompt, model, and additional parameters.
312///
313/// # Returns
314///
315/// A [`CreateCompletionResponse`] containing the generated text (in [`CompletionChoice`])
316/// and metadata about usage and indexing.
317///
318/// # Errors
319///
320/// - [`OpenAIError::HTTPError`]: if the request fails at the network layer.
321/// - [`OpenAIError::DeserializeError`]: if the response fails to parse.
322/// - [`OpenAIError::APIError`]: if OpenAI returns an error (e.g. invalid request).
323pub async fn create_completion(
324    client: &OpenAIClient,
325    request: &CreateCompletionRequest,
326) -> Result<CreateCompletionResponse, OpenAIError> {
327    let endpoint = "completions";
328    post_json(client, endpoint, request).await
329}
330
331/// Creates a streaming text completion using the OpenAI Completions API.
332///
333/// When the `stream` field in the request is set to `Some(true)`, the API
334/// will return partial responses as a stream. This function returns an asynchronous
335/// stream of [`CreateCompletionResponse`] objects. Each item in the stream represents
336/// a partial response until the final message is received (typically signaled by `[DONE]`).
337pub async fn create_completion_stream(
338    client: &OpenAIClient,
339    request: &CreateCompletionRequest,
340) -> Result<
341    impl tokio_stream::Stream<Item = Result<CreateCompletionResponse, OpenAIError>>,
342    OpenAIError,
343> {
344    let endpoint = "completions";
345    post_json_stream(client, endpoint, request).await
346}
347
348#[cfg(test)]
349mod tests {
350    /// # Tests for the `completions` module
351    ///
352    /// These tests use [`wiremock`](https://crates.io/crates/wiremock) to mock responses from the
353    /// `/v1/completions` endpoint. We cover:
354    /// 1. A successful JSON response, ensuring we can deserialize a [`CreateCompletionResponse`].
355    /// 2. A non-2xx OpenAI-style error, which should map to [`OpenAIError::APIError`].
356    /// 3. Malformed JSON that triggers a [`OpenAIError::DeserializeError`].
357    ///
358    use super::*;
359    use crate::config::OpenAIClient;
360    use crate::error::OpenAIError;
361    use serde_json::json;
362    use wiremock::matchers::{method, path};
363    use wiremock::{Mock, MockServer, ResponseTemplate};
364
365    #[tokio::test]
366    async fn test_create_completion_success() {
367        // Spin up a local mock server
368        let mock_server = MockServer::start().await;
369
370        // Mock JSON body for a successful 200 response
371        let success_body = json!({
372            "id": "cmpl-12345",
373            "object": "text_completion",
374            "created": 1673643147,
375            "model": "text-davinci-003",
376            "choices": [{
377                "text": "This is a funny cat joke!",
378                "index": 0,
379                "finish_reason": "stop"
380            }],
381            "usage": {
382                "prompt_tokens": 10,
383                "completion_tokens": 7,
384                "total_tokens": 17
385            }
386        });
387
388        // Expect a POST to /v1/completions
389        Mock::given(method("POST"))
390            .and(path("/completions"))
391            .respond_with(ResponseTemplate::new(200).set_body_json(success_body))
392            .mount(&mock_server)
393            .await;
394
395        let client = OpenAIClient::builder()
396            .with_api_key("test-key")
397            // Override the base URL to the mock server
398            .with_base_url(&mock_server.uri())
399            .build()
400            .unwrap();
401
402        // Create a minimal request
403        let req = CreateCompletionRequest {
404            model: "text-davinci-003".into(),
405            prompt: Some(PromptInput::String("Tell me a cat joke".into())),
406            max_tokens: Some(20),
407            ..Default::default()
408        };
409
410        // Call the function under test
411        let result = create_completion(&client, &req).await;
412        assert!(result.is_ok(), "Expected success, got: {:?}", result);
413
414        let resp = result.unwrap();
415        assert_eq!(resp.id, "cmpl-12345");
416        assert_eq!(resp.object, "text_completion");
417        assert_eq!(resp.model, "text-davinci-003".into());
418        assert_eq!(resp.choices.len(), 1);
419
420        let choice = &resp.choices[0];
421        assert_eq!(choice.text, "This is a funny cat joke!");
422        assert_eq!(choice.index, 0);
423        assert_eq!(choice.finish_reason.as_deref(), Some("stop"));
424
425        let usage = resp.usage.as_ref().unwrap();
426        assert_eq!(usage.prompt_tokens, 10);
427        assert_eq!(usage.completion_tokens, 7);
428        assert_eq!(usage.total_tokens, 17);
429    }
430
431    #[tokio::test]
432    async fn test_create_completion_api_error() {
433        let mock_server = MockServer::start().await;
434
435        let error_body = json!({
436            "error": {
437                "message": "Model is unavailable",
438                "type": "invalid_request_error",
439                "code": "model_unavailable"
440            }
441        });
442
443        Mock::given(method("POST"))
444            .and(path("/completions"))
445            .respond_with(ResponseTemplate::new(404).set_body_json(error_body))
446            .mount(&mock_server)
447            .await;
448
449        let client = OpenAIClient::builder()
450            .with_api_key("test-key")
451            .with_base_url(&mock_server.uri())
452            .build()
453            .unwrap();
454
455        let req = CreateCompletionRequest {
456            model: "unknown-model".into(),
457            prompt: Some(PromptInput::String("Hello".into())),
458            ..Default::default()
459        };
460
461        let result = create_completion(&client, &req).await;
462        match result {
463            Err(OpenAIError::APIError { message, .. }) => {
464                assert!(message.contains("Model is unavailable"));
465            }
466            other => panic!("Expected APIError, got: {:?}", other),
467        }
468    }
469
470    #[tokio::test]
471    async fn test_create_completion_deserialize_error() {
472        let mock_server = MockServer::start().await;
473
474        // Return a 200 but with malformed JSON that doesn't match `CreateCompletionResponse`
475        let malformed_json = r#"{
476            "id": "cmpl-12345",
477            "object": "text_completion",
478            "created": "invalid_number",
479            "model": "text-davinci-003",
480            "choices": "should_be_array"
481        }"#;
482
483        Mock::given(method("POST"))
484            .and(path("/completions"))
485            .respond_with(
486                ResponseTemplate::new(200).set_body_raw(malformed_json, "application/json"),
487            )
488            .mount(&mock_server)
489            .await;
490
491        let client = OpenAIClient::builder()
492            .with_api_key("test-key")
493            .with_base_url(&mock_server.uri())
494            .build()
495            .unwrap();
496
497        let req = CreateCompletionRequest {
498            model: "text-davinci-003".into(),
499            prompt: Some(PromptInput::String("Hello".into())),
500            ..Default::default()
501        };
502
503        let result = create_completion(&client, &req).await;
504        match result {
505            Err(OpenAIError::DeserializeError(_)) => {
506                // success
507            }
508            other => panic!("Expected DeserializeError, got {:?}", other),
509        }
510    }
511}