Skip to main content

atomcode_core/
pricing.rs

1//! Model pricing — maps model names to per-million-token costs.
2
3/// Returns (input_price, output_price) in USD per million tokens.
4/// Uses substring matching to handle version suffixes (e.g. "claude-sonnet-4-20250514").
5///
6/// Match order matters: specific names (gpt-4o-mini, o1-mini) are tested
7/// before their broader prefixes (gpt-4o, o1) to prevent premature hits.
8/// This is tested by the ordering of the if-else chain — adding a new model
9/// that is a substring of an existing one must go ABOVE the broader match.
10/// Unknown models get a conservative (1.0, 3.0) fallback so cost tracking
11/// never silently returns zero for a billable model.
12pub fn cost_per_million(model: &str) -> (f64, f64) {
13    let m = model.to_lowercase();
14    // Claude
15    if m.contains("opus") { (15.0, 75.0) }
16    else if m.contains("sonnet") { (3.0, 15.0) }
17    else if m.contains("haiku") { (0.25, 1.25) }
18    // OpenAI (gpt-4o-mini must come before gpt-4o to avoid premature match)
19    else if m.contains("gpt-4o-mini") { (0.15, 0.6) }
20    else if m.contains("gpt-4o") { (2.5, 10.0) }
21    else if m.contains("gpt-4.1") { (2.0, 8.0) }
22    // o1/o3 mini must come before o1/o3 to avoid premature match
23    else if m.contains("o1-mini") || m.contains("o3-mini") { (3.0, 12.0) }
24    else if m.contains("o1") { (15.0, 60.0) }
25    else if m.contains("o3") { (10.0, 40.0) }
26    // DeepSeek
27    else if m.contains("deepseek") { (0.27, 1.1) }
28    // Qwen
29    else if m.contains("qwen") { (0.5, 2.0) }
30    // GLM / Zhipu
31    else if m.contains("glm") { (0.5, 2.0) }
32    // SiliconFlow / open models
33    else if m.contains("llama") || m.contains("mistral") { (0.3, 0.6) }
34    // MiniMax
35    else if m.contains("minimax") || m.contains("m2.7") { (0.5, 2.0) }
36    // Local / Ollama — free
37    else if m.contains("ollama") { (0.0, 0.0) }
38    // Unknown — conservative estimate
39    else { (1.0, 3.0) }
40}
41
42/// Calculate cost in USD from token counts and model name.
43pub fn calculate_cost(model: &str, prompt_tokens: usize, completion_tokens: usize, cached_tokens: usize) -> f64 {
44    let (input_price, output_price) = cost_per_million(model);
45    // Cached tokens are typically 90% cheaper (Anthropic) or free (OpenAI)
46    let cached_price = input_price * 0.1;
47
48    let input_cost = (prompt_tokens as f64 / 1_000_000.0) * input_price;
49    let cached_cost = (cached_tokens as f64 / 1_000_000.0) * cached_price;
50    let output_cost = (completion_tokens as f64 / 1_000_000.0) * output_price;
51
52    input_cost + cached_cost + output_cost
53}
54
55/// Format cost as a human-readable string.
56pub fn format_cost(cost: f64) -> String {
57    if cost < 0.01 {
58        format!("${:.4}", cost)
59    } else {
60        format!("${:.2}", cost)
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67
68    #[test]
69    fn test_claude_sonnet_pricing() {
70        let (i, o) = cost_per_million("claude-sonnet-4-20250514");
71        assert_eq!(i, 3.0);
72        assert_eq!(o, 15.0);
73    }
74
75    #[test]
76    fn test_deepseek_pricing() {
77        let (i, o) = cost_per_million("deepseek-chat");
78        assert_eq!(i, 0.27);
79        assert_eq!(o, 1.1);
80    }
81
82    #[test]
83    fn test_gpt4o_pricing() {
84        let (i, o) = cost_per_million("gpt-4o");
85        assert_eq!(i, 2.5);
86        assert_eq!(o, 10.0);
87    }
88
89    #[test]
90    fn test_unknown_model() {
91        let (i, o) = cost_per_million("some-unknown-model");
92        assert_eq!(i, 1.0);
93        assert_eq!(o, 3.0);
94    }
95
96    #[test]
97    fn test_calculate_cost() {
98        // 1000 prompt tokens + 500 completion tokens with deepseek
99        let cost = calculate_cost("deepseek-chat", 1000, 500, 0);
100        let expected = (1000.0 / 1_000_000.0) * 0.27 + (500.0 / 1_000_000.0) * 1.1;
101        assert!((cost - expected).abs() < 1e-10);
102    }
103
104    #[test]
105    fn test_calculate_cost_with_cache() {
106        let cost = calculate_cost("claude-sonnet-4-20250514", 1000, 500, 800);
107        // input: 1000 * 3.0/1M, cached: 800 * 0.3/1M, output: 500 * 15.0/1M
108        let expected = 1000.0 * 3.0 / 1e6 + 800.0 * 0.3 / 1e6 + 500.0 * 15.0 / 1e6;
109        assert!((cost - expected).abs() < 1e-10);
110    }
111
112    #[test]
113    fn test_format_cost() {
114        assert_eq!(format_cost(0.42), "$0.42");
115        assert_eq!(format_cost(0.001), "$0.0010");
116        assert_eq!(format_cost(12.5), "$12.50");
117    }
118}