Skip to main content

openai_protocol/
common.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator;
6
7// ============================================================================
8// Default value helpers
9// ============================================================================
10
11/// Default model for endpoints where model is optional (e.g., /generate).
12/// Uses UNKNOWN_MODEL_ID so routers treat it as "any available worker."
13pub fn default_unknown_model() -> String {
14    super::UNKNOWN_MODEL_ID.to_string()
15}
16
17/// Helper function for serde default value (returns true)
18pub fn default_true() -> bool {
19    true
20}
21
22/// Deserialize a bool that also accepts JSON `null` (mapped to `false`).
23///
24/// Use with `#[serde(default, deserialize_with = "deserialize_null_as_false")]`
25/// on fields that the OpenAI spec defines as `Optional[bool]` defaulting to `false`.
26pub fn deserialize_null_as_false<'de, D>(deserializer: D) -> Result<bool, D::Error>
27where
28    D: serde::Deserializer<'de>,
29{
30    Option::<bool>::deserialize(deserializer).map(|opt| opt.unwrap_or(false))
31}
32
33// ============================================================================
34// GenerationRequest Trait
35// ============================================================================
36
37/// Trait for unified access to generation request properties
38/// Implemented by ChatCompletionRequest, CompletionRequest, GenerateRequest,
39/// EmbeddingRequest, RerankRequest, and ResponsesRequest
40pub trait GenerationRequest: Send + Sync {
41    /// Check if the request is for streaming
42    fn is_stream(&self) -> bool;
43
44    /// Get the model name if specified
45    fn get_model(&self) -> Option<&str>;
46
47    /// Extract text content for routing decisions
48    fn extract_text_for_routing(&self) -> String;
49}
50
51// ============================================================================
52// String/Array Utilities
53// ============================================================================
54
55/// A type that can be either a single string or an array of strings
56#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, schemars::JsonSchema)]
57#[serde(untagged)]
58pub enum StringOrArray {
59    String(String),
60    Array(Vec<String>),
61}
62
63impl StringOrArray {
64    /// Get the number of items in the StringOrArray
65    pub fn len(&self) -> usize {
66        match self {
67            StringOrArray::String(_) => 1,
68            StringOrArray::Array(arr) => arr.len(),
69        }
70    }
71
72    /// Check if the StringOrArray is empty
73    pub fn is_empty(&self) -> bool {
74        match self {
75            StringOrArray::String(s) => s.is_empty(),
76            StringOrArray::Array(arr) => arr.is_empty(),
77        }
78    }
79
80    /// Convert to a vector of strings (clones the data)
81    pub fn to_vec(&self) -> Vec<String> {
82        match self {
83            StringOrArray::String(s) => vec![s.clone()],
84            StringOrArray::Array(arr) => arr.clone(),
85        }
86    }
87
88    /// Returns an iterator over string references without cloning.
89    /// Use this instead of `to_vec()` when you only need to iterate.
90    pub fn iter(&self) -> StringOrArrayIter<'_> {
91        StringOrArrayIter {
92            inner: self,
93            index: 0,
94        }
95    }
96
97    /// Returns the first string, or None if empty
98    pub fn first(&self) -> Option<&str> {
99        match self {
100            StringOrArray::String(s) => {
101                if s.is_empty() {
102                    None
103                } else {
104                    Some(s)
105                }
106            }
107            StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
108        }
109    }
110}
111
112/// Iterator over StringOrArray that yields string references without cloning
113pub struct StringOrArrayIter<'a> {
114    inner: &'a StringOrArray,
115    index: usize,
116}
117
118impl<'a> Iterator for StringOrArrayIter<'a> {
119    type Item = &'a str;
120
121    fn next(&mut self) -> Option<Self::Item> {
122        match self.inner {
123            StringOrArray::String(s) => {
124                if self.index == 0 {
125                    self.index = 1;
126                    Some(s.as_str())
127                } else {
128                    None
129                }
130            }
131            StringOrArray::Array(arr) => {
132                if self.index < arr.len() {
133                    let item = &arr[self.index];
134                    self.index += 1;
135                    Some(item.as_str())
136                } else {
137                    None
138                }
139            }
140        }
141    }
142
143    fn size_hint(&self) -> (usize, Option<usize>) {
144        let remaining = match self.inner {
145            StringOrArray::String(_) => 1 - self.index,
146            StringOrArray::Array(arr) => arr.len() - self.index,
147        };
148        (remaining, Some(remaining))
149    }
150}
151
152impl<'a> ExactSizeIterator for StringOrArrayIter<'a> {}
153
154/// Validates stop sequences (max 4, non-empty strings)
155/// Used by both ChatCompletionRequest and ResponsesRequest
156pub fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> {
157    match stop {
158        StringOrArray::String(s) => {
159            if s.is_empty() {
160                return Err(validator::ValidationError::new(
161                    "stop sequences cannot be empty",
162                ));
163            }
164        }
165        StringOrArray::Array(arr) => {
166            if arr.len() > 4 {
167                return Err(validator::ValidationError::new(
168                    "maximum 4 stop sequences allowed",
169                ));
170            }
171            for s in arr {
172                if s.is_empty() {
173                    return Err(validator::ValidationError::new(
174                        "stop sequences cannot be empty",
175                    ));
176                }
177            }
178        }
179    }
180    Ok(())
181}
182
183// ============================================================================
184// Content Parts (for multimodal messages)
185// ============================================================================
186
187#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)]
188#[serde(tag = "type")]
189pub enum ContentPart {
190    #[serde(rename = "text")]
191    Text { text: String },
192    #[serde(rename = "image_url")]
193    ImageUrl { image_url: ImageUrl },
194    #[serde(rename = "video_url")]
195    VideoUrl { video_url: VideoUrl },
196}
197
198#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)]
199pub struct ImageUrl {
200    pub url: String,
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub detail: Option<String>, // "auto", "low", or "high"
203}
204
205#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)]
206pub struct VideoUrl {
207    pub url: String,
208}
209
210// ============================================================================
211// Response Format (for structured outputs)
212// ============================================================================
213
214#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
215#[serde(tag = "type")]
216pub enum ResponseFormat {
217    #[serde(rename = "text")]
218    Text,
219    #[serde(rename = "json_object")]
220    JsonObject,
221    #[serde(rename = "json_schema")]
222    JsonSchema { json_schema: JsonSchemaFormat },
223}
224
225#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
226pub struct JsonSchemaFormat {
227    pub name: String,
228    pub schema: Value,
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub strict: Option<bool>,
231}
232
233// ============================================================================
234// Streaming
235// ============================================================================
236
237#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
238pub struct StreamOptions {
239    #[serde(skip_serializing_if = "Option::is_none")]
240    pub include_usage: Option<bool>,
241}
242
243#[serde_with::skip_serializing_none]
244#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
245pub struct ToolCallDelta {
246    pub index: u32,
247    pub id: Option<String>,
248    #[serde(rename = "type")]
249    pub tool_type: Option<String>,
250    pub function: Option<FunctionCallDelta>,
251}
252
253#[serde_with::skip_serializing_none]
254#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
255pub struct FunctionCallDelta {
256    pub name: Option<String>,
257    pub arguments: Option<String>,
258}
259
260// ============================================================================
261// Tools and Function Calling
262// ============================================================================
263
264/// Tool choice value for simple string options
265#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
266#[serde(rename_all = "snake_case")]
267pub enum ToolChoiceValue {
268    Auto,
269    Required,
270    None,
271}
272
273/// Tool choice for both Chat Completion and Responses APIs
274#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
275#[serde(untagged)]
276pub enum ToolChoice {
277    Value(ToolChoiceValue),
278    Function {
279        #[serde(rename = "type")]
280        tool_type: String, // "function"
281        function: FunctionChoice,
282    },
283    AllowedTools {
284        #[serde(rename = "type")]
285        tool_type: String, // "allowed_tools"
286        mode: String, // "auto" | "required" TODO: need validation
287        tools: Vec<ToolReference>,
288    },
289}
290
291impl Default for ToolChoice {
292    fn default() -> Self {
293        Self::Value(ToolChoiceValue::Auto)
294    }
295}
296
297impl ToolChoice {
298    /// Serialize tool_choice to string for ResponsesResponse
299    ///
300    /// Returns the JSON-serialized tool_choice or "auto" as default
301    pub fn serialize_to_string(tool_choice: Option<&ToolChoice>) -> String {
302        tool_choice
303            .map(|tc| serde_json::to_string(tc).unwrap_or_else(|_| "auto".to_string()))
304            .unwrap_or_else(|| "auto".to_string())
305    }
306}
307
308/// Function choice specification for ToolChoice::Function
309#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
310pub struct FunctionChoice {
311    pub name: String,
312}
313
314/// Tool reference for ToolChoice::AllowedTools
315///
316/// Represents a reference to a specific tool in the allowed_tools array.
317/// Different tool types have different required fields.
318#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
319#[serde(tag = "type")]
320#[serde(rename_all = "snake_case")]
321pub enum ToolReference {
322    /// Reference to a function tool
323    #[serde(rename = "function")]
324    Function { name: String },
325
326    /// Reference to an MCP tool
327    #[serde(rename = "mcp")]
328    Mcp {
329        server_label: String,
330        #[serde(skip_serializing_if = "Option::is_none")]
331        name: Option<String>,
332    },
333
334    /// File search hosted tool
335    #[serde(rename = "file_search")]
336    FileSearch,
337
338    /// Web search preview hosted tool
339    #[serde(rename = "web_search_preview")]
340    WebSearchPreview,
341
342    /// Computer use preview hosted tool
343    #[serde(rename = "computer_use_preview")]
344    ComputerUsePreview,
345
346    /// Code interpreter hosted tool
347    #[serde(rename = "code_interpreter")]
348    CodeInterpreter,
349
350    /// Image generation hosted tool
351    #[serde(rename = "image_generation")]
352    ImageGeneration,
353}
354
355impl ToolReference {
356    /// Get a unique identifier for this tool reference
357    pub fn identifier(&self) -> String {
358        match self {
359            ToolReference::Function { name } => format!("function:{name}"),
360            ToolReference::Mcp { server_label, name } => {
361                if let Some(n) = name {
362                    format!("mcp:{server_label}:{n}")
363                } else {
364                    format!("mcp:{server_label}")
365                }
366            }
367            ToolReference::FileSearch => "file_search".to_string(),
368            ToolReference::WebSearchPreview => "web_search_preview".to_string(),
369            ToolReference::ComputerUsePreview => "computer_use_preview".to_string(),
370            ToolReference::CodeInterpreter => "code_interpreter".to_string(),
371            ToolReference::ImageGeneration => "image_generation".to_string(),
372        }
373    }
374
375    /// Get the tool name if this is a function tool
376    pub fn function_name(&self) -> Option<&str> {
377        match self {
378            ToolReference::Function { name } => Some(name.as_str()),
379            _ => None,
380        }
381    }
382}
383
384#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
385pub struct Tool {
386    #[serde(rename = "type")]
387    pub tool_type: String, // "function"
388    pub function: Function,
389}
390
391#[serde_with::skip_serializing_none]
392#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
393pub struct Function {
394    pub name: String,
395    pub description: Option<String>,
396    pub parameters: Value, // JSON Schema
397    /// Whether to enable strict schema adherence (OpenAI structured outputs)
398    pub strict: Option<bool>,
399}
400
401#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
402pub struct ToolCall {
403    pub id: String,
404    #[serde(rename = "type")]
405    pub tool_type: String, // "function"
406    pub function: FunctionCallResponse,
407}
408
409/// Deprecated `function_call` field from the OpenAI API.
410/// Can be `"none"`, `"auto"`, or `{"name": "function_name"}`.
411#[derive(Debug, Clone)]
412pub enum FunctionCall {
413    None,
414    Auto,
415    Function { name: String },
416}
417
418impl Serialize for FunctionCall {
419    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
420        match self {
421            FunctionCall::None => serializer.serialize_str("none"),
422            FunctionCall::Auto => serializer.serialize_str("auto"),
423            FunctionCall::Function { name } => {
424                use serde::ser::SerializeMap;
425                let mut map = serializer.serialize_map(Some(1))?;
426                map.serialize_entry("name", name)?;
427                map.end()
428            }
429        }
430    }
431}
432
433impl<'de> Deserialize<'de> for FunctionCall {
434    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
435        let value = Value::deserialize(deserializer)?;
436        match &value {
437            Value::String(s) => match s.as_str() {
438                "none" => Ok(FunctionCall::None),
439                "auto" => Ok(FunctionCall::Auto),
440                other => Err(serde::de::Error::custom(format!(
441                    "unknown function_call value: \"{other}\""
442                ))),
443            },
444            Value::Object(map) => {
445                if let Some(Value::String(name)) = map.get("name") {
446                    Ok(FunctionCall::Function { name: name.clone() })
447                } else {
448                    Err(serde::de::Error::custom(
449                        "function_call object must have a \"name\" string field",
450                    ))
451                }
452            }
453            _ => Err(serde::de::Error::custom(
454                "function_call must be a string or object",
455            )),
456        }
457    }
458}
459
460impl schemars::JsonSchema for FunctionCall {
461    fn schema_name() -> String {
462        "FunctionCall".to_string()
463    }
464    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
465        use schemars::schema::*;
466        // FunctionCall is either "none", "auto", or {"name": "..."}
467        let string_schema = SchemaObject {
468            instance_type: Some(InstanceType::String.into()),
469            enum_values: Some(vec!["none".into(), "auto".into()]),
470            ..Default::default()
471        };
472        let object_schema = SchemaObject {
473            instance_type: Some(InstanceType::Object.into()),
474            object: Some(Box::new(ObjectValidation {
475                properties: {
476                    let mut map = schemars::Map::new();
477                    map.insert("name".to_string(), gen.subschema_for::<String>());
478                    map
479                },
480                required: {
481                    let mut set = std::collections::BTreeSet::new();
482                    set.insert("name".to_string());
483                    set
484                },
485                ..Default::default()
486            })),
487            ..Default::default()
488        };
489        SchemaObject {
490            subschemas: Some(Box::new(SubschemaValidation {
491                any_of: Some(vec![string_schema.into(), object_schema.into()]),
492                ..Default::default()
493            })),
494            ..Default::default()
495        }
496        .into()
497    }
498}
499
500#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
501pub struct FunctionCallResponse {
502    pub name: String,
503    #[serde(default)]
504    pub arguments: Option<String>, // JSON string
505}
506
507// ============================================================================
508// Usage and Logging
509// ============================================================================
510#[serde_with::skip_serializing_none]
511#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
512pub struct Usage {
513    pub prompt_tokens: u32,
514    pub completion_tokens: u32,
515    pub total_tokens: u32,
516    pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
517    pub completion_tokens_details: Option<CompletionTokensDetails>,
518}
519
520impl Usage {
521    /// Create a Usage from prompt and completion token counts
522    pub fn from_counts(prompt_tokens: u32, completion_tokens: u32) -> Self {
523        Self {
524            prompt_tokens,
525            completion_tokens,
526            total_tokens: prompt_tokens + completion_tokens,
527            prompt_tokens_details: None,
528            completion_tokens_details: None,
529        }
530    }
531
532    /// Add cached token details to this Usage
533    pub fn with_cached_tokens(mut self, cached_tokens: u32) -> Self {
534        if cached_tokens > 0 {
535            self.prompt_tokens_details = Some(PromptTokenUsageInfo { cached_tokens });
536        }
537        self
538    }
539
540    /// Add reasoning token details to this Usage
541    pub fn with_reasoning_tokens(mut self, reasoning_tokens: u32) -> Self {
542        if reasoning_tokens > 0 {
543            self.completion_tokens_details = Some(CompletionTokensDetails {
544                reasoning_tokens: Some(reasoning_tokens),
545                accepted_prediction_tokens: None,
546                rejected_prediction_tokens: None,
547            });
548        }
549        self
550    }
551}
552
553#[serde_with::skip_serializing_none]
554#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
555pub struct CompletionTokensDetails {
556    pub reasoning_tokens: Option<u32>,
557    pub accepted_prediction_tokens: Option<u32>,
558    pub rejected_prediction_tokens: Option<u32>,
559}
560
561/// Usage information (used by rerank and other endpoints)
562#[serde_with::skip_serializing_none]
563#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
564pub struct UsageInfo {
565    pub prompt_tokens: u32,
566    pub completion_tokens: u32,
567    pub total_tokens: u32,
568    pub reasoning_tokens: Option<u32>,
569    pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
570}
571
572#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
573pub struct PromptTokenUsageInfo {
574    pub cached_tokens: u32,
575}
576
577#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
578pub struct LogProbs {
579    pub tokens: Vec<String>,
580    pub token_logprobs: Vec<Option<f32>>,
581    pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
582    pub text_offset: Vec<u32>,
583}
584
585#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
586#[serde(untagged)]
587pub enum ChatLogProbs {
588    Detailed {
589        #[serde(skip_serializing_if = "Option::is_none")]
590        content: Option<Vec<ChatLogProbsContent>>,
591    },
592    Raw(Value),
593}
594
595#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
596pub struct ChatLogProbsContent {
597    pub token: String,
598    pub logprob: f32,
599    pub bytes: Option<Vec<u8>>,
600    pub top_logprobs: Vec<TopLogProb>,
601}
602
603#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
604pub struct TopLogProb {
605    pub token: String,
606    pub logprob: f32,
607    pub bytes: Option<Vec<u8>>,
608}
609
610// ============================================================================
611// Error Types
612// ============================================================================
613
614#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
615pub struct ErrorResponse {
616    pub error: ErrorDetail,
617}
618
619#[serde_with::skip_serializing_none]
620#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
621pub struct ErrorDetail {
622    pub message: String,
623    #[serde(rename = "type")]
624    pub error_type: String,
625    pub param: Option<String>,
626    pub code: Option<String>,
627}
628
629// ============================================================================
630// Input Types
631// ============================================================================
632
633#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
634#[serde(untagged)]
635pub enum InputIds {
636    Single(Vec<i32>),
637    Batch(Vec<Vec<i32>>),
638}
639
640/// LoRA adapter path - can be single path or batch of paths (SGLang extension)
641#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
642#[serde(untagged)]
643pub enum LoRAPath {
644    Single(Option<String>),
645    Batch(Vec<Option<String>>),
646}
647
648// ============================================================================
649// Redacted Types
650// ============================================================================
651#[derive(Clone, Serialize, Deserialize, schemars::JsonSchema)]
652pub struct Redacted(pub String);
653
654impl std::fmt::Debug for Redacted {
655    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
656        f.write_str("[REDACTED]")
657    }
658}
659
660// ============================================================================
661// Response Prompt
662// ============================================================================
663
664/// Reference to a prompt template and its variables.
665#[serde_with::skip_serializing_none]
666#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
667pub struct ResponsePrompt {
668    pub id: String,
669    pub variables: Option<HashMap<String, PromptVariable>>,
670    pub version: Option<String>,
671}
672
673/// A prompt variable value: plain string or a typed input (text, image, file).
674///
675/// Variant order matters for `#[serde(untagged)]`: a bare JSON string succeeds
676/// as `String`; a JSON object falls through to `Typed`.
677#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
678#[serde(untagged)]
679pub enum PromptVariable {
680    String(String),
681    Typed(PromptVariableTyped),
682}
683
684/// Typed prompt variable input.
685#[serde_with::skip_serializing_none]
686#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
687#[serde(tag = "type")]
688#[expect(
689    clippy::enum_variant_names,
690    reason = "variant names match OpenAI API spec"
691)]
692pub enum PromptVariableTyped {
693    #[serde(rename = "input_text")]
694    ResponseInputText { text: String },
695    #[serde(rename = "input_image")]
696    ResponseInputImage {
697        detail: Option<Detail>,
698        file_id: Option<String>,
699        image_url: Option<String>,
700    },
701    #[serde(rename = "input_file")]
702    ResponseInputFile {
703        file_data: Option<String>,
704        file_id: Option<String>,
705        file_url: Option<String>,
706        filename: Option<String>,
707    },
708}
709
710/// Image detail level for [`PromptVariableTyped::InputImage`].
711#[derive(Debug, Clone, Serialize, Deserialize, Default, schemars::JsonSchema)]
712#[serde(rename_all = "snake_case")]
713pub enum Detail {
714    Low,
715    High,
716    #[default]
717    Auto,
718}
719
720#[cfg(test)]
721mod tests {
722    use serde::Deserialize;
723    use serde_json::json;
724
725    use super::*;
726
727    #[derive(Deserialize)]
728    struct NullableBoolTest {
729        #[serde(default, deserialize_with = "deserialize_null_as_false")]
730        field: bool,
731    }
732
733    #[test]
734    fn test_deserialize_null_as_false() {
735        let cases = [
736            (json!({"field": true}), true),
737            (json!({"field": false}), false),
738            (json!({"field": null}), false),
739            (json!({}), false),
740        ];
741        for (input, expected) in cases {
742            let t: NullableBoolTest = serde_json::from_value(input).unwrap();
743            assert_eq!(t.field, expected);
744        }
745    }
746
747    #[test]
748    fn test_deserialize_null_as_false_rejects_non_bool() {
749        let result = serde_json::from_value::<NullableBoolTest>(json!({"field": "yes"}));
750        assert!(result.is_err());
751    }
752}