1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt;
6use std::hash::Hash;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[non_exhaustive]
13pub enum Api {
14 #[serde(rename = "openai-completions")]
16 OpenAiCompletions,
17 #[serde(rename = "openai-responses")]
19 OpenAiResponses,
20 #[serde(rename = "anthropic-messages")]
22 AnthropicMessages,
23 #[serde(rename = "google-generative-ai")]
25 GoogleGenerativeAi,
26 #[serde(rename = "google-vertex")]
28 GoogleVertex,
29 #[serde(rename = "mistral-conversations")]
31 MistralConversations,
32 #[serde(rename = "azure-openai-responses")]
34 AzureOpenAiResponses,
35 #[serde(rename = "bedrock-converse-stream")]
37 BedrockConverseStream,
38}
39
40impl fmt::Display for Api {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 Api::OpenAiCompletions => write!(f, "openai-completions"),
44 Api::OpenAiResponses => write!(f, "openai-responses"),
45 Api::AnthropicMessages => write!(f, "anthropic-messages"),
46 Api::GoogleGenerativeAi => write!(f, "google-generative-ai"),
47 Api::GoogleVertex => write!(f, "google-vertex"),
48 Api::MistralConversations => write!(f, "mistral-conversations"),
49 Api::AzureOpenAiResponses => write!(f, "azure-openai-responses"),
50 Api::BedrockConverseStream => write!(f, "bedrock-converse-stream"),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
57#[serde(rename_all = "lowercase")]
58pub enum CacheRetention {
59 #[default]
61 None,
62 Short,
64 Long,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
70#[serde(rename_all = "lowercase")]
71#[non_exhaustive]
72pub enum ThinkingLevel {
73 #[default]
75 Off,
76 Minimal,
78 Low,
80 Medium,
82 High,
84 XHigh,
86}
87
88impl ThinkingLevel {
89 pub fn as_str(&self) -> Option<&str> {
91 match self {
92 ThinkingLevel::Off => None,
93 ThinkingLevel::Minimal => Some("minimal"),
94 ThinkingLevel::Low => Some("low"),
95 ThinkingLevel::Medium => Some("medium"),
96 ThinkingLevel::High => Some("high"),
97 ThinkingLevel::XHigh => Some("xhigh"),
98 }
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
104#[serde(rename_all = "lowercase")]
105#[non_exhaustive]
106pub enum InputModality {
107 Text,
109 Image,
111}
112
113#[derive(Debug, Clone, Default, Serialize, Deserialize)]
115#[serde(default)]
116pub struct Cost {
117 #[serde(default)]
119 pub input: f64,
120 #[serde(default)]
122 pub output: f64,
123 #[serde(default)]
125 pub cache_read: f64,
126 #[serde(default)]
128 pub cache_write: f64,
129}
130
131impl Cost {
132 pub fn total(&self) -> f64 {
134 self.input + self.output + self.cache_read + self.cache_write
135 }
136}
137
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
140#[serde(rename_all = "camelCase")]
141#[non_exhaustive]
142pub enum StopReason {
143 Stop,
145 Length,
147 ToolUse,
149 Error,
151 Aborted,
153}
154
155#[derive(Debug, Clone, Default, Serialize, Deserialize)]
157pub struct Usage {
158 #[serde(default)]
160 pub input: usize,
161 #[serde(default)]
163 pub output: usize,
164 #[serde(default)]
166 pub cache_read: usize,
167 #[serde(default)]
169 pub cache_write: usize,
170 #[serde(default)]
172 pub total_tokens: usize,
173 #[serde(default)]
175 pub cost: Cost,
176}
177
178impl Usage {
179 pub fn calculate_cost(
183 &mut self,
184 input_cost_per_million: Option<f64>,
185 output_cost_per_million: Option<f64>,
186 ) {
187 self.total_tokens = self.input + self.output + self.cache_read + self.cache_write;
188 self.cost.input = input_cost_per_million.unwrap_or(1.0) * self.input as f64 / 1_000_000.0;
189 self.cost.output =
190 output_cost_per_million.unwrap_or(1.0) * self.output as f64 / 1_000_000.0;
191 self.cost.cache_read = (self.cache_read as f64) / 1_000_000.0;
192 self.cost.cache_write = (self.cache_write as f64) / 1_000_000.0;
193 }
194}
195
196#[derive(Debug, Clone, Default, Serialize, Deserialize)]
201#[serde(default)]
202pub struct CompatSettings {
203 #[serde(default = "default_true")]
205 pub supports_store: bool,
206 #[serde(default = "default_true")]
208 pub supports_developer_role: bool,
209 #[serde(default = "default_true")]
211 pub supports_reasoning_effort: bool,
212 #[serde(default = "default_true")]
214 pub supports_usage_in_streaming: bool,
215 #[serde(default)]
217 pub max_tokens_field: Option<MaxTokensField>,
218 #[serde(default = "default_false")]
220 pub requires_tool_result_name: bool,
221 #[serde(default = "default_false")]
223 pub requires_assistant_after_tool_result: bool,
224 #[serde(default = "default_false")]
226 pub requires_thinking_as_text: bool,
227 #[serde(default)]
229 pub thinking_format: Option<ThinkingFormat>,
230}
231
232fn default_true() -> bool {
233 true
234}
235fn default_false() -> bool {
236 false
237}
238
239#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
241#[serde(rename_all = "kebab-case")]
242pub enum MaxTokensField {
243 MaxCompletionTokens,
245 MaxTokens,
247}
248
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
251#[serde(rename_all = "lowercase")]
252pub enum ThinkingFormat {
253 OpenAI,
255 OpenRouter,
257 DeepSeek,
259 Zai,
261 Qwen,
263 QwenChatTemplate,
265}
266
267#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
269pub enum Complexity {
270 Trivial,
272 Simple,
274 Moderate,
276 #[default]
278 Complex,
279 Research,
281}
282
283impl Complexity {
284 pub fn cost_tier(&self) -> u8 {
286 match self {
287 Self::Trivial => 0,
288 Self::Simple => 1,
289 Self::Moderate => 2,
290 Self::Complex => 3,
291 Self::Research => 4,
292 }
293 }
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct ToolResult {
299 pub tool_call_id: String,
301 pub content: String,
303 pub status: String,
305}
306
307impl ToolResult {
308 pub fn success(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
310 Self {
311 tool_call_id: tool_call_id.into(),
312 content: content.into(),
313 status: "success".to_string(),
314 }
315 }
316
317 pub fn error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
319 Self {
320 tool_call_id: tool_call_id.into(),
321 content: content.into(),
322 status: "error".to_string(),
323 }
324 }
325
326 pub fn is_error(&self) -> bool {
328 self.status == "error"
329 }
330}
331
332#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct Model {
337 pub id: String,
339 pub name: String,
341 pub api: Api,
343 pub provider: String,
345 pub base_url: String,
347 #[serde(default)]
349 pub reasoning: bool,
350 #[serde(default)]
352 pub input: Vec<InputModality>,
353 #[serde(default)]
355 pub cost: Cost,
356 pub context_window: usize,
358 pub max_tokens: usize,
360 #[serde(default)]
362 pub headers: HashMap<String, String>,
363 #[serde(default)]
365 pub compat: Option<CompatSettings>,
366}
367
368impl Model {
369 pub fn new(
371 id: impl Into<String>,
372 name: impl Into<String>,
373 api: Api,
374 provider: impl Into<String>,
375 base_url: impl Into<String>,
376 ) -> Self {
377 Self {
378 id: id.into(),
379 name: name.into(),
380 api,
381 provider: provider.into(),
382 base_url: base_url.into(),
383 reasoning: false,
384 input: vec![InputModality::Text],
385 cost: Cost::default(),
386 context_window: 128_000,
387 max_tokens: 32_000,
388 headers: HashMap::new(),
389 compat: None,
390 }
391 }
392
393 pub fn supports_vision(&self) -> bool {
395 self.input.contains(&InputModality::Image)
396 }
397
398 pub fn supports_reasoning(&self) -> bool {
400 self.reasoning
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[test]
409 fn model_roundtrip() {
410 let mut model = Model::new(
411 "gpt-4o",
412 "GPT-4o",
413 Api::OpenAiCompletions,
414 "openai",
415 "https://api.openai.com/v1",
416 );
417 model.reasoning = true;
418 model.input.push(InputModality::Image);
419 model.cost = Cost {
420 input: 5.0,
421 output: 15.0,
422 cache_read: 2.5,
423 cache_write: 0.0,
424 };
425 model.compat = Some(CompatSettings::default());
426
427 let json = serde_json::to_string(&model).unwrap();
428 let deserialized: Model = serde_json::from_str(&json).unwrap();
429
430 assert_eq!(deserialized.id, "gpt-4o");
431 assert_eq!(deserialized.name, "GPT-4o");
432 assert_eq!(deserialized.api, Api::OpenAiCompletions);
433 assert_eq!(deserialized.provider, "openai");
434 assert!(deserialized.reasoning);
435 assert!(deserialized.supports_vision());
436 assert!(deserialized.supports_reasoning());
437 assert_eq!(deserialized.cost.input, 5.0);
438 assert_eq!(deserialized.cost.output, 15.0);
439 }
440
441 #[test]
442 fn usage_calculate_cost() {
443 let mut usage = Usage {
444 input: 1_000_000,
445 output: 500_000,
446 cache_read: 200_000,
447 cache_write: 100_000,
448 ..Default::default()
449 };
450 usage.calculate_cost(None, None);
451
452 assert_eq!(usage.total_tokens, 1_800_000);
453 assert_eq!(usage.cost.input, 1.0);
454 assert_eq!(usage.cost.output, 0.5);
455 assert_eq!(usage.cost.cache_read, 0.2);
456 assert_eq!(usage.cost.cache_write, 0.1);
457 }
458
459 #[test]
460 fn cost_total() {
461 let cost = Cost {
462 input: 3.0,
463 output: 6.0,
464 cache_read: 1.0,
465 cache_write: 0.5,
466 };
467 assert!((cost.total() - 10.5).abs() < f64::EPSILON);
468
469 let default_cost = Cost::default();
470 assert_eq!(default_cost.total(), 0.0);
471 }
472
473 #[test]
474 fn api_display() {
475 assert_eq!(Api::OpenAiCompletions.to_string(), "openai-completions");
476 assert_eq!(Api::OpenAiResponses.to_string(), "openai-responses");
477 assert_eq!(Api::AnthropicMessages.to_string(), "anthropic-messages");
478 assert_eq!(Api::GoogleGenerativeAi.to_string(), "google-generative-ai");
479 assert_eq!(Api::GoogleVertex.to_string(), "google-vertex");
480 assert_eq!(
481 Api::MistralConversations.to_string(),
482 "mistral-conversations"
483 );
484 assert_eq!(
485 Api::AzureOpenAiResponses.to_string(),
486 "azure-openai-responses"
487 );
488 assert_eq!(
489 Api::BedrockConverseStream.to_string(),
490 "bedrock-converse-stream"
491 );
492 }
493
494 #[test]
495 fn api_serde_roundtrip() {
496 for api in [
497 Api::OpenAiCompletions,
498 Api::OpenAiResponses,
499 Api::AnthropicMessages,
500 Api::GoogleGenerativeAi,
501 Api::GoogleVertex,
502 Api::MistralConversations,
503 Api::AzureOpenAiResponses,
504 Api::BedrockConverseStream,
505 ] {
506 let json = serde_json::to_string(&api).unwrap();
507 let back: Api = serde_json::from_str(&json).unwrap();
508 assert_eq!(api, back);
509 }
510 }
511
512 #[test]
513 fn thinking_level_serde() {
514 for level in [
515 ThinkingLevel::Off,
516 ThinkingLevel::Minimal,
517 ThinkingLevel::Low,
518 ThinkingLevel::Medium,
519 ThinkingLevel::High,
520 ThinkingLevel::XHigh,
521 ] {
522 let json = serde_json::to_string(&level).unwrap();
523 let back: ThinkingLevel = serde_json::from_str(&json).unwrap();
524 assert_eq!(level, back);
525 }
526 assert_eq!(ThinkingLevel::default(), ThinkingLevel::Off);
528 assert_eq!(
530 serde_json::to_string(&ThinkingLevel::High).unwrap(),
531 "\"high\""
532 );
533 assert_eq!(
534 serde_json::to_string(&ThinkingLevel::Off).unwrap(),
535 "\"off\""
536 );
537 assert!(ThinkingLevel::Off.as_str().is_none());
539 assert_eq!(ThinkingLevel::High.as_str(), Some("high"));
540 assert_eq!(ThinkingLevel::XHigh.as_str(), Some("xhigh"));
541 }
542
543 #[test]
544 fn stop_reason_serde() {
545 assert_eq!(
546 serde_json::to_string(&StopReason::ToolUse).unwrap(),
547 "\"toolUse\""
548 );
549 let back: StopReason = serde_json::from_str("\"toolUse\"").unwrap();
550 assert_eq!(back, StopReason::ToolUse);
551 }
552
553 #[test]
554 fn tool_result_helpers() {
555 let success = ToolResult::success("call_1", "result text");
556 assert_eq!(success.tool_call_id, "call_1");
557 assert_eq!(success.content, "result text");
558 assert_eq!(success.status, "success");
559 assert!(!success.is_error());
560
561 let error = ToolResult::error("call_2", "something failed");
562 assert!(error.is_error());
563 assert_eq!(error.status, "error");
564 }
565
566 #[test]
567 fn cache_retention_default() {
568 assert_eq!(CacheRetention::default(), CacheRetention::None);
569 }
570}