chat_gpt_lib_rs/api_resources/completions.rs
1//! This module provides functionality for creating text completions using the
2//! [OpenAI Completions API](https://platform.openai.com/docs/api-reference/completions).
3//!
4//! **Note**: This struct (`CreateCompletionRequest`) has been expanded to capture additional
5//! fields from the OpenAI specification, including `best_of`, `seed`, `suffix`, etc. Some
6//! fields support multiple data types (e.g., `prompt`, `stop`) using `#[serde(untagged)]` enums
7//! for flexible deserialization and serialization.
8//!
9//! # Overview
10//!
11//! The Completions API can generate or manipulate text based on a given prompt. You specify a model
12//! (e.g., `"gpt-3.5-turbo-instruct"`), a prompt, and various parameters like `max_tokens` and
13//! `temperature`.
14//! **Important**: This request object allows for advanced configurations such as
15//! `best_of`, `seed`, and `logit_bias`. Use them carefully, especially if they can consume many
16//! tokens or produce unexpected outputs.
17//!
18//! Typical usage involves calling [`create_completion`] with a [`CreateCompletionRequest`]:
19//!
20//! ```rust,no_run
21//! use chat_gpt_lib_rs::api_resources::completions::{create_completion, CreateCompletionRequest, PromptInput};
22//! use chat_gpt_lib_rs::OpenAIClient;
23//! use chat_gpt_lib_rs::error::OpenAIError;
24//!
25//! #[tokio::main]
26//! async fn main() -> Result<(), OpenAIError> {
27//! let client = OpenAIClient::new(None)?; // Reads API key from OPENAI_API_KEY
28//!
29//! let request = CreateCompletionRequest {
30//! model: "gpt-3.5-turbo-instruct".into(),
31//! // `PromptInput::String` variant if we just have a single prompt text
32//! prompt: Some(PromptInput::String("Tell me a joke about cats".to_string())),
33//! max_tokens: Some(50),
34//! temperature: Some(1.0),
35//! ..Default::default()
36//! };
37//!
38//! let response = create_completion(&client, &request).await?;
39//! if let Some(choice) = response.choices.get(0) {
40//! println!("Completion: {}", choice.text);
41//! }
42//! Ok(())
43//! }
44//! ```
45
46use std::collections::HashMap;
47
48use serde::{Deserialize, Serialize};
49
50use crate::api::{post_json, post_json_stream};
51use crate::config::OpenAIClient;
52use crate::error::OpenAIError;
53
54use super::models::Model;
55
56/// Represents the diverse ways a prompt can be supplied:
57///
58/// - A single string (`"Hello, world!"`)
59/// - An array of strings
60/// - An array of integers (token IDs)
61/// - An array of arrays of integers (multiple sequences of token IDs)
62///
63/// This enumeration corresponds to the JSON schema's `oneOf` for `prompt`.
64/// By using `#[serde(untagged)]`, Serde will automatically handle whichever
65/// variant is provided.
66#[derive(Debug, Serialize, Deserialize, Clone)]
67#[serde(untagged)]
68pub enum PromptInput {
69 /// A single string prompt
70 String(String),
71 /// Multiple string prompts
72 Strings(Vec<String>),
73 /// A single sequence of token IDs
74 Ints(Vec<i64>),
75 /// Multiple sequences of token IDs
76 MultiInts(Vec<Vec<i64>>),
77}
78
79/// Represents the different ways `stop` can be supplied:
80///
81/// - A single string (e.g. `"\n"`)
82/// - An array of up to 4 strings (e.g. `[".END", "Goodbye"]`)
83#[derive(Debug, Serialize, Deserialize, Clone)]
84#[serde(untagged)]
85pub enum StopSequence {
86 /// A single stopping string
87 Single(String),
88 /// Multiple stopping strings
89 Multiple(Vec<String>),
90}
91
92/// Placeholder for potential streaming options, per the spec reference:
93/// `#/components/schemas/ChatCompletionStreamOptions`.
94///
95/// If you plan to implement streaming logic, define fields here accordingly.
96#[derive(Debug, Serialize, Deserialize, Clone, Default)]
97pub struct ChatCompletionStreamOptions {
98 // For now, this is an empty placeholder.
99 // Extend or remove based on your streaming logic requirements.
100}
101
102/// A request struct for creating text completions with the OpenAI API.
103///
104/// This struct fully reflects the extended specification from OpenAI,
105/// including fields such as `best_of`, `seed`, and `suffix`.
106#[derive(Debug, Serialize, Default, Clone)]
107pub struct CreateCompletionRequest {
108 /// **Required.** ID of the model to use. For example: `"gpt-3.5-turbo-instruct"`, `"davinci-002"`,
109 /// or `"text-davinci-003"`.
110 pub model: Model,
111
112 /// **Required.** The prompt(s) to generate completions for.
113 /// Defaults to `<|endoftext|>` if not provided.
114 ///
115 /// Can be a single string, an array of strings, an array of integers (token IDs),
116 /// or an array of arrays of integers (multiple token sequences).
117 #[serde(skip_serializing_if = "Option::is_none")]
118 #[serde(default = "default_prompt")]
119 pub prompt: Option<PromptInput>,
120
121 /// The maximum number of [tokens](https://platform.openai.com/tokenizer) to generate
122 /// in the completion. Defaults to 16.
123 ///
124 /// The combined length of prompt + `max_tokens` cannot exceed the model's context length.
125 #[serde(skip_serializing_if = "Option::is_none")]
126 #[serde(default = "default_max_tokens")]
127 pub max_tokens: Option<u32>,
128
129 /// What sampling temperature to use, between `0` and `2`. Higher values like `0.8` will make the
130 /// output more random, while lower values like `0.2` will make it more focused and deterministic.
131 ///
132 /// We generally recommend altering this or `top_p` but not both.
133 #[serde(skip_serializing_if = "Option::is_none")]
134 #[serde(default = "default_temperature")]
135 pub temperature: Option<f64>,
136
137 /// An alternative to sampling with temperature, called nucleus sampling, where the model
138 /// considers the results of the tokens with `top_p` probability mass. So `0.1` means only
139 /// the tokens comprising the top 10% probability mass are considered.
140 #[serde(skip_serializing_if = "Option::is_none")]
141 #[serde(default = "default_top_p")]
142 pub top_p: Option<f64>,
143
144 /// How many completions to generate for each prompt. Defaults to 1.
145 ///
146 /// **Note**: Because this parameter generates many completions, it can quickly consume your
147 /// token quota. Use carefully and ensure you have reasonable settings for `max_tokens` and `stop`.
148 #[serde(skip_serializing_if = "Option::is_none")]
149 #[serde(default = "default_n")]
150 pub n: Option<u32>,
151
152 /// Generates `best_of` completions server-side and returns the "best" (the one with the
153 /// highest log probability per token). Must be greater than `n`. Defaults to 1.
154 ///
155 /// **Note**: This parameter can quickly consume your token quota if `best_of` is large.
156 #[serde(skip_serializing_if = "Option::is_none")]
157 #[serde(default = "default_best_of")]
158 pub best_of: Option<u32>,
159
160 /// Whether to stream back partial progress. Defaults to `false`.
161 ///
162 /// If set to `true`, tokens will be sent as data-only server-sent events (SSE) as they
163 /// become available, with the stream terminated by a `data: [DONE]` message.
164 #[serde(skip_serializing_if = "Option::is_none")]
165 #[serde(default)]
166 pub stream: Option<bool>,
167
168 /// Additional options that could be used in streaming scenarios.
169 /// This is a placeholder for any extended streaming logic.
170 #[serde(skip_serializing_if = "Option::is_none")]
171 pub stream_options: Option<ChatCompletionStreamOptions>,
172
173 /// Include the log probabilities on the `logprobs` most likely tokens, along with the chosen tokens.
174 /// A value of `5` returns the 5 most likely tokens. Defaults to `null`.
175 #[serde(skip_serializing_if = "Option::is_none")]
176 pub logprobs: Option<u32>,
177
178 /// Echo back the prompt in addition to the completion. Defaults to `false`.
179 #[serde(skip_serializing_if = "Option::is_none")]
180 #[serde(default)]
181 pub echo: Option<bool>,
182
183 /// Up to 4 sequences where the API will stop generating further tokens. The returned text will
184 /// not contain the stop sequence. Defaults to `null`.
185 #[serde(skip_serializing_if = "Option::is_none")]
186 pub stop: Option<StopSequence>,
187
188 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in
189 /// the text so far, increasing the model's likelihood to talk about new topics. Defaults to 0.
190 #[serde(skip_serializing_if = "Option::is_none")]
191 #[serde(default)]
192 pub presence_penalty: Option<f64>,
193
194 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency
195 /// in the text so far, decreasing the model's likelihood to repeat the same line verbatim. Defaults to 0.
196 #[serde(skip_serializing_if = "Option::is_none")]
197 #[serde(default)]
198 pub frequency_penalty: Option<f64>,
199
200 /// Modify the likelihood of specified tokens appearing in the completion.
201 /// Maps token IDs to a bias value from -100 to 100. Defaults to `null`.
202 #[serde(skip_serializing_if = "Option::is_none")]
203 pub logit_bias: Option<HashMap<String, i32>>,
204
205 /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
206 /// This is optional, but recommended. Example: `"user-1234"`.
207 #[serde(skip_serializing_if = "Option::is_none")]
208 pub user: Option<String>,
209
210 /// If specified, the system will make a best effort to sample deterministically.
211 /// Repeated requests with the same `seed` and parameters should return the same result (best-effort).
212 ///
213 /// Determinism is not guaranteed, and you should refer to the `system_fingerprint` in the response
214 /// to monitor backend changes.
215 #[serde(skip_serializing_if = "Option::is_none")]
216 pub seed: Option<i64>,
217
218 /// The suffix that comes after a completion of inserted text. This parameter is only supported
219 /// for `gpt-3.5-turbo-instruct`. Defaults to `null`.
220 #[serde(skip_serializing_if = "Option::is_none")]
221 pub suffix: Option<String>,
222}
223
224/// Default prompt is `<|endoftext|>`, per the specification.
225#[allow(dead_code)] // This way, Serde can still invoke them at runtime, but the compiler won’t complain.
226fn default_prompt() -> Option<PromptInput> {
227 Some(PromptInput::String("<|endoftext|>".to_string()))
228}
229
230/// Default max_tokens is `16`.
231#[allow(dead_code)] // This way, Serde can still invoke them at runtime, but the compiler won’t complain.
232fn default_max_tokens() -> Option<u32> {
233 Some(16)
234}
235
236/// Default temperature is `1.0`.
237#[allow(dead_code)] // This way, Serde can still invoke them at runtime, but the compiler won’t complain.
238fn default_temperature() -> Option<f64> {
239 Some(1.0)
240}
241
242/// Default top_p is `1.0`.
243#[allow(dead_code)] // This way, Serde can still invoke them at runtime, but the compiler won’t complain.
244fn default_top_p() -> Option<f64> {
245 Some(1.0)
246}
247
248/// Default `n` is `1`.
249#[allow(dead_code)] // This way, Serde can still invoke them at runtime, but the compiler won’t complain.
250fn default_n() -> Option<u32> {
251 Some(1)
252}
253
254/// Default `best_of` is `1`.
255#[allow(dead_code)] // This way, Serde can still invoke them at runtime, but the compiler won’t complain.
256fn default_best_of() -> Option<u32> {
257 Some(1)
258}
259
260/// The response returned by the OpenAI Completions API.
261///
262/// Contains generated `choices` plus optional `usage` metrics.
263#[derive(Debug, Deserialize)]
264pub struct CreateCompletionResponse {
265 /// An identifier fo this completion (e.g. `"cmpl-xxxxxxxx"`).
266 pub id: String,
267 /// The object type, usually `"text_completion"`.
268 pub object: String,
269 /// The creation time in epoch seconds.
270 pub created: u64,
271 /// The model used for this request.
272 pub model: Model,
273 /// A list of generated completions.
274 pub choices: Vec<CompletionChoice>,
275 /// Token usage data (optional field).
276 #[serde(default)]
277 pub usage: Option<CompletionUsage>,
278}
279
280/// A single generated completion choice within a [`CreateCompletionResponse`].
281#[derive(Debug, Deserialize)]
282pub struct CompletionChoice {
283 /// The generated text.
284 pub text: String,
285 /// Which completion index this choice corresponds to (useful if `n` > 1).
286 pub index: u32,
287 /// The reason why the completion ended (e.g., "stop", "length").
288 #[serde(skip_serializing_if = "Option::is_none")]
289 pub finish_reason: Option<String>,
290 /// The log probabilities, if `logprobs` was requested.
291 #[serde(skip_serializing_if = "Option::is_none")]
292 pub logprobs: Option<serde_json::Value>,
293}
294
295/// Token usage data, if requested or included by default.
296#[derive(Debug, Deserialize)]
297pub struct CompletionUsage {
298 /// Number of tokens in the prompt.
299 pub prompt_tokens: u32,
300 /// Number of tokens in the generated completion.
301 pub completion_tokens: u32,
302 /// Total number of tokens consumed by this request.
303 pub total_tokens: u32,
304}
305
306/// Creates a text completion using the [OpenAI Completions API](https://platform.openai.com/docs/api-reference/completions).
307///
308/// # Parameters
309///
310/// * `client` - The [`OpenAIClient`](crate::config::OpenAIClient) to use for the request.
311/// * `request` - A [`CreateCompletionRequest`] specifying the prompt, model, and additional parameters.
312///
313/// # Returns
314///
315/// A [`CreateCompletionResponse`] containing the generated text (in [`CompletionChoice`])
316/// and metadata about usage and indexing.
317///
318/// # Errors
319///
320/// - [`OpenAIError::HTTPError`]: if the request fails at the network layer.
321/// - [`OpenAIError::DeserializeError`]: if the response fails to parse.
322/// - [`OpenAIError::APIError`]: if OpenAI returns an error (e.g. invalid request).
323pub async fn create_completion(
324 client: &OpenAIClient,
325 request: &CreateCompletionRequest,
326) -> Result<CreateCompletionResponse, OpenAIError> {
327 let endpoint = "completions";
328 post_json(client, endpoint, request).await
329}
330
331/// Creates a streaming text completion using the OpenAI Completions API.
332///
333/// When the `stream` field in the request is set to `Some(true)`, the API
334/// will return partial responses as a stream. This function returns an asynchronous
335/// stream of [`CreateCompletionResponse`] objects. Each item in the stream represents
336/// a partial response until the final message is received (typically signaled by `[DONE]`).
337pub async fn create_completion_stream(
338 client: &OpenAIClient,
339 request: &CreateCompletionRequest,
340) -> Result<
341 impl tokio_stream::Stream<Item = Result<CreateCompletionResponse, OpenAIError>>,
342 OpenAIError,
343> {
344 let endpoint = "completions";
345 post_json_stream(client, endpoint, request).await
346}
347
348#[cfg(test)]
349mod tests {
350 /// # Tests for the `completions` module
351 ///
352 /// These tests use [`wiremock`](https://crates.io/crates/wiremock) to mock responses from the
353 /// `/v1/completions` endpoint. We cover:
354 /// 1. A successful JSON response, ensuring we can deserialize a [`CreateCompletionResponse`].
355 /// 2. A non-2xx OpenAI-style error, which should map to [`OpenAIError::APIError`].
356 /// 3. Malformed JSON that triggers a [`OpenAIError::DeserializeError`].
357 ///
358 use super::*;
359 use crate::config::OpenAIClient;
360 use crate::error::OpenAIError;
361 use serde_json::json;
362 use wiremock::matchers::{method, path};
363 use wiremock::{Mock, MockServer, ResponseTemplate};
364
365 #[tokio::test]
366 async fn test_create_completion_success() {
367 // Spin up a local mock server
368 let mock_server = MockServer::start().await;
369
370 // Mock JSON body for a successful 200 response
371 let success_body = json!({
372 "id": "cmpl-12345",
373 "object": "text_completion",
374 "created": 1673643147,
375 "model": "text-davinci-003",
376 "choices": [{
377 "text": "This is a funny cat joke!",
378 "index": 0,
379 "finish_reason": "stop"
380 }],
381 "usage": {
382 "prompt_tokens": 10,
383 "completion_tokens": 7,
384 "total_tokens": 17
385 }
386 });
387
388 // Expect a POST to /v1/completions
389 Mock::given(method("POST"))
390 .and(path("/completions"))
391 .respond_with(ResponseTemplate::new(200).set_body_json(success_body))
392 .mount(&mock_server)
393 .await;
394
395 let client = OpenAIClient::builder()
396 .with_api_key("test-key")
397 // Override the base URL to the mock server
398 .with_base_url(&mock_server.uri())
399 .build()
400 .unwrap();
401
402 // Create a minimal request
403 let req = CreateCompletionRequest {
404 model: "text-davinci-003".into(),
405 prompt: Some(PromptInput::String("Tell me a cat joke".into())),
406 max_tokens: Some(20),
407 ..Default::default()
408 };
409
410 // Call the function under test
411 let result = create_completion(&client, &req).await;
412 assert!(result.is_ok(), "Expected success, got: {:?}", result);
413
414 let resp = result.unwrap();
415 assert_eq!(resp.id, "cmpl-12345");
416 assert_eq!(resp.object, "text_completion");
417 assert_eq!(resp.model, "text-davinci-003".into());
418 assert_eq!(resp.choices.len(), 1);
419
420 let choice = &resp.choices[0];
421 assert_eq!(choice.text, "This is a funny cat joke!");
422 assert_eq!(choice.index, 0);
423 assert_eq!(choice.finish_reason.as_deref(), Some("stop"));
424
425 let usage = resp.usage.as_ref().unwrap();
426 assert_eq!(usage.prompt_tokens, 10);
427 assert_eq!(usage.completion_tokens, 7);
428 assert_eq!(usage.total_tokens, 17);
429 }
430
431 #[tokio::test]
432 async fn test_create_completion_api_error() {
433 let mock_server = MockServer::start().await;
434
435 let error_body = json!({
436 "error": {
437 "message": "Model is unavailable",
438 "type": "invalid_request_error",
439 "code": "model_unavailable"
440 }
441 });
442
443 Mock::given(method("POST"))
444 .and(path("/completions"))
445 .respond_with(ResponseTemplate::new(404).set_body_json(error_body))
446 .mount(&mock_server)
447 .await;
448
449 let client = OpenAIClient::builder()
450 .with_api_key("test-key")
451 .with_base_url(&mock_server.uri())
452 .build()
453 .unwrap();
454
455 let req = CreateCompletionRequest {
456 model: "unknown-model".into(),
457 prompt: Some(PromptInput::String("Hello".into())),
458 ..Default::default()
459 };
460
461 let result = create_completion(&client, &req).await;
462 match result {
463 Err(OpenAIError::APIError { message, .. }) => {
464 assert!(message.contains("Model is unavailable"));
465 }
466 other => panic!("Expected APIError, got: {:?}", other),
467 }
468 }
469
470 #[tokio::test]
471 async fn test_create_completion_deserialize_error() {
472 let mock_server = MockServer::start().await;
473
474 // Return a 200 but with malformed JSON that doesn't match `CreateCompletionResponse`
475 let malformed_json = r#"{
476 "id": "cmpl-12345",
477 "object": "text_completion",
478 "created": "invalid_number",
479 "model": "text-davinci-003",
480 "choices": "should_be_array"
481 }"#;
482
483 Mock::given(method("POST"))
484 .and(path("/completions"))
485 .respond_with(
486 ResponseTemplate::new(200).set_body_raw(malformed_json, "application/json"),
487 )
488 .mount(&mock_server)
489 .await;
490
491 let client = OpenAIClient::builder()
492 .with_api_key("test-key")
493 .with_base_url(&mock_server.uri())
494 .build()
495 .unwrap();
496
497 let req = CreateCompletionRequest {
498 model: "text-davinci-003".into(),
499 prompt: Some(PromptInput::String("Hello".into())),
500 ..Default::default()
501 };
502
503 let result = create_completion(&client, &req).await;
504 match result {
505 Err(OpenAIError::DeserializeError(_)) => {
506 // success
507 }
508 other => panic!("Expected DeserializeError, got {:?}", other),
509 }
510 }
511}