Skip to main content

claude_agent/models/
spec.rs

1use serde::{Deserialize, Serialize};
2
3use super::family::ModelFamily;
4use super::provider::{ProviderIds, ProviderKind};
5use crate::budget::ModelPricing;
6
7pub type ModelId = String;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ModelSpec {
11    pub id: ModelId,
12    pub family: ModelFamily,
13    pub version: ModelVersion,
14    pub capabilities: Capabilities,
15    pub pricing: ModelPricing,
16    pub provider_ids: ProviderIds,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ModelVersion {
21    pub version: String,
22    pub snapshot: Option<String>,
23    pub knowledge_cutoff: Option<String>,
24}
25
26#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
27pub struct Capabilities {
28    pub context_window: u64,
29    pub extended_context_window: Option<u64>,
30    pub max_output_tokens: u64,
31    pub thinking: bool,
32    pub vision: bool,
33    pub tool_use: bool,
34    pub caching: bool,
35}
36
37impl Capabilities {
38    pub fn effective_context(&self, extended_enabled: bool) -> u64 {
39        if extended_enabled {
40            self.extended_context_window.unwrap_or(self.context_window)
41        } else {
42            self.context_window
43        }
44    }
45
46    pub fn supports_extended_context(&self) -> bool {
47        self.extended_context_window.is_some()
48    }
49}
50
51pub const LONG_CONTEXT_THRESHOLD: u64 = 200_000;
52
53impl ModelSpec {
54    pub fn provider_id(&self, provider: ProviderKind) -> Option<&str> {
55        self.provider_ids.for_provider(provider)
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62    use rust_decimal_macros::dec;
63
64    #[test]
65    fn test_pricing_standard() {
66        let pricing = ModelPricing::from_base(dec!(3), dec!(15));
67        // Standard pricing: context < 200K
68        let cost = pricing.calculate_raw(100_000, 100_000, 0, 0);
69        // input: 0.1 * 3.0 = 0.3, output: 0.1 * 15.0 = 1.5
70        assert_eq!(cost, dec!(1.8));
71    }
72
73    #[test]
74    fn test_pricing_large_volume() {
75        let pricing = ModelPricing::from_base(dec!(3), dec!(15));
76        // 1M tokens each, context = 1M > 200K, so 2x multiplier on input
77        let cost = pricing.calculate_raw(1_000_000, 1_000_000, 0, 0);
78        // input: 1.0 * 3.0 * 2.0 = 6.0, output: 1.0 * 15.0 = 15.0
79        assert_eq!(cost, dec!(21));
80    }
81
82    #[test]
83    fn test_pricing_long_context() {
84        let pricing = ModelPricing::from_base(dec!(3), dec!(15));
85        let cost = pricing.calculate_raw(250_000, 0, 0, 0);
86        // 0.25 * 3.0 * 2.0 = 1.5
87        assert_eq!(cost, dec!(1.5));
88    }
89
90    #[test]
91    fn test_effective_context() {
92        let caps = Capabilities {
93            context_window: 200_000,
94            extended_context_window: Some(1_000_000),
95            max_output_tokens: 64_000,
96            thinking: true,
97            vision: true,
98            tool_use: true,
99            caching: true,
100        };
101
102        assert_eq!(caps.effective_context(false), 200_000);
103        assert_eq!(caps.effective_context(true), 1_000_000);
104    }
105}