Skip to main content

oxi_ai/
types.rs

1//! Core domain types for oxi-ai
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt;
6use std::hash::Hash;
7
8/// Provider API identifier.
9///
10/// Selects the wire-format / protocol dialect spoken to a particular LLM provider.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[non_exhaustive]
13pub enum Api {
14    /// OpenAI Chat Completions API.
15    #[serde(rename = "openai-completions")]
16    OpenAiCompletions,
17    /// OpenAI Responses API.
18    #[serde(rename = "openai-responses")]
19    OpenAiResponses,
20    /// Anthropic Messages API.
21    #[serde(rename = "anthropic-messages")]
22    AnthropicMessages,
23    /// Google Generative AI (Gemini) API.
24    #[serde(rename = "google-generative-ai")]
25    GoogleGenerativeAi,
26    /// Google Vertex AI endpoint.
27    #[serde(rename = "google-vertex")]
28    GoogleVertex,
29    /// Mistral Conversations API.
30    #[serde(rename = "mistral-conversations")]
31    MistralConversations,
32    /// Azure OpenAI Responses API.
33    #[serde(rename = "azure-openai-responses")]
34    AzureOpenAiResponses,
35    /// AWS Bedrock Converse Stream API.
36    #[serde(rename = "bedrock-converse-stream")]
37    BedrockConverseStream,
38}
39
40impl fmt::Display for Api {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            Api::OpenAiCompletions => write!(f, "openai-completions"),
44            Api::OpenAiResponses => write!(f, "openai-responses"),
45            Api::AnthropicMessages => write!(f, "anthropic-messages"),
46            Api::GoogleGenerativeAi => write!(f, "google-generative-ai"),
47            Api::GoogleVertex => write!(f, "google-vertex"),
48            Api::MistralConversations => write!(f, "mistral-conversations"),
49            Api::AzureOpenAiResponses => write!(f, "azure-openai-responses"),
50            Api::BedrockConverseStream => write!(f, "bedrock-converse-stream"),
51        }
52    }
53}
54
55/// Cache retention preference
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
57#[serde(rename_all = "lowercase")]
58pub enum CacheRetention {
59    /// No caching (default).
60    #[default]
61    None,
62    /// Short-lived cache.
63    Short,
64    /// Long-lived cache.
65    Long,
66}
67
68/// Model thinking/reasoning level
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
70#[serde(rename_all = "lowercase")]
71#[non_exhaustive]
72pub enum ThinkingLevel {
73    /// Extended reasoning disabled (default).
74    #[default]
75    Off,
76    /// Minimal reasoning.
77    Minimal,
78    /// Low reasoning.
79    Low,
80    /// Medium reasoning.
81    Medium,
82    /// High reasoning.
83    High,
84    /// Very high reasoning.
85    XHigh,
86}
87
88impl ThinkingLevel {
89    /// Returns the reasoning level as a string. Returns `None` for `Off`.
90    pub fn as_str(&self) -> Option<&str> {
91        match self {
92            ThinkingLevel::Off => None,
93            ThinkingLevel::Minimal => Some("minimal"),
94            ThinkingLevel::Low => Some("low"),
95            ThinkingLevel::Medium => Some("medium"),
96            ThinkingLevel::High => Some("high"),
97            ThinkingLevel::XHigh => Some("xhigh"),
98        }
99    }
100}
101
102/// Input modalities
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
104#[serde(rename_all = "lowercase")]
105#[non_exhaustive]
106pub enum InputModality {
107    /// Text input.
108    Text,
109    /// Image input.
110    Image,
111}
112
113/// Cost structure – prices per million tokens.
114#[derive(Debug, Clone, Default, Serialize, Deserialize)]
115#[serde(default)]
116pub struct Cost {
117    /// Input token cost ($/M tokens).
118    #[serde(default)]
119    pub input: f64,
120    /// Output token cost ($/M tokens).
121    #[serde(default)]
122    pub output: f64,
123    /// Cached-input read cost ($/M tokens).
124    #[serde(default)]
125    pub cache_read: f64,
126    /// Cache write cost ($/M tokens).
127    #[serde(default)]
128    pub cache_write: f64,
129}
130
131impl Cost {
132    /// Sum of all cost components.
133    pub fn total(&self) -> f64 {
134        self.input + self.output + self.cache_read + self.cache_write
135    }
136}
137
138/// Stop reason – why the model finished generating.
139#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
140#[serde(rename_all = "camelCase")]
141#[non_exhaustive]
142pub enum StopReason {
143    /// Normal stop – the model finished its response.
144    Stop,
145    /// Hit the maximum output token limit.
146    Length,
147    /// Stopped to invoke a tool.
148    ToolUse,
149    /// An error occurred during generation.
150    Error,
151    /// Generation was aborted by the client.
152    Aborted,
153}
154
155/// Token usage statistics.
156#[derive(Debug, Clone, Default, Serialize, Deserialize)]
157pub struct Usage {
158    /// Number of input (prompt) tokens.
159    #[serde(default)]
160    pub input: usize,
161    /// Number of output (completion) tokens.
162    #[serde(default)]
163    pub output: usize,
164    /// Number of tokens read from cache.
165    #[serde(default)]
166    pub cache_read: usize,
167    /// Number of tokens written to cache.
168    #[serde(default)]
169    pub cache_write: usize,
170    /// Total tokens (input + output + cache).
171    #[serde(default)]
172    pub total_tokens: usize,
173    /// Computed cost in dollars.
174    #[serde(default)]
175    pub cost: Cost,
176}
177
178impl Usage {
179    /// Recalculate `total_tokens` and per-component costs from raw token counts.
180    ///
181    /// If pricing parameters are provided, they override the default $1/M rate.
182    pub fn calculate_cost(
183        &mut self,
184        input_cost_per_million: Option<f64>,
185        output_cost_per_million: Option<f64>,
186    ) {
187        self.total_tokens = self.input + self.output + self.cache_read + self.cache_write;
188        self.cost.input = input_cost_per_million.unwrap_or(1.0) * self.input as f64 / 1_000_000.0;
189        self.cost.output =
190            output_cost_per_million.unwrap_or(1.0) * self.output as f64 / 1_000_000.0;
191        self.cost.cache_read = (self.cache_read as f64) / 1_000_000.0;
192        self.cost.cache_write = (self.cache_write as f64) / 1_000_000.0;
193    }
194}
195
196/// Compatibility settings for OpenAI-compatible APIs.
197///
198/// Not every OpenAI-compatible provider supports every feature.
199/// These flags let the streaming layer adapt its request shape.
200#[derive(Debug, Clone, Default, Serialize, Deserialize)]
201#[serde(default)]
202pub struct CompatSettings {
203    /// Whether the provider supports the `store` parameter.
204    #[serde(default = "default_true")]
205    pub supports_store: bool,
206    /// Whether the provider recognises the `developer` role.
207    #[serde(default = "default_true")]
208    pub supports_developer_role: bool,
209    /// Whether the provider supports `reasoning_effort`.
210    #[serde(default = "default_true")]
211    pub supports_reasoning_effort: bool,
212    /// Whether the provider returns usage data in streaming responses.
213    #[serde(default = "default_true")]
214    pub supports_usage_in_streaming: bool,
215    /// Which JSON field name to use for the max-tokens parameter.
216    #[serde(default)]
217    pub max_tokens_field: Option<MaxTokensField>,
218    /// Whether tool results must include the tool name.
219    #[serde(default = "default_false")]
220    pub requires_tool_result_name: bool,
221    /// Whether an assistant message must follow every tool result.
222    #[serde(default = "default_false")]
223    pub requires_assistant_after_tool_result: bool,
224    /// Whether thinking should be sent as plain text.
225    #[serde(default = "default_false")]
226    pub requires_thinking_as_text: bool,
227    /// Provider-specific thinking wire-format.
228    #[serde(default)]
229    pub thinking_format: Option<ThinkingFormat>,
230}
231
232fn default_true() -> bool {
233    true
234}
235fn default_false() -> bool {
236    false
237}
238
239/// Which JSON field to use for the maximum output token count.
240#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
241#[serde(rename_all = "kebab-case")]
242pub enum MaxTokensField {
243    /// Use `max_completion_tokens`.
244    MaxCompletionTokens,
245    /// Use `max_tokens`.
246    MaxTokens,
247}
248
249/// Provider-specific wire format for extended thinking.
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
251#[serde(rename_all = "lowercase")]
252pub enum ThinkingFormat {
253    /// OpenAI native thinking format.
254    OpenAI,
255    /// OpenRouter thinking format.
256    OpenRouter,
257    /// DeepSeek thinking format.
258    DeepSeek,
259    /// Zai thinking format.
260    Zai,
261    /// Qwen API thinking format.
262    Qwen,
263    /// Qwen chat-template thinking format.
264    QwenChatTemplate,
265}
266
267/// Task complexity level for routing decisions.
268#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
269pub enum Complexity {
270    /// Simple, single-step tasks (e.g., "translate this text")
271    Trivial,
272    /// Routine tasks needing moderate reasoning (e.g., "write a function")
273    Simple,
274    /// Tasks requiring multi-step reasoning (e.g., "architect a service")
275    Moderate,
276    /// Complex tasks needing deep analysis (e.g., "write a full codebase")
277    #[default]
278    Complex,
279    /// Research-grade tasks needing the best models
280    Research,
281}
282
283impl Complexity {
284    /// Returns the relative cost tier (0=cheapest, 4=most expensive) for routing
285    pub fn cost_tier(&self) -> u8 {
286        match self {
287            Self::Trivial => 0,
288            Self::Simple => 1,
289            Self::Moderate => 2,
290            Self::Complex => 3,
291            Self::Research => 4,
292        }
293    }
294}
295
296/// Tool result returned by agent tool execution.
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct ToolResult {
299    /// ID of the tool call this result corresponds to.
300    pub tool_call_id: String,
301    /// Human-readable result or error text.
302    pub content: String,
303    /// `"success"` or `"error"`.
304    pub status: String,
305}
306
307impl ToolResult {
308    /// Create a successful tool result.
309    pub fn success(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
310        Self {
311            tool_call_id: tool_call_id.into(),
312            content: content.into(),
313            status: "success".to_string(),
314        }
315    }
316
317    /// Create an error tool result.
318    pub fn error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
319        Self {
320            tool_call_id: tool_call_id.into(),
321            content: content.into(),
322            status: "error".to_string(),
323        }
324    }
325
326    /// Returns `true` if this result represents an error.
327    pub fn is_error(&self) -> bool {
328        self.status == "error"
329    }
330}
331
332/// Images API provider type.
333#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
334#[non_exhaustive]
335pub enum ImagesApi {
336    /// OpenRouter API (supports multiple image generation models).
337    OpenRouter,
338}
339
340impl std::fmt::Display for ImagesApi {
341    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342        match self {
343            ImagesApi::OpenRouter => write!(f, "openrouter"),
344        }
345    }
346}
347
348/// Request for image generation via an Images API provider.
349#[derive(Debug, Clone, Serialize, Deserialize)]
350#[serde(default)]
351pub struct ImageGenerationRequest {
352    /// The text prompt describing the desired image.
353    pub prompt: String,
354    /// Model identifier (e.g. `"openai/dall-e-3"`, `"black-forest-labs/flux-1-dev"`).
355    pub model: Option<String>,
356    /// Output size. Provider-dependent. Examples: `"1024x1024"`, `"1024x1792"`.
357    pub size: Option<String>,
358    /// Number of images to generate (default 1).
359    pub n: Option<u32>,
360    /// Output format: `"url"` (default) or `"b64_json"`.
361    pub response_format: Option<String>,
362}
363
364impl Default for ImageGenerationRequest {
365    fn default() -> Self {
366        Self {
367            prompt: String::new(),
368            model: None,
369            size: None,
370            n: Some(1),
371            response_format: Some("b64_json".to_string()),
372        }
373    }
374}
375
376/// Response from an image generation API call.
377#[derive(Debug, Clone, Serialize, Deserialize, Default)]
378#[serde(default)]
379pub struct ImageGenerationResponse {
380    /// Vector of generated image bytes (one per `n`). Raw PNG/JPEG data.
381    pub images: Vec<Vec<u8>>,
382    /// Revised prompt from the model (providers may rewrite prompts).
383    pub revised_prompt: Option<String>,
384}
385
386/// LLM model definition.
387///
388/// Describes a model's capabilities, endpoint, and cost structure.
389#[derive(Debug, Clone, Serialize, Deserialize)]
390pub struct Model {
391    /// Unique model identifier (e.g. `"gpt-4o"`, `"claude-3-5-sonnet"`).
392    pub id: String,
393    /// Human-readable display name.
394    pub name: String,
395    /// Which API dialect this model speaks.
396    pub api: Api,
397    /// Provider name (e.g. `"openai"`, `"anthropic"`).
398    pub provider: String,
399    /// Base URL for the provider API.
400    pub base_url: String,
401    /// Whether this model supports extended reasoning / thinking.
402    #[serde(default)]
403    pub reasoning: bool,
404    /// Supported input modalities.
405    #[serde(default)]
406    pub input: Vec<InputModality>,
407    /// Pricing information.
408    #[serde(default)]
409    pub cost: Cost,
410    /// Maximum context window in tokens.
411    pub context_window: usize,
412    /// Maximum output tokens per request.
413    pub max_tokens: usize,
414    /// Extra HTTP headers to send with every request.
415    #[serde(default)]
416    pub headers: HashMap<String, String>,
417    /// Compatibility tweaks for non-standard providers.
418    #[serde(default)]
419    pub compat: Option<CompatSettings>,
420}
421
422impl Model {
423    /// Create a new model with sensible defaults.
424    pub fn new(
425        id: impl Into<String>,
426        name: impl Into<String>,
427        api: Api,
428        provider: impl Into<String>,
429        base_url: impl Into<String>,
430    ) -> Self {
431        Self {
432            id: id.into(),
433            name: name.into(),
434            api,
435            provider: provider.into(),
436            base_url: base_url.into(),
437            reasoning: false,
438            input: vec![InputModality::Text],
439            cost: Cost::default(),
440            context_window: 128_000,
441            max_tokens: 32_000,
442            headers: HashMap::new(),
443            compat: None,
444        }
445    }
446
447    /// Returns `true` if the model accepts image inputs.
448    pub fn supports_vision(&self) -> bool {
449        self.input.contains(&InputModality::Image)
450    }
451
452    /// Returns `true` if the model supports extended reasoning.
453    pub fn supports_reasoning(&self) -> bool {
454        self.reasoning
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn model_roundtrip() {
464        let mut model = Model::new(
465            "gpt-4o",
466            "GPT-4o",
467            Api::OpenAiCompletions,
468            "openai",
469            "https://api.openai.com/v1",
470        );
471        model.reasoning = true;
472        model.input.push(InputModality::Image);
473        model.cost = Cost {
474            input: 5.0,
475            output: 15.0,
476            cache_read: 2.5,
477            cache_write: 0.0,
478        };
479        model.compat = Some(CompatSettings::default());
480
481        let json = serde_json::to_string(&model).unwrap();
482        let deserialized: Model = serde_json::from_str(&json).unwrap();
483
484        assert_eq!(deserialized.id, "gpt-4o");
485        assert_eq!(deserialized.name, "GPT-4o");
486        assert_eq!(deserialized.api, Api::OpenAiCompletions);
487        assert_eq!(deserialized.provider, "openai");
488        assert!(deserialized.reasoning);
489        assert!(deserialized.supports_vision());
490        assert!(deserialized.supports_reasoning());
491        assert_eq!(deserialized.cost.input, 5.0);
492        assert_eq!(deserialized.cost.output, 15.0);
493    }
494
495    #[test]
496    fn usage_calculate_cost() {
497        let mut usage = Usage {
498            input: 1_000_000,
499            output: 500_000,
500            cache_read: 200_000,
501            cache_write: 100_000,
502            ..Default::default()
503        };
504        usage.calculate_cost(None, None);
505
506        assert_eq!(usage.total_tokens, 1_800_000);
507        assert_eq!(usage.cost.input, 1.0);
508        assert_eq!(usage.cost.output, 0.5);
509        assert_eq!(usage.cost.cache_read, 0.2);
510        assert_eq!(usage.cost.cache_write, 0.1);
511    }
512
513    #[test]
514    fn cost_total() {
515        let cost = Cost {
516            input: 3.0,
517            output: 6.0,
518            cache_read: 1.0,
519            cache_write: 0.5,
520        };
521        assert!((cost.total() - 10.5).abs() < f64::EPSILON);
522
523        let default_cost = Cost::default();
524        assert_eq!(default_cost.total(), 0.0);
525    }
526
527    #[test]
528    fn api_display() {
529        assert_eq!(Api::OpenAiCompletions.to_string(), "openai-completions");
530        assert_eq!(Api::OpenAiResponses.to_string(), "openai-responses");
531        assert_eq!(Api::AnthropicMessages.to_string(), "anthropic-messages");
532        assert_eq!(Api::GoogleGenerativeAi.to_string(), "google-generative-ai");
533        assert_eq!(Api::GoogleVertex.to_string(), "google-vertex");
534        assert_eq!(
535            Api::MistralConversations.to_string(),
536            "mistral-conversations"
537        );
538        assert_eq!(
539            Api::AzureOpenAiResponses.to_string(),
540            "azure-openai-responses"
541        );
542        assert_eq!(
543            Api::BedrockConverseStream.to_string(),
544            "bedrock-converse-stream"
545        );
546    }
547
548    #[test]
549    fn api_serde_roundtrip() {
550        for api in [
551            Api::OpenAiCompletions,
552            Api::OpenAiResponses,
553            Api::AnthropicMessages,
554            Api::GoogleGenerativeAi,
555            Api::GoogleVertex,
556            Api::MistralConversations,
557            Api::AzureOpenAiResponses,
558            Api::BedrockConverseStream,
559        ] {
560            let json = serde_json::to_string(&api).unwrap();
561            let back: Api = serde_json::from_str(&json).unwrap();
562            assert_eq!(api, back);
563        }
564    }
565
566    #[test]
567    fn thinking_level_serde() {
568        for level in [
569            ThinkingLevel::Off,
570            ThinkingLevel::Minimal,
571            ThinkingLevel::Low,
572            ThinkingLevel::Medium,
573            ThinkingLevel::High,
574            ThinkingLevel::XHigh,
575        ] {
576            let json = serde_json::to_string(&level).unwrap();
577            let back: ThinkingLevel = serde_json::from_str(&json).unwrap();
578            assert_eq!(level, back);
579        }
580        // Verify default
581        assert_eq!(ThinkingLevel::default(), ThinkingLevel::Off);
582        // Verify rename values
583        assert_eq!(
584            serde_json::to_string(&ThinkingLevel::High).unwrap(),
585            "\"high\""
586        );
587        assert_eq!(
588            serde_json::to_string(&ThinkingLevel::Off).unwrap(),
589            "\"off\""
590        );
591        // as_str
592        assert!(ThinkingLevel::Off.as_str().is_none());
593        assert_eq!(ThinkingLevel::High.as_str(), Some("high"));
594        assert_eq!(ThinkingLevel::XHigh.as_str(), Some("xhigh"));
595    }
596
597    #[test]
598    fn stop_reason_serde() {
599        assert_eq!(
600            serde_json::to_string(&StopReason::ToolUse).unwrap(),
601            "\"toolUse\""
602        );
603        let back: StopReason = serde_json::from_str("\"toolUse\"").unwrap();
604        assert_eq!(back, StopReason::ToolUse);
605    }
606
607    #[test]
608    fn tool_result_helpers() {
609        let success = ToolResult::success("call_1", "result text");
610        assert_eq!(success.tool_call_id, "call_1");
611        assert_eq!(success.content, "result text");
612        assert_eq!(success.status, "success");
613        assert!(!success.is_error());
614
615        let error = ToolResult::error("call_2", "something failed");
616        assert!(error.is_error());
617        assert_eq!(error.status, "error");
618    }
619
620    #[test]
621    fn cache_retention_default() {
622        assert_eq!(CacheRetention::default(), CacheRetention::None);
623    }
624}