Skip to main content

openai_protocol/
completion.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::{Map, Value};
5
6use super::common::*;
7
8// ============================================================================
9// Completions API (v1/completions) - DEPRECATED but still supported
10// ============================================================================
11
12#[serde_with::skip_serializing_none]
13#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
14pub struct CompletionRequest {
15    /// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang)
16    pub model: String,
17
18    /// The prompt(s) to generate completions for
19    pub prompt: StringOrArray,
20
21    /// The suffix that comes after a completion of inserted text
22    pub suffix: Option<String>,
23
24    /// The maximum number of tokens to generate
25    pub max_tokens: Option<u32>,
26
27    /// What sampling temperature to use, between 0 and 2
28    pub temperature: Option<f32>,
29
30    /// An alternative to sampling with temperature (nucleus sampling)
31    pub top_p: Option<f32>,
32
33    /// How many completions to generate for each prompt
34    pub n: Option<u32>,
35
36    /// Whether to stream back partial progress
37    #[serde(default)]
38    pub stream: bool,
39
40    /// Options for streaming response
41    pub stream_options: Option<StreamOptions>,
42
43    /// Include the log probabilities on the logprobs most likely tokens
44    pub logprobs: Option<u32>,
45
46    /// Echo back the prompt in addition to the completion
47    #[serde(default)]
48    pub echo: bool,
49
50    /// Up to 4 sequences where the API will stop generating further tokens
51    pub stop: Option<StringOrArray>,
52
53    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
54    pub presence_penalty: Option<f32>,
55
56    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
57    pub frequency_penalty: Option<f32>,
58
59    /// Generates best_of completions server-side and returns the "best"
60    pub best_of: Option<u32>,
61
62    /// Modify the likelihood of specified tokens appearing in the completion
63    pub logit_bias: Option<HashMap<String, f32>>,
64
65    /// A unique identifier representing your end-user
66    pub user: Option<String>,
67
68    /// If specified, our system will make a best effort to sample deterministically
69    pub seed: Option<i64>,
70
71    // -------- Engine Specific Sampling Parameters --------
72    /// Top-k sampling parameter (-1 to disable)
73    pub top_k: Option<i32>,
74
75    /// Min-p nucleus sampling parameter
76    pub min_p: Option<f32>,
77
78    /// Minimum number of tokens to generate
79    pub min_tokens: Option<u32>,
80
81    /// Repetition penalty for reducing repetitive text
82    pub repetition_penalty: Option<f32>,
83
84    /// Regex constraint for output generation
85    pub regex: Option<String>,
86
87    /// EBNF grammar constraint for structured output
88    pub ebnf: Option<String>,
89
90    /// JSON schema constraint for structured output
91    pub json_schema: Option<String>,
92
93    /// Specific token IDs to use as stop conditions
94    pub stop_token_ids: Option<Vec<u32>>,
95
96    /// Skip trimming stop tokens from output
97    #[serde(default)]
98    pub no_stop_trim: bool,
99
100    /// Ignore end-of-sequence tokens during generation
101    #[serde(default)]
102    pub ignore_eos: bool,
103
104    /// Skip special tokens during detokenization
105    #[serde(default = "default_true")]
106    pub skip_special_tokens: bool,
107
108    /// Path to LoRA adapter(s) for model customization
109    pub lora_path: Option<String>,
110
111    /// Session parameters for continual prompting
112    pub session_params: Option<HashMap<String, Value>>,
113
114    /// Return model hidden states
115    #[serde(default)]
116    pub return_hidden_states: bool,
117
118    /// Sampling seed for deterministic outputs
119    pub sampling_seed: Option<u64>,
120
121    /// Additional fields including bootstrap info for PD routing
122    #[serde(flatten)]
123    pub other: Map<String, Value>,
124}
125
126impl GenerationRequest for CompletionRequest {
127    fn is_stream(&self) -> bool {
128        self.stream
129    }
130
131    fn get_model(&self) -> Option<&str> {
132        Some(&self.model)
133    }
134
135    fn extract_text_for_routing(&self) -> String {
136        match &self.prompt {
137            StringOrArray::String(s) => s.clone(),
138            StringOrArray::Array(v) => v.join(" "),
139        }
140    }
141}
142
143// ============================================================================
144// Response Types
145// ============================================================================
146
147#[serde_with::skip_serializing_none]
148#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
149pub struct CompletionResponse {
150    pub id: String,
151    pub object: String, // "text_completion"
152    pub created: u64,
153    pub model: String,
154    pub choices: Vec<CompletionChoice>,
155    pub usage: Option<Usage>,
156    pub system_fingerprint: Option<String>,
157}
158
159#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
160pub struct CompletionChoice {
161    pub text: String,
162    pub index: u32,
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub logprobs: Option<LogProbs>,
165    pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
166    /// Information about which stop condition was matched
167    #[serde(skip_serializing_if = "Option::is_none")]
168    pub matched_stop: Option<Value>, // Can be string or integer
169}
170
171#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
172pub struct CompletionStreamResponse {
173    pub id: String,
174    pub object: String, // "text_completion"
175    pub created: u64,
176    pub choices: Vec<CompletionStreamChoice>,
177    pub model: String,
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub system_fingerprint: Option<String>,
180}
181
182#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
183pub struct CompletionStreamChoice {
184    pub text: String,
185    pub index: u32,
186    #[serde(skip_serializing_if = "Option::is_none")]
187    pub logprobs: Option<LogProbs>,
188    pub finish_reason: Option<String>,
189}