Skip to main content

openai_protocol/
common.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator;
6
7use super::UNKNOWN_MODEL_ID;
8
9// ============================================================================
10// Default value helpers
11// ============================================================================
12
13/// Default model value when not specified
14pub(crate) fn default_model() -> String {
15    UNKNOWN_MODEL_ID.to_string()
16}
17
18/// Helper function for serde default value (returns true)
19pub fn default_true() -> bool {
20    true
21}
22
23// ============================================================================
24// GenerationRequest Trait
25// ============================================================================
26
27/// Trait for unified access to generation request properties
28/// Implemented by ChatCompletionRequest, CompletionRequest, GenerateRequest,
29/// EmbeddingRequest, RerankRequest, and ResponsesRequest
30pub trait GenerationRequest: Send + Sync {
31    /// Check if the request is for streaming
32    fn is_stream(&self) -> bool;
33
34    /// Get the model name if specified
35    fn get_model(&self) -> Option<&str>;
36
37    /// Extract text content for routing decisions
38    fn extract_text_for_routing(&self) -> String;
39}
40
41// ============================================================================
42// String/Array Utilities
43// ============================================================================
44
45/// A type that can be either a single string or an array of strings
46#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, schemars::JsonSchema)]
47#[serde(untagged)]
48pub enum StringOrArray {
49    String(String),
50    Array(Vec<String>),
51}
52
53impl StringOrArray {
54    /// Get the number of items in the StringOrArray
55    pub fn len(&self) -> usize {
56        match self {
57            StringOrArray::String(_) => 1,
58            StringOrArray::Array(arr) => arr.len(),
59        }
60    }
61
62    /// Check if the StringOrArray is empty
63    pub fn is_empty(&self) -> bool {
64        match self {
65            StringOrArray::String(s) => s.is_empty(),
66            StringOrArray::Array(arr) => arr.is_empty(),
67        }
68    }
69
70    /// Convert to a vector of strings (clones the data)
71    pub fn to_vec(&self) -> Vec<String> {
72        match self {
73            StringOrArray::String(s) => vec![s.clone()],
74            StringOrArray::Array(arr) => arr.clone(),
75        }
76    }
77
78    /// Returns an iterator over string references without cloning.
79    /// Use this instead of `to_vec()` when you only need to iterate.
80    pub fn iter(&self) -> StringOrArrayIter<'_> {
81        StringOrArrayIter {
82            inner: self,
83            index: 0,
84        }
85    }
86
87    /// Returns the first string, or None if empty
88    pub fn first(&self) -> Option<&str> {
89        match self {
90            StringOrArray::String(s) => {
91                if s.is_empty() {
92                    None
93                } else {
94                    Some(s)
95                }
96            }
97            StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
98        }
99    }
100}
101
102/// Iterator over StringOrArray that yields string references without cloning
103pub struct StringOrArrayIter<'a> {
104    inner: &'a StringOrArray,
105    index: usize,
106}
107
108impl<'a> Iterator for StringOrArrayIter<'a> {
109    type Item = &'a str;
110
111    fn next(&mut self) -> Option<Self::Item> {
112        match self.inner {
113            StringOrArray::String(s) => {
114                if self.index == 0 {
115                    self.index = 1;
116                    Some(s.as_str())
117                } else {
118                    None
119                }
120            }
121            StringOrArray::Array(arr) => {
122                if self.index < arr.len() {
123                    let item = &arr[self.index];
124                    self.index += 1;
125                    Some(item.as_str())
126                } else {
127                    None
128                }
129            }
130        }
131    }
132
133    fn size_hint(&self) -> (usize, Option<usize>) {
134        let remaining = match self.inner {
135            StringOrArray::String(_) => 1 - self.index,
136            StringOrArray::Array(arr) => arr.len() - self.index,
137        };
138        (remaining, Some(remaining))
139    }
140}
141
142impl<'a> ExactSizeIterator for StringOrArrayIter<'a> {}
143
144/// Validates stop sequences (max 4, non-empty strings)
145/// Used by both ChatCompletionRequest and ResponsesRequest
146pub fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> {
147    match stop {
148        StringOrArray::String(s) => {
149            if s.is_empty() {
150                return Err(validator::ValidationError::new(
151                    "stop sequences cannot be empty",
152                ));
153            }
154        }
155        StringOrArray::Array(arr) => {
156            if arr.len() > 4 {
157                return Err(validator::ValidationError::new(
158                    "maximum 4 stop sequences allowed",
159                ));
160            }
161            for s in arr {
162                if s.is_empty() {
163                    return Err(validator::ValidationError::new(
164                        "stop sequences cannot be empty",
165                    ));
166                }
167            }
168        }
169    }
170    Ok(())
171}
172
173// ============================================================================
174// Content Parts (for multimodal messages)
175// ============================================================================
176
177#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)]
178#[serde(tag = "type")]
179pub enum ContentPart {
180    #[serde(rename = "text")]
181    Text { text: String },
182    #[serde(rename = "image_url")]
183    ImageUrl { image_url: ImageUrl },
184    #[serde(rename = "video_url")]
185    VideoUrl { video_url: VideoUrl },
186}
187
188#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)]
189pub struct ImageUrl {
190    pub url: String,
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub detail: Option<String>, // "auto", "low", or "high"
193}
194
195#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)]
196pub struct VideoUrl {
197    pub url: String,
198}
199
200// ============================================================================
201// Response Format (for structured outputs)
202// ============================================================================
203
204#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
205#[serde(tag = "type")]
206pub enum ResponseFormat {
207    #[serde(rename = "text")]
208    Text,
209    #[serde(rename = "json_object")]
210    JsonObject,
211    #[serde(rename = "json_schema")]
212    JsonSchema { json_schema: JsonSchemaFormat },
213}
214
215#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
216pub struct JsonSchemaFormat {
217    pub name: String,
218    pub schema: Value,
219    #[serde(skip_serializing_if = "Option::is_none")]
220    pub strict: Option<bool>,
221}
222
223// ============================================================================
224// Streaming
225// ============================================================================
226
227#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
228pub struct StreamOptions {
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub include_usage: Option<bool>,
231}
232
233#[serde_with::skip_serializing_none]
234#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
235pub struct ToolCallDelta {
236    pub index: u32,
237    pub id: Option<String>,
238    #[serde(rename = "type")]
239    pub tool_type: Option<String>,
240    pub function: Option<FunctionCallDelta>,
241}
242
243#[serde_with::skip_serializing_none]
244#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
245pub struct FunctionCallDelta {
246    pub name: Option<String>,
247    pub arguments: Option<String>,
248}
249
250// ============================================================================
251// Tools and Function Calling
252// ============================================================================
253
254/// Tool choice value for simple string options
255#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
256#[serde(rename_all = "snake_case")]
257pub enum ToolChoiceValue {
258    Auto,
259    Required,
260    None,
261}
262
263/// Tool choice for both Chat Completion and Responses APIs
264#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
265#[serde(untagged)]
266pub enum ToolChoice {
267    Value(ToolChoiceValue),
268    Function {
269        #[serde(rename = "type")]
270        tool_type: String, // "function"
271        function: FunctionChoice,
272    },
273    AllowedTools {
274        #[serde(rename = "type")]
275        tool_type: String, // "allowed_tools"
276        mode: String, // "auto" | "required" TODO: need validation
277        tools: Vec<ToolReference>,
278    },
279}
280
281impl Default for ToolChoice {
282    fn default() -> Self {
283        Self::Value(ToolChoiceValue::Auto)
284    }
285}
286
287impl ToolChoice {
288    /// Serialize tool_choice to string for ResponsesResponse
289    ///
290    /// Returns the JSON-serialized tool_choice or "auto" as default
291    pub fn serialize_to_string(tool_choice: Option<&ToolChoice>) -> String {
292        tool_choice
293            .map(|tc| serde_json::to_string(tc).unwrap_or_else(|_| "auto".to_string()))
294            .unwrap_or_else(|| "auto".to_string())
295    }
296}
297
298/// Function choice specification for ToolChoice::Function
299#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
300pub struct FunctionChoice {
301    pub name: String,
302}
303
304/// Tool reference for ToolChoice::AllowedTools
305///
306/// Represents a reference to a specific tool in the allowed_tools array.
307/// Different tool types have different required fields.
308#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
309#[serde(tag = "type")]
310#[serde(rename_all = "snake_case")]
311pub enum ToolReference {
312    /// Reference to a function tool
313    #[serde(rename = "function")]
314    Function { name: String },
315
316    /// Reference to an MCP tool
317    #[serde(rename = "mcp")]
318    Mcp {
319        server_label: String,
320        #[serde(skip_serializing_if = "Option::is_none")]
321        name: Option<String>,
322    },
323
324    /// File search hosted tool
325    #[serde(rename = "file_search")]
326    FileSearch,
327
328    /// Web search preview hosted tool
329    #[serde(rename = "web_search_preview")]
330    WebSearchPreview,
331
332    /// Computer use preview hosted tool
333    #[serde(rename = "computer_use_preview")]
334    ComputerUsePreview,
335
336    /// Code interpreter hosted tool
337    #[serde(rename = "code_interpreter")]
338    CodeInterpreter,
339
340    /// Image generation hosted tool
341    #[serde(rename = "image_generation")]
342    ImageGeneration,
343}
344
345impl ToolReference {
346    /// Get a unique identifier for this tool reference
347    pub fn identifier(&self) -> String {
348        match self {
349            ToolReference::Function { name } => format!("function:{name}"),
350            ToolReference::Mcp { server_label, name } => {
351                if let Some(n) = name {
352                    format!("mcp:{server_label}:{n}")
353                } else {
354                    format!("mcp:{server_label}")
355                }
356            }
357            ToolReference::FileSearch => "file_search".to_string(),
358            ToolReference::WebSearchPreview => "web_search_preview".to_string(),
359            ToolReference::ComputerUsePreview => "computer_use_preview".to_string(),
360            ToolReference::CodeInterpreter => "code_interpreter".to_string(),
361            ToolReference::ImageGeneration => "image_generation".to_string(),
362        }
363    }
364
365    /// Get the tool name if this is a function tool
366    pub fn function_name(&self) -> Option<&str> {
367        match self {
368            ToolReference::Function { name } => Some(name.as_str()),
369            _ => None,
370        }
371    }
372}
373
374#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
375pub struct Tool {
376    #[serde(rename = "type")]
377    pub tool_type: String, // "function"
378    pub function: Function,
379}
380
381#[serde_with::skip_serializing_none]
382#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
383pub struct Function {
384    pub name: String,
385    pub description: Option<String>,
386    pub parameters: Value, // JSON Schema
387    /// Whether to enable strict schema adherence (OpenAI structured outputs)
388    pub strict: Option<bool>,
389}
390
391#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
392pub struct ToolCall {
393    pub id: String,
394    #[serde(rename = "type")]
395    pub tool_type: String, // "function"
396    pub function: FunctionCallResponse,
397}
398
399/// Deprecated `function_call` field from the OpenAI API.
400/// Can be `"none"`, `"auto"`, or `{"name": "function_name"}`.
401#[derive(Debug, Clone)]
402pub enum FunctionCall {
403    None,
404    Auto,
405    Function { name: String },
406}
407
408impl Serialize for FunctionCall {
409    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
410        match self {
411            FunctionCall::None => serializer.serialize_str("none"),
412            FunctionCall::Auto => serializer.serialize_str("auto"),
413            FunctionCall::Function { name } => {
414                use serde::ser::SerializeMap;
415                let mut map = serializer.serialize_map(Some(1))?;
416                map.serialize_entry("name", name)?;
417                map.end()
418            }
419        }
420    }
421}
422
423impl<'de> Deserialize<'de> for FunctionCall {
424    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
425        let value = Value::deserialize(deserializer)?;
426        match &value {
427            Value::String(s) => match s.as_str() {
428                "none" => Ok(FunctionCall::None),
429                "auto" => Ok(FunctionCall::Auto),
430                other => Err(serde::de::Error::custom(format!(
431                    "unknown function_call value: \"{other}\""
432                ))),
433            },
434            Value::Object(map) => {
435                if let Some(Value::String(name)) = map.get("name") {
436                    Ok(FunctionCall::Function { name: name.clone() })
437                } else {
438                    Err(serde::de::Error::custom(
439                        "function_call object must have a \"name\" string field",
440                    ))
441                }
442            }
443            _ => Err(serde::de::Error::custom(
444                "function_call must be a string or object",
445            )),
446        }
447    }
448}
449
450impl schemars::JsonSchema for FunctionCall {
451    fn schema_name() -> String {
452        "FunctionCall".to_string()
453    }
454    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
455        use schemars::schema::*;
456        // FunctionCall is either "none", "auto", or {"name": "..."}
457        let string_schema = SchemaObject {
458            instance_type: Some(InstanceType::String.into()),
459            enum_values: Some(vec!["none".into(), "auto".into()]),
460            ..Default::default()
461        };
462        let object_schema = SchemaObject {
463            instance_type: Some(InstanceType::Object.into()),
464            object: Some(Box::new(ObjectValidation {
465                properties: {
466                    let mut map = schemars::Map::new();
467                    map.insert("name".to_string(), gen.subschema_for::<String>());
468                    map
469                },
470                required: {
471                    let mut set = std::collections::BTreeSet::new();
472                    set.insert("name".to_string());
473                    set
474                },
475                ..Default::default()
476            })),
477            ..Default::default()
478        };
479        SchemaObject {
480            subschemas: Some(Box::new(SubschemaValidation {
481                any_of: Some(vec![string_schema.into(), object_schema.into()]),
482                ..Default::default()
483            })),
484            ..Default::default()
485        }
486        .into()
487    }
488}
489
490#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
491pub struct FunctionCallResponse {
492    pub name: String,
493    #[serde(default)]
494    pub arguments: Option<String>, // JSON string
495}
496
497// ============================================================================
498// Usage and Logging
499// ============================================================================
500#[serde_with::skip_serializing_none]
501#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
502pub struct Usage {
503    pub prompt_tokens: u32,
504    pub completion_tokens: u32,
505    pub total_tokens: u32,
506    pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
507    pub completion_tokens_details: Option<CompletionTokensDetails>,
508}
509
510impl Usage {
511    /// Create a Usage from prompt and completion token counts
512    pub fn from_counts(prompt_tokens: u32, completion_tokens: u32) -> Self {
513        Self {
514            prompt_tokens,
515            completion_tokens,
516            total_tokens: prompt_tokens + completion_tokens,
517            prompt_tokens_details: None,
518            completion_tokens_details: None,
519        }
520    }
521
522    /// Add cached token details to this Usage
523    pub fn with_cached_tokens(mut self, cached_tokens: u32) -> Self {
524        if cached_tokens > 0 {
525            self.prompt_tokens_details = Some(PromptTokenUsageInfo { cached_tokens });
526        }
527        self
528    }
529
530    /// Add reasoning token details to this Usage
531    pub fn with_reasoning_tokens(mut self, reasoning_tokens: u32) -> Self {
532        if reasoning_tokens > 0 {
533            self.completion_tokens_details = Some(CompletionTokensDetails {
534                reasoning_tokens: Some(reasoning_tokens),
535                accepted_prediction_tokens: None,
536                rejected_prediction_tokens: None,
537            });
538        }
539        self
540    }
541}
542
543#[serde_with::skip_serializing_none]
544#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
545pub struct CompletionTokensDetails {
546    pub reasoning_tokens: Option<u32>,
547    pub accepted_prediction_tokens: Option<u32>,
548    pub rejected_prediction_tokens: Option<u32>,
549}
550
551/// Usage information (used by rerank and other endpoints)
552#[serde_with::skip_serializing_none]
553#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
554pub struct UsageInfo {
555    pub prompt_tokens: u32,
556    pub completion_tokens: u32,
557    pub total_tokens: u32,
558    pub reasoning_tokens: Option<u32>,
559    pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
560}
561
562#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
563pub struct PromptTokenUsageInfo {
564    pub cached_tokens: u32,
565}
566
567#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
568pub struct LogProbs {
569    pub tokens: Vec<String>,
570    pub token_logprobs: Vec<Option<f32>>,
571    pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
572    pub text_offset: Vec<u32>,
573}
574
575#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
576#[serde(untagged)]
577pub enum ChatLogProbs {
578    Detailed {
579        #[serde(skip_serializing_if = "Option::is_none")]
580        content: Option<Vec<ChatLogProbsContent>>,
581    },
582    Raw(Value),
583}
584
585#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
586pub struct ChatLogProbsContent {
587    pub token: String,
588    pub logprob: f32,
589    pub bytes: Option<Vec<u8>>,
590    pub top_logprobs: Vec<TopLogProb>,
591}
592
593#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
594pub struct TopLogProb {
595    pub token: String,
596    pub logprob: f32,
597    pub bytes: Option<Vec<u8>>,
598}
599
600// ============================================================================
601// Error Types
602// ============================================================================
603
604#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
605pub struct ErrorResponse {
606    pub error: ErrorDetail,
607}
608
609#[serde_with::skip_serializing_none]
610#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
611pub struct ErrorDetail {
612    pub message: String,
613    #[serde(rename = "type")]
614    pub error_type: String,
615    pub param: Option<String>,
616    pub code: Option<String>,
617}
618
619// ============================================================================
620// Input Types
621// ============================================================================
622
623#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
624#[serde(untagged)]
625pub enum InputIds {
626    Single(Vec<i32>),
627    Batch(Vec<Vec<i32>>),
628}
629
630/// LoRA adapter path - can be single path or batch of paths (SGLang extension)
631#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
632#[serde(untagged)]
633pub enum LoRAPath {
634    Single(Option<String>),
635    Batch(Vec<Option<String>>),
636}
637
638// ============================================================================
639// Redacted Types
640// ============================================================================
641#[derive(Clone, Serialize, Deserialize, schemars::JsonSchema)]
642pub struct Redacted(pub String);
643
644impl std::fmt::Debug for Redacted {
645    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
646        f.write_str("[REDACTED]")
647    }
648}
649
650// ============================================================================
651// Response Prompt
652// ============================================================================
653
654/// Reference to a prompt template and its variables.
655#[serde_with::skip_serializing_none]
656#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
657pub struct ResponsePrompt {
658    pub id: String,
659    pub variables: Option<HashMap<String, PromptVariable>>,
660    pub version: Option<String>,
661}
662
663/// A prompt variable value: plain string or a typed input (text, image, file).
664///
665/// Variant order matters for `#[serde(untagged)]`: a bare JSON string succeeds
666/// as `String`; a JSON object falls through to `Typed`.
667#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
668#[serde(untagged)]
669pub enum PromptVariable {
670    String(String),
671    Typed(PromptVariableTyped),
672}
673
674/// Typed prompt variable input.
675#[serde_with::skip_serializing_none]
676#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
677#[serde(tag = "type")]
678#[expect(
679    clippy::enum_variant_names,
680    reason = "variant names match OpenAI API spec"
681)]
682pub enum PromptVariableTyped {
683    #[serde(rename = "input_text")]
684    ResponseInputText { text: String },
685    #[serde(rename = "input_image")]
686    ResponseInputImage {
687        detail: Option<Detail>,
688        file_id: Option<String>,
689        image_url: Option<String>,
690    },
691    #[serde(rename = "input_file")]
692    ResponseInputFile {
693        file_data: Option<String>,
694        file_id: Option<String>,
695        file_url: Option<String>,
696        filename: Option<String>,
697    },
698}
699
700/// Image detail level for [`PromptVariableTyped::InputImage`].
701#[derive(Debug, Clone, Serialize, Deserialize, Default, schemars::JsonSchema)]
702#[serde(rename_all = "snake_case")]
703pub enum Detail {
704    Low,
705    High,
706    #[default]
707    Auto,
708}