claude_agent/models/
spec.rs

1use super::family::ModelFamily;
2use super::provider::{CloudProvider, ProviderIds};
3use serde::{Deserialize, Serialize};
4
5pub type ModelId = String;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ModelSpec {
9    pub id: ModelId,
10    pub family: ModelFamily,
11    pub version: ModelVersion,
12    pub capabilities: Capabilities,
13    pub pricing: Pricing,
14    pub provider_ids: ProviderIds,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ModelVersion {
19    pub version: String,
20    pub snapshot: Option<String>,
21    pub knowledge_cutoff: Option<String>,
22}
23
24#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
25pub struct Capabilities {
26    pub context_window: u64,
27    pub extended_context_window: Option<u64>,
28    pub max_output_tokens: u64,
29    pub thinking: bool,
30    pub vision: bool,
31    pub tool_use: bool,
32    pub caching: bool,
33}
34
35impl Capabilities {
36    pub fn effective_context(&self, extended_enabled: bool) -> u64 {
37        if extended_enabled {
38            self.extended_context_window.unwrap_or(self.context_window)
39        } else {
40            self.context_window
41        }
42    }
43
44    pub fn supports_extended_context(&self) -> bool {
45        self.extended_context_window.is_some()
46    }
47}
48
49pub const LONG_CONTEXT_THRESHOLD: u64 = 200_000;
50
51#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
52pub struct Pricing {
53    pub input: f64,
54    pub output: f64,
55    pub cache_read: f64,
56    pub cache_write: f64,
57    pub long_context_multiplier: f64,
58}
59
60impl Pricing {
61    pub fn new(input: f64, output: f64) -> Self {
62        Self {
63            input,
64            output,
65            cache_read: input * 0.1,
66            cache_write: input * 1.25,
67            long_context_multiplier: 2.0,
68        }
69    }
70
71    pub fn calculate(
72        &self,
73        input_tokens: u64,
74        output_tokens: u64,
75        cache_read: u64,
76        cache_write: u64,
77    ) -> f64 {
78        let context = input_tokens + cache_read + cache_write;
79        let multiplier = if context > LONG_CONTEXT_THRESHOLD {
80            self.long_context_multiplier
81        } else {
82            1.0
83        };
84
85        let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input * multiplier;
86        let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output;
87        let cache_read_cost = (cache_read as f64 / 1_000_000.0) * self.cache_read * multiplier;
88        let cache_write_cost = (cache_write as f64 / 1_000_000.0) * self.cache_write * multiplier;
89
90        input_cost + output_cost + cache_read_cost + cache_write_cost
91    }
92}
93
94impl ModelSpec {
95    pub fn provider_id(&self, provider: CloudProvider) -> Option<&str> {
96        self.provider_ids.for_provider(provider)
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn test_pricing_standard() {
106        let pricing = Pricing::new(3.0, 15.0);
107        // Standard pricing: context < 200K
108        let cost = pricing.calculate(100_000, 100_000, 0, 0);
109        // input: 0.1 * 3.0 = 0.3, output: 0.1 * 15.0 = 1.5
110        assert!((cost - 1.8).abs() < 0.01);
111    }
112
113    #[test]
114    fn test_pricing_large_volume() {
115        let pricing = Pricing::new(3.0, 15.0);
116        // 1M tokens each, context = 1M > 200K, so 2x multiplier on input
117        let cost = pricing.calculate(1_000_000, 1_000_000, 0, 0);
118        // input: 1.0 * 3.0 * 2.0 = 6.0, output: 1.0 * 15.0 = 15.0
119        assert!((cost - 21.0).abs() < 0.01);
120    }
121
122    #[test]
123    fn test_pricing_long_context() {
124        let pricing = Pricing::new(3.0, 15.0);
125        let cost = pricing.calculate(250_000, 0, 0, 0);
126        let expected = (250_000.0 / 1_000_000.0) * 3.0 * 2.0;
127        assert!((cost - expected).abs() < 0.01);
128    }
129
130    #[test]
131    fn test_effective_context() {
132        let caps = Capabilities {
133            context_window: 200_000,
134            extended_context_window: Some(1_000_000),
135            max_output_tokens: 64_000,
136            thinking: true,
137            vision: true,
138            tool_use: true,
139            caching: true,
140        };
141
142        assert_eq!(caps.effective_context(false), 200_000);
143        assert_eq!(caps.effective_context(true), 1_000_000);
144    }
145}