1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::{Map, Value};
5use validator::Validate;
6
7use super::{
8 common::*,
9 sampling_params::{validate_top_k_value, validate_top_p_value},
10};
11use crate::validated::Normalizable;
12
13#[serde_with::skip_serializing_none]
18#[derive(Debug, Clone, Deserialize, Serialize, Validate, schemars::JsonSchema)]
19#[validate(schema(function = "validate_completion_cross_parameters"))]
20pub struct CompletionRequest {
21 pub model: String,
23
24 #[validate(custom(function = "validate_completion_prompt"))]
26 pub prompt: StringOrArray,
27
28 #[validate(range(min = 0, max = 20))]
30 pub best_of: Option<u32>,
31
32 #[serde(default)]
34 pub echo: bool,
35
36 #[validate(range(min = -2.0, max = 2.0))]
38 pub frequency_penalty: Option<f32>,
39
40 pub logit_bias: Option<HashMap<String, f32>>,
42
43 #[validate(range(min = 0, max = 5))]
45 pub logprobs: Option<u32>,
46
47 #[validate(range(min = 0))]
49 pub max_tokens: Option<u32>,
50
51 #[validate(range(min = 1, max = 128))]
53 pub n: Option<u32>,
54
55 #[validate(range(min = -2.0, max = 2.0))]
57 pub presence_penalty: Option<f32>,
58
59 pub seed: Option<i64>,
61
62 #[validate(custom(function = "validate_stop"))]
64 pub stop: Option<StringOrArray>,
65
66 #[serde(default, deserialize_with = "deserialize_null_as_false")]
68 pub stream: bool,
69
70 pub stream_options: Option<StreamOptions>,
72
73 pub suffix: Option<String>,
75
76 #[validate(range(min = 0.0, max = 2.0))]
78 pub temperature: Option<f32>,
79
80 #[validate(custom(function = "validate_top_p_value"))]
82 pub top_p: Option<f32>,
83
84 pub user: Option<String>,
86
87 #[validate(custom(function = "validate_top_k_value"))]
92 pub top_k: Option<i32>,
93
94 #[validate(range(min = 0.0, max = 1.0))]
96 pub min_p: Option<f32>,
97
98 #[validate(range(min = 1))]
100 pub min_tokens: Option<u32>,
101
102 #[validate(range(min = 0.0, max = 2.0))]
104 pub repetition_penalty: Option<f32>,
105
106 pub regex: Option<String>,
108
109 pub ebnf: Option<String>,
111
112 pub json_schema: Option<String>,
114
115 pub stop_token_ids: Option<Vec<u32>>,
117
118 #[serde(default)]
120 pub no_stop_trim: bool,
121
122 #[serde(default)]
124 pub ignore_eos: bool,
125
126 #[serde(default = "default_true")]
128 pub skip_special_tokens: bool,
129
130 pub lora_path: Option<String>,
132
133 pub session_params: Option<HashMap<String, Value>>,
135
136 #[serde(default)]
138 pub return_hidden_states: bool,
139
140 pub sampling_seed: Option<u64>,
142
143 #[serde(flatten)]
145 pub other: Map<String, Value>,
146}
147
148impl Normalizable for CompletionRequest {}
149
150fn validate_completion_prompt(prompt: &StringOrArray) -> Result<(), validator::ValidationError> {
151 match prompt {
152 StringOrArray::String(_) => {}
153 StringOrArray::Array(arr) => {
154 if arr.is_empty() {
155 let mut error = validator::ValidationError::new("prompt_empty");
156 error.message = Some("prompt array cannot be empty".into());
157 return Err(error);
158 }
159 }
160 }
161
162 Ok(())
163}
164
165fn validate_completion_cross_parameters(
166 req: &CompletionRequest,
167) -> Result<(), validator::ValidationError> {
168 if req.stream_options.is_some() && !req.stream {
169 let mut error = validator::ValidationError::new("stream_options_requires_stream");
170 error.message =
171 Some("The 'stream_options' parameter is only allowed when 'stream' is enabled".into());
172 return Err(error);
173 }
174
175 if let (Some(min), Some(max)) = (req.min_tokens, req.max_tokens) {
176 if min > max {
177 let mut error = validator::ValidationError::new("min_tokens_exceeds_max");
178 error.message = Some("min_tokens cannot exceed max_tokens".into());
179 return Err(error);
180 }
181 }
182
183 let constraint_count =
184 req.regex.is_some() as u8 + req.ebnf.is_some() as u8 + req.json_schema.is_some() as u8;
185 if constraint_count > 1 {
186 let mut error = validator::ValidationError::new("multiple_constraints");
187 error.message = Some(
188 "only one structured output constraint (regex, ebnf, or json_schema) can be active at a time"
189 .into(),
190 );
191 return Err(error);
192 }
193
194 if let (Some(best_of), Some(n)) = (req.best_of, req.n) {
195 if best_of <= n {
196 let mut error = validator::ValidationError::new("best_of_less_than_n");
197 error.message = Some("best_of must be greater than n".into());
198 return Err(error);
199 }
200 }
201
202 if req.stream && req.best_of.is_some() {
203 let mut error = validator::ValidationError::new("best_of_not_supported_with_stream");
204 error.message = Some("best_of is not supported when stream is enabled".into());
205 return Err(error);
206 }
207
208 Ok(())
209}
210
211impl GenerationRequest for CompletionRequest {
212 fn is_stream(&self) -> bool {
213 self.stream
214 }
215
216 fn get_model(&self) -> Option<&str> {
217 Some(&self.model)
218 }
219
220 fn extract_text_for_routing(&self) -> String {
221 match &self.prompt {
222 StringOrArray::String(s) => s.clone(),
223 StringOrArray::Array(v) => v.join(" "),
224 }
225 }
226}
227
228#[serde_with::skip_serializing_none]
233#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
234pub struct CompletionResponse {
235 pub id: String,
236 pub object: String, pub created: u64,
238 pub model: String,
239 pub choices: Vec<CompletionChoice>,
240 pub usage: Option<Usage>,
241 pub system_fingerprint: Option<String>,
242}
243
244#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
245pub struct CompletionChoice {
246 pub text: String,
247 pub index: u32,
248 #[serde(skip_serializing_if = "Option::is_none")]
249 pub logprobs: Option<LogProbs>,
250 pub finish_reason: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
253 pub matched_stop: Option<Value>, }
255
256#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
257pub struct CompletionStreamResponse {
258 pub id: String,
259 pub object: String, pub created: u64,
261 pub choices: Vec<CompletionStreamChoice>,
262 pub model: String,
263 #[serde(skip_serializing_if = "Option::is_none")]
264 pub system_fingerprint: Option<String>,
265 #[serde(skip_serializing_if = "Option::is_none")]
266 pub usage: Option<Usage>,
267}
268
269#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
270pub struct CompletionStreamChoice {
271 pub text: String,
272 pub index: u32,
273 #[serde(skip_serializing_if = "Option::is_none")]
274 pub logprobs: Option<LogProbs>,
275 pub finish_reason: Option<String>,
276}