openai_mock/models/
completion.rs

1//! This module defines the data structures for handling completion requests
2//! and responses in the API.
3
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use crate::validators::StopSequence;
8
9/// Represents a request payload for the Completions API.
10///
11/// This structure includes various optional and required fields that
12/// configure the behavior of the text generation process.
13#[derive(Debug, Deserialize, Serialize)]
14pub struct CompletionRequest {
15    /// ID of the model to use.
16    pub model: String,
17
18    /// The prompt(s) to generate completions for.
19    ///
20    /// Can be a string, an array of strings, or `null`.
21    #[serde(default)]
22    pub prompt: Option<Value>,
23
24    /// The suffix that comes after the generated completion.
25    #[serde(default)]
26    pub suffix: Option<String>,
27
28    /// The maximum number of tokens to generate.
29    #[serde(default = "default_max_tokens")]
30    pub max_tokens: Option<u32>,
31
32    /// Sampling temperature to use. Higher values make output more random.
33    #[serde(default = "default_temperature")]
34    pub temperature: Option<f32>,
35
36    /// Nucleus sampling probability.
37    #[serde(default = "default_top_p")]
38    pub top_p: Option<f32>,
39
40    /// Number of completions to generate for each prompt.
41    #[serde(default = "default_n")]
42    pub n: Option<i32>,
43
44    /// Whether to stream back partial progress.
45    #[serde(default = "default_stream")]
46    pub stream: Option<bool>,
47
48    /// Include the log probabilities of the top tokens.
49    #[serde(default)]
50    pub logprobs: Option<u32>,
51
52    /// Echo back the prompt in addition to the completion.
53    #[serde(default = "default_echo")]
54    pub echo: Option<bool>,
55
56    /// Sequences where the API will stop generating further tokens.
57    #[serde(default)]
58    pub stop: Option<StopSequence>,
59
60    /// Penalizes repeated tokens based on presence.
61    #[serde(default = "default_presence_penalty")]
62    pub presence_penalty: Option<f32>,
63
64    /// Penalizes repeated tokens based on frequency.
65    #[serde(default = "default_frequency_penalty")]
66    pub frequency_penalty: Option<f32>,
67
68    /// Generates `best_of` completions server-side and returns the best one.
69    #[serde(default)]
70    pub best_of: Option<i32>,
71
72    /// Modifies the likelihood of specified tokens appearing in the completion.
73    #[serde(default)]
74    pub logit_bias: Option<HashMap<String, i32>>,
75
76    /// A unique identifier representing the end-user.
77    #[serde(default)]
78    pub user: Option<String>,
79}
80
81// Default values for optional parameters
82
83fn default_max_tokens() -> Option<u32> {
84    Some(16)
85}
86
87fn default_temperature() -> Option<f32> {
88    Some(1.0)
89}
90
91fn default_top_p() -> Option<f32> {
92    Some(1.0)
93}
94
95fn default_n() -> Option<i32> {
96    Some(1)
97}
98
99fn default_stream() -> Option<bool> {
100    Some(false)
101}
102
103fn default_echo() -> Option<bool> {
104    Some(false)
105}
106
107fn default_presence_penalty() -> Option<f32> {
108    Some(0.0)
109}
110
111fn default_frequency_penalty() -> Option<f32> {
112    Some(0.0)
113}
114
115impl Default for CompletionRequest {
116    /// Provides default values for `CompletionRequest`.
117    ///
118    /// # Example
119    ///
120    /// ```
121    /// let default_request = CompletionRequest::default();
122    /// ```
123    fn default() -> Self {
124        Self {
125            model: String::new(),
126            prompt: None,
127            suffix: None,
128            max_tokens: default_max_tokens(),
129            temperature: default_temperature(),
130            top_p: default_top_p(),
131            n: default_n(),
132            stream: default_stream(),
133            logprobs: None,
134            echo: default_echo(),
135            stop: None,
136            presence_penalty: default_presence_penalty(),
137            frequency_penalty: default_frequency_penalty(),
138            best_of: None,
139            logit_bias: None,
140            user: None,
141        }
142    }
143}
144
145/// Represents a response from the Completions API.
146///
147/// Contains generated completions along with usage statistics.
148#[derive(Debug, Serialize, Deserialize)]
149pub struct CompletionResponse {
150    /// Unique identifier for the completion.
151    pub id: String,
152
153    /// The object type (e.g., "text_completion").
154    pub object: String,
155
156    /// Creation time in epoch seconds.
157    pub created: u64,
158
159    /// The model used for the completion.
160    pub model: String,
161
162    /// The list of generated completions.
163    pub choices: Vec<Choice>,
164
165    /// Usage statistics for the completion.
166    pub usage: Usage,
167}
168
169/// Represents a single completion choice.
170///
171/// Contains the generated text and additional metadata.
172#[derive(Debug, Serialize, Deserialize, Clone)]
173pub struct Choice {
174    /// The generated text.
175    pub text: String,
176
177    /// The index of this choice in the returned list.
178    pub index: i32,
179
180    /// The log probabilities of the tokens, if requested.
181    #[serde(default)]
182    pub logprobs: Option<Logprobs>,
183
184    /// The reason why the completion ended (e.g., "stop", "length").
185    #[serde(default)]
186    pub finish_reason: Option<String>,
187}
188
189/// Represents the log probabilities of tokens.
190///
191/// Provides detailed information about token generation probabilities.
192#[derive(Debug, Serialize, Deserialize, Clone)]
193pub struct Logprobs {
194    /// List of tokens generated.
195    pub tokens: Vec<String>,
196
197    /// Log probability of each token.
198    pub token_logprobs: Vec<f32>,
199
200    /// Indices of tokens in the original text.
201    pub text_offset: Vec<usize>,
202
203    /// Top log probabilities for each token position.
204    pub top_logprobs: Vec<HashMap<String, f32>>,
205}
206
207/// Represents usage statistics for a completion.
208///
209/// Tracks the number of tokens consumed in the prompt and completion.
210#[derive(Debug, Serialize, Deserialize)]
211pub struct Usage {
212    /// The number of tokens in the prompt.
213    pub prompt_tokens: u32,
214
215    /// The number of tokens in the completion.
216    pub completion_tokens: u32,
217
218    /// The total number of tokens used.
219    pub total_tokens: u32,
220}