Skip to main content

openai_protocol/
generate.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator::Validate;
6
7use super::{
8    common::{default_true, deserialize_null_as_false, GenerationRequest, InputIds},
9    sampling_params::SamplingParams,
10};
11use crate::validated::Normalizable;
12
13// ============================================================================
14// SGLang Generate API (native format)
15// ============================================================================
16
17#[derive(Clone, Debug, Serialize, Deserialize, Validate, schemars::JsonSchema)]
18#[validate(schema(function = "validate_generate_request"))]
19pub struct GenerateRequest {
20    /// Text input - SGLang native format
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub text: Option<String>,
23
24    #[serde(default = "super::common::default_unknown_model")]
25    pub model: String,
26
27    /// Input IDs for tokenized input
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub input_ids: Option<InputIds>,
30
31    /// Input embeddings for direct embedding input
32    /// Can be a 2D array (single request) or 3D array (batch of requests)
33    /// Placeholder for future use
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub input_embeds: Option<Value>,
36
37    /// Image input data
38    /// Can be an image instance, file name, URL, or base64 encoded string
39    /// Supports single images, lists of images, or nested lists for batch processing
40    /// Placeholder for future use
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub image_data: Option<Value>,
43
44    /// Video input data
45    /// Can be a file name, URL, or base64 encoded string
46    /// Supports single videos, lists of videos, or nested lists for batch processing
47    /// Placeholder for future use
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub video_data: Option<Value>,
50
51    /// Audio input data
52    /// Can be a file name, URL, or base64 encoded string
53    /// Supports single audio files, lists of audio, or nested lists for batch processing
54    /// Placeholder for future use
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub audio_data: Option<Value>,
57
58    /// Sampling parameters (sglang style)
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub sampling_params: Option<SamplingParams>,
61
62    /// Whether to return logprobs
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub return_logprob: Option<bool>,
65
66    /// If return logprobs, the start location in the prompt for returning logprobs.
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub logprob_start_len: Option<i32>,
69
70    /// If return logprobs, the number of top logprobs to return at each position.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub top_logprobs_num: Option<i32>,
73
74    /// If return logprobs, the token ids to return logprob for.
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub token_ids_logprob: Option<Vec<u32>>,
77
78    /// Whether to detokenize tokens in text in the returned logprobs.
79    #[serde(default)]
80    pub return_text_in_logprobs: bool,
81
82    /// Whether to stream the response
83    #[serde(default, deserialize_with = "deserialize_null_as_false")]
84    pub stream: bool,
85
86    /// Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
87    #[serde(default = "default_true")]
88    pub log_metrics: bool,
89
90    /// Return model hidden states
91    #[serde(default)]
92    pub return_hidden_states: bool,
93
94    /// The modalities of the image data [image, multi-images, video]
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub modalities: Option<Vec<String>>,
97
98    /// Session parameters for continual prompting
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub session_params: Option<HashMap<String, Value>>,
101
102    /// Path to LoRA adapter(s) for model customization
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub lora_path: Option<String>,
105
106    /// LoRA adapter ID (if pre-loaded)
107    #[serde(skip_serializing_if = "Option::is_none")]
108    pub lora_id: Option<String>,
109
110    /// Custom logit processor for advanced sampling control. Must be a serialized instance
111    /// of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
112    /// Use the processor's `to_str()` method to generate the serialized string.
113    #[serde(skip_serializing_if = "Option::is_none")]
114    pub custom_logit_processor: Option<String>,
115
116    /// For disaggregated inference
117    #[serde(skip_serializing_if = "Option::is_none")]
118    pub bootstrap_host: Option<String>,
119
120    /// For disaggregated inference
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub bootstrap_port: Option<i32>,
123
124    /// For disaggregated inference
125    #[serde(skip_serializing_if = "Option::is_none")]
126    pub bootstrap_room: Option<i32>,
127
128    /// For disaggregated inference
129    #[serde(skip_serializing_if = "Option::is_none")]
130    pub bootstrap_pair_key: Option<String>,
131
132    /// Data parallel rank routing
133    #[serde(skip_serializing_if = "Option::is_none")]
134    pub data_parallel_rank: Option<i32>,
135
136    /// Background response
137    #[serde(default)]
138    pub background: bool,
139
140    /// Conversation ID for tracking
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub conversation_id: Option<String>,
143
144    /// Priority for the request
145    #[serde(skip_serializing_if = "Option::is_none")]
146    pub priority: Option<i32>,
147
148    /// Extra key for classifying the request (e.g. cache_salt)
149    #[serde(skip_serializing_if = "Option::is_none")]
150    pub extra_key: Option<String>,
151
152    /// Whether to disallow logging for this request (e.g. due to ZDR)
153    #[serde(default)]
154    pub no_logs: bool,
155
156    /// Custom metric labels
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub custom_labels: Option<HashMap<String, String>>,
159
160    /// Whether to return bytes for image generation
161    #[serde(default)]
162    pub return_bytes: bool,
163
164    /// Whether to return entropy
165    #[serde(default)]
166    pub return_entropy: bool,
167
168    /// Request ID for tracking (inherited from BaseReq in Python)
169    #[serde(skip_serializing_if = "Option::is_none")]
170    pub rid: Option<String>,
171}
172
173impl Normalizable for GenerateRequest {
174    // Use default no-op implementation - no normalization needed for GenerateRequest
175}
176
177/// Validation function for GenerateRequest - ensure exactly one input type is provided
178fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> {
179    // Exactly one of text or input_ids must be provided
180    // Note: input_embeds not yet supported in Rust implementation
181    let has_text = req.text.is_some();
182    let has_input_ids = req.input_ids.is_some();
183
184    let count = [has_text, has_input_ids].iter().filter(|&&x| x).count();
185
186    if count == 0 {
187        return Err(validator::ValidationError::new(
188            "Either text or input_ids should be provided.",
189        ));
190    }
191
192    if count > 1 {
193        return Err(validator::ValidationError::new(
194            "Either text or input_ids should be provided.",
195        ));
196    }
197
198    Ok(())
199}
200
201impl GenerationRequest for GenerateRequest {
202    fn is_stream(&self) -> bool {
203        self.stream
204    }
205
206    fn get_model(&self) -> Option<&str> {
207        Some(self.model.as_str())
208    }
209
210    fn extract_text_for_routing(&self) -> String {
211        // Check fields in priority order: text, input_ids
212        if let Some(ref text) = self.text {
213            return text.clone();
214        }
215
216        if let Some(ref input_ids) = self.input_ids {
217            return match input_ids {
218                InputIds::Single(ids) => ids
219                    .iter()
220                    .map(|&id| id.to_string())
221                    .collect::<Vec<String>>()
222                    .join(" "),
223                InputIds::Batch(batches) => batches
224                    .iter()
225                    .flat_map(|batch| batch.iter().map(|&id| id.to_string()))
226                    .collect::<Vec<String>>()
227                    .join(" "),
228            };
229        }
230
231        // No text input found
232        String::new()
233    }
234}
235
236// ============================================================================
237// SGLang Generate Response Types
238// ============================================================================
239
240/// SGLang generate response (single completion or array for n>1)
241///
242/// Format for n=1:
243/// ```json
244/// {
245///   "text": "...",
246///   "output_ids": [...],
247///   "meta_info": { ... }
248/// }
249/// ```
250///
251/// Format for n>1:
252/// ```json
253/// [
254///   {"text": "...", "output_ids": [...], "meta_info": {...}},
255///   {"text": "...", "output_ids": [...], "meta_info": {...}}
256/// ]
257/// ```
258#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
259pub struct GenerateResponse {
260    pub text: String,
261    pub output_ids: Vec<u32>,
262    pub meta_info: GenerateMetaInfo,
263}
264
265/// Metadata for a single generate completion
266#[serde_with::skip_serializing_none]
267#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
268pub struct GenerateMetaInfo {
269    pub id: String,
270    pub finish_reason: GenerateFinishReason,
271    pub prompt_tokens: u32,
272    pub weight_version: String,
273    pub input_token_logprobs: Option<Vec<Vec<Option<f64>>>>,
274    pub output_token_logprobs: Option<Vec<Vec<Option<f64>>>>,
275    pub completion_tokens: u32,
276    pub cached_tokens: u32,
277    pub e2e_latency: f64,
278    pub matched_stop: Option<Value>,
279}
280
281/// Finish reason for generate endpoint
282#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
283#[serde(untagged)]
284pub enum GenerateFinishReason {
285    Length {
286        #[serde(rename = "type")]
287        finish_type: GenerateFinishType,
288        length: u32,
289    },
290    Stop {
291        #[serde(rename = "type")]
292        finish_type: GenerateFinishType,
293    },
294    Other(Value),
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
298#[serde(rename_all = "lowercase")]
299pub enum GenerateFinishType {
300    Length,
301    Stop,
302}