1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt;
6use std::hash::Hash;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[non_exhaustive]
13pub enum Api {
14 #[serde(rename = "openai-completions")]
16 OpenAiCompletions,
17 #[serde(rename = "openai-responses")]
19 OpenAiResponses,
20 #[serde(rename = "anthropic-messages")]
22 AnthropicMessages,
23 #[serde(rename = "google-generative-ai")]
25 GoogleGenerativeAi,
26 #[serde(rename = "google-vertex")]
28 GoogleVertex,
29 #[serde(rename = "mistral-conversations")]
31 MistralConversations,
32 #[serde(rename = "azure-openai-responses")]
34 AzureOpenAiResponses,
35 #[serde(rename = "bedrock-converse-stream")]
37 BedrockConverseStream,
38}
39
40impl fmt::Display for Api {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 Api::OpenAiCompletions => write!(f, "openai-completions"),
44 Api::OpenAiResponses => write!(f, "openai-responses"),
45 Api::AnthropicMessages => write!(f, "anthropic-messages"),
46 Api::GoogleGenerativeAi => write!(f, "google-generative-ai"),
47 Api::GoogleVertex => write!(f, "google-vertex"),
48 Api::MistralConversations => write!(f, "mistral-conversations"),
49 Api::AzureOpenAiResponses => write!(f, "azure-openai-responses"),
50 Api::BedrockConverseStream => write!(f, "bedrock-converse-stream"),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
57#[serde(rename_all = "lowercase")]
58pub enum CacheRetention {
59 #[default]
61 None,
62 Short,
64 Long,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
70#[serde(rename_all = "lowercase")]
71#[non_exhaustive]
72pub enum ThinkingLevel {
73 #[default]
75 Off,
76 Minimal,
78 Low,
80 Medium,
82 High,
84 XHigh,
86}
87
88impl ThinkingLevel {
89 pub fn as_str(&self) -> Option<&str> {
91 match self {
92 ThinkingLevel::Off => None,
93 ThinkingLevel::Minimal => Some("minimal"),
94 ThinkingLevel::Low => Some("low"),
95 ThinkingLevel::Medium => Some("medium"),
96 ThinkingLevel::High => Some("high"),
97 ThinkingLevel::XHigh => Some("xhigh"),
98 }
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
104#[serde(rename_all = "lowercase")]
105#[non_exhaustive]
106pub enum InputModality {
107 Text,
109 Image,
111}
112
113#[derive(Debug, Clone, Default, Serialize, Deserialize)]
115#[serde(default)]
116pub struct Cost {
117 #[serde(default)]
119 pub input: f64,
120 #[serde(default)]
122 pub output: f64,
123 #[serde(default)]
125 pub cache_read: f64,
126 #[serde(default)]
128 pub cache_write: f64,
129}
130
131impl Cost {
132 pub fn total(&self) -> f64 {
134 self.input + self.output + self.cache_read + self.cache_write
135 }
136}
137
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
140#[serde(rename_all = "camelCase")]
141#[non_exhaustive]
142pub enum StopReason {
143 Stop,
145 Length,
147 ToolUse,
149 Error,
151 Aborted,
153}
154
155#[derive(Debug, Clone, Default, Serialize, Deserialize)]
157pub struct Usage {
158 #[serde(default)]
160 pub input: usize,
161 #[serde(default)]
163 pub output: usize,
164 #[serde(default)]
166 pub cache_read: usize,
167 #[serde(default)]
169 pub cache_write: usize,
170 #[serde(default)]
172 pub total_tokens: usize,
173 #[serde(default)]
175 pub cost: Cost,
176}
177
178impl Usage {
179 pub fn calculate_cost(
183 &mut self,
184 input_cost_per_million: Option<f64>,
185 output_cost_per_million: Option<f64>,
186 ) {
187 self.total_tokens = self.input + self.output + self.cache_read + self.cache_write;
188 self.cost.input = input_cost_per_million.unwrap_or(1.0) * self.input as f64 / 1_000_000.0;
189 self.cost.output =
190 output_cost_per_million.unwrap_or(1.0) * self.output as f64 / 1_000_000.0;
191 self.cost.cache_read = (self.cache_read as f64) / 1_000_000.0;
192 self.cost.cache_write = (self.cache_write as f64) / 1_000_000.0;
193 }
194}
195
196#[derive(Debug, Clone, Default, Serialize, Deserialize)]
201#[serde(default)]
202pub struct CompatSettings {
203 #[serde(default = "default_true")]
205 pub supports_store: bool,
206 #[serde(default = "default_true")]
208 pub supports_developer_role: bool,
209 #[serde(default = "default_true")]
211 pub supports_reasoning_effort: bool,
212 #[serde(default = "default_true")]
214 pub supports_usage_in_streaming: bool,
215 #[serde(default)]
217 pub max_tokens_field: Option<MaxTokensField>,
218 #[serde(default = "default_false")]
220 pub requires_tool_result_name: bool,
221 #[serde(default = "default_false")]
223 pub requires_assistant_after_tool_result: bool,
224 #[serde(default = "default_false")]
226 pub requires_thinking_as_text: bool,
227 #[serde(default)]
229 pub thinking_format: Option<ThinkingFormat>,
230}
231
232fn default_true() -> bool {
233 true
234}
235fn default_false() -> bool {
236 false
237}
238
239#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
241#[serde(rename_all = "kebab-case")]
242pub enum MaxTokensField {
243 MaxCompletionTokens,
245 MaxTokens,
247}
248
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
251#[serde(rename_all = "lowercase")]
252pub enum ThinkingFormat {
253 OpenAI,
255 OpenRouter,
257 DeepSeek,
259 Zai,
261 Qwen,
263 QwenChatTemplate,
265}
266
267#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct ToolResult {
270 pub tool_call_id: String,
272 pub content: String,
274 pub status: String,
276}
277
278impl ToolResult {
279 pub fn success(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
281 Self {
282 tool_call_id: tool_call_id.into(),
283 content: content.into(),
284 status: "success".to_string(),
285 }
286 }
287
288 pub fn error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
290 Self {
291 tool_call_id: tool_call_id.into(),
292 content: content.into(),
293 status: "error".to_string(),
294 }
295 }
296
297 pub fn is_error(&self) -> bool {
299 self.status == "error"
300 }
301}
302
303#[derive(Debug, Clone, Serialize, Deserialize)]
307pub struct Model {
308 pub id: String,
310 pub name: String,
312 pub api: Api,
314 pub provider: String,
316 pub base_url: String,
318 #[serde(default)]
320 pub reasoning: bool,
321 #[serde(default)]
323 pub input: Vec<InputModality>,
324 #[serde(default)]
326 pub cost: Cost,
327 pub context_window: usize,
329 pub max_tokens: usize,
331 #[serde(default)]
333 pub headers: HashMap<String, String>,
334 #[serde(default)]
336 pub compat: Option<CompatSettings>,
337}
338
339impl Model {
340 pub fn new(
342 id: impl Into<String>,
343 name: impl Into<String>,
344 api: Api,
345 provider: impl Into<String>,
346 base_url: impl Into<String>,
347 ) -> Self {
348 Self {
349 id: id.into(),
350 name: name.into(),
351 api,
352 provider: provider.into(),
353 base_url: base_url.into(),
354 reasoning: false,
355 input: vec![InputModality::Text],
356 cost: Cost::default(),
357 context_window: 128_000,
358 max_tokens: 32_000,
359 headers: HashMap::new(),
360 compat: None,
361 }
362 }
363
364 pub fn supports_vision(&self) -> bool {
366 self.input.contains(&InputModality::Image)
367 }
368
369 pub fn supports_reasoning(&self) -> bool {
371 self.reasoning
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn model_roundtrip() {
381 let mut model = Model::new(
382 "gpt-4o",
383 "GPT-4o",
384 Api::OpenAiCompletions,
385 "openai",
386 "https://api.openai.com/v1",
387 );
388 model.reasoning = true;
389 model.input.push(InputModality::Image);
390 model.cost = Cost {
391 input: 5.0,
392 output: 15.0,
393 cache_read: 2.5,
394 cache_write: 0.0,
395 };
396 model.compat = Some(CompatSettings::default());
397
398 let json = serde_json::to_string(&model).unwrap();
399 let deserialized: Model = serde_json::from_str(&json).unwrap();
400
401 assert_eq!(deserialized.id, "gpt-4o");
402 assert_eq!(deserialized.name, "GPT-4o");
403 assert_eq!(deserialized.api, Api::OpenAiCompletions);
404 assert_eq!(deserialized.provider, "openai");
405 assert!(deserialized.reasoning);
406 assert!(deserialized.supports_vision());
407 assert!(deserialized.supports_reasoning());
408 assert_eq!(deserialized.cost.input, 5.0);
409 assert_eq!(deserialized.cost.output, 15.0);
410 }
411
412 #[test]
413 fn usage_calculate_cost() {
414 let mut usage = Usage {
415 input: 1_000_000,
416 output: 500_000,
417 cache_read: 200_000,
418 cache_write: 100_000,
419 ..Default::default()
420 };
421 usage.calculate_cost(None, None);
422
423 assert_eq!(usage.total_tokens, 1_800_000);
424 assert_eq!(usage.cost.input, 1.0);
425 assert_eq!(usage.cost.output, 0.5);
426 assert_eq!(usage.cost.cache_read, 0.2);
427 assert_eq!(usage.cost.cache_write, 0.1);
428 }
429
430 #[test]
431 fn cost_total() {
432 let cost = Cost {
433 input: 3.0,
434 output: 6.0,
435 cache_read: 1.0,
436 cache_write: 0.5,
437 };
438 assert!((cost.total() - 10.5).abs() < f64::EPSILON);
439
440 let default_cost = Cost::default();
441 assert_eq!(default_cost.total(), 0.0);
442 }
443
444 #[test]
445 fn api_display() {
446 assert_eq!(Api::OpenAiCompletions.to_string(), "openai-completions");
447 assert_eq!(Api::OpenAiResponses.to_string(), "openai-responses");
448 assert_eq!(Api::AnthropicMessages.to_string(), "anthropic-messages");
449 assert_eq!(Api::GoogleGenerativeAi.to_string(), "google-generative-ai");
450 assert_eq!(Api::GoogleVertex.to_string(), "google-vertex");
451 assert_eq!(
452 Api::MistralConversations.to_string(),
453 "mistral-conversations"
454 );
455 assert_eq!(
456 Api::AzureOpenAiResponses.to_string(),
457 "azure-openai-responses"
458 );
459 assert_eq!(
460 Api::BedrockConverseStream.to_string(),
461 "bedrock-converse-stream"
462 );
463 }
464
465 #[test]
466 fn api_serde_roundtrip() {
467 for api in [
468 Api::OpenAiCompletions,
469 Api::OpenAiResponses,
470 Api::AnthropicMessages,
471 Api::GoogleGenerativeAi,
472 Api::GoogleVertex,
473 Api::MistralConversations,
474 Api::AzureOpenAiResponses,
475 Api::BedrockConverseStream,
476 ] {
477 let json = serde_json::to_string(&api).unwrap();
478 let back: Api = serde_json::from_str(&json).unwrap();
479 assert_eq!(api, back);
480 }
481 }
482
483 #[test]
484 fn thinking_level_serde() {
485 for level in [
486 ThinkingLevel::Off,
487 ThinkingLevel::Minimal,
488 ThinkingLevel::Low,
489 ThinkingLevel::Medium,
490 ThinkingLevel::High,
491 ThinkingLevel::XHigh,
492 ] {
493 let json = serde_json::to_string(&level).unwrap();
494 let back: ThinkingLevel = serde_json::from_str(&json).unwrap();
495 assert_eq!(level, back);
496 }
497 assert_eq!(ThinkingLevel::default(), ThinkingLevel::Off);
499 assert_eq!(
501 serde_json::to_string(&ThinkingLevel::High).unwrap(),
502 "\"high\""
503 );
504 assert_eq!(
505 serde_json::to_string(&ThinkingLevel::Off).unwrap(),
506 "\"off\""
507 );
508 assert!(ThinkingLevel::Off.as_str().is_none());
510 assert_eq!(ThinkingLevel::High.as_str(), Some("high"));
511 assert_eq!(ThinkingLevel::XHigh.as_str(), Some("xhigh"));
512 }
513
514 #[test]
515 fn stop_reason_serde() {
516 assert_eq!(
517 serde_json::to_string(&StopReason::ToolUse).unwrap(),
518 "\"toolUse\""
519 );
520 let back: StopReason = serde_json::from_str("\"toolUse\"").unwrap();
521 assert_eq!(back, StopReason::ToolUse);
522 }
523
524 #[test]
525 fn tool_result_helpers() {
526 let success = ToolResult::success("call_1", "result text");
527 assert_eq!(success.tool_call_id, "call_1");
528 assert_eq!(success.content, "result text");
529 assert_eq!(success.status, "success");
530 assert!(!success.is_error());
531
532 let error = ToolResult::error("call_2", "something failed");
533 assert!(error.is_error());
534 assert_eq!(error.status, "error");
535 }
536
537 #[test]
538 fn cache_retention_default() {
539 assert_eq!(CacheRetention::default(), CacheRetention::None);
540 }
541}