Skip to main content

oxi_ai/
types.rs

1//! Core domain types for oxi-ai
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt;
6use std::hash::Hash;
7
8/// Provider API identifier.
9///
10/// Selects the wire-format / protocol dialect spoken to a particular LLM provider.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[non_exhaustive]
13pub enum Api {
14    /// OpenAI Chat Completions API.
15    #[serde(rename = "openai-completions")]
16    OpenAiCompletions,
17    /// OpenAI Responses API.
18    #[serde(rename = "openai-responses")]
19    OpenAiResponses,
20    /// Anthropic Messages API.
21    #[serde(rename = "anthropic-messages")]
22    AnthropicMessages,
23    /// Google Generative AI (Gemini) API.
24    #[serde(rename = "google-generative-ai")]
25    GoogleGenerativeAi,
26    /// Google Vertex AI endpoint.
27    #[serde(rename = "google-vertex")]
28    GoogleVertex,
29    /// Mistral Conversations API.
30    #[serde(rename = "mistral-conversations")]
31    MistralConversations,
32    /// Azure OpenAI Responses API.
33    #[serde(rename = "azure-openai-responses")]
34    AzureOpenAiResponses,
35    /// AWS Bedrock Converse Stream API.
36    #[serde(rename = "bedrock-converse-stream")]
37    BedrockConverseStream,
38}
39
40impl fmt::Display for Api {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            Api::OpenAiCompletions => write!(f, "openai-completions"),
44            Api::OpenAiResponses => write!(f, "openai-responses"),
45            Api::AnthropicMessages => write!(f, "anthropic-messages"),
46            Api::GoogleGenerativeAi => write!(f, "google-generative-ai"),
47            Api::GoogleVertex => write!(f, "google-vertex"),
48            Api::MistralConversations => write!(f, "mistral-conversations"),
49            Api::AzureOpenAiResponses => write!(f, "azure-openai-responses"),
50            Api::BedrockConverseStream => write!(f, "bedrock-converse-stream"),
51        }
52    }
53}
54
55/// Cache retention preference
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
57#[serde(rename_all = "lowercase")]
58pub enum CacheRetention {
59    /// 캐시를 사용하지 않음 (기본값).
60    #[default]
61    None,
62    /// 단기 캐시 유지.
63    Short,
64    /// 장기 캐시 유지.
65    Long,
66}
67
68/// Model thinking/reasoning level
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
70#[serde(rename_all = "lowercase")]
71#[non_exhaustive]
72pub enum ThinkingLevel {
73    /// 확장 추론 비활성화 (기본값).
74    #[default]
75    Off,
76    /// 최소 수준의 추론.
77    Minimal,
78    /// 낮은 수준의 추론.
79    Low,
80    /// 중간 수준의 추론.
81    Medium,
82    /// 높은 수준의 추론.
83    High,
84    /// 매우 높은 수준의 추론.
85    XHigh,
86}
87
88impl ThinkingLevel {
89    /// 추론 수준을 문자열로 반환. `Off`면 `None`.
90    pub fn as_str(&self) -> Option<&str> {
91        match self {
92            ThinkingLevel::Off => None,
93            ThinkingLevel::Minimal => Some("minimal"),
94            ThinkingLevel::Low => Some("low"),
95            ThinkingLevel::Medium => Some("medium"),
96            ThinkingLevel::High => Some("high"),
97            ThinkingLevel::XHigh => Some("xhigh"),
98        }
99    }
100}
101
102/// Input modalities
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
104#[serde(rename_all = "lowercase")]
105#[non_exhaustive]
106pub enum InputModality {
107    /// 텍스트 입력.
108    Text,
109    /// 이미지 입력.
110    Image,
111}
112
113/// Cost structure – prices per million tokens.
114#[derive(Debug, Clone, Default, Serialize, Deserialize)]
115#[serde(default)]
116pub struct Cost {
117    /// Input token cost ($/M tokens).
118    #[serde(default)]
119    pub input: f64,
120    /// Output token cost ($/M tokens).
121    #[serde(default)]
122    pub output: f64,
123    /// Cached-input read cost ($/M tokens).
124    #[serde(default)]
125    pub cache_read: f64,
126    /// Cache write cost ($/M tokens).
127    #[serde(default)]
128    pub cache_write: f64,
129}
130
131impl Cost {
132    /// Sum of all cost components.
133    pub fn total(&self) -> f64 {
134        self.input + self.output + self.cache_read + self.cache_write
135    }
136}
137
138/// Stop reason – why the model finished generating.
139#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
140#[serde(rename_all = "camelCase")]
141#[non_exhaustive]
142pub enum StopReason {
143    /// Normal stop – the model finished its response.
144    Stop,
145    /// Hit the maximum output token limit.
146    Length,
147    /// Stopped to invoke a tool.
148    ToolUse,
149    /// An error occurred during generation.
150    Error,
151    /// Generation was aborted by the client.
152    Aborted,
153}
154
155/// Token usage statistics.
156#[derive(Debug, Clone, Default, Serialize, Deserialize)]
157pub struct Usage {
158    /// Number of input (prompt) tokens.
159    #[serde(default)]
160    pub input: usize,
161    /// Number of output (completion) tokens.
162    #[serde(default)]
163    pub output: usize,
164    /// Number of tokens read from cache.
165    #[serde(default)]
166    pub cache_read: usize,
167    /// Number of tokens written to cache.
168    #[serde(default)]
169    pub cache_write: usize,
170    /// Total tokens (input + output + cache).
171    #[serde(default)]
172    pub total_tokens: usize,
173    /// Computed cost in dollars.
174    #[serde(default)]
175    pub cost: Cost,
176}
177
178impl Usage {
179    /// Recalculate `total_tokens` and per-component costs from raw token counts.
180    ///
181    /// If pricing parameters are provided, they override the default $1/M rate.
182    pub fn calculate_cost(
183        &mut self,
184        input_cost_per_million: Option<f64>,
185        output_cost_per_million: Option<f64>,
186    ) {
187        self.total_tokens = self.input + self.output + self.cache_read + self.cache_write;
188        self.cost.input = input_cost_per_million.unwrap_or(1.0) * self.input as f64 / 1_000_000.0;
189        self.cost.output =
190            output_cost_per_million.unwrap_or(1.0) * self.output as f64 / 1_000_000.0;
191        self.cost.cache_read = (self.cache_read as f64) / 1_000_000.0;
192        self.cost.cache_write = (self.cache_write as f64) / 1_000_000.0;
193    }
194}
195
196/// Compatibility settings for OpenAI-compatible APIs.
197///
198/// Not every OpenAI-compatible provider supports every feature.
199/// These flags let the streaming layer adapt its request shape.
200#[derive(Debug, Clone, Default, Serialize, Deserialize)]
201#[serde(default)]
202pub struct CompatSettings {
203    /// Whether the provider supports the `store` parameter.
204    #[serde(default = "default_true")]
205    pub supports_store: bool,
206    /// Whether the provider recognises the `developer` role.
207    #[serde(default = "default_true")]
208    pub supports_developer_role: bool,
209    /// Whether the provider supports `reasoning_effort`.
210    #[serde(default = "default_true")]
211    pub supports_reasoning_effort: bool,
212    /// Whether the provider returns usage data in streaming responses.
213    #[serde(default = "default_true")]
214    pub supports_usage_in_streaming: bool,
215    /// Which JSON field name to use for the max-tokens parameter.
216    #[serde(default)]
217    pub max_tokens_field: Option<MaxTokensField>,
218    /// Whether tool results must include the tool name.
219    #[serde(default = "default_false")]
220    pub requires_tool_result_name: bool,
221    /// Whether an assistant message must follow every tool result.
222    #[serde(default = "default_false")]
223    pub requires_assistant_after_tool_result: bool,
224    /// Whether thinking should be sent as plain text.
225    #[serde(default = "default_false")]
226    pub requires_thinking_as_text: bool,
227    /// Provider-specific thinking wire-format.
228    #[serde(default)]
229    pub thinking_format: Option<ThinkingFormat>,
230}
231
232fn default_true() -> bool {
233    true
234}
235fn default_false() -> bool {
236    false
237}
238
239/// Which JSON field to use for the maximum output token count.
240#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
241#[serde(rename_all = "kebab-case")]
242pub enum MaxTokensField {
243    /// Use `max_completion_tokens`.
244    MaxCompletionTokens,
245    /// Use `max_tokens`.
246    MaxTokens,
247}
248
249/// Provider-specific wire format for extended thinking.
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
251#[serde(rename_all = "lowercase")]
252pub enum ThinkingFormat {
253    /// OpenAI native thinking format.
254    OpenAI,
255    /// OpenRouter thinking format.
256    OpenRouter,
257    /// DeepSeek thinking format.
258    DeepSeek,
259    /// Zai thinking format.
260    Zai,
261    /// Qwen API thinking format.
262    Qwen,
263    /// Qwen chat-template thinking format.
264    QwenChatTemplate,
265}
266
267/// Tool result returned by agent tool execution.
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct ToolResult {
270    /// ID of the tool call this result corresponds to.
271    pub tool_call_id: String,
272    /// Human-readable result or error text.
273    pub content: String,
274    /// `"success"` or `"error"`.
275    pub status: String,
276}
277
278impl ToolResult {
279    /// Create a successful tool result.
280    pub fn success(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
281        Self {
282            tool_call_id: tool_call_id.into(),
283            content: content.into(),
284            status: "success".to_string(),
285        }
286    }
287
288    /// Create an error tool result.
289    pub fn error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
290        Self {
291            tool_call_id: tool_call_id.into(),
292            content: content.into(),
293            status: "error".to_string(),
294        }
295    }
296
297    /// Returns `true` if this result represents an error.
298    pub fn is_error(&self) -> bool {
299        self.status == "error"
300    }
301}
302
303/// LLM model definition.
304///
305/// Describes a model's capabilities, endpoint, and cost structure.
306#[derive(Debug, Clone, Serialize, Deserialize)]
307pub struct Model {
308    /// Unique model identifier (e.g. `"gpt-4o"`, `"claude-3-5-sonnet"`).
309    pub id: String,
310    /// Human-readable display name.
311    pub name: String,
312    /// Which API dialect this model speaks.
313    pub api: Api,
314    /// Provider name (e.g. `"openai"`, `"anthropic"`).
315    pub provider: String,
316    /// Base URL for the provider API.
317    pub base_url: String,
318    /// Whether this model supports extended reasoning / thinking.
319    #[serde(default)]
320    pub reasoning: bool,
321    /// Supported input modalities.
322    #[serde(default)]
323    pub input: Vec<InputModality>,
324    /// Pricing information.
325    #[serde(default)]
326    pub cost: Cost,
327    /// Maximum context window in tokens.
328    pub context_window: usize,
329    /// Maximum output tokens per request.
330    pub max_tokens: usize,
331    /// Extra HTTP headers to send with every request.
332    #[serde(default)]
333    pub headers: HashMap<String, String>,
334    /// Compatibility tweaks for non-standard providers.
335    #[serde(default)]
336    pub compat: Option<CompatSettings>,
337}
338
339impl Model {
340    /// Create a new model with sensible defaults.
341    pub fn new(
342        id: impl Into<String>,
343        name: impl Into<String>,
344        api: Api,
345        provider: impl Into<String>,
346        base_url: impl Into<String>,
347    ) -> Self {
348        Self {
349            id: id.into(),
350            name: name.into(),
351            api,
352            provider: provider.into(),
353            base_url: base_url.into(),
354            reasoning: false,
355            input: vec![InputModality::Text],
356            cost: Cost::default(),
357            context_window: 128_000,
358            max_tokens: 32_000,
359            headers: HashMap::new(),
360            compat: None,
361        }
362    }
363
364    /// Returns `true` if the model accepts image inputs.
365    pub fn supports_vision(&self) -> bool {
366        self.input.contains(&InputModality::Image)
367    }
368
369    /// Returns `true` if the model supports extended reasoning.
370    pub fn supports_reasoning(&self) -> bool {
371        self.reasoning
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn model_roundtrip() {
381        let mut model = Model::new(
382            "gpt-4o",
383            "GPT-4o",
384            Api::OpenAiCompletions,
385            "openai",
386            "https://api.openai.com/v1",
387        );
388        model.reasoning = true;
389        model.input.push(InputModality::Image);
390        model.cost = Cost {
391            input: 5.0,
392            output: 15.0,
393            cache_read: 2.5,
394            cache_write: 0.0,
395        };
396        model.compat = Some(CompatSettings::default());
397
398        let json = serde_json::to_string(&model).unwrap();
399        let deserialized: Model = serde_json::from_str(&json).unwrap();
400
401        assert_eq!(deserialized.id, "gpt-4o");
402        assert_eq!(deserialized.name, "GPT-4o");
403        assert_eq!(deserialized.api, Api::OpenAiCompletions);
404        assert_eq!(deserialized.provider, "openai");
405        assert!(deserialized.reasoning);
406        assert!(deserialized.supports_vision());
407        assert!(deserialized.supports_reasoning());
408        assert_eq!(deserialized.cost.input, 5.0);
409        assert_eq!(deserialized.cost.output, 15.0);
410    }
411
412    #[test]
413    fn usage_calculate_cost() {
414        let mut usage = Usage {
415            input: 1_000_000,
416            output: 500_000,
417            cache_read: 200_000,
418            cache_write: 100_000,
419            ..Default::default()
420        };
421        usage.calculate_cost(None, None);
422
423        assert_eq!(usage.total_tokens, 1_800_000);
424        assert_eq!(usage.cost.input, 1.0);
425        assert_eq!(usage.cost.output, 0.5);
426        assert_eq!(usage.cost.cache_read, 0.2);
427        assert_eq!(usage.cost.cache_write, 0.1);
428    }
429
430    #[test]
431    fn cost_total() {
432        let cost = Cost {
433            input: 3.0,
434            output: 6.0,
435            cache_read: 1.0,
436            cache_write: 0.5,
437        };
438        assert!((cost.total() - 10.5).abs() < f64::EPSILON);
439
440        let default_cost = Cost::default();
441        assert_eq!(default_cost.total(), 0.0);
442    }
443
444    #[test]
445    fn api_display() {
446        assert_eq!(Api::OpenAiCompletions.to_string(), "openai-completions");
447        assert_eq!(Api::OpenAiResponses.to_string(), "openai-responses");
448        assert_eq!(Api::AnthropicMessages.to_string(), "anthropic-messages");
449        assert_eq!(Api::GoogleGenerativeAi.to_string(), "google-generative-ai");
450        assert_eq!(Api::GoogleVertex.to_string(), "google-vertex");
451        assert_eq!(
452            Api::MistralConversations.to_string(),
453            "mistral-conversations"
454        );
455        assert_eq!(
456            Api::AzureOpenAiResponses.to_string(),
457            "azure-openai-responses"
458        );
459        assert_eq!(
460            Api::BedrockConverseStream.to_string(),
461            "bedrock-converse-stream"
462        );
463    }
464
465    #[test]
466    fn api_serde_roundtrip() {
467        for api in [
468            Api::OpenAiCompletions,
469            Api::OpenAiResponses,
470            Api::AnthropicMessages,
471            Api::GoogleGenerativeAi,
472            Api::GoogleVertex,
473            Api::MistralConversations,
474            Api::AzureOpenAiResponses,
475            Api::BedrockConverseStream,
476        ] {
477            let json = serde_json::to_string(&api).unwrap();
478            let back: Api = serde_json::from_str(&json).unwrap();
479            assert_eq!(api, back);
480        }
481    }
482
483    #[test]
484    fn thinking_level_serde() {
485        for level in [
486            ThinkingLevel::Off,
487            ThinkingLevel::Minimal,
488            ThinkingLevel::Low,
489            ThinkingLevel::Medium,
490            ThinkingLevel::High,
491            ThinkingLevel::XHigh,
492        ] {
493            let json = serde_json::to_string(&level).unwrap();
494            let back: ThinkingLevel = serde_json::from_str(&json).unwrap();
495            assert_eq!(level, back);
496        }
497        // Verify default
498        assert_eq!(ThinkingLevel::default(), ThinkingLevel::Off);
499        // Verify rename values
500        assert_eq!(
501            serde_json::to_string(&ThinkingLevel::High).unwrap(),
502            "\"high\""
503        );
504        assert_eq!(
505            serde_json::to_string(&ThinkingLevel::Off).unwrap(),
506            "\"off\""
507        );
508        // as_str
509        assert!(ThinkingLevel::Off.as_str().is_none());
510        assert_eq!(ThinkingLevel::High.as_str(), Some("high"));
511        assert_eq!(ThinkingLevel::XHigh.as_str(), Some("xhigh"));
512    }
513
514    #[test]
515    fn stop_reason_serde() {
516        assert_eq!(
517            serde_json::to_string(&StopReason::ToolUse).unwrap(),
518            "\"toolUse\""
519        );
520        let back: StopReason = serde_json::from_str("\"toolUse\"").unwrap();
521        assert_eq!(back, StopReason::ToolUse);
522    }
523
524    #[test]
525    fn tool_result_helpers() {
526        let success = ToolResult::success("call_1", "result text");
527        assert_eq!(success.tool_call_id, "call_1");
528        assert_eq!(success.content, "result text");
529        assert_eq!(success.status, "success");
530        assert!(!success.is_error());
531
532        let error = ToolResult::error("call_2", "something failed");
533        assert!(error.is_error());
534        assert_eq!(error.status, "error");
535    }
536
537    #[test]
538    fn cache_retention_default() {
539        assert_eq!(CacheRetention::default(), CacheRetention::None);
540    }
541}