1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
7pub struct ModelObject {
8 pub id: String,
10
11 pub display_name: String,
13
14 pub created_at: DateTime<Utc>,
16
17 #[serde(rename = "type")]
19 pub object_type: String,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, Default)]
24pub struct ModelListParams {
25 pub before_id: Option<String>,
27
28 pub after_id: Option<String>,
30
31 pub limit: Option<u32>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ModelList {
38 pub data: Vec<ModelObject>,
40
41 pub first_id: Option<String>,
43
44 pub last_id: Option<String>,
46
47 pub has_more: bool,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ModelCapabilities {
54 pub max_context_length: u64,
56
57 pub max_output_tokens: u64,
59
60 pub capabilities: Vec<ModelCapability>,
62
63 pub family: String,
65
66 pub generation: String,
68
69 pub supports_vision: bool,
71
72 pub supports_tools: bool,
74
75 pub supports_system_messages: bool,
77
78 pub supports_streaming: bool,
80
81 pub supported_languages: Vec<String>,
83
84 pub training_cutoff: Option<DateTime<Utc>>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
90#[serde(rename_all = "snake_case")]
91pub enum ModelCapability {
92 TextGeneration,
94 Vision,
96 ToolUse,
98 CodeGeneration,
100 Mathematical,
102 Creative,
104 Analysis,
106 Summarization,
108 Translation,
110 LongContext,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ModelPricing {
117 pub model_id: String,
119
120 pub input_price_per_million: f64,
122
123 pub output_price_per_million: f64,
125
126 pub batch_input_price_per_million: Option<f64>,
128
129 pub batch_output_price_per_million: Option<f64>,
131
132 pub cache_write_price_per_million: Option<f64>,
134
135 pub cache_read_price_per_million: Option<f64>,
137
138 pub tier: PricingTier,
140
141 pub currency: String,
143
144 pub updated_at: DateTime<Utc>,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
150#[serde(rename_all = "lowercase")]
151pub enum PricingTier {
152 Premium,
154 Standard,
156 Fast,
158 Legacy,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct ModelComparison {
165 pub models: Vec<ModelObject>,
167
168 pub capabilities: Vec<ModelCapabilities>,
170
171 pub pricing: Vec<ModelPricing>,
173
174 pub performance: Vec<ModelPerformance>,
176
177 pub summary: ComparisonSummary,
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct ModelPerformance {
184 pub model_id: String,
186
187 pub speed_score: u8,
189
190 pub quality_score: u8,
192
193 pub avg_response_time_ms: Option<u64>,
195
196 pub tokens_per_second: Option<f64>,
198
199 pub cost_efficiency_score: u8,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct ComparisonSummary {
206 pub fastest_model: String,
208
209 pub highest_quality_model: String,
211
212 pub most_cost_effective_model: String,
214
215 pub best_overall_model: String,
217
218 pub key_differences: Vec<String>,
220
221 pub use_case_recommendations: HashMap<String, String>,
223}
224
225#[derive(Debug, Clone, Default)]
227pub struct ModelRequirements {
228 pub max_input_cost_per_token: Option<f64>,
230
231 pub max_output_cost_per_token: Option<f64>,
233
234 pub min_context_length: Option<u64>,
236
237 pub required_capabilities: Vec<ModelCapability>,
239
240 pub preferred_family: Option<String>,
242
243 pub min_speed_score: Option<u8>,
245
246 pub min_quality_score: Option<u8>,
248
249 pub requires_vision: Option<bool>,
251
252 pub requires_tools: Option<bool>,
254
255 pub max_response_time_ms: Option<u64>,
257
258 pub preferred_languages: Vec<String>,
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct ModelUsageRecommendations {
265 pub use_case: String,
267
268 pub recommended_models: Vec<ModelRecommendation>,
270
271 pub guidelines: Vec<String>,
273
274 pub recommended_parameters: RecommendedParameters,
276
277 pub pitfalls: Vec<String>,
279
280 pub expected_performance: PerformanceExpectations,
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct ModelRecommendation {
287 pub model_id: String,
289
290 pub reason: String,
292
293 pub confidence_score: u8,
295
296 pub cost_range: CostRange,
298
299 pub strengths: Vec<String>,
301
302 pub limitations: Vec<String>,
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct RecommendedParameters {
309 pub temperature_range: (f32, f32),
311
312 pub max_tokens_range: (u32, u32),
314
315 pub top_p_range: Option<(f32, f32)>,
317
318 pub use_streaming: Option<bool>,
320
321 pub system_message_patterns: Vec<String>,
323}
324
325#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct PerformanceExpectations {
328 pub response_time_range_ms: (u64, u64),
330
331 pub cost_range: CostRange,
333
334 pub quality_level: QualityLevel,
336
337 pub success_rate_percentage: f32,
339}
340
341#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct CostRange {
344 pub min_cost_usd: f64,
346
347 pub max_cost_usd: f64,
349
350 pub typical_cost_usd: f64,
352}
353
354#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
356#[serde(rename_all = "lowercase")]
357pub enum QualityLevel {
358 Excellent,
360 Good,
362 Acceptable,
364 Basic,
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct CostEstimation {
371 pub model_id: String,
373
374 pub input_tokens: u64,
376
377 pub output_tokens: u64,
379
380 pub input_cost_usd: f64,
382
383 pub output_cost_usd: f64,
385
386 pub total_cost_usd: f64,
388
389 pub batch_discount_usd: Option<f64>,
391
392 pub cache_savings_usd: Option<f64>,
394
395 pub final_cost_usd: f64,
397
398 pub breakdown: CostBreakdown,
400}
401
402#[derive(Debug, Clone, Serialize, Deserialize)]
404pub struct CostBreakdown {
405 pub cost_per_input_token_usd: f64,
407
408 pub cost_per_output_token_usd: f64,
410
411 pub effective_cost_per_token_usd: f64,
413
414 pub cost_vs_alternatives: HashMap<String, f64>,
416}
417
418impl ModelListParams {
419 pub fn new() -> Self {
421 Self::default()
422 }
423
424 pub fn before_id(mut self, before_id: impl Into<String>) -> Self {
426 self.before_id = Some(before_id.into());
427 self
428 }
429
430 pub fn after_id(mut self, after_id: impl Into<String>) -> Self {
432 self.after_id = Some(after_id.into());
433 self
434 }
435
436 pub fn limit(mut self, limit: u32) -> Self {
438 self.limit = Some(limit.min(1000).max(1));
439 self
440 }
441}
442
443impl ModelRequirements {
444 pub fn new() -> Self {
446 Self::default()
447 }
448
449 pub fn max_input_cost_per_token(mut self, cost: f64) -> Self {
451 self.max_input_cost_per_token = Some(cost);
452 self
453 }
454
455 pub fn max_output_cost_per_token(mut self, cost: f64) -> Self {
457 self.max_output_cost_per_token = Some(cost);
458 self
459 }
460
461 pub fn min_context_length(mut self, length: u64) -> Self {
463 self.min_context_length = Some(length);
464 self
465 }
466
467 pub fn require_capability(mut self, capability: ModelCapability) -> Self {
469 self.required_capabilities.push(capability);
470 self
471 }
472
473 pub fn capabilities(mut self, capabilities: Vec<ModelCapability>) -> Self {
475 self.required_capabilities = capabilities;
476 self
477 }
478
479 pub fn preferred_family(mut self, family: impl Into<String>) -> Self {
481 self.preferred_family = Some(family.into());
482 self
483 }
484
485 pub fn require_vision(mut self) -> Self {
487 self.requires_vision = Some(true);
488 self
489 }
490
491 pub fn require_tools(mut self) -> Self {
493 self.requires_tools = Some(true);
494 self
495 }
496
497 pub fn min_quality_score(mut self, score: u8) -> Self {
499 self.min_quality_score = Some(score.min(10));
500 self
501 }
502
503 pub fn min_speed_score(mut self, score: u8) -> Self {
505 self.min_speed_score = Some(score.min(10));
506 self
507 }
508}
509
510impl ModelObject {
511 pub fn is_alias(&self) -> bool {
513 self.id.contains("latest") || self.id.ends_with("-0")
514 }
515
516 pub fn family(&self) -> String {
518 let parts: Vec<&str> = self.id.split('-').collect();
519 if parts.len() >= 3 {
520 format!("{}-{}", parts[0], parts[1])
521 } else {
522 parts[0].to_string()
523 }
524 }
525
526 pub fn is_family(&self, family: &str) -> bool {
528 self.id.starts_with(family)
529 }
530
531 pub fn model_size(&self) -> Option<String> {
533 if self.id.contains("opus") {
534 Some("opus".to_string())
535 } else if self.id.contains("sonnet") {
536 Some("sonnet".to_string())
537 } else if self.id.contains("haiku") {
538 Some("haiku".to_string())
539 } else {
540 None
541 }
542 }
543}
544
545impl ModelComparison {
546 pub fn best_for_speed(&self) -> Option<&ModelObject> {
548 self.performance
549 .iter()
550 .max_by_key(|p| p.speed_score)
551 .and_then(|p| self.models.iter().find(|m| m.id == p.model_id))
552 }
553
554 pub fn best_for_quality(&self) -> Option<&ModelObject> {
556 self.performance
557 .iter()
558 .max_by_key(|p| p.quality_score)
559 .and_then(|p| self.models.iter().find(|m| m.id == p.model_id))
560 }
561
562 pub fn most_cost_effective(&self) -> Option<&ModelObject> {
564 self.performance
565 .iter()
566 .max_by_key(|p| p.cost_efficiency_score)
567 .and_then(|p| self.models.iter().find(|m| m.id == p.model_id))
568 }
569}
570
571impl CostEstimation {
572 pub fn cost_per_1k_tokens(&self) -> f64 {
574 let total_tokens = self.input_tokens + self.output_tokens;
575 if total_tokens > 0 {
576 (self.final_cost_usd * 1000.0) / total_tokens as f64
577 } else {
578 0.0
579 }
580 }
581
582 pub fn savings_percentage(&self) -> f64 {
584 let original_cost = self.input_cost_usd + self.output_cost_usd;
585 if original_cost > 0.0 {
586 ((original_cost - self.final_cost_usd) / original_cost) * 100.0
587 } else {
588 0.0
589 }
590 }
591}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596
597 #[test]
598 fn test_model_list_params_builder() {
599 let params = ModelListParams::new()
600 .limit(50)
601 .after_id("model_123");
602
603 assert_eq!(params.limit, Some(50));
604 assert_eq!(params.after_id, Some("model_123".to_string()));
605 assert_eq!(params.before_id, None);
606 }
607
608 #[test]
609 fn test_model_requirements_builder() {
610 let requirements = ModelRequirements::new()
611 .max_input_cost_per_token(0.01)
612 .min_context_length(100000)
613 .require_vision()
614 .require_capability(ModelCapability::ToolUse);
615
616 assert_eq!(requirements.max_input_cost_per_token, Some(0.01));
617 assert_eq!(requirements.min_context_length, Some(100000));
618 assert_eq!(requirements.requires_vision, Some(true));
619 assert!(requirements.required_capabilities.contains(&ModelCapability::ToolUse));
620 }
621
622 #[test]
623 fn test_model_object_methods() {
624 let model = ModelObject {
625 id: "claude-3-5-sonnet-latest".to_string(),
626 display_name: "Claude 3.5 Sonnet".to_string(),
627 created_at: Utc::now(),
628 object_type: "model".to_string(),
629 };
630
631 assert!(model.is_alias());
632 assert_eq!(model.family(), "claude-3");
633 assert!(model.is_family("claude-3-5"));
634 assert_eq!(model.model_size(), Some("sonnet".to_string()));
635 }
636
637 #[test]
638 fn test_cost_estimation_calculations() {
639 let estimation = CostEstimation {
640 model_id: "test-model".to_string(),
641 input_tokens: 1000,
642 output_tokens: 500,
643 input_cost_usd: 0.01,
644 output_cost_usd: 0.03,
645 total_cost_usd: 0.04,
646 batch_discount_usd: Some(0.005),
647 cache_savings_usd: None,
648 final_cost_usd: 0.035,
649 breakdown: CostBreakdown {
650 cost_per_input_token_usd: 0.00001,
651 cost_per_output_token_usd: 0.00006,
652 effective_cost_per_token_usd: 0.000023,
653 cost_vs_alternatives: HashMap::new(),
654 },
655 };
656
657 assert!((estimation.cost_per_1k_tokens() - 0.02333).abs() < 0.001);
658 assert!((estimation.savings_percentage() - 12.5).abs() < 0.1);
659 }
660
661 #[test]
662 fn test_limit_validation() {
663 let params = ModelListParams::new().limit(2000);
664 assert_eq!(params.limit, Some(1000)); let params = ModelListParams::new().limit(0);
667 assert_eq!(params.limit, Some(1)); }
669
670 #[test]
671 fn test_model_capability_serialization() {
672 let capability = ModelCapability::Vision;
673 let serialized = serde_json::to_string(&capability).unwrap();
674 assert_eq!(serialized, "\"vision\"");
675
676 let deserialized: ModelCapability = serde_json::from_str(&serialized).unwrap();
677 assert_eq!(deserialized, ModelCapability::Vision);
678 }
679}