claude_agent/models/
spec.rs1use super::family::ModelFamily;
2use super::provider::{CloudProvider, ProviderIds};
3use serde::{Deserialize, Serialize};
4
5pub type ModelId = String;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ModelSpec {
9 pub id: ModelId,
10 pub family: ModelFamily,
11 pub version: ModelVersion,
12 pub capabilities: Capabilities,
13 pub pricing: Pricing,
14 pub provider_ids: ProviderIds,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ModelVersion {
19 pub version: String,
20 pub snapshot: Option<String>,
21 pub knowledge_cutoff: Option<String>,
22}
23
24#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
25pub struct Capabilities {
26 pub context_window: u64,
27 pub extended_context_window: Option<u64>,
28 pub max_output_tokens: u64,
29 pub thinking: bool,
30 pub vision: bool,
31 pub tool_use: bool,
32 pub caching: bool,
33}
34
35impl Capabilities {
36 pub fn effective_context(&self, extended_enabled: bool) -> u64 {
37 if extended_enabled {
38 self.extended_context_window.unwrap_or(self.context_window)
39 } else {
40 self.context_window
41 }
42 }
43
44 pub fn supports_extended_context(&self) -> bool {
45 self.extended_context_window.is_some()
46 }
47}
48
49pub const LONG_CONTEXT_THRESHOLD: u64 = 200_000;
50
51#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
52pub struct Pricing {
53 pub input: f64,
54 pub output: f64,
55 pub cache_read: f64,
56 pub cache_write: f64,
57 pub long_context_multiplier: f64,
58}
59
60impl Pricing {
61 pub fn new(input: f64, output: f64) -> Self {
62 Self {
63 input,
64 output,
65 cache_read: input * 0.1,
66 cache_write: input * 1.25,
67 long_context_multiplier: 2.0,
68 }
69 }
70
71 pub fn calculate(
72 &self,
73 input_tokens: u64,
74 output_tokens: u64,
75 cache_read: u64,
76 cache_write: u64,
77 ) -> f64 {
78 let context = input_tokens + cache_read + cache_write;
79 let multiplier = if context > LONG_CONTEXT_THRESHOLD {
80 self.long_context_multiplier
81 } else {
82 1.0
83 };
84
85 let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input * multiplier;
86 let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output;
87 let cache_read_cost = (cache_read as f64 / 1_000_000.0) * self.cache_read * multiplier;
88 let cache_write_cost = (cache_write as f64 / 1_000_000.0) * self.cache_write * multiplier;
89
90 input_cost + output_cost + cache_read_cost + cache_write_cost
91 }
92}
93
94impl ModelSpec {
95 pub fn provider_id(&self, provider: CloudProvider) -> Option<&str> {
96 self.provider_ids.for_provider(provider)
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn test_pricing_standard() {
106 let pricing = Pricing::new(3.0, 15.0);
107 let cost = pricing.calculate(100_000, 100_000, 0, 0);
109 assert!((cost - 1.8).abs() < 0.01);
111 }
112
113 #[test]
114 fn test_pricing_large_volume() {
115 let pricing = Pricing::new(3.0, 15.0);
116 let cost = pricing.calculate(1_000_000, 1_000_000, 0, 0);
118 assert!((cost - 21.0).abs() < 0.01);
120 }
121
122 #[test]
123 fn test_pricing_long_context() {
124 let pricing = Pricing::new(3.0, 15.0);
125 let cost = pricing.calculate(250_000, 0, 0, 0);
126 let expected = (250_000.0 / 1_000_000.0) * 3.0 * 2.0;
127 assert!((cost - expected).abs() < 0.01);
128 }
129
130 #[test]
131 fn test_effective_context() {
132 let caps = Capabilities {
133 context_window: 200_000,
134 extended_context_window: Some(1_000_000),
135 max_output_tokens: 64_000,
136 thinking: true,
137 vision: true,
138 tool_use: true,
139 caching: true,
140 };
141
142 assert_eq!(caps.effective_context(false), 200_000);
143 assert_eq!(caps.effective_context(true), 1_000_000);
144 }
145}