Skip to main content

openai_protocol/
common.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator;
6
7use super::UNKNOWN_MODEL_ID;
8
9// ============================================================================
10// Default value helpers
11// ============================================================================
12
13/// Default model value when not specified
14pub(crate) fn default_model() -> String {
15    UNKNOWN_MODEL_ID.to_string()
16}
17
18/// Helper function for serde default value (returns true)
19pub fn default_true() -> bool {
20    true
21}
22
23// ============================================================================
24// GenerationRequest Trait
25// ============================================================================
26
27/// Trait for unified access to generation request properties
28/// Implemented by ChatCompletionRequest, CompletionRequest, GenerateRequest,
29/// EmbeddingRequest, RerankRequest, and ResponsesRequest
30pub trait GenerationRequest: Send + Sync {
31    /// Check if the request is for streaming
32    fn is_stream(&self) -> bool;
33
34    /// Get the model name if specified
35    fn get_model(&self) -> Option<&str>;
36
37    /// Extract text content for routing decisions
38    fn extract_text_for_routing(&self) -> String;
39}
40
41// ============================================================================
42// String/Array Utilities
43// ============================================================================
44
45/// A type that can be either a single string or an array of strings
46#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
47#[serde(untagged)]
48pub enum StringOrArray {
49    String(String),
50    Array(Vec<String>),
51}
52
53impl StringOrArray {
54    /// Get the number of items in the StringOrArray
55    pub fn len(&self) -> usize {
56        match self {
57            StringOrArray::String(_) => 1,
58            StringOrArray::Array(arr) => arr.len(),
59        }
60    }
61
62    /// Check if the StringOrArray is empty
63    pub fn is_empty(&self) -> bool {
64        match self {
65            StringOrArray::String(s) => s.is_empty(),
66            StringOrArray::Array(arr) => arr.is_empty(),
67        }
68    }
69
70    /// Convert to a vector of strings (clones the data)
71    pub fn to_vec(&self) -> Vec<String> {
72        match self {
73            StringOrArray::String(s) => vec![s.clone()],
74            StringOrArray::Array(arr) => arr.clone(),
75        }
76    }
77
78    /// Returns an iterator over string references without cloning.
79    /// Use this instead of `to_vec()` when you only need to iterate.
80    pub fn iter(&self) -> StringOrArrayIter<'_> {
81        StringOrArrayIter {
82            inner: self,
83            index: 0,
84        }
85    }
86
87    /// Returns the first string, or None if empty
88    pub fn first(&self) -> Option<&str> {
89        match self {
90            StringOrArray::String(s) => {
91                if s.is_empty() {
92                    None
93                } else {
94                    Some(s)
95                }
96            }
97            StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
98        }
99    }
100}
101
102/// Iterator over StringOrArray that yields string references without cloning
103pub struct StringOrArrayIter<'a> {
104    inner: &'a StringOrArray,
105    index: usize,
106}
107
108impl<'a> Iterator for StringOrArrayIter<'a> {
109    type Item = &'a str;
110
111    fn next(&mut self) -> Option<Self::Item> {
112        match self.inner {
113            StringOrArray::String(s) => {
114                if self.index == 0 {
115                    self.index = 1;
116                    Some(s.as_str())
117                } else {
118                    None
119                }
120            }
121            StringOrArray::Array(arr) => {
122                if self.index < arr.len() {
123                    let item = &arr[self.index];
124                    self.index += 1;
125                    Some(item.as_str())
126                } else {
127                    None
128                }
129            }
130        }
131    }
132
133    fn size_hint(&self) -> (usize, Option<usize>) {
134        let remaining = match self.inner {
135            StringOrArray::String(_) => 1 - self.index,
136            StringOrArray::Array(arr) => arr.len() - self.index,
137        };
138        (remaining, Some(remaining))
139    }
140}
141
142impl<'a> ExactSizeIterator for StringOrArrayIter<'a> {}
143
144/// Validates stop sequences (max 4, non-empty strings)
145/// Used by both ChatCompletionRequest and ResponsesRequest
146pub fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> {
147    match stop {
148        StringOrArray::String(s) => {
149            if s.is_empty() {
150                return Err(validator::ValidationError::new(
151                    "stop sequences cannot be empty",
152                ));
153            }
154        }
155        StringOrArray::Array(arr) => {
156            if arr.len() > 4 {
157                return Err(validator::ValidationError::new(
158                    "maximum 4 stop sequences allowed",
159                ));
160            }
161            for s in arr {
162                if s.is_empty() {
163                    return Err(validator::ValidationError::new(
164                        "stop sequences cannot be empty",
165                    ));
166                }
167            }
168        }
169    }
170    Ok(())
171}
172
173// ============================================================================
174// Content Parts (for multimodal messages)
175// ============================================================================
176
177#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
178#[serde(tag = "type")]
179pub enum ContentPart {
180    #[serde(rename = "text")]
181    Text { text: String },
182    #[serde(rename = "image_url")]
183    ImageUrl { image_url: ImageUrl },
184    #[serde(rename = "video_url")]
185    VideoUrl { video_url: VideoUrl },
186}
187
188#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
189pub struct ImageUrl {
190    pub url: String,
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub detail: Option<String>, // "auto", "low", or "high"
193}
194
195#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
196pub struct VideoUrl {
197    pub url: String,
198}
199
200// ============================================================================
201// Response Format (for structured outputs)
202// ============================================================================
203
204#[derive(Debug, Clone, Deserialize, Serialize)]
205#[serde(tag = "type")]
206pub enum ResponseFormat {
207    #[serde(rename = "text")]
208    Text,
209    #[serde(rename = "json_object")]
210    JsonObject,
211    #[serde(rename = "json_schema")]
212    JsonSchema { json_schema: JsonSchemaFormat },
213}
214
215#[derive(Debug, Clone, Deserialize, Serialize)]
216pub struct JsonSchemaFormat {
217    pub name: String,
218    pub schema: Value,
219    #[serde(skip_serializing_if = "Option::is_none")]
220    pub strict: Option<bool>,
221}
222
223// ============================================================================
224// Streaming
225// ============================================================================
226
227#[derive(Debug, Clone, Deserialize, Serialize)]
228pub struct StreamOptions {
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub include_usage: Option<bool>,
231}
232
233#[serde_with::skip_serializing_none]
234#[derive(Debug, Clone, Deserialize, Serialize)]
235pub struct ToolCallDelta {
236    pub index: u32,
237    pub id: Option<String>,
238    #[serde(rename = "type")]
239    pub tool_type: Option<String>,
240    pub function: Option<FunctionCallDelta>,
241}
242
243#[serde_with::skip_serializing_none]
244#[derive(Debug, Clone, Deserialize, Serialize)]
245pub struct FunctionCallDelta {
246    pub name: Option<String>,
247    pub arguments: Option<String>,
248}
249
250// ============================================================================
251// Tools and Function Calling
252// ============================================================================
253
254/// Tool choice value for simple string options
255#[derive(Debug, Clone, Deserialize, Serialize)]
256#[serde(rename_all = "snake_case")]
257pub enum ToolChoiceValue {
258    Auto,
259    Required,
260    None,
261}
262
263/// Tool choice for both Chat Completion and Responses APIs
264#[derive(Debug, Clone, Deserialize, Serialize)]
265#[serde(untagged)]
266pub enum ToolChoice {
267    Value(ToolChoiceValue),
268    Function {
269        #[serde(rename = "type")]
270        tool_type: String, // "function"
271        function: FunctionChoice,
272    },
273    AllowedTools {
274        #[serde(rename = "type")]
275        tool_type: String, // "allowed_tools"
276        mode: String, // "auto" | "required" TODO: need validation
277        tools: Vec<ToolReference>,
278    },
279}
280
281impl Default for ToolChoice {
282    fn default() -> Self {
283        Self::Value(ToolChoiceValue::Auto)
284    }
285}
286
287impl ToolChoice {
288    /// Serialize tool_choice to string for ResponsesResponse
289    ///
290    /// Returns the JSON-serialized tool_choice or "auto" as default
291    pub fn serialize_to_string(tool_choice: &Option<ToolChoice>) -> String {
292        tool_choice
293            .as_ref()
294            .map(|tc| serde_json::to_string(tc).unwrap_or_else(|_| "auto".to_string()))
295            .unwrap_or_else(|| "auto".to_string())
296    }
297}
298
299/// Function choice specification for ToolChoice::Function
300#[derive(Debug, Clone, Deserialize, Serialize)]
301pub struct FunctionChoice {
302    pub name: String,
303}
304
305/// Tool reference for ToolChoice::AllowedTools
306///
307/// Represents a reference to a specific tool in the allowed_tools array.
308/// Different tool types have different required fields.
309#[derive(Debug, Clone, Deserialize, Serialize)]
310#[serde(tag = "type")]
311#[serde(rename_all = "snake_case")]
312pub enum ToolReference {
313    /// Reference to a function tool
314    #[serde(rename = "function")]
315    Function { name: String },
316
317    /// Reference to an MCP tool
318    #[serde(rename = "mcp")]
319    Mcp {
320        server_label: String,
321        #[serde(skip_serializing_if = "Option::is_none")]
322        name: Option<String>,
323    },
324
325    /// File search hosted tool
326    #[serde(rename = "file_search")]
327    FileSearch,
328
329    /// Web search preview hosted tool
330    #[serde(rename = "web_search_preview")]
331    WebSearchPreview,
332
333    /// Computer use preview hosted tool
334    #[serde(rename = "computer_use_preview")]
335    ComputerUsePreview,
336
337    /// Code interpreter hosted tool
338    #[serde(rename = "code_interpreter")]
339    CodeInterpreter,
340
341    /// Image generation hosted tool
342    #[serde(rename = "image_generation")]
343    ImageGeneration,
344}
345
346impl ToolReference {
347    /// Get a unique identifier for this tool reference
348    pub fn identifier(&self) -> String {
349        match self {
350            ToolReference::Function { name } => format!("function:{}", name),
351            ToolReference::Mcp { server_label, name } => {
352                if let Some(n) = name {
353                    format!("mcp:{}:{}", server_label, n)
354                } else {
355                    format!("mcp:{}", server_label)
356                }
357            }
358            ToolReference::FileSearch => "file_search".to_string(),
359            ToolReference::WebSearchPreview => "web_search_preview".to_string(),
360            ToolReference::ComputerUsePreview => "computer_use_preview".to_string(),
361            ToolReference::CodeInterpreter => "code_interpreter".to_string(),
362            ToolReference::ImageGeneration => "image_generation".to_string(),
363        }
364    }
365
366    /// Get the tool name if this is a function tool
367    pub fn function_name(&self) -> Option<&str> {
368        match self {
369            ToolReference::Function { name } => Some(name.as_str()),
370            _ => None,
371        }
372    }
373}
374
375#[derive(Debug, Clone, Deserialize, Serialize)]
376pub struct Tool {
377    #[serde(rename = "type")]
378    pub tool_type: String, // "function"
379    pub function: Function,
380}
381
382#[serde_with::skip_serializing_none]
383#[derive(Debug, Clone, Deserialize, Serialize)]
384pub struct Function {
385    pub name: String,
386    pub description: Option<String>,
387    pub parameters: Value, // JSON Schema
388    /// Whether to enable strict schema adherence (OpenAI structured outputs)
389    pub strict: Option<bool>,
390}
391
392#[derive(Debug, Clone, Deserialize, Serialize)]
393pub struct ToolCall {
394    pub id: String,
395    #[serde(rename = "type")]
396    pub tool_type: String, // "function"
397    pub function: FunctionCallResponse,
398}
399
400/// Deprecated `function_call` field from the OpenAI API.
401/// Can be `"none"`, `"auto"`, or `{"name": "function_name"}`.
402#[derive(Debug, Clone)]
403pub enum FunctionCall {
404    None,
405    Auto,
406    Function { name: String },
407}
408
409impl Serialize for FunctionCall {
410    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
411        match self {
412            FunctionCall::None => serializer.serialize_str("none"),
413            FunctionCall::Auto => serializer.serialize_str("auto"),
414            FunctionCall::Function { name } => {
415                use serde::ser::SerializeMap;
416                let mut map = serializer.serialize_map(Some(1))?;
417                map.serialize_entry("name", name)?;
418                map.end()
419            }
420        }
421    }
422}
423
424impl<'de> Deserialize<'de> for FunctionCall {
425    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
426        let value = Value::deserialize(deserializer)?;
427        match &value {
428            Value::String(s) => match s.as_str() {
429                "none" => Ok(FunctionCall::None),
430                "auto" => Ok(FunctionCall::Auto),
431                other => Err(serde::de::Error::custom(format!(
432                    "unknown function_call value: \"{}\"",
433                    other
434                ))),
435            },
436            Value::Object(map) => {
437                if let Some(Value::String(name)) = map.get("name") {
438                    Ok(FunctionCall::Function { name: name.clone() })
439                } else {
440                    Err(serde::de::Error::custom(
441                        "function_call object must have a \"name\" string field",
442                    ))
443                }
444            }
445            _ => Err(serde::de::Error::custom(
446                "function_call must be a string or object",
447            )),
448        }
449    }
450}
451
452#[derive(Debug, Clone, Deserialize, Serialize)]
453pub struct FunctionCallResponse {
454    pub name: String,
455    #[serde(default)]
456    pub arguments: Option<String>, // JSON string
457}
458
459// ============================================================================
460// Usage and Logging
461// ============================================================================
462
463#[derive(Debug, Clone, Deserialize, Serialize)]
464pub struct Usage {
465    pub prompt_tokens: u32,
466    pub completion_tokens: u32,
467    pub total_tokens: u32,
468    #[serde(skip_serializing_if = "Option::is_none")]
469    pub completion_tokens_details: Option<CompletionTokensDetails>,
470}
471
472impl Usage {
473    /// Create a Usage from prompt and completion token counts
474    pub fn from_counts(prompt_tokens: u32, completion_tokens: u32) -> Self {
475        Self {
476            prompt_tokens,
477            completion_tokens,
478            total_tokens: prompt_tokens + completion_tokens,
479            completion_tokens_details: None,
480        }
481    }
482
483    /// Add reasoning token details to this Usage
484    pub fn with_reasoning_tokens(mut self, reasoning_tokens: u32) -> Self {
485        if reasoning_tokens > 0 {
486            self.completion_tokens_details = Some(CompletionTokensDetails {
487                reasoning_tokens: Some(reasoning_tokens),
488            });
489        }
490        self
491    }
492}
493
494#[derive(Debug, Clone, Deserialize, Serialize)]
495pub struct CompletionTokensDetails {
496    pub reasoning_tokens: Option<u32>,
497}
498
499/// Usage information (used by rerank and other endpoints)
500#[serde_with::skip_serializing_none]
501#[derive(Debug, Clone, Deserialize, Serialize)]
502pub struct UsageInfo {
503    pub prompt_tokens: u32,
504    pub completion_tokens: u32,
505    pub total_tokens: u32,
506    pub reasoning_tokens: Option<u32>,
507    pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
508}
509
510#[derive(Debug, Clone, Deserialize, Serialize)]
511pub struct PromptTokenUsageInfo {
512    pub cached_tokens: u32,
513}
514
515#[derive(Debug, Clone, Deserialize, Serialize)]
516pub struct LogProbs {
517    pub tokens: Vec<String>,
518    pub token_logprobs: Vec<Option<f32>>,
519    pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
520    pub text_offset: Vec<u32>,
521}
522
523#[derive(Debug, Clone, Deserialize, Serialize)]
524#[serde(untagged)]
525pub enum ChatLogProbs {
526    Detailed {
527        #[serde(skip_serializing_if = "Option::is_none")]
528        content: Option<Vec<ChatLogProbsContent>>,
529    },
530    Raw(Value),
531}
532
533#[derive(Debug, Clone, Deserialize, Serialize)]
534pub struct ChatLogProbsContent {
535    pub token: String,
536    pub logprob: f32,
537    pub bytes: Option<Vec<u8>>,
538    pub top_logprobs: Vec<TopLogProb>,
539}
540
541#[derive(Debug, Clone, Deserialize, Serialize)]
542pub struct TopLogProb {
543    pub token: String,
544    pub logprob: f32,
545    pub bytes: Option<Vec<u8>>,
546}
547
548// ============================================================================
549// Error Types
550// ============================================================================
551
552#[derive(Debug, Clone, Deserialize, Serialize)]
553pub struct ErrorResponse {
554    pub error: ErrorDetail,
555}
556
557#[serde_with::skip_serializing_none]
558#[derive(Debug, Clone, Deserialize, Serialize)]
559pub struct ErrorDetail {
560    pub message: String,
561    #[serde(rename = "type")]
562    pub error_type: String,
563    pub param: Option<String>,
564    pub code: Option<String>,
565}
566
567// ============================================================================
568// Input Types
569// ============================================================================
570
571#[derive(Debug, Clone, Deserialize, Serialize)]
572#[serde(untagged)]
573pub enum InputIds {
574    Single(Vec<i32>),
575    Batch(Vec<Vec<i32>>),
576}
577
578/// LoRA adapter path - can be single path or batch of paths (SGLang extension)
579#[derive(Debug, Clone, Deserialize, Serialize)]
580#[serde(untagged)]
581pub enum LoRAPath {
582    Single(Option<String>),
583    Batch(Vec<Option<String>>),
584}