anthropic_sdk/types/
models_api.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5/// Model object returned by the API
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
7pub struct ModelObject {
8    /// Unique model identifier
9    pub id: String,
10    
11    /// Human-readable name for the model
12    pub display_name: String,
13    
14    /// RFC 3339 datetime string representing when the model was released
15    pub created_at: DateTime<Utc>,
16    
17    /// Object type, always "model" for models
18    #[serde(rename = "type")]
19    pub object_type: String,
20}
21
22/// Parameters for listing models
23#[derive(Debug, Clone, Serialize, Deserialize, Default)]
24pub struct ModelListParams {
25    /// ID of the object to use as a cursor for pagination (before)
26    pub before_id: Option<String>,
27    
28    /// ID of the object to use as a cursor for pagination (after)
29    pub after_id: Option<String>,
30    
31    /// Number of items to return per page (1-1000, default 20)
32    pub limit: Option<u32>,
33}
34
35/// Paginated list of models
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ModelList {
38    /// Array of model objects
39    pub data: Vec<ModelObject>,
40    
41    /// First ID in the data list
42    pub first_id: Option<String>,
43    
44    /// Last ID in the data list  
45    pub last_id: Option<String>,
46    
47    /// Indicates if there are more results available
48    pub has_more: bool,
49}
50
51/// Model capabilities and limitations
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ModelCapabilities {
54    /// Maximum context length in tokens
55    pub max_context_length: u64,
56    
57    /// Maximum output tokens per request
58    pub max_output_tokens: u64,
59    
60    /// Supported capabilities
61    pub capabilities: Vec<ModelCapability>,
62    
63    /// Model family (e.g., "claude-3", "claude-3-5")
64    pub family: String,
65    
66    /// Model generation/version
67    pub generation: String,
68    
69    /// Whether the model supports vision (image input)
70    pub supports_vision: bool,
71    
72    /// Whether the model supports tool use/function calling
73    pub supports_tools: bool,
74    
75    /// Whether the model supports system messages
76    pub supports_system_messages: bool,
77    
78    /// Whether the model supports streaming
79    pub supports_streaming: bool,
80    
81    /// Supported languages (ISO codes)
82    pub supported_languages: Vec<String>,
83    
84    /// Training data cutoff date
85    pub training_cutoff: Option<DateTime<Utc>>,
86}
87
88/// Individual model capability
89#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
90#[serde(rename_all = "snake_case")]
91pub enum ModelCapability {
92    /// Text generation and conversation
93    TextGeneration,
94    /// Vision and image understanding
95    Vision,
96    /// Tool use and function calling
97    ToolUse,
98    /// Code generation and analysis
99    CodeGeneration,
100    /// Mathematical reasoning
101    Mathematical,
102    /// Creative writing
103    Creative,
104    /// Analysis and reasoning
105    Analysis,
106    /// Summarization
107    Summarization,
108    /// Translation between languages
109    Translation,
110    /// Long context handling
111    LongContext,
112}
113
114/// Pricing information for a model
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ModelPricing {
117    /// Model ID this pricing applies to
118    pub model_id: String,
119    
120    /// Input token price in USD per 1M tokens
121    pub input_price_per_million: f64,
122    
123    /// Output token price in USD per 1M tokens  
124    pub output_price_per_million: f64,
125    
126    /// Batch input token price in USD per 1M tokens (if available)
127    pub batch_input_price_per_million: Option<f64>,
128    
129    /// Batch output token price in USD per 1M tokens (if available)
130    pub batch_output_price_per_million: Option<f64>,
131    
132    /// Cache write price in USD per 1M tokens (if available)
133    pub cache_write_price_per_million: Option<f64>,
134    
135    /// Cache read price in USD per 1M tokens (if available)
136    pub cache_read_price_per_million: Option<f64>,
137    
138    /// Pricing tier or category
139    pub tier: PricingTier,
140    
141    /// Currency code (usually "USD")
142    pub currency: String,
143    
144    /// When this pricing was last updated
145    pub updated_at: DateTime<Utc>,
146}
147
148/// Pricing tier categories
149#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
150#[serde(rename_all = "lowercase")]
151pub enum PricingTier {
152    /// Premium/flagship models with highest capabilities
153    Premium,
154    /// Standard models with balanced price/performance
155    Standard,
156    /// Fast/efficient models optimized for speed
157    Fast,
158    /// Legacy models with older pricing
159    Legacy,
160}
161
162/// Model comparison result
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct ModelComparison {
165    /// Models being compared
166    pub models: Vec<ModelObject>,
167    
168    /// Capabilities comparison
169    pub capabilities: Vec<ModelCapabilities>,
170    
171    /// Pricing comparison
172    pub pricing: Vec<ModelPricing>,
173    
174    /// Performance characteristics
175    pub performance: Vec<ModelPerformance>,
176    
177    /// Comparison summary and recommendations
178    pub summary: ComparisonSummary,
179}
180
181/// Performance characteristics for a model
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct ModelPerformance {
184    /// Model ID
185    pub model_id: String,
186    
187    /// Relative speed score (1-10, higher is faster)
188    pub speed_score: u8,
189    
190    /// Relative quality score (1-10, higher is better)
191    pub quality_score: u8,
192    
193    /// Average response time in milliseconds
194    pub avg_response_time_ms: Option<u64>,
195    
196    /// Tokens per second throughput
197    pub tokens_per_second: Option<f64>,
198    
199    /// Cost efficiency score (1-10, higher is more cost effective)
200    pub cost_efficiency_score: u8,
201}
202
203/// Summary of model comparison
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct ComparisonSummary {
206    /// Best model for speed
207    pub fastest_model: String,
208    
209    /// Best model for quality
210    pub highest_quality_model: String,
211    
212    /// Most cost-effective model
213    pub most_cost_effective_model: String,
214    
215    /// Best overall balanced model
216    pub best_overall_model: String,
217    
218    /// Key differences and trade-offs
219    pub key_differences: Vec<String>,
220    
221    /// Recommendations by use case
222    pub use_case_recommendations: HashMap<String, String>,
223}
224
225/// Requirements for model selection
226#[derive(Debug, Clone, Default)]
227pub struct ModelRequirements {
228    /// Maximum cost per input token
229    pub max_input_cost_per_token: Option<f64>,
230    
231    /// Maximum cost per output token
232    pub max_output_cost_per_token: Option<f64>,
233    
234    /// Minimum context length required
235    pub min_context_length: Option<u64>,
236    
237    /// Required capabilities
238    pub required_capabilities: Vec<ModelCapability>,
239    
240    /// Preferred model family
241    pub preferred_family: Option<String>,
242    
243    /// Minimum speed score
244    pub min_speed_score: Option<u8>,
245    
246    /// Minimum quality score
247    pub min_quality_score: Option<u8>,
248    
249    /// Whether vision support is required
250    pub requires_vision: Option<bool>,
251    
252    /// Whether tool use support is required
253    pub requires_tools: Option<bool>,
254    
255    /// Maximum acceptable response time in milliseconds
256    pub max_response_time_ms: Option<u64>,
257    
258    /// Preferred languages (ISO codes)
259    pub preferred_languages: Vec<String>,
260}
261
262/// Usage recommendations for specific use cases
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct ModelUsageRecommendations {
265    /// Use case category
266    pub use_case: String,
267    
268    /// Recommended models in order of preference
269    pub recommended_models: Vec<ModelRecommendation>,
270    
271    /// General guidelines for this use case
272    pub guidelines: Vec<String>,
273    
274    /// Recommended parameters
275    pub recommended_parameters: RecommendedParameters,
276    
277    /// Common pitfalls to avoid
278    pub pitfalls: Vec<String>,
279    
280    /// Expected performance characteristics
281    pub expected_performance: PerformanceExpectations,
282}
283
284/// Individual model recommendation
285#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct ModelRecommendation {
287    /// Model ID
288    pub model_id: String,
289    
290    /// Reason for recommendation
291    pub reason: String,
292    
293    /// Confidence score (1-10)
294    pub confidence_score: u8,
295    
296    /// Expected cost range for typical usage
297    pub cost_range: CostRange,
298    
299    /// Specific strengths for this use case
300    pub strengths: Vec<String>,
301    
302    /// Potential limitations
303    pub limitations: Vec<String>,
304}
305
306/// Recommended parameters for a use case
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct RecommendedParameters {
309    /// Recommended temperature range
310    pub temperature_range: (f32, f32),
311    
312    /// Recommended max_tokens range
313    pub max_tokens_range: (u32, u32),
314    
315    /// Recommended top_p range
316    pub top_p_range: Option<(f32, f32)>,
317    
318    /// Whether to use streaming
319    pub use_streaming: Option<bool>,
320    
321    /// Recommended system message patterns
322    pub system_message_patterns: Vec<String>,
323}
324
325/// Expected performance for a use case
326#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct PerformanceExpectations {
328    /// Expected response time range in milliseconds
329    pub response_time_range_ms: (u64, u64),
330    
331    /// Expected cost range for typical request
332    pub cost_range: CostRange,
333    
334    /// Expected quality level
335    pub quality_level: QualityLevel,
336    
337    /// Success rate expectations
338    pub success_rate_percentage: f32,
339}
340
341/// Cost range information
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct CostRange {
344    /// Minimum cost in USD
345    pub min_cost_usd: f64,
346    
347    /// Maximum cost in USD
348    pub max_cost_usd: f64,
349    
350    /// Typical/average cost in USD
351    pub typical_cost_usd: f64,
352}
353
354/// Quality level categories
355#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
356#[serde(rename_all = "lowercase")]
357pub enum QualityLevel {
358    /// Highest quality, most accurate
359    Excellent,
360    /// Good quality, reliable
361    Good,
362    /// Acceptable quality, some limitations
363    Acceptable,
364    /// Lower quality, may need refinement
365    Basic,
366}
367
368/// Cost estimation result
369#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct CostEstimation {
371    /// Model ID
372    pub model_id: String,
373    
374    /// Input tokens
375    pub input_tokens: u64,
376    
377    /// Output tokens
378    pub output_tokens: u64,
379    
380    /// Input cost in USD
381    pub input_cost_usd: f64,
382    
383    /// Output cost in USD
384    pub output_cost_usd: f64,
385    
386    /// Total cost in USD
387    pub total_cost_usd: f64,
388    
389    /// Batch discount (if applicable)
390    pub batch_discount_usd: Option<f64>,
391    
392    /// Cache savings (if applicable)
393    pub cache_savings_usd: Option<f64>,
394    
395    /// Final cost after discounts
396    pub final_cost_usd: f64,
397    
398    /// Cost breakdown
399    pub breakdown: CostBreakdown,
400}
401
402/// Detailed cost breakdown
403#[derive(Debug, Clone, Serialize, Deserialize)]
404pub struct CostBreakdown {
405    /// Cost per input token
406    pub cost_per_input_token_usd: f64,
407    
408    /// Cost per output token
409    pub cost_per_output_token_usd: f64,
410    
411    /// Effective cost per token (total / total tokens)
412    pub effective_cost_per_token_usd: f64,
413    
414    /// Cost comparison to other models
415    pub cost_vs_alternatives: HashMap<String, f64>,
416}
417
418impl ModelListParams {
419    /// Create new model list parameters
420    pub fn new() -> Self {
421        Self::default()
422    }
423    
424    /// Set the before_id for pagination
425    pub fn before_id(mut self, before_id: impl Into<String>) -> Self {
426        self.before_id = Some(before_id.into());
427        self
428    }
429    
430    /// Set the after_id for pagination
431    pub fn after_id(mut self, after_id: impl Into<String>) -> Self {
432        self.after_id = Some(after_id.into());
433        self
434    }
435    
436    /// Set the limit for pagination
437    pub fn limit(mut self, limit: u32) -> Self {
438        self.limit = Some(limit.min(1000).max(1));
439        self
440    }
441}
442
443impl ModelRequirements {
444    /// Create new model requirements
445    pub fn new() -> Self {
446        Self::default()
447    }
448    
449    /// Set maximum input cost per token
450    pub fn max_input_cost_per_token(mut self, cost: f64) -> Self {
451        self.max_input_cost_per_token = Some(cost);
452        self
453    }
454    
455    /// Set maximum output cost per token
456    pub fn max_output_cost_per_token(mut self, cost: f64) -> Self {
457        self.max_output_cost_per_token = Some(cost);
458        self
459    }
460    
461    /// Set minimum context length requirement
462    pub fn min_context_length(mut self, length: u64) -> Self {
463        self.min_context_length = Some(length);
464        self
465    }
466    
467    /// Add required capability
468    pub fn require_capability(mut self, capability: ModelCapability) -> Self {
469        self.required_capabilities.push(capability);
470        self
471    }
472    
473    /// Set required capabilities
474    pub fn capabilities(mut self, capabilities: Vec<ModelCapability>) -> Self {
475        self.required_capabilities = capabilities;
476        self
477    }
478    
479    /// Set preferred model family
480    pub fn preferred_family(mut self, family: impl Into<String>) -> Self {
481        self.preferred_family = Some(family.into());
482        self
483    }
484    
485    /// Require vision support
486    pub fn require_vision(mut self) -> Self {
487        self.requires_vision = Some(true);
488        self
489    }
490    
491    /// Require tool use support
492    pub fn require_tools(mut self) -> Self {
493        self.requires_tools = Some(true);
494        self
495    }
496    
497    /// Set minimum quality score
498    pub fn min_quality_score(mut self, score: u8) -> Self {
499        self.min_quality_score = Some(score.min(10));
500        self
501    }
502    
503    /// Set minimum speed score
504    pub fn min_speed_score(mut self, score: u8) -> Self {
505        self.min_speed_score = Some(score.min(10));
506        self
507    }
508}
509
510impl ModelObject {
511    /// Check if this is a latest/alias model
512    pub fn is_alias(&self) -> bool {
513        self.id.contains("latest") || self.id.ends_with("-0")
514    }
515    
516    /// Get the model family (e.g., "claude-3-5" from "claude-3-5-sonnet-latest")
517    pub fn family(&self) -> String {
518        let parts: Vec<&str> = self.id.split('-').collect();
519        if parts.len() >= 3 {
520            format!("{}-{}", parts[0], parts[1])
521        } else {
522            parts[0].to_string()
523        }
524    }
525    
526    /// Check if model belongs to a specific family
527    pub fn is_family(&self, family: &str) -> bool {
528        self.id.starts_with(family)
529    }
530    
531    /// Get model size/tier (sonnet, haiku, opus)
532    pub fn model_size(&self) -> Option<String> {
533        if self.id.contains("opus") {
534            Some("opus".to_string())
535        } else if self.id.contains("sonnet") {
536            Some("sonnet".to_string())
537        } else if self.id.contains("haiku") {
538            Some("haiku".to_string())
539        } else {
540            None
541        }
542    }
543}
544
545impl ModelComparison {
546    /// Get the best model for a specific criterion
547    pub fn best_for_speed(&self) -> Option<&ModelObject> {
548        self.performance
549            .iter()
550            .max_by_key(|p| p.speed_score)
551            .and_then(|p| self.models.iter().find(|m| m.id == p.model_id))
552    }
553    
554    /// Get the best model for quality
555    pub fn best_for_quality(&self) -> Option<&ModelObject> {
556        self.performance
557            .iter()
558            .max_by_key(|p| p.quality_score)
559            .and_then(|p| self.models.iter().find(|m| m.id == p.model_id))
560    }
561    
562    /// Get the most cost-effective model
563    pub fn most_cost_effective(&self) -> Option<&ModelObject> {
564        self.performance
565            .iter()
566            .max_by_key(|p| p.cost_efficiency_score)
567            .and_then(|p| self.models.iter().find(|m| m.id == p.model_id))
568    }
569}
570
571impl CostEstimation {
572    /// Calculate cost per 1000 tokens
573    pub fn cost_per_1k_tokens(&self) -> f64 {
574        let total_tokens = self.input_tokens + self.output_tokens;
575        if total_tokens > 0 {
576            (self.final_cost_usd * 1000.0) / total_tokens as f64
577        } else {
578            0.0
579        }
580    }
581    
582    /// Get savings percentage from discounts
583    pub fn savings_percentage(&self) -> f64 {
584        let original_cost = self.input_cost_usd + self.output_cost_usd;
585        if original_cost > 0.0 {
586            ((original_cost - self.final_cost_usd) / original_cost) * 100.0
587        } else {
588            0.0
589        }
590    }
591}
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596    
597    #[test]
598    fn test_model_list_params_builder() {
599        let params = ModelListParams::new()
600            .limit(50)
601            .after_id("model_123");
602            
603        assert_eq!(params.limit, Some(50));
604        assert_eq!(params.after_id, Some("model_123".to_string()));
605        assert_eq!(params.before_id, None);
606    }
607    
608    #[test]
609    fn test_model_requirements_builder() {
610        let requirements = ModelRequirements::new()
611            .max_input_cost_per_token(0.01)
612            .min_context_length(100000)
613            .require_vision()
614            .require_capability(ModelCapability::ToolUse);
615            
616        assert_eq!(requirements.max_input_cost_per_token, Some(0.01));
617        assert_eq!(requirements.min_context_length, Some(100000));
618        assert_eq!(requirements.requires_vision, Some(true));
619        assert!(requirements.required_capabilities.contains(&ModelCapability::ToolUse));
620    }
621    
622    #[test]
623    fn test_model_object_methods() {
624        let model = ModelObject {
625            id: "claude-3-5-sonnet-latest".to_string(),
626            display_name: "Claude 3.5 Sonnet".to_string(),
627            created_at: Utc::now(),
628            object_type: "model".to_string(),
629        };
630        
631        assert!(model.is_alias());
632        assert_eq!(model.family(), "claude-3");
633        assert!(model.is_family("claude-3-5"));
634        assert_eq!(model.model_size(), Some("sonnet".to_string()));
635    }
636    
637    #[test]
638    fn test_cost_estimation_calculations() {
639        let estimation = CostEstimation {
640            model_id: "test-model".to_string(),
641            input_tokens: 1000,
642            output_tokens: 500,
643            input_cost_usd: 0.01,
644            output_cost_usd: 0.03,
645            total_cost_usd: 0.04,
646            batch_discount_usd: Some(0.005),
647            cache_savings_usd: None,
648            final_cost_usd: 0.035,
649            breakdown: CostBreakdown {
650                cost_per_input_token_usd: 0.00001,
651                cost_per_output_token_usd: 0.00006,
652                effective_cost_per_token_usd: 0.000023,
653                cost_vs_alternatives: HashMap::new(),
654            },
655        };
656        
657        assert!((estimation.cost_per_1k_tokens() - 0.02333).abs() < 0.001);
658        assert!((estimation.savings_percentage() - 12.5).abs() < 0.1);
659    }
660    
661    #[test]
662    fn test_limit_validation() {
663        let params = ModelListParams::new().limit(2000);
664        assert_eq!(params.limit, Some(1000)); // Should be clamped to max
665        
666        let params = ModelListParams::new().limit(0);
667        assert_eq!(params.limit, Some(1)); // Should be clamped to min
668    }
669    
670    #[test]
671    fn test_model_capability_serialization() {
672        let capability = ModelCapability::Vision;
673        let serialized = serde_json::to_string(&capability).unwrap();
674        assert_eq!(serialized, "\"vision\"");
675        
676        let deserialized: ModelCapability = serde_json::from_str(&serialized).unwrap();
677        assert_eq!(deserialized, ModelCapability::Vision);
678    }
679}