1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt;
6use std::hash::Hash;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[non_exhaustive]
13pub enum Api {
14 #[serde(rename = "openai-completions")]
16 OpenAiCompletions,
17 #[serde(rename = "openai-responses")]
19 OpenAiResponses,
20 #[serde(rename = "anthropic-messages")]
22 AnthropicMessages,
23 #[serde(rename = "google-generative-ai")]
25 GoogleGenerativeAi,
26 #[serde(rename = "google-vertex")]
28 GoogleVertex,
29 #[serde(rename = "mistral-conversations")]
31 MistralConversations,
32 #[serde(rename = "azure-openai-responses")]
34 AzureOpenAiResponses,
35 #[serde(rename = "bedrock-converse-stream")]
37 BedrockConverseStream,
38}
39
40impl fmt::Display for Api {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 Api::OpenAiCompletions => write!(f, "openai-completions"),
44 Api::OpenAiResponses => write!(f, "openai-responses"),
45 Api::AnthropicMessages => write!(f, "anthropic-messages"),
46 Api::GoogleGenerativeAi => write!(f, "google-generative-ai"),
47 Api::GoogleVertex => write!(f, "google-vertex"),
48 Api::MistralConversations => write!(f, "mistral-conversations"),
49 Api::AzureOpenAiResponses => write!(f, "azure-openai-responses"),
50 Api::BedrockConverseStream => write!(f, "bedrock-converse-stream"),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
57#[serde(rename_all = "lowercase")]
58pub enum CacheRetention {
59 #[default]
61 None,
62 Short,
64 Long,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
70#[serde(rename_all = "lowercase")]
71#[non_exhaustive]
72pub enum ThinkingLevel {
73 #[default]
75 Off,
76 Minimal,
78 Low,
80 Medium,
82 High,
84 XHigh,
86}
87
88impl ThinkingLevel {
89 pub fn as_str(&self) -> Option<&str> {
91 match self {
92 ThinkingLevel::Off => None,
93 ThinkingLevel::Minimal => Some("minimal"),
94 ThinkingLevel::Low => Some("low"),
95 ThinkingLevel::Medium => Some("medium"),
96 ThinkingLevel::High => Some("high"),
97 ThinkingLevel::XHigh => Some("xhigh"),
98 }
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
104#[serde(rename_all = "lowercase")]
105#[non_exhaustive]
106pub enum InputModality {
107 Text,
109 Image,
111}
112
113#[derive(Debug, Clone, Default, Serialize, Deserialize)]
115#[serde(default)]
116pub struct Cost {
117 #[serde(default)]
119 pub input: f64,
120 #[serde(default)]
122 pub output: f64,
123 #[serde(default)]
125 pub cache_read: f64,
126 #[serde(default)]
128 pub cache_write: f64,
129}
130
131impl Cost {
132 pub fn total(&self) -> f64 {
134 self.input + self.output + self.cache_read + self.cache_write
135 }
136}
137
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
140#[serde(rename_all = "camelCase")]
141#[non_exhaustive]
142pub enum StopReason {
143 Stop,
145 Length,
147 ToolUse,
149 Error,
151 Aborted,
153}
154
155#[derive(Debug, Clone, Default, Serialize, Deserialize)]
157pub struct Usage {
158 #[serde(default)]
160 pub input: usize,
161 #[serde(default)]
163 pub output: usize,
164 #[serde(default)]
166 pub cache_read: usize,
167 #[serde(default)]
169 pub cache_write: usize,
170 #[serde(default)]
172 pub total_tokens: usize,
173 #[serde(default)]
175 pub cost: Cost,
176}
177
178impl Usage {
179 pub fn calculate_cost(
183 &mut self,
184 input_cost_per_million: Option<f64>,
185 output_cost_per_million: Option<f64>,
186 ) {
187 self.total_tokens = self.input + self.output + self.cache_read + self.cache_write;
188 self.cost.input = input_cost_per_million.unwrap_or(1.0) * self.input as f64 / 1_000_000.0;
189 self.cost.output =
190 output_cost_per_million.unwrap_or(1.0) * self.output as f64 / 1_000_000.0;
191 self.cost.cache_read = (self.cache_read as f64) / 1_000_000.0;
192 self.cost.cache_write = (self.cache_write as f64) / 1_000_000.0;
193 }
194}
195
196#[derive(Debug, Clone, Default, Serialize, Deserialize)]
201#[serde(default)]
202pub struct CompatSettings {
203 #[serde(default = "default_true")]
205 pub supports_store: bool,
206 #[serde(default = "default_true")]
208 pub supports_developer_role: bool,
209 #[serde(default = "default_true")]
211 pub supports_reasoning_effort: bool,
212 #[serde(default = "default_true")]
214 pub supports_usage_in_streaming: bool,
215 #[serde(default)]
217 pub max_tokens_field: Option<MaxTokensField>,
218 #[serde(default = "default_false")]
220 pub requires_tool_result_name: bool,
221 #[serde(default = "default_false")]
223 pub requires_assistant_after_tool_result: bool,
224 #[serde(default = "default_false")]
226 pub requires_thinking_as_text: bool,
227 #[serde(default)]
229 pub thinking_format: Option<ThinkingFormat>,
230}
231
232fn default_true() -> bool {
233 true
234}
235fn default_false() -> bool {
236 false
237}
238
239#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
241#[serde(rename_all = "kebab-case")]
242pub enum MaxTokensField {
243 MaxCompletionTokens,
245 MaxTokens,
247}
248
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
251#[serde(rename_all = "lowercase")]
252pub enum ThinkingFormat {
253 OpenAI,
255 OpenRouter,
257 DeepSeek,
259 Zai,
261 Qwen,
263 QwenChatTemplate,
265}
266
267#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
269pub enum Complexity {
270 Trivial,
272 Simple,
274 Moderate,
276 #[default]
278 Complex,
279 Research,
281}
282
283impl Complexity {
284 pub fn cost_tier(&self) -> u8 {
286 match self {
287 Self::Trivial => 0,
288 Self::Simple => 1,
289 Self::Moderate => 2,
290 Self::Complex => 3,
291 Self::Research => 4,
292 }
293 }
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct ToolResult {
299 pub tool_call_id: String,
301 pub content: String,
303 pub status: String,
305}
306
307impl ToolResult {
308 pub fn success(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
310 Self {
311 tool_call_id: tool_call_id.into(),
312 content: content.into(),
313 status: "success".to_string(),
314 }
315 }
316
317 pub fn error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
319 Self {
320 tool_call_id: tool_call_id.into(),
321 content: content.into(),
322 status: "error".to_string(),
323 }
324 }
325
326 pub fn is_error(&self) -> bool {
328 self.status == "error"
329 }
330}
331
332#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
334#[non_exhaustive]
335pub enum ImagesApi {
336 OpenRouter,
338}
339
340impl std::fmt::Display for ImagesApi {
341 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342 match self {
343 ImagesApi::OpenRouter => write!(f, "openrouter"),
344 }
345 }
346}
347
348#[derive(Debug, Clone, Serialize, Deserialize)]
350#[serde(default)]
351pub struct ImageGenerationRequest {
352 pub prompt: String,
354 pub model: Option<String>,
356 pub size: Option<String>,
358 pub n: Option<u32>,
360 pub response_format: Option<String>,
362}
363
364impl Default for ImageGenerationRequest {
365 fn default() -> Self {
366 Self {
367 prompt: String::new(),
368 model: None,
369 size: None,
370 n: Some(1),
371 response_format: Some("b64_json".to_string()),
372 }
373 }
374}
375
376#[derive(Debug, Clone, Serialize, Deserialize, Default)]
378#[serde(default)]
379pub struct ImageGenerationResponse {
380 pub images: Vec<Vec<u8>>,
382 pub revised_prompt: Option<String>,
384}
385
386#[derive(Debug, Clone, Serialize, Deserialize)]
390pub struct Model {
391 pub id: String,
393 pub name: String,
395 pub api: Api,
397 pub provider: String,
399 pub base_url: String,
401 #[serde(default)]
403 pub reasoning: bool,
404 #[serde(default)]
406 pub input: Vec<InputModality>,
407 #[serde(default)]
409 pub cost: Cost,
410 pub context_window: usize,
412 pub max_tokens: usize,
414 #[serde(default)]
416 pub headers: HashMap<String, String>,
417 #[serde(default)]
419 pub compat: Option<CompatSettings>,
420}
421
422impl Model {
423 pub fn new(
425 id: impl Into<String>,
426 name: impl Into<String>,
427 api: Api,
428 provider: impl Into<String>,
429 base_url: impl Into<String>,
430 ) -> Self {
431 Self {
432 id: id.into(),
433 name: name.into(),
434 api,
435 provider: provider.into(),
436 base_url: base_url.into(),
437 reasoning: false,
438 input: vec![InputModality::Text],
439 cost: Cost::default(),
440 context_window: 128_000,
441 max_tokens: 32_000,
442 headers: HashMap::new(),
443 compat: None,
444 }
445 }
446
447 pub fn supports_vision(&self) -> bool {
449 self.input.contains(&InputModality::Image)
450 }
451
452 pub fn supports_reasoning(&self) -> bool {
454 self.reasoning
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn model_roundtrip() {
464 let mut model = Model::new(
465 "gpt-4o",
466 "GPT-4o",
467 Api::OpenAiCompletions,
468 "openai",
469 "https://api.openai.com/v1",
470 );
471 model.reasoning = true;
472 model.input.push(InputModality::Image);
473 model.cost = Cost {
474 input: 5.0,
475 output: 15.0,
476 cache_read: 2.5,
477 cache_write: 0.0,
478 };
479 model.compat = Some(CompatSettings::default());
480
481 let json = serde_json::to_string(&model).unwrap();
482 let deserialized: Model = serde_json::from_str(&json).unwrap();
483
484 assert_eq!(deserialized.id, "gpt-4o");
485 assert_eq!(deserialized.name, "GPT-4o");
486 assert_eq!(deserialized.api, Api::OpenAiCompletions);
487 assert_eq!(deserialized.provider, "openai");
488 assert!(deserialized.reasoning);
489 assert!(deserialized.supports_vision());
490 assert!(deserialized.supports_reasoning());
491 assert_eq!(deserialized.cost.input, 5.0);
492 assert_eq!(deserialized.cost.output, 15.0);
493 }
494
495 #[test]
496 fn usage_calculate_cost() {
497 let mut usage = Usage {
498 input: 1_000_000,
499 output: 500_000,
500 cache_read: 200_000,
501 cache_write: 100_000,
502 ..Default::default()
503 };
504 usage.calculate_cost(None, None);
505
506 assert_eq!(usage.total_tokens, 1_800_000);
507 assert_eq!(usage.cost.input, 1.0);
508 assert_eq!(usage.cost.output, 0.5);
509 assert_eq!(usage.cost.cache_read, 0.2);
510 assert_eq!(usage.cost.cache_write, 0.1);
511 }
512
513 #[test]
514 fn cost_total() {
515 let cost = Cost {
516 input: 3.0,
517 output: 6.0,
518 cache_read: 1.0,
519 cache_write: 0.5,
520 };
521 assert!((cost.total() - 10.5).abs() < f64::EPSILON);
522
523 let default_cost = Cost::default();
524 assert_eq!(default_cost.total(), 0.0);
525 }
526
527 #[test]
528 fn api_display() {
529 assert_eq!(Api::OpenAiCompletions.to_string(), "openai-completions");
530 assert_eq!(Api::OpenAiResponses.to_string(), "openai-responses");
531 assert_eq!(Api::AnthropicMessages.to_string(), "anthropic-messages");
532 assert_eq!(Api::GoogleGenerativeAi.to_string(), "google-generative-ai");
533 assert_eq!(Api::GoogleVertex.to_string(), "google-vertex");
534 assert_eq!(
535 Api::MistralConversations.to_string(),
536 "mistral-conversations"
537 );
538 assert_eq!(
539 Api::AzureOpenAiResponses.to_string(),
540 "azure-openai-responses"
541 );
542 assert_eq!(
543 Api::BedrockConverseStream.to_string(),
544 "bedrock-converse-stream"
545 );
546 }
547
548 #[test]
549 fn api_serde_roundtrip() {
550 for api in [
551 Api::OpenAiCompletions,
552 Api::OpenAiResponses,
553 Api::AnthropicMessages,
554 Api::GoogleGenerativeAi,
555 Api::GoogleVertex,
556 Api::MistralConversations,
557 Api::AzureOpenAiResponses,
558 Api::BedrockConverseStream,
559 ] {
560 let json = serde_json::to_string(&api).unwrap();
561 let back: Api = serde_json::from_str(&json).unwrap();
562 assert_eq!(api, back);
563 }
564 }
565
566 #[test]
567 fn thinking_level_serde() {
568 for level in [
569 ThinkingLevel::Off,
570 ThinkingLevel::Minimal,
571 ThinkingLevel::Low,
572 ThinkingLevel::Medium,
573 ThinkingLevel::High,
574 ThinkingLevel::XHigh,
575 ] {
576 let json = serde_json::to_string(&level).unwrap();
577 let back: ThinkingLevel = serde_json::from_str(&json).unwrap();
578 assert_eq!(level, back);
579 }
580 assert_eq!(ThinkingLevel::default(), ThinkingLevel::Off);
582 assert_eq!(
584 serde_json::to_string(&ThinkingLevel::High).unwrap(),
585 "\"high\""
586 );
587 assert_eq!(
588 serde_json::to_string(&ThinkingLevel::Off).unwrap(),
589 "\"off\""
590 );
591 assert!(ThinkingLevel::Off.as_str().is_none());
593 assert_eq!(ThinkingLevel::High.as_str(), Some("high"));
594 assert_eq!(ThinkingLevel::XHigh.as_str(), Some("xhigh"));
595 }
596
597 #[test]
598 fn stop_reason_serde() {
599 assert_eq!(
600 serde_json::to_string(&StopReason::ToolUse).unwrap(),
601 "\"toolUse\""
602 );
603 let back: StopReason = serde_json::from_str("\"toolUse\"").unwrap();
604 assert_eq!(back, StopReason::ToolUse);
605 }
606
607 #[test]
608 fn tool_result_helpers() {
609 let success = ToolResult::success("call_1", "result text");
610 assert_eq!(success.tool_call_id, "call_1");
611 assert_eq!(success.content, "result text");
612 assert_eq!(success.status, "success");
613 assert!(!success.is_error());
614
615 let error = ToolResult::error("call_2", "something failed");
616 assert!(error.is_error());
617 assert_eq!(error.status, "error");
618 }
619
620 #[test]
621 fn cache_retention_default() {
622 assert_eq!(CacheRetention::default(), CacheRetention::None);
623 }
624}