Skip to main content

openai_protocol/
common.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator;
6
7use super::UNKNOWN_MODEL_ID;
8
9// ============================================================================
10// Default value helpers
11// ============================================================================
12
13/// Default model value when not specified
14pub(crate) fn default_model() -> String {
15    UNKNOWN_MODEL_ID.to_string()
16}
17
18/// Helper function for serde default value (returns true)
19pub fn default_true() -> bool {
20    true
21}
22
23// ============================================================================
24// GenerationRequest Trait
25// ============================================================================
26
27/// Trait for unified access to generation request properties
28/// Implemented by ChatCompletionRequest, CompletionRequest, GenerateRequest,
29/// EmbeddingRequest, RerankRequest, and ResponsesRequest
30pub trait GenerationRequest: Send + Sync {
31    /// Check if the request is for streaming
32    fn is_stream(&self) -> bool;
33
34    /// Get the model name if specified
35    fn get_model(&self) -> Option<&str>;
36
37    /// Extract text content for routing decisions
38    fn extract_text_for_routing(&self) -> String;
39}
40
41// ============================================================================
42// String/Array Utilities
43// ============================================================================
44
45/// A type that can be either a single string or an array of strings
46#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
47#[serde(untagged)]
48pub enum StringOrArray {
49    String(String),
50    Array(Vec<String>),
51}
52
53impl StringOrArray {
54    /// Get the number of items in the StringOrArray
55    pub fn len(&self) -> usize {
56        match self {
57            StringOrArray::String(_) => 1,
58            StringOrArray::Array(arr) => arr.len(),
59        }
60    }
61
62    /// Check if the StringOrArray is empty
63    pub fn is_empty(&self) -> bool {
64        match self {
65            StringOrArray::String(s) => s.is_empty(),
66            StringOrArray::Array(arr) => arr.is_empty(),
67        }
68    }
69
70    /// Convert to a vector of strings (clones the data)
71    pub fn to_vec(&self) -> Vec<String> {
72        match self {
73            StringOrArray::String(s) => vec![s.clone()],
74            StringOrArray::Array(arr) => arr.clone(),
75        }
76    }
77
78    /// Returns an iterator over string references without cloning.
79    /// Use this instead of `to_vec()` when you only need to iterate.
80    pub fn iter(&self) -> StringOrArrayIter<'_> {
81        StringOrArrayIter {
82            inner: self,
83            index: 0,
84        }
85    }
86
87    /// Returns the first string, or None if empty
88    pub fn first(&self) -> Option<&str> {
89        match self {
90            StringOrArray::String(s) => {
91                if s.is_empty() {
92                    None
93                } else {
94                    Some(s)
95                }
96            }
97            StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
98        }
99    }
100}
101
102/// Iterator over StringOrArray that yields string references without cloning
103pub struct StringOrArrayIter<'a> {
104    inner: &'a StringOrArray,
105    index: usize,
106}
107
108impl<'a> Iterator for StringOrArrayIter<'a> {
109    type Item = &'a str;
110
111    fn next(&mut self) -> Option<Self::Item> {
112        match self.inner {
113            StringOrArray::String(s) => {
114                if self.index == 0 {
115                    self.index = 1;
116                    Some(s.as_str())
117                } else {
118                    None
119                }
120            }
121            StringOrArray::Array(arr) => {
122                if self.index < arr.len() {
123                    let item = &arr[self.index];
124                    self.index += 1;
125                    Some(item.as_str())
126                } else {
127                    None
128                }
129            }
130        }
131    }
132
133    fn size_hint(&self) -> (usize, Option<usize>) {
134        let remaining = match self.inner {
135            StringOrArray::String(_) => 1 - self.index,
136            StringOrArray::Array(arr) => arr.len() - self.index,
137        };
138        (remaining, Some(remaining))
139    }
140}
141
142impl<'a> ExactSizeIterator for StringOrArrayIter<'a> {}
143
144/// Validates stop sequences (max 4, non-empty strings)
145/// Used by both ChatCompletionRequest and ResponsesRequest
146pub fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> {
147    match stop {
148        StringOrArray::String(s) => {
149            if s.is_empty() {
150                return Err(validator::ValidationError::new(
151                    "stop sequences cannot be empty",
152                ));
153            }
154        }
155        StringOrArray::Array(arr) => {
156            if arr.len() > 4 {
157                return Err(validator::ValidationError::new(
158                    "maximum 4 stop sequences allowed",
159                ));
160            }
161            for s in arr {
162                if s.is_empty() {
163                    return Err(validator::ValidationError::new(
164                        "stop sequences cannot be empty",
165                    ));
166                }
167            }
168        }
169    }
170    Ok(())
171}
172
173// ============================================================================
174// Content Parts (for multimodal messages)
175// ============================================================================
176
177#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
178#[serde(tag = "type")]
179pub enum ContentPart {
180    #[serde(rename = "text")]
181    Text { text: String },
182    #[serde(rename = "image_url")]
183    ImageUrl { image_url: ImageUrl },
184    #[serde(rename = "video_url")]
185    VideoUrl { video_url: VideoUrl },
186}
187
188#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
189pub struct ImageUrl {
190    pub url: String,
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub detail: Option<String>, // "auto", "low", or "high"
193}
194
195#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
196pub struct VideoUrl {
197    pub url: String,
198}
199
200// ============================================================================
201// Response Format (for structured outputs)
202// ============================================================================
203
204#[derive(Debug, Clone, Deserialize, Serialize)]
205#[serde(tag = "type")]
206pub enum ResponseFormat {
207    #[serde(rename = "text")]
208    Text,
209    #[serde(rename = "json_object")]
210    JsonObject,
211    #[serde(rename = "json_schema")]
212    JsonSchema { json_schema: JsonSchemaFormat },
213}
214
215#[derive(Debug, Clone, Deserialize, Serialize)]
216pub struct JsonSchemaFormat {
217    pub name: String,
218    pub schema: Value,
219    #[serde(skip_serializing_if = "Option::is_none")]
220    pub strict: Option<bool>,
221}
222
223// ============================================================================
224// Streaming
225// ============================================================================
226
227#[derive(Debug, Clone, Deserialize, Serialize)]
228pub struct StreamOptions {
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub include_usage: Option<bool>,
231}
232
233#[serde_with::skip_serializing_none]
234#[derive(Debug, Clone, Deserialize, Serialize)]
235pub struct ToolCallDelta {
236    pub index: u32,
237    pub id: Option<String>,
238    #[serde(rename = "type")]
239    pub tool_type: Option<String>,
240    pub function: Option<FunctionCallDelta>,
241}
242
243#[serde_with::skip_serializing_none]
244#[derive(Debug, Clone, Deserialize, Serialize)]
245pub struct FunctionCallDelta {
246    pub name: Option<String>,
247    pub arguments: Option<String>,
248}
249
250// ============================================================================
251// Tools and Function Calling
252// ============================================================================
253
254/// Tool choice value for simple string options
255#[derive(Debug, Clone, Deserialize, Serialize)]
256#[serde(rename_all = "snake_case")]
257pub enum ToolChoiceValue {
258    Auto,
259    Required,
260    None,
261}
262
263/// Tool choice for both Chat Completion and Responses APIs
264#[derive(Debug, Clone, Deserialize, Serialize)]
265#[serde(untagged)]
266pub enum ToolChoice {
267    Value(ToolChoiceValue),
268    Function {
269        #[serde(rename = "type")]
270        tool_type: String, // "function"
271        function: FunctionChoice,
272    },
273    AllowedTools {
274        #[serde(rename = "type")]
275        tool_type: String, // "allowed_tools"
276        mode: String, // "auto" | "required" TODO: need validation
277        tools: Vec<ToolReference>,
278    },
279}
280
281impl Default for ToolChoice {
282    fn default() -> Self {
283        Self::Value(ToolChoiceValue::Auto)
284    }
285}
286
287impl ToolChoice {
288    /// Serialize tool_choice to string for ResponsesResponse
289    ///
290    /// Returns the JSON-serialized tool_choice or "auto" as default
291    pub fn serialize_to_string(tool_choice: Option<&ToolChoice>) -> String {
292        tool_choice
293            .map(|tc| serde_json::to_string(tc).unwrap_or_else(|_| "auto".to_string()))
294            .unwrap_or_else(|| "auto".to_string())
295    }
296}
297
298/// Function choice specification for ToolChoice::Function
299#[derive(Debug, Clone, Deserialize, Serialize)]
300pub struct FunctionChoice {
301    pub name: String,
302}
303
304/// Tool reference for ToolChoice::AllowedTools
305///
306/// Represents a reference to a specific tool in the allowed_tools array.
307/// Different tool types have different required fields.
308#[derive(Debug, Clone, Deserialize, Serialize)]
309#[serde(tag = "type")]
310#[serde(rename_all = "snake_case")]
311pub enum ToolReference {
312    /// Reference to a function tool
313    #[serde(rename = "function")]
314    Function { name: String },
315
316    /// Reference to an MCP tool
317    #[serde(rename = "mcp")]
318    Mcp {
319        server_label: String,
320        #[serde(skip_serializing_if = "Option::is_none")]
321        name: Option<String>,
322    },
323
324    /// File search hosted tool
325    #[serde(rename = "file_search")]
326    FileSearch,
327
328    /// Web search preview hosted tool
329    #[serde(rename = "web_search_preview")]
330    WebSearchPreview,
331
332    /// Computer use preview hosted tool
333    #[serde(rename = "computer_use_preview")]
334    ComputerUsePreview,
335
336    /// Code interpreter hosted tool
337    #[serde(rename = "code_interpreter")]
338    CodeInterpreter,
339
340    /// Image generation hosted tool
341    #[serde(rename = "image_generation")]
342    ImageGeneration,
343}
344
345impl ToolReference {
346    /// Get a unique identifier for this tool reference
347    pub fn identifier(&self) -> String {
348        match self {
349            ToolReference::Function { name } => format!("function:{name}"),
350            ToolReference::Mcp { server_label, name } => {
351                if let Some(n) = name {
352                    format!("mcp:{server_label}:{n}")
353                } else {
354                    format!("mcp:{server_label}")
355                }
356            }
357            ToolReference::FileSearch => "file_search".to_string(),
358            ToolReference::WebSearchPreview => "web_search_preview".to_string(),
359            ToolReference::ComputerUsePreview => "computer_use_preview".to_string(),
360            ToolReference::CodeInterpreter => "code_interpreter".to_string(),
361            ToolReference::ImageGeneration => "image_generation".to_string(),
362        }
363    }
364
365    /// Get the tool name if this is a function tool
366    pub fn function_name(&self) -> Option<&str> {
367        match self {
368            ToolReference::Function { name } => Some(name.as_str()),
369            _ => None,
370        }
371    }
372}
373
374#[derive(Debug, Clone, Deserialize, Serialize)]
375pub struct Tool {
376    #[serde(rename = "type")]
377    pub tool_type: String, // "function"
378    pub function: Function,
379}
380
381#[serde_with::skip_serializing_none]
382#[derive(Debug, Clone, Deserialize, Serialize)]
383pub struct Function {
384    pub name: String,
385    pub description: Option<String>,
386    pub parameters: Value, // JSON Schema
387    /// Whether to enable strict schema adherence (OpenAI structured outputs)
388    pub strict: Option<bool>,
389}
390
391#[derive(Debug, Clone, Deserialize, Serialize)]
392pub struct ToolCall {
393    pub id: String,
394    #[serde(rename = "type")]
395    pub tool_type: String, // "function"
396    pub function: FunctionCallResponse,
397}
398
399/// Deprecated `function_call` field from the OpenAI API.
400/// Can be `"none"`, `"auto"`, or `{"name": "function_name"}`.
401#[derive(Debug, Clone)]
402pub enum FunctionCall {
403    None,
404    Auto,
405    Function { name: String },
406}
407
408impl Serialize for FunctionCall {
409    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
410        match self {
411            FunctionCall::None => serializer.serialize_str("none"),
412            FunctionCall::Auto => serializer.serialize_str("auto"),
413            FunctionCall::Function { name } => {
414                use serde::ser::SerializeMap;
415                let mut map = serializer.serialize_map(Some(1))?;
416                map.serialize_entry("name", name)?;
417                map.end()
418            }
419        }
420    }
421}
422
423impl<'de> Deserialize<'de> for FunctionCall {
424    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
425        let value = Value::deserialize(deserializer)?;
426        match &value {
427            Value::String(s) => match s.as_str() {
428                "none" => Ok(FunctionCall::None),
429                "auto" => Ok(FunctionCall::Auto),
430                other => Err(serde::de::Error::custom(format!(
431                    "unknown function_call value: \"{other}\""
432                ))),
433            },
434            Value::Object(map) => {
435                if let Some(Value::String(name)) = map.get("name") {
436                    Ok(FunctionCall::Function { name: name.clone() })
437                } else {
438                    Err(serde::de::Error::custom(
439                        "function_call object must have a \"name\" string field",
440                    ))
441                }
442            }
443            _ => Err(serde::de::Error::custom(
444                "function_call must be a string or object",
445            )),
446        }
447    }
448}
449
450#[derive(Debug, Clone, Deserialize, Serialize)]
451pub struct FunctionCallResponse {
452    pub name: String,
453    #[serde(default)]
454    pub arguments: Option<String>, // JSON string
455}
456
457// ============================================================================
458// Usage and Logging
459// ============================================================================
460
461#[derive(Debug, Clone, Deserialize, Serialize)]
462pub struct Usage {
463    pub prompt_tokens: u32,
464    pub completion_tokens: u32,
465    pub total_tokens: u32,
466    #[serde(skip_serializing_if = "Option::is_none")]
467    pub completion_tokens_details: Option<CompletionTokensDetails>,
468}
469
470impl Usage {
471    /// Create a Usage from prompt and completion token counts
472    pub fn from_counts(prompt_tokens: u32, completion_tokens: u32) -> Self {
473        Self {
474            prompt_tokens,
475            completion_tokens,
476            total_tokens: prompt_tokens + completion_tokens,
477            completion_tokens_details: None,
478        }
479    }
480
481    /// Add reasoning token details to this Usage
482    pub fn with_reasoning_tokens(mut self, reasoning_tokens: u32) -> Self {
483        if reasoning_tokens > 0 {
484            self.completion_tokens_details = Some(CompletionTokensDetails {
485                reasoning_tokens: Some(reasoning_tokens),
486            });
487        }
488        self
489    }
490}
491
492#[derive(Debug, Clone, Deserialize, Serialize)]
493pub struct CompletionTokensDetails {
494    pub reasoning_tokens: Option<u32>,
495}
496
497/// Usage information (used by rerank and other endpoints)
498#[serde_with::skip_serializing_none]
499#[derive(Debug, Clone, Deserialize, Serialize)]
500pub struct UsageInfo {
501    pub prompt_tokens: u32,
502    pub completion_tokens: u32,
503    pub total_tokens: u32,
504    pub reasoning_tokens: Option<u32>,
505    pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
506}
507
508#[derive(Debug, Clone, Deserialize, Serialize)]
509pub struct PromptTokenUsageInfo {
510    pub cached_tokens: u32,
511}
512
513#[derive(Debug, Clone, Deserialize, Serialize)]
514pub struct LogProbs {
515    pub tokens: Vec<String>,
516    pub token_logprobs: Vec<Option<f32>>,
517    pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
518    pub text_offset: Vec<u32>,
519}
520
521#[derive(Debug, Clone, Deserialize, Serialize)]
522#[serde(untagged)]
523pub enum ChatLogProbs {
524    Detailed {
525        #[serde(skip_serializing_if = "Option::is_none")]
526        content: Option<Vec<ChatLogProbsContent>>,
527    },
528    Raw(Value),
529}
530
531#[derive(Debug, Clone, Deserialize, Serialize)]
532pub struct ChatLogProbsContent {
533    pub token: String,
534    pub logprob: f32,
535    pub bytes: Option<Vec<u8>>,
536    pub top_logprobs: Vec<TopLogProb>,
537}
538
539#[derive(Debug, Clone, Deserialize, Serialize)]
540pub struct TopLogProb {
541    pub token: String,
542    pub logprob: f32,
543    pub bytes: Option<Vec<u8>>,
544}
545
546// ============================================================================
547// Error Types
548// ============================================================================
549
550#[derive(Debug, Clone, Deserialize, Serialize)]
551pub struct ErrorResponse {
552    pub error: ErrorDetail,
553}
554
555#[serde_with::skip_serializing_none]
556#[derive(Debug, Clone, Deserialize, Serialize)]
557pub struct ErrorDetail {
558    pub message: String,
559    #[serde(rename = "type")]
560    pub error_type: String,
561    pub param: Option<String>,
562    pub code: Option<String>,
563}
564
565// ============================================================================
566// Input Types
567// ============================================================================
568
569#[derive(Debug, Clone, Deserialize, Serialize)]
570#[serde(untagged)]
571pub enum InputIds {
572    Single(Vec<i32>),
573    Batch(Vec<Vec<i32>>),
574}
575
576/// LoRA adapter path - can be single path or batch of paths (SGLang extension)
577#[derive(Debug, Clone, Deserialize, Serialize)]
578#[serde(untagged)]
579pub enum LoRAPath {
580    Single(Option<String>),
581    Batch(Vec<Option<String>>),
582}
583
584// ============================================================================
585// Redacted Types
586// ============================================================================
587#[derive(Clone, Serialize, Deserialize)]
588pub struct Redacted(pub String);
589
590impl std::fmt::Debug for Redacted {
591    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
592        f.write_str("[REDACTED]")
593    }
594}
595
596// ============================================================================
597// Response Prompt
598// ============================================================================
599
600/// Reference to a prompt template and its variables.
601#[serde_with::skip_serializing_none]
602#[derive(Debug, Clone, Serialize, Deserialize)]
603pub struct ResponsePrompt {
604    pub id: String,
605    pub variables: Option<HashMap<String, PromptVariable>>,
606    pub version: Option<String>,
607}
608
609/// A prompt variable value: plain string or a typed input (text, image, file).
610///
611/// Variant order matters for `#[serde(untagged)]`: a bare JSON string succeeds
612/// as `String`; a JSON object falls through to `Typed`.
613#[derive(Debug, Clone, Serialize, Deserialize)]
614#[serde(untagged)]
615pub enum PromptVariable {
616    String(String),
617    Typed(PromptVariableTyped),
618}
619
620/// Typed prompt variable input.
621#[serde_with::skip_serializing_none]
622#[derive(Debug, Clone, Serialize, Deserialize)]
623#[serde(tag = "type")]
624#[expect(
625    clippy::enum_variant_names,
626    reason = "variant names match OpenAI API spec"
627)]
628pub enum PromptVariableTyped {
629    #[serde(rename = "input_text")]
630    ResponseInputText { text: String },
631    #[serde(rename = "input_image")]
632    ResponseInputImage {
633        detail: Option<Detail>,
634        file_id: Option<String>,
635        image_url: Option<String>,
636    },
637    #[serde(rename = "input_file")]
638    ResponseInputFile {
639        file_data: Option<String>,
640        file_id: Option<String>,
641        file_url: Option<String>,
642        filename: Option<String>,
643    },
644}
645
646/// Image detail level for [`PromptVariableTyped::InputImage`].
647#[derive(Debug, Clone, Serialize, Deserialize, Default)]
648#[serde(rename_all = "snake_case")]
649pub enum Detail {
650    Low,
651    High,
652    #[default]
653    Auto,
654}