openai_mock/models/completion.rs
1//! This module defines the data structures for handling completion requests
2//! and responses in the API.
3
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use crate::validators::StopSequence;
8
9/// Represents a request payload for the Completions API.
10///
11/// This structure includes various optional and required fields that
12/// configure the behavior of the text generation process.
13#[derive(Debug, Deserialize, Serialize)]
14pub struct CompletionRequest {
15 /// ID of the model to use.
16 pub model: String,
17
18 /// The prompt(s) to generate completions for.
19 ///
20 /// Can be a string, an array of strings, or `null`.
21 #[serde(default)]
22 pub prompt: Option<Value>,
23
24 /// The suffix that comes after the generated completion.
25 #[serde(default)]
26 pub suffix: Option<String>,
27
28 /// The maximum number of tokens to generate.
29 #[serde(default = "default_max_tokens")]
30 pub max_tokens: Option<u32>,
31
32 /// Sampling temperature to use. Higher values make output more random.
33 #[serde(default = "default_temperature")]
34 pub temperature: Option<f32>,
35
36 /// Nucleus sampling probability.
37 #[serde(default = "default_top_p")]
38 pub top_p: Option<f32>,
39
40 /// Number of completions to generate for each prompt.
41 #[serde(default = "default_n")]
42 pub n: Option<i32>,
43
44 /// Whether to stream back partial progress.
45 #[serde(default = "default_stream")]
46 pub stream: Option<bool>,
47
48 /// Include the log probabilities of the top tokens.
49 #[serde(default)]
50 pub logprobs: Option<u32>,
51
52 /// Echo back the prompt in addition to the completion.
53 #[serde(default = "default_echo")]
54 pub echo: Option<bool>,
55
56 /// Sequences where the API will stop generating further tokens.
57 #[serde(default)]
58 pub stop: Option<StopSequence>,
59
60 /// Penalizes repeated tokens based on presence.
61 #[serde(default = "default_presence_penalty")]
62 pub presence_penalty: Option<f32>,
63
64 /// Penalizes repeated tokens based on frequency.
65 #[serde(default = "default_frequency_penalty")]
66 pub frequency_penalty: Option<f32>,
67
68 /// Generates `best_of` completions server-side and returns the best one.
69 #[serde(default)]
70 pub best_of: Option<i32>,
71
72 /// Modifies the likelihood of specified tokens appearing in the completion.
73 #[serde(default)]
74 pub logit_bias: Option<HashMap<String, i32>>,
75
76 /// A unique identifier representing the end-user.
77 #[serde(default)]
78 pub user: Option<String>,
79}
80
81// Default values for optional parameters
82
83fn default_max_tokens() -> Option<u32> {
84 Some(16)
85}
86
87fn default_temperature() -> Option<f32> {
88 Some(1.0)
89}
90
91fn default_top_p() -> Option<f32> {
92 Some(1.0)
93}
94
95fn default_n() -> Option<i32> {
96 Some(1)
97}
98
99fn default_stream() -> Option<bool> {
100 Some(false)
101}
102
103fn default_echo() -> Option<bool> {
104 Some(false)
105}
106
107fn default_presence_penalty() -> Option<f32> {
108 Some(0.0)
109}
110
111fn default_frequency_penalty() -> Option<f32> {
112 Some(0.0)
113}
114
115impl Default for CompletionRequest {
116 /// Provides default values for `CompletionRequest`.
117 ///
118 /// # Example
119 ///
120 /// ```
121 /// let default_request = CompletionRequest::default();
122 /// ```
123 fn default() -> Self {
124 Self {
125 model: String::new(),
126 prompt: None,
127 suffix: None,
128 max_tokens: default_max_tokens(),
129 temperature: default_temperature(),
130 top_p: default_top_p(),
131 n: default_n(),
132 stream: default_stream(),
133 logprobs: None,
134 echo: default_echo(),
135 stop: None,
136 presence_penalty: default_presence_penalty(),
137 frequency_penalty: default_frequency_penalty(),
138 best_of: None,
139 logit_bias: None,
140 user: None,
141 }
142 }
143}
144
145/// Represents a response from the Completions API.
146///
147/// Contains generated completions along with usage statistics.
148#[derive(Debug, Serialize, Deserialize)]
149pub struct CompletionResponse {
150 /// Unique identifier for the completion.
151 pub id: String,
152
153 /// The object type (e.g., "text_completion").
154 pub object: String,
155
156 /// Creation time in epoch seconds.
157 pub created: u64,
158
159 /// The model used for the completion.
160 pub model: String,
161
162 /// The list of generated completions.
163 pub choices: Vec<Choice>,
164
165 /// Usage statistics for the completion.
166 pub usage: Usage,
167}
168
169/// Represents a single completion choice.
170///
171/// Contains the generated text and additional metadata.
172#[derive(Debug, Serialize, Deserialize, Clone)]
173pub struct Choice {
174 /// The generated text.
175 pub text: String,
176
177 /// The index of this choice in the returned list.
178 pub index: i32,
179
180 /// The log probabilities of the tokens, if requested.
181 #[serde(default)]
182 pub logprobs: Option<Logprobs>,
183
184 /// The reason why the completion ended (e.g., "stop", "length").
185 #[serde(default)]
186 pub finish_reason: Option<String>,
187}
188
189/// Represents the log probabilities of tokens.
190///
191/// Provides detailed information about token generation probabilities.
192#[derive(Debug, Serialize, Deserialize, Clone)]
193pub struct Logprobs {
194 /// List of tokens generated.
195 pub tokens: Vec<String>,
196
197 /// Log probability of each token.
198 pub token_logprobs: Vec<f32>,
199
200 /// Indices of tokens in the original text.
201 pub text_offset: Vec<usize>,
202
203 /// Top log probabilities for each token position.
204 pub top_logprobs: Vec<HashMap<String, f32>>,
205}
206
207/// Represents usage statistics for a completion.
208///
209/// Tracks the number of tokens consumed in the prompt and completion.
210#[derive(Debug, Serialize, Deserialize)]
211pub struct Usage {
212 /// The number of tokens in the prompt.
213 pub prompt_tokens: u32,
214
215 /// The number of tokens in the completion.
216 pub completion_tokens: u32,
217
218 /// The total number of tokens used.
219 pub total_tokens: u32,
220}