openai_protocol/
generate.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator::Validate;
6
7use super::{
8    common::{default_true, GenerationRequest, InputIds},
9    sampling_params::SamplingParams,
10};
11use crate::validated::Normalizable;
12
13// ============================================================================
14// SGLang Generate API (native format)
15// ============================================================================
16
17#[derive(Clone, Debug, Serialize, Deserialize, Validate)]
18#[validate(schema(function = "validate_generate_request"))]
19pub struct GenerateRequest {
20    /// Text input - SGLang native format
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub text: Option<String>,
23
24    pub model: Option<String>,
25
26    /// Input IDs for tokenized input
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub input_ids: Option<InputIds>,
29
30    /// Input embeddings for direct embedding input
31    /// Can be a 2D array (single request) or 3D array (batch of requests)
32    /// Placeholder for future use
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub input_embeds: Option<Value>,
35
36    /// Image input data
37    /// Can be an image instance, file name, URL, or base64 encoded string
38    /// Supports single images, lists of images, or nested lists for batch processing
39    /// Placeholder for future use
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub image_data: Option<Value>,
42
43    /// Video input data
44    /// Can be a file name, URL, or base64 encoded string
45    /// Supports single videos, lists of videos, or nested lists for batch processing
46    /// Placeholder for future use
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub video_data: Option<Value>,
49
50    /// Audio input data
51    /// Can be a file name, URL, or base64 encoded string
52    /// Supports single audio files, lists of audio, or nested lists for batch processing
53    /// Placeholder for future use
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub audio_data: Option<Value>,
56
57    /// Sampling parameters (sglang style)
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub sampling_params: Option<SamplingParams>,
60
61    /// Whether to return logprobs
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub return_logprob: Option<bool>,
64
65    /// If return logprobs, the start location in the prompt for returning logprobs.
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub logprob_start_len: Option<i32>,
68
69    /// If return logprobs, the number of top logprobs to return at each position.
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub top_logprobs_num: Option<i32>,
72
73    /// If return logprobs, the token ids to return logprob for.
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub token_ids_logprob: Option<Vec<u32>>,
76
77    /// Whether to detokenize tokens in text in the returned logprobs.
78    #[serde(default)]
79    pub return_text_in_logprobs: bool,
80
81    /// Whether to stream the response
82    #[serde(default)]
83    pub stream: bool,
84
85    /// Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
86    #[serde(default = "default_true")]
87    pub log_metrics: bool,
88
89    /// Return model hidden states
90    #[serde(default)]
91    pub return_hidden_states: bool,
92
93    /// The modalities of the image data [image, multi-images, video]
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub modalities: Option<Vec<String>>,
96
97    /// Session parameters for continual prompting
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub session_params: Option<HashMap<String, Value>>,
100
101    /// Path to LoRA adapter(s) for model customization
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub lora_path: Option<String>,
104
105    /// LoRA adapter ID (if pre-loaded)
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub lora_id: Option<String>,
108
109    /// Custom logit processor for advanced sampling control. Must be a serialized instance
110    /// of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
111    /// Use the processor's `to_str()` method to generate the serialized string.
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub custom_logit_processor: Option<String>,
114
115    /// For disaggregated inference
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub bootstrap_host: Option<String>,
118
119    /// For disaggregated inference
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub bootstrap_port: Option<i32>,
122
123    /// For disaggregated inference
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub bootstrap_room: Option<i32>,
126
127    /// For disaggregated inference
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub bootstrap_pair_key: Option<String>,
130
131    /// Data parallel rank routing
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub data_parallel_rank: Option<i32>,
134
135    /// Background response
136    #[serde(default)]
137    pub background: bool,
138
139    /// Conversation ID for tracking
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub conversation_id: Option<String>,
142
143    /// Priority for the request
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub priority: Option<i32>,
146
147    /// Extra key for classifying the request (e.g. cache_salt)
148    #[serde(skip_serializing_if = "Option::is_none")]
149    pub extra_key: Option<String>,
150
151    /// Whether to disallow logging for this request (e.g. due to ZDR)
152    #[serde(default)]
153    pub no_logs: bool,
154
155    /// Custom metric labels
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub custom_labels: Option<HashMap<String, String>>,
158
159    /// Whether to return bytes for image generation
160    #[serde(default)]
161    pub return_bytes: bool,
162
163    /// Whether to return entropy
164    #[serde(default)]
165    pub return_entropy: bool,
166
167    /// Request ID for tracking (inherited from BaseReq in Python)
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub rid: Option<String>,
170}
171
172impl Normalizable for GenerateRequest {
173    // Use default no-op implementation - no normalization needed for GenerateRequest
174}
175
176/// Validation function for GenerateRequest - ensure exactly one input type is provided
177fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> {
178    // Exactly one of text or input_ids must be provided
179    // Note: input_embeds not yet supported in Rust implementation
180    let has_text = req.text.is_some();
181    let has_input_ids = req.input_ids.is_some();
182
183    let count = [has_text, has_input_ids].iter().filter(|&&x| x).count();
184
185    if count == 0 {
186        return Err(validator::ValidationError::new(
187            "Either text or input_ids should be provided.",
188        ));
189    }
190
191    if count > 1 {
192        return Err(validator::ValidationError::new(
193            "Either text or input_ids should be provided.",
194        ));
195    }
196
197    Ok(())
198}
199
200impl GenerationRequest for GenerateRequest {
201    fn is_stream(&self) -> bool {
202        self.stream
203    }
204
205    fn get_model(&self) -> Option<&str> {
206        // Generate requests have an optional model field
207        if let Some(s) = &self.model {
208            Some(s.as_str())
209        } else {
210            None
211        }
212    }
213
214    fn extract_text_for_routing(&self) -> String {
215        // Check fields in priority order: text, input_ids
216        if let Some(ref text) = self.text {
217            return text.clone();
218        }
219
220        if let Some(ref input_ids) = self.input_ids {
221            return match input_ids {
222                InputIds::Single(ids) => ids
223                    .iter()
224                    .map(|&id| id.to_string())
225                    .collect::<Vec<String>>()
226                    .join(" "),
227                InputIds::Batch(batches) => batches
228                    .iter()
229                    .flat_map(|batch| batch.iter().map(|&id| id.to_string()))
230                    .collect::<Vec<String>>()
231                    .join(" "),
232            };
233        }
234
235        // No text input found
236        String::new()
237    }
238}
239
240// ============================================================================
241// SGLang Generate Response Types
242// ============================================================================
243
244/// SGLang generate response (single completion or array for n>1)
245///
246/// Format for n=1:
247/// ```json
248/// {
249///   "text": "...",
250///   "output_ids": [...],
251///   "meta_info": { ... }
252/// }
253/// ```
254///
255/// Format for n>1:
256/// ```json
257/// [
258///   {"text": "...", "output_ids": [...], "meta_info": {...}},
259///   {"text": "...", "output_ids": [...], "meta_info": {...}}
260/// ]
261/// ```
262#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct GenerateResponse {
264    pub text: String,
265    pub output_ids: Vec<u32>,
266    pub meta_info: GenerateMetaInfo,
267}
268
269/// Metadata for a single generate completion
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct GenerateMetaInfo {
272    pub id: String,
273    pub finish_reason: GenerateFinishReason,
274    pub prompt_tokens: u32,
275    pub weight_version: String,
276    #[serde(skip_serializing_if = "Option::is_none")]
277    pub input_token_logprobs: Option<Vec<Vec<Option<f64>>>>,
278    #[serde(skip_serializing_if = "Option::is_none")]
279    pub output_token_logprobs: Option<Vec<Vec<Option<f64>>>>,
280    pub completion_tokens: u32,
281    pub cached_tokens: u32,
282    pub e2e_latency: f64,
283    #[serde(skip_serializing_if = "Option::is_none")]
284    pub matched_stop: Option<Value>,
285}
286
287/// Finish reason for generate endpoint
288#[derive(Debug, Clone, Serialize, Deserialize)]
289#[serde(tag = "type", rename_all = "lowercase")]
290pub enum GenerateFinishReason {
291    Length {
292        length: u32,
293    },
294    Stop,
295    #[serde(untagged)]
296    Other(Value),
297}