openai_protocol/
common.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator;
6
7use super::UNKNOWN_MODEL_ID;
8
9// ============================================================================
10// Default value helpers
11// ============================================================================
12
13/// Default model value when not specified
14pub(crate) fn default_model() -> String {
15    UNKNOWN_MODEL_ID.to_string()
16}
17
18/// Helper function for serde default value (returns true)
19pub fn default_true() -> bool {
20    true
21}
22
23// ============================================================================
24// GenerationRequest Trait
25// ============================================================================
26
27/// Trait for unified access to generation request properties
28/// Implemented by ChatCompletionRequest, CompletionRequest, GenerateRequest,
29/// EmbeddingRequest, RerankRequest, and ResponsesRequest
30pub trait GenerationRequest: Send + Sync {
31    /// Check if the request is for streaming
32    fn is_stream(&self) -> bool;
33
34    /// Get the model name if specified
35    fn get_model(&self) -> Option<&str>;
36
37    /// Extract text content for routing decisions
38    fn extract_text_for_routing(&self) -> String;
39}
40
41// ============================================================================
42// String/Array Utilities
43// ============================================================================
44
45/// A type that can be either a single string or an array of strings
46#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
47#[serde(untagged)]
48pub enum StringOrArray {
49    String(String),
50    Array(Vec<String>),
51}
52
53impl StringOrArray {
54    /// Get the number of items in the StringOrArray
55    pub fn len(&self) -> usize {
56        match self {
57            StringOrArray::String(_) => 1,
58            StringOrArray::Array(arr) => arr.len(),
59        }
60    }
61
62    /// Check if the StringOrArray is empty
63    pub fn is_empty(&self) -> bool {
64        match self {
65            StringOrArray::String(s) => s.is_empty(),
66            StringOrArray::Array(arr) => arr.is_empty(),
67        }
68    }
69
70    /// Convert to a vector of strings (clones the data)
71    pub fn to_vec(&self) -> Vec<String> {
72        match self {
73            StringOrArray::String(s) => vec![s.clone()],
74            StringOrArray::Array(arr) => arr.clone(),
75        }
76    }
77
78    /// Returns an iterator over string references without cloning.
79    /// Use this instead of `to_vec()` when you only need to iterate.
80    pub fn iter(&self) -> StringOrArrayIter<'_> {
81        StringOrArrayIter {
82            inner: self,
83            index: 0,
84        }
85    }
86
87    /// Returns the first string, or None if empty
88    pub fn first(&self) -> Option<&str> {
89        match self {
90            StringOrArray::String(s) => {
91                if s.is_empty() {
92                    None
93                } else {
94                    Some(s)
95                }
96            }
97            StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
98        }
99    }
100}
101
102/// Iterator over StringOrArray that yields string references without cloning
103pub struct StringOrArrayIter<'a> {
104    inner: &'a StringOrArray,
105    index: usize,
106}
107
108impl<'a> Iterator for StringOrArrayIter<'a> {
109    type Item = &'a str;
110
111    fn next(&mut self) -> Option<Self::Item> {
112        match self.inner {
113            StringOrArray::String(s) => {
114                if self.index == 0 {
115                    self.index = 1;
116                    Some(s.as_str())
117                } else {
118                    None
119                }
120            }
121            StringOrArray::Array(arr) => {
122                if self.index < arr.len() {
123                    let item = &arr[self.index];
124                    self.index += 1;
125                    Some(item.as_str())
126                } else {
127                    None
128                }
129            }
130        }
131    }
132
133    fn size_hint(&self) -> (usize, Option<usize>) {
134        let remaining = match self.inner {
135            StringOrArray::String(_) => 1 - self.index,
136            StringOrArray::Array(arr) => arr.len() - self.index,
137        };
138        (remaining, Some(remaining))
139    }
140}
141
142impl<'a> ExactSizeIterator for StringOrArrayIter<'a> {}
143
144/// Validates stop sequences (max 4, non-empty strings)
145/// Used by both ChatCompletionRequest and ResponsesRequest
146pub fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> {
147    match stop {
148        StringOrArray::String(s) => {
149            if s.is_empty() {
150                return Err(validator::ValidationError::new(
151                    "stop sequences cannot be empty",
152                ));
153            }
154        }
155        StringOrArray::Array(arr) => {
156            if arr.len() > 4 {
157                return Err(validator::ValidationError::new(
158                    "maximum 4 stop sequences allowed",
159                ));
160            }
161            for s in arr {
162                if s.is_empty() {
163                    return Err(validator::ValidationError::new(
164                        "stop sequences cannot be empty",
165                    ));
166                }
167            }
168        }
169    }
170    Ok(())
171}
172
173// ============================================================================
174// Content Parts (for multimodal messages)
175// ============================================================================
176
177#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
178#[serde(tag = "type")]
179pub enum ContentPart {
180    #[serde(rename = "text")]
181    Text { text: String },
182    #[serde(rename = "image_url")]
183    ImageUrl { image_url: ImageUrl },
184    #[serde(rename = "video_url")]
185    VideoUrl { video_url: VideoUrl },
186}
187
188#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
189pub struct ImageUrl {
190    pub url: String,
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub detail: Option<String>, // "auto", "low", or "high"
193}
194
195#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
196pub struct VideoUrl {
197    pub url: String,
198}
199
200// ============================================================================
201// Response Format (for structured outputs)
202// ============================================================================
203
204#[derive(Debug, Clone, Deserialize, Serialize)]
205#[serde(tag = "type")]
206pub enum ResponseFormat {
207    #[serde(rename = "text")]
208    Text,
209    #[serde(rename = "json_object")]
210    JsonObject,
211    #[serde(rename = "json_schema")]
212    JsonSchema { json_schema: JsonSchemaFormat },
213}
214
215#[derive(Debug, Clone, Deserialize, Serialize)]
216pub struct JsonSchemaFormat {
217    pub name: String,
218    pub schema: Value,
219    #[serde(skip_serializing_if = "Option::is_none")]
220    pub strict: Option<bool>,
221}
222
223// ============================================================================
224// Streaming
225// ============================================================================
226
227#[derive(Debug, Clone, Deserialize, Serialize)]
228pub struct StreamOptions {
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub include_usage: Option<bool>,
231}
232
233#[derive(Debug, Clone, Deserialize, Serialize)]
234pub struct ToolCallDelta {
235    pub index: u32,
236    #[serde(skip_serializing_if = "Option::is_none")]
237    pub id: Option<String>,
238    #[serde(skip_serializing_if = "Option::is_none")]
239    #[serde(rename = "type")]
240    pub tool_type: Option<String>,
241    #[serde(skip_serializing_if = "Option::is_none")]
242    pub function: Option<FunctionCallDelta>,
243}
244
245#[derive(Debug, Clone, Deserialize, Serialize)]
246pub struct FunctionCallDelta {
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub name: Option<String>,
249    #[serde(skip_serializing_if = "Option::is_none")]
250    pub arguments: Option<String>,
251}
252
253// ============================================================================
254// Tools and Function Calling
255// ============================================================================
256
257/// Tool choice value for simple string options
258#[derive(Debug, Clone, Deserialize, Serialize)]
259#[serde(rename_all = "snake_case")]
260pub enum ToolChoiceValue {
261    Auto,
262    Required,
263    None,
264}
265
266/// Tool choice for both Chat Completion and Responses APIs
267#[derive(Debug, Clone, Deserialize, Serialize)]
268#[serde(untagged)]
269pub enum ToolChoice {
270    Value(ToolChoiceValue),
271    Function {
272        #[serde(rename = "type")]
273        tool_type: String, // "function"
274        function: FunctionChoice,
275    },
276    AllowedTools {
277        #[serde(rename = "type")]
278        tool_type: String, // "allowed_tools"
279        mode: String, // "auto" | "required" TODO: need validation
280        tools: Vec<ToolReference>,
281    },
282}
283
284impl Default for ToolChoice {
285    fn default() -> Self {
286        Self::Value(ToolChoiceValue::Auto)
287    }
288}
289
290impl ToolChoice {
291    /// Serialize tool_choice to string for ResponsesResponse
292    ///
293    /// Returns the JSON-serialized tool_choice or "auto" as default
294    pub fn serialize_to_string(tool_choice: &Option<ToolChoice>) -> String {
295        tool_choice
296            .as_ref()
297            .map(|tc| serde_json::to_string(tc).unwrap_or_else(|_| "auto".to_string()))
298            .unwrap_or_else(|| "auto".to_string())
299    }
300}
301
302/// Function choice specification for ToolChoice::Function
303#[derive(Debug, Clone, Deserialize, Serialize)]
304pub struct FunctionChoice {
305    pub name: String,
306}
307
308/// Tool reference for ToolChoice::AllowedTools
309///
310/// Represents a reference to a specific tool in the allowed_tools array.
311/// Different tool types have different required fields.
312#[derive(Debug, Clone, Deserialize, Serialize)]
313#[serde(tag = "type")]
314#[serde(rename_all = "snake_case")]
315pub enum ToolReference {
316    /// Reference to a function tool
317    #[serde(rename = "function")]
318    Function { name: String },
319
320    /// Reference to an MCP tool
321    #[serde(rename = "mcp")]
322    Mcp {
323        server_label: String,
324        #[serde(skip_serializing_if = "Option::is_none")]
325        name: Option<String>,
326    },
327
328    /// File search hosted tool
329    #[serde(rename = "file_search")]
330    FileSearch,
331
332    /// Web search preview hosted tool
333    #[serde(rename = "web_search_preview")]
334    WebSearchPreview,
335
336    /// Computer use preview hosted tool
337    #[serde(rename = "computer_use_preview")]
338    ComputerUsePreview,
339
340    /// Code interpreter hosted tool
341    #[serde(rename = "code_interpreter")]
342    CodeInterpreter,
343
344    /// Image generation hosted tool
345    #[serde(rename = "image_generation")]
346    ImageGeneration,
347}
348
349impl ToolReference {
350    /// Get a unique identifier for this tool reference
351    pub fn identifier(&self) -> String {
352        match self {
353            ToolReference::Function { name } => format!("function:{}", name),
354            ToolReference::Mcp { server_label, name } => {
355                if let Some(n) = name {
356                    format!("mcp:{}:{}", server_label, n)
357                } else {
358                    format!("mcp:{}", server_label)
359                }
360            }
361            ToolReference::FileSearch => "file_search".to_string(),
362            ToolReference::WebSearchPreview => "web_search_preview".to_string(),
363            ToolReference::ComputerUsePreview => "computer_use_preview".to_string(),
364            ToolReference::CodeInterpreter => "code_interpreter".to_string(),
365            ToolReference::ImageGeneration => "image_generation".to_string(),
366        }
367    }
368
369    /// Get the tool name if this is a function tool
370    pub fn function_name(&self) -> Option<&str> {
371        match self {
372            ToolReference::Function { name } => Some(name.as_str()),
373            _ => None,
374        }
375    }
376}
377
378#[derive(Debug, Clone, Deserialize, Serialize)]
379pub struct Tool {
380    #[serde(rename = "type")]
381    pub tool_type: String, // "function"
382    pub function: Function,
383}
384
385#[derive(Debug, Clone, Deserialize, Serialize)]
386pub struct Function {
387    pub name: String,
388    #[serde(skip_serializing_if = "Option::is_none")]
389    pub description: Option<String>,
390    pub parameters: Value, // JSON Schema
391    /// Whether to enable strict schema adherence (OpenAI structured outputs)
392    #[serde(skip_serializing_if = "Option::is_none")]
393    pub strict: Option<bool>,
394}
395
396#[derive(Debug, Clone, Deserialize, Serialize)]
397pub struct ToolCall {
398    pub id: String,
399    #[serde(rename = "type")]
400    pub tool_type: String, // "function"
401    pub function: FunctionCallResponse,
402}
403
404#[derive(Debug, Clone, Deserialize, Serialize)]
405#[serde(untagged)]
406pub enum FunctionCall {
407    None,
408    Auto,
409    Function { name: String },
410}
411
412#[derive(Debug, Clone, Deserialize, Serialize)]
413pub struct FunctionCallResponse {
414    pub name: String,
415    #[serde(default)]
416    pub arguments: Option<String>, // JSON string
417}
418
419// ============================================================================
420// Usage and Logging
421// ============================================================================
422
423#[derive(Debug, Clone, Deserialize, Serialize)]
424pub struct Usage {
425    pub prompt_tokens: u32,
426    pub completion_tokens: u32,
427    pub total_tokens: u32,
428    #[serde(skip_serializing_if = "Option::is_none")]
429    pub completion_tokens_details: Option<CompletionTokensDetails>,
430}
431
432#[derive(Debug, Clone, Deserialize, Serialize)]
433pub struct CompletionTokensDetails {
434    pub reasoning_tokens: Option<u32>,
435}
436
437/// Usage information (used by rerank and other endpoints)
438#[derive(Debug, Clone, Deserialize, Serialize)]
439pub struct UsageInfo {
440    pub prompt_tokens: u32,
441    pub completion_tokens: u32,
442    pub total_tokens: u32,
443    #[serde(skip_serializing_if = "Option::is_none")]
444    pub reasoning_tokens: Option<u32>,
445    #[serde(skip_serializing_if = "Option::is_none")]
446    pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
447}
448
449#[derive(Debug, Clone, Deserialize, Serialize)]
450pub struct PromptTokenUsageInfo {
451    pub cached_tokens: u32,
452}
453
454#[derive(Debug, Clone, Deserialize, Serialize)]
455pub struct LogProbs {
456    pub tokens: Vec<String>,
457    pub token_logprobs: Vec<Option<f32>>,
458    pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
459    pub text_offset: Vec<u32>,
460}
461
462#[derive(Debug, Clone, Deserialize, Serialize)]
463#[serde(untagged)]
464pub enum ChatLogProbs {
465    Detailed {
466        #[serde(skip_serializing_if = "Option::is_none")]
467        content: Option<Vec<ChatLogProbsContent>>,
468    },
469    Raw(Value),
470}
471
472#[derive(Debug, Clone, Deserialize, Serialize)]
473pub struct ChatLogProbsContent {
474    pub token: String,
475    pub logprob: f32,
476    pub bytes: Option<Vec<u8>>,
477    pub top_logprobs: Vec<TopLogProb>,
478}
479
480#[derive(Debug, Clone, Deserialize, Serialize)]
481pub struct TopLogProb {
482    pub token: String,
483    pub logprob: f32,
484    pub bytes: Option<Vec<u8>>,
485}
486
487// ============================================================================
488// Error Types
489// ============================================================================
490
491#[derive(Debug, Clone, Deserialize, Serialize)]
492pub struct ErrorResponse {
493    pub error: ErrorDetail,
494}
495
496#[derive(Debug, Clone, Deserialize, Serialize)]
497pub struct ErrorDetail {
498    pub message: String,
499    #[serde(rename = "type")]
500    pub error_type: String,
501    #[serde(skip_serializing_if = "Option::is_none")]
502    pub param: Option<String>,
503    #[serde(skip_serializing_if = "Option::is_none")]
504    pub code: Option<String>,
505}
506
507// ============================================================================
508// Input Types
509// ============================================================================
510
511#[derive(Debug, Clone, Deserialize, Serialize)]
512#[serde(untagged)]
513pub enum InputIds {
514    Single(Vec<i32>),
515    Batch(Vec<Vec<i32>>),
516}
517
518/// LoRA adapter path - can be single path or batch of paths (SGLang extension)
519#[derive(Debug, Clone, Deserialize, Serialize)]
520#[serde(untagged)]
521pub enum LoRAPath {
522    Single(Option<String>),
523    Batch(Vec<Option<String>>),
524}