Skip to main content

openai_protocol/
completion.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::{Map, Value};
5use validator::Validate;
6
7use super::{
8    common::*,
9    sampling_params::{validate_top_k_value, validate_top_p_value},
10};
11use crate::validated::Normalizable;
12
13// ============================================================================
14// Completions API (v1/completions) - DEPRECATED but still supported
15// ============================================================================
16
17#[serde_with::skip_serializing_none]
18#[derive(Debug, Clone, Deserialize, Serialize, Validate, schemars::JsonSchema)]
19#[validate(schema(function = "validate_completion_cross_parameters"))]
20pub struct CompletionRequest {
21    /// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang)
22    pub model: String,
23
24    /// The prompt(s) to generate completions for
25    #[validate(custom(function = "validate_completion_prompt"))]
26    pub prompt: StringOrArray,
27
28    /// Generates `best_of` completions server-side and returns the "best"
29    #[validate(range(min = 0, max = 20))]
30    pub best_of: Option<u32>,
31
32    /// Echo back the prompt in addition to the completion
33    #[serde(default)]
34    pub echo: bool,
35
36    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
37    #[validate(range(min = -2.0, max = 2.0))]
38    pub frequency_penalty: Option<f32>,
39
40    /// Modify the likelihood of specified tokens appearing in the completion
41    pub logit_bias: Option<HashMap<String, f32>>,
42
43    /// Include the log probabilities on the `logprobs` most likely tokens
44    #[validate(range(min = 0, max = 5))]
45    pub logprobs: Option<u32>,
46
47    /// The maximum number of tokens to generate
48    #[validate(range(min = 0))]
49    pub max_tokens: Option<u32>,
50
51    /// How many completions to generate for each prompt
52    #[validate(range(min = 1, max = 128))]
53    pub n: Option<u32>,
54
55    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
56    #[validate(range(min = -2.0, max = 2.0))]
57    pub presence_penalty: Option<f32>,
58
59    /// If specified, our system will make a best effort to sample deterministically
60    pub seed: Option<i64>,
61
62    /// Up to 4 sequences where the API will stop generating further tokens
63    #[validate(custom(function = "validate_stop"))]
64    pub stop: Option<StringOrArray>,
65
66    /// Whether to stream back partial progress
67    #[serde(default, deserialize_with = "deserialize_null_as_false")]
68    pub stream: bool,
69
70    /// Options for streaming response
71    pub stream_options: Option<StreamOptions>,
72
73    /// The suffix that comes after a completion of inserted text
74    pub suffix: Option<String>,
75
76    /// What sampling temperature to use, between 0 and 2
77    #[validate(range(min = 0.0, max = 2.0))]
78    pub temperature: Option<f32>,
79
80    /// An alternative to sampling with temperature (nucleus sampling)
81    #[validate(custom(function = "validate_top_p_value"))]
82    pub top_p: Option<f32>,
83
84    /// A unique identifier representing your end-user
85    pub user: Option<String>,
86
87    // =============================================================================
88    // Engine-Specific Sampling Parameters
89    // =============================================================================
90    /// Top-k sampling parameter (-1 to disable)
91    #[validate(custom(function = "validate_top_k_value"))]
92    pub top_k: Option<i32>,
93
94    /// Min-p nucleus sampling parameter
95    #[validate(range(min = 0.0, max = 1.0))]
96    pub min_p: Option<f32>,
97
98    /// Minimum number of tokens to generate
99    #[validate(range(min = 1))]
100    pub min_tokens: Option<u32>,
101
102    /// Repetition penalty for reducing repetitive text
103    #[validate(range(min = 0.0, max = 2.0))]
104    pub repetition_penalty: Option<f32>,
105
106    /// Regex constraint for output generation
107    pub regex: Option<String>,
108
109    /// EBNF grammar constraint for structured output
110    pub ebnf: Option<String>,
111
112    /// JSON schema constraint for structured output
113    pub json_schema: Option<String>,
114
115    /// Specific token IDs to use as stop conditions
116    pub stop_token_ids: Option<Vec<u32>>,
117
118    /// Skip trimming stop tokens from output
119    #[serde(default)]
120    pub no_stop_trim: bool,
121
122    /// Ignore end-of-sequence tokens during generation
123    #[serde(default)]
124    pub ignore_eos: bool,
125
126    /// Skip special tokens during detokenization
127    #[serde(default = "default_true")]
128    pub skip_special_tokens: bool,
129
130    /// Path to LoRA adapter(s) for model customization
131    pub lora_path: Option<String>,
132
133    /// Session parameters for continual prompting
134    pub session_params: Option<HashMap<String, Value>>,
135
136    /// Return model hidden states
137    #[serde(default)]
138    pub return_hidden_states: bool,
139
140    /// Sampling seed for deterministic outputs
141    pub sampling_seed: Option<u64>,
142
143    /// Additional fields including bootstrap info for PD routing
144    #[serde(flatten)]
145    pub other: Map<String, Value>,
146}
147
148impl Normalizable for CompletionRequest {}
149
150fn validate_completion_prompt(prompt: &StringOrArray) -> Result<(), validator::ValidationError> {
151    match prompt {
152        StringOrArray::String(_) => {}
153        StringOrArray::Array(arr) => {
154            if arr.is_empty() {
155                let mut error = validator::ValidationError::new("prompt_empty");
156                error.message = Some("prompt array cannot be empty".into());
157                return Err(error);
158            }
159        }
160    }
161
162    Ok(())
163}
164
165fn validate_completion_cross_parameters(
166    req: &CompletionRequest,
167) -> Result<(), validator::ValidationError> {
168    if req.stream_options.is_some() && !req.stream {
169        let mut error = validator::ValidationError::new("stream_options_requires_stream");
170        error.message =
171            Some("The 'stream_options' parameter is only allowed when 'stream' is enabled".into());
172        return Err(error);
173    }
174
175    if let (Some(min), Some(max)) = (req.min_tokens, req.max_tokens) {
176        if min > max {
177            let mut error = validator::ValidationError::new("min_tokens_exceeds_max");
178            error.message = Some("min_tokens cannot exceed max_tokens".into());
179            return Err(error);
180        }
181    }
182
183    let constraint_count =
184        req.regex.is_some() as u8 + req.ebnf.is_some() as u8 + req.json_schema.is_some() as u8;
185    if constraint_count > 1 {
186        let mut error = validator::ValidationError::new("multiple_constraints");
187        error.message = Some(
188            "only one structured output constraint (regex, ebnf, or json_schema) can be active at a time"
189                .into(),
190        );
191        return Err(error);
192    }
193
194    if let (Some(best_of), Some(n)) = (req.best_of, req.n) {
195        if best_of <= n {
196            let mut error = validator::ValidationError::new("best_of_less_than_n");
197            error.message = Some("best_of must be greater than n".into());
198            return Err(error);
199        }
200    }
201
202    if req.stream && req.best_of.is_some() {
203        let mut error = validator::ValidationError::new("best_of_not_supported_with_stream");
204        error.message = Some("best_of is not supported when stream is enabled".into());
205        return Err(error);
206    }
207
208    Ok(())
209}
210
211impl GenerationRequest for CompletionRequest {
212    fn is_stream(&self) -> bool {
213        self.stream
214    }
215
216    fn get_model(&self) -> Option<&str> {
217        Some(&self.model)
218    }
219
220    fn extract_text_for_routing(&self) -> String {
221        match &self.prompt {
222            StringOrArray::String(s) => s.clone(),
223            StringOrArray::Array(v) => v.join(" "),
224        }
225    }
226}
227
228// ============================================================================
229// Response Types
230// ============================================================================
231
232#[serde_with::skip_serializing_none]
233#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
234pub struct CompletionResponse {
235    pub id: String,
236    pub object: String, // "text_completion"
237    pub created: u64,
238    pub model: String,
239    pub choices: Vec<CompletionChoice>,
240    pub usage: Option<Usage>,
241    pub system_fingerprint: Option<String>,
242}
243
244#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
245pub struct CompletionChoice {
246    pub text: String,
247    pub index: u32,
248    #[serde(skip_serializing_if = "Option::is_none")]
249    pub logprobs: Option<LogProbs>,
250    pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
251    /// Information about which stop condition was matched
252    #[serde(skip_serializing_if = "Option::is_none")]
253    pub matched_stop: Option<Value>, // Can be string or integer
254}
255
256#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
257pub struct CompletionStreamResponse {
258    pub id: String,
259    pub object: String, // "text_completion"
260    pub created: u64,
261    pub choices: Vec<CompletionStreamChoice>,
262    pub model: String,
263    #[serde(skip_serializing_if = "Option::is_none")]
264    pub system_fingerprint: Option<String>,
265    #[serde(skip_serializing_if = "Option::is_none")]
266    pub usage: Option<Usage>,
267}
268
269#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
270pub struct CompletionStreamChoice {
271    pub text: String,
272    pub index: u32,
273    #[serde(skip_serializing_if = "Option::is_none")]
274    pub logprobs: Option<LogProbs>,
275    pub finish_reason: Option<String>,
276}