Skip to main content

oxi_ai/
types.rs

1//! Core domain types for oxi-ai
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt;
6use std::hash::Hash;
7
8/// Provider API identifier.
9///
10/// Selects the wire-format / protocol dialect spoken to a particular LLM provider.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[non_exhaustive]
13pub enum Api {
14    /// OpenAI Chat Completions API.
15    #[serde(rename = "openai-completions")]
16    OpenAiCompletions,
17    /// OpenAI Responses API.
18    #[serde(rename = "openai-responses")]
19    OpenAiResponses,
20    /// Anthropic Messages API.
21    #[serde(rename = "anthropic-messages")]
22    AnthropicMessages,
23    /// Google Generative AI (Gemini) API.
24    #[serde(rename = "google-generative-ai")]
25    GoogleGenerativeAi,
26    /// Google Vertex AI endpoint.
27    #[serde(rename = "google-vertex")]
28    GoogleVertex,
29    /// Mistral Conversations API.
30    #[serde(rename = "mistral-conversations")]
31    MistralConversations,
32    /// Azure OpenAI Responses API.
33    #[serde(rename = "azure-openai-responses")]
34    AzureOpenAiResponses,
35    /// AWS Bedrock Converse Stream API.
36    #[serde(rename = "bedrock-converse-stream")]
37    BedrockConverseStream,
38}
39
40impl fmt::Display for Api {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            Api::OpenAiCompletions => write!(f, "openai-completions"),
44            Api::OpenAiResponses => write!(f, "openai-responses"),
45            Api::AnthropicMessages => write!(f, "anthropic-messages"),
46            Api::GoogleGenerativeAi => write!(f, "google-generative-ai"),
47            Api::GoogleVertex => write!(f, "google-vertex"),
48            Api::MistralConversations => write!(f, "mistral-conversations"),
49            Api::AzureOpenAiResponses => write!(f, "azure-openai-responses"),
50            Api::BedrockConverseStream => write!(f, "bedrock-converse-stream"),
51        }
52    }
53}
54
55/// Cache retention preference
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
57#[serde(rename_all = "lowercase")]
58pub enum CacheRetention {
59    /// 캐시를 사용하지 않음 (기본값).
60    #[default]
61    None,
62    /// 단기 캐시 유지.
63    Short,
64    /// 장기 캐시 유지.
65    Long,
66}
67
68/// Model thinking/reasoning level
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
70#[serde(rename_all = "lowercase")]
71#[non_exhaustive]
72pub enum ThinkingLevel {
73    /// 확장 추론 비활성화 (기본값).
74    #[default]
75    Off,
76    /// 최소 수준의 추론.
77    Minimal,
78    /// 낮은 수준의 추론.
79    Low,
80    /// 중간 수준의 추론.
81    Medium,
82    /// 높은 수준의 추론.
83    High,
84    /// 매우 높은 수준의 추론.
85    XHigh,
86}
87
88impl ThinkingLevel {
89    /// 추론 수준을 문자열로 반환. `Off`면 `None`.
90    pub fn as_str(&self) -> Option<&str> {
91        match self {
92            ThinkingLevel::Off => None,
93            ThinkingLevel::Minimal => Some("minimal"),
94            ThinkingLevel::Low => Some("low"),
95            ThinkingLevel::Medium => Some("medium"),
96            ThinkingLevel::High => Some("high"),
97            ThinkingLevel::XHigh => Some("xhigh"),
98        }
99    }
100}
101
102/// Input modalities
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
104#[serde(rename_all = "lowercase")]
105#[non_exhaustive]
106pub enum InputModality {
107    /// 텍스트 입력.
108    Text,
109    /// 이미지 입력.
110    Image,
111}
112
113/// Cost structure – prices per million tokens.
114#[derive(Debug, Clone, Default, Serialize, Deserialize)]
115#[serde(default)]
116pub struct Cost {
117    /// Input token cost ($/M tokens).
118    #[serde(default)]
119    pub input: f64,
120    /// Output token cost ($/M tokens).
121    #[serde(default)]
122    pub output: f64,
123    /// Cached-input read cost ($/M tokens).
124    #[serde(default)]
125    pub cache_read: f64,
126    /// Cache write cost ($/M tokens).
127    #[serde(default)]
128    pub cache_write: f64,
129}
130
131impl Cost {
132    /// Sum of all cost components.
133    pub fn total(&self) -> f64 {
134        self.input + self.output + self.cache_read + self.cache_write
135    }
136}
137
138/// Stop reason – why the model finished generating.
139#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
140#[serde(rename_all = "camelCase")]
141#[non_exhaustive]
142pub enum StopReason {
143    /// Normal stop – the model finished its response.
144    Stop,
145    /// Hit the maximum output token limit.
146    Length,
147    /// Stopped to invoke a tool.
148    ToolUse,
149    /// An error occurred during generation.
150    Error,
151    /// Generation was aborted by the client.
152    Aborted,
153}
154
155/// Token usage statistics.
156#[derive(Debug, Clone, Default, Serialize, Deserialize)]
157pub struct Usage {
158    /// Number of input (prompt) tokens.
159    #[serde(default)]
160    pub input: usize,
161    /// Number of output (completion) tokens.
162    #[serde(default)]
163    pub output: usize,
164    /// Number of tokens read from cache.
165    #[serde(default)]
166    pub cache_read: usize,
167    /// Number of tokens written to cache.
168    #[serde(default)]
169    pub cache_write: usize,
170    /// Total tokens (input + output + cache).
171    #[serde(default)]
172    pub total_tokens: usize,
173    /// Computed cost in dollars.
174    #[serde(default)]
175    pub cost: Cost,
176}
177
178impl Usage {
179    /// Recalculate `total_tokens` and per-component costs from raw token counts.
180    ///
181    /// If pricing parameters are provided, they override the default $1/M rate.
182    pub fn calculate_cost(
183        &mut self,
184        input_cost_per_million: Option<f64>,
185        output_cost_per_million: Option<f64>,
186    ) {
187        self.total_tokens = self.input + self.output + self.cache_read + self.cache_write;
188        self.cost.input = input_cost_per_million.unwrap_or(1.0) * self.input as f64 / 1_000_000.0;
189        self.cost.output =
190            output_cost_per_million.unwrap_or(1.0) * self.output as f64 / 1_000_000.0;
191        self.cost.cache_read = (self.cache_read as f64) / 1_000_000.0;
192        self.cost.cache_write = (self.cache_write as f64) / 1_000_000.0;
193    }
194}
195
196/// Compatibility settings for OpenAI-compatible APIs.
197///
198/// Not every OpenAI-compatible provider supports every feature.
199/// These flags let the streaming layer adapt its request shape.
200#[derive(Debug, Clone, Default, Serialize, Deserialize)]
201#[serde(default)]
202pub struct CompatSettings {
203    /// Whether the provider supports the `store` parameter.
204    #[serde(default = "default_true")]
205    pub supports_store: bool,
206    /// Whether the provider recognises the `developer` role.
207    #[serde(default = "default_true")]
208    pub supports_developer_role: bool,
209    /// Whether the provider supports `reasoning_effort`.
210    #[serde(default = "default_true")]
211    pub supports_reasoning_effort: bool,
212    /// Whether the provider returns usage data in streaming responses.
213    #[serde(default = "default_true")]
214    pub supports_usage_in_streaming: bool,
215    /// Which JSON field name to use for the max-tokens parameter.
216    #[serde(default)]
217    pub max_tokens_field: Option<MaxTokensField>,
218    /// Whether tool results must include the tool name.
219    #[serde(default = "default_false")]
220    pub requires_tool_result_name: bool,
221    /// Whether an assistant message must follow every tool result.
222    #[serde(default = "default_false")]
223    pub requires_assistant_after_tool_result: bool,
224    /// Whether thinking should be sent as plain text.
225    #[serde(default = "default_false")]
226    pub requires_thinking_as_text: bool,
227    /// Provider-specific thinking wire-format.
228    #[serde(default)]
229    pub thinking_format: Option<ThinkingFormat>,
230}
231
232fn default_true() -> bool {
233    true
234}
235fn default_false() -> bool {
236    false
237}
238
239/// Which JSON field to use for the maximum output token count.
240#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
241#[serde(rename_all = "kebab-case")]
242pub enum MaxTokensField {
243    /// Use `max_completion_tokens`.
244    MaxCompletionTokens,
245    /// Use `max_tokens`.
246    MaxTokens,
247}
248
249/// Provider-specific wire format for extended thinking.
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
251#[serde(rename_all = "lowercase")]
252pub enum ThinkingFormat {
253    /// OpenAI native thinking format.
254    OpenAI,
255    /// OpenRouter thinking format.
256    OpenRouter,
257    /// DeepSeek thinking format.
258    DeepSeek,
259    /// Zai thinking format.
260    Zai,
261    /// Qwen API thinking format.
262    Qwen,
263    /// Qwen chat-template thinking format.
264    QwenChatTemplate,
265}
266
267/// Task complexity level for routing decisions.
268#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
269pub enum Complexity {
270    /// Simple, single-step tasks (e.g., "translate this text")
271    Trivial,
272    /// Routine tasks needing moderate reasoning (e.g., "write a function")
273    Simple,
274    /// Tasks requiring multi-step reasoning (e.g., "architect a service")
275    Moderate,
276    /// Complex tasks needing deep analysis (e.g., "write a full codebase")
277    #[default]
278    Complex,
279    /// Research-grade tasks needing the best models
280    Research,
281}
282
283impl Complexity {
284    /// Returns the relative cost tier (0=cheapest, 4=most expensive) for routing
285    pub fn cost_tier(&self) -> u8 {
286        match self {
287            Self::Trivial => 0,
288            Self::Simple => 1,
289            Self::Moderate => 2,
290            Self::Complex => 3,
291            Self::Research => 4,
292        }
293    }
294}
295
296/// Tool result returned by agent tool execution.
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct ToolResult {
299    /// ID of the tool call this result corresponds to.
300    pub tool_call_id: String,
301    /// Human-readable result or error text.
302    pub content: String,
303    /// `"success"` or `"error"`.
304    pub status: String,
305}
306
307impl ToolResult {
308    /// Create a successful tool result.
309    pub fn success(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
310        Self {
311            tool_call_id: tool_call_id.into(),
312            content: content.into(),
313            status: "success".to_string(),
314        }
315    }
316
317    /// Create an error tool result.
318    pub fn error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
319        Self {
320            tool_call_id: tool_call_id.into(),
321            content: content.into(),
322            status: "error".to_string(),
323        }
324    }
325
326    /// Returns `true` if this result represents an error.
327    pub fn is_error(&self) -> bool {
328        self.status == "error"
329    }
330}
331
332/// LLM model definition.
333///
334/// Describes a model's capabilities, endpoint, and cost structure.
335#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct Model {
337    /// Unique model identifier (e.g. `"gpt-4o"`, `"claude-3-5-sonnet"`).
338    pub id: String,
339    /// Human-readable display name.
340    pub name: String,
341    /// Which API dialect this model speaks.
342    pub api: Api,
343    /// Provider name (e.g. `"openai"`, `"anthropic"`).
344    pub provider: String,
345    /// Base URL for the provider API.
346    pub base_url: String,
347    /// Whether this model supports extended reasoning / thinking.
348    #[serde(default)]
349    pub reasoning: bool,
350    /// Supported input modalities.
351    #[serde(default)]
352    pub input: Vec<InputModality>,
353    /// Pricing information.
354    #[serde(default)]
355    pub cost: Cost,
356    /// Maximum context window in tokens.
357    pub context_window: usize,
358    /// Maximum output tokens per request.
359    pub max_tokens: usize,
360    /// Extra HTTP headers to send with every request.
361    #[serde(default)]
362    pub headers: HashMap<String, String>,
363    /// Compatibility tweaks for non-standard providers.
364    #[serde(default)]
365    pub compat: Option<CompatSettings>,
366}
367
368impl Model {
369    /// Create a new model with sensible defaults.
370    pub fn new(
371        id: impl Into<String>,
372        name: impl Into<String>,
373        api: Api,
374        provider: impl Into<String>,
375        base_url: impl Into<String>,
376    ) -> Self {
377        Self {
378            id: id.into(),
379            name: name.into(),
380            api,
381            provider: provider.into(),
382            base_url: base_url.into(),
383            reasoning: false,
384            input: vec![InputModality::Text],
385            cost: Cost::default(),
386            context_window: 128_000,
387            max_tokens: 32_000,
388            headers: HashMap::new(),
389            compat: None,
390        }
391    }
392
393    /// Returns `true` if the model accepts image inputs.
394    pub fn supports_vision(&self) -> bool {
395        self.input.contains(&InputModality::Image)
396    }
397
398    /// Returns `true` if the model supports extended reasoning.
399    pub fn supports_reasoning(&self) -> bool {
400        self.reasoning
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn model_roundtrip() {
410        let mut model = Model::new(
411            "gpt-4o",
412            "GPT-4o",
413            Api::OpenAiCompletions,
414            "openai",
415            "https://api.openai.com/v1",
416        );
417        model.reasoning = true;
418        model.input.push(InputModality::Image);
419        model.cost = Cost {
420            input: 5.0,
421            output: 15.0,
422            cache_read: 2.5,
423            cache_write: 0.0,
424        };
425        model.compat = Some(CompatSettings::default());
426
427        let json = serde_json::to_string(&model).unwrap();
428        let deserialized: Model = serde_json::from_str(&json).unwrap();
429
430        assert_eq!(deserialized.id, "gpt-4o");
431        assert_eq!(deserialized.name, "GPT-4o");
432        assert_eq!(deserialized.api, Api::OpenAiCompletions);
433        assert_eq!(deserialized.provider, "openai");
434        assert!(deserialized.reasoning);
435        assert!(deserialized.supports_vision());
436        assert!(deserialized.supports_reasoning());
437        assert_eq!(deserialized.cost.input, 5.0);
438        assert_eq!(deserialized.cost.output, 15.0);
439    }
440
441    #[test]
442    fn usage_calculate_cost() {
443        let mut usage = Usage {
444            input: 1_000_000,
445            output: 500_000,
446            cache_read: 200_000,
447            cache_write: 100_000,
448            ..Default::default()
449        };
450        usage.calculate_cost(None, None);
451
452        assert_eq!(usage.total_tokens, 1_800_000);
453        assert_eq!(usage.cost.input, 1.0);
454        assert_eq!(usage.cost.output, 0.5);
455        assert_eq!(usage.cost.cache_read, 0.2);
456        assert_eq!(usage.cost.cache_write, 0.1);
457    }
458
459    #[test]
460    fn cost_total() {
461        let cost = Cost {
462            input: 3.0,
463            output: 6.0,
464            cache_read: 1.0,
465            cache_write: 0.5,
466        };
467        assert!((cost.total() - 10.5).abs() < f64::EPSILON);
468
469        let default_cost = Cost::default();
470        assert_eq!(default_cost.total(), 0.0);
471    }
472
473    #[test]
474    fn api_display() {
475        assert_eq!(Api::OpenAiCompletions.to_string(), "openai-completions");
476        assert_eq!(Api::OpenAiResponses.to_string(), "openai-responses");
477        assert_eq!(Api::AnthropicMessages.to_string(), "anthropic-messages");
478        assert_eq!(Api::GoogleGenerativeAi.to_string(), "google-generative-ai");
479        assert_eq!(Api::GoogleVertex.to_string(), "google-vertex");
480        assert_eq!(
481            Api::MistralConversations.to_string(),
482            "mistral-conversations"
483        );
484        assert_eq!(
485            Api::AzureOpenAiResponses.to_string(),
486            "azure-openai-responses"
487        );
488        assert_eq!(
489            Api::BedrockConverseStream.to_string(),
490            "bedrock-converse-stream"
491        );
492    }
493
494    #[test]
495    fn api_serde_roundtrip() {
496        for api in [
497            Api::OpenAiCompletions,
498            Api::OpenAiResponses,
499            Api::AnthropicMessages,
500            Api::GoogleGenerativeAi,
501            Api::GoogleVertex,
502            Api::MistralConversations,
503            Api::AzureOpenAiResponses,
504            Api::BedrockConverseStream,
505        ] {
506            let json = serde_json::to_string(&api).unwrap();
507            let back: Api = serde_json::from_str(&json).unwrap();
508            assert_eq!(api, back);
509        }
510    }
511
512    #[test]
513    fn thinking_level_serde() {
514        for level in [
515            ThinkingLevel::Off,
516            ThinkingLevel::Minimal,
517            ThinkingLevel::Low,
518            ThinkingLevel::Medium,
519            ThinkingLevel::High,
520            ThinkingLevel::XHigh,
521        ] {
522            let json = serde_json::to_string(&level).unwrap();
523            let back: ThinkingLevel = serde_json::from_str(&json).unwrap();
524            assert_eq!(level, back);
525        }
526        // Verify default
527        assert_eq!(ThinkingLevel::default(), ThinkingLevel::Off);
528        // Verify rename values
529        assert_eq!(
530            serde_json::to_string(&ThinkingLevel::High).unwrap(),
531            "\"high\""
532        );
533        assert_eq!(
534            serde_json::to_string(&ThinkingLevel::Off).unwrap(),
535            "\"off\""
536        );
537        // as_str
538        assert!(ThinkingLevel::Off.as_str().is_none());
539        assert_eq!(ThinkingLevel::High.as_str(), Some("high"));
540        assert_eq!(ThinkingLevel::XHigh.as_str(), Some("xhigh"));
541    }
542
543    #[test]
544    fn stop_reason_serde() {
545        assert_eq!(
546            serde_json::to_string(&StopReason::ToolUse).unwrap(),
547            "\"toolUse\""
548        );
549        let back: StopReason = serde_json::from_str("\"toolUse\"").unwrap();
550        assert_eq!(back, StopReason::ToolUse);
551    }
552
553    #[test]
554    fn tool_result_helpers() {
555        let success = ToolResult::success("call_1", "result text");
556        assert_eq!(success.tool_call_id, "call_1");
557        assert_eq!(success.content, "result text");
558        assert_eq!(success.status, "success");
559        assert!(!success.is_error());
560
561        let error = ToolResult::error("call_2", "something failed");
562        assert!(error.is_error());
563        assert_eq!(error.status, "error");
564    }
565
566    #[test]
567    fn cache_retention_default() {
568        assert_eq!(CacheRetention::default(), CacheRetention::None);
569    }
570}