Skip to main content

cognis_trace/
cost.rs

1//! USD-cost computation. Per-million-token rates by model id.
2//!
3//! Defaults are a dated snapshot — clearly marked. Users can override or add
4//! models. See spec §8.
5
6use std::collections::HashMap;
7
8use crate::span::{CostDetails, TokenUsage};
9
10/// Per-million-token rates for one model.
11#[derive(Debug, Clone, Copy, Default, PartialEq)]
12pub struct ModelPrice {
13    /// USD per 1M input tokens.
14    pub input: f64,
15    /// USD per 1M output tokens.
16    pub output: f64,
17    /// USD per 1M cache-read tokens.
18    pub cache_read: f64,
19    /// USD per 1M cache-write tokens.
20    pub cache_write: f64,
21}
22
23/// Map of model id to per-million-token rates.
24#[derive(Debug, Clone, Default)]
25pub struct PriceTable {
26    inner: HashMap<String, ModelPrice>,
27}
28
29impl PriceTable {
30    /// Empty table.
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    /// Pre-populated with `default_pricing_2026_05`.
36    pub fn with_defaults() -> Self {
37        default_pricing_2026_05()
38    }
39
40    /// Insert or replace a price entry.
41    pub fn insert(&mut self, model_id: impl Into<String>, price: ModelPrice) -> &mut Self {
42        self.inner.insert(model_id.into(), price);
43        self
44    }
45
46    /// Lookup, exact match first, then prefix match (e.g. "gpt-4o" matches
47    /// "gpt-4o-2024-08-06"). Returns `None` if no match.
48    pub fn get(&self, model_id: &str) -> Option<&ModelPrice> {
49        if let Some(p) = self.inner.get(model_id) {
50            return Some(p);
51        }
52        self.inner
53            .iter()
54            .filter(|(k, _)| model_id.starts_with(k.as_str()))
55            .max_by_key(|(k, _)| k.len())
56            .map(|(_, v)| v)
57    }
58
59    /// Compute structured cost from token usage. Returns `None` when the
60    /// model is unknown.
61    pub fn compute(&self, model_id: &str, usage: TokenUsage) -> Option<CostDetails> {
62        let p = self.get(model_id)?;
63        let scale = 1_000_000.0;
64        let input = (usage.input as f64) * p.input / scale;
65        let output = (usage.output as f64) * p.output / scale;
66        let cache_read = (usage.cache_read as f64) * p.cache_read / scale;
67        let cache_write = (usage.cache_write as f64) * p.cache_write / scale;
68        Some(CostDetails {
69            input,
70            output,
71            cache_read,
72            cache_write,
73            total: input + output + cache_read + cache_write,
74        })
75    }
76}
77
78/// Snapshot of mainstream provider prices captured 2026-05-06.
79/// Stale entries can be overridden with `PriceTable::insert`.
80pub fn default_pricing_2026_05() -> PriceTable {
81    let mut t = PriceTable::new();
82
83    // OpenAI (USD per 1M tokens, 2026-05 snapshot).
84    t.insert(
85        "gpt-4o",
86        ModelPrice {
87            input: 2.50,
88            output: 10.00,
89            cache_read: 1.25,
90            cache_write: 0.0,
91        },
92    );
93    t.insert(
94        "gpt-4o-mini",
95        ModelPrice {
96            input: 0.15,
97            output: 0.60,
98            cache_read: 0.075,
99            cache_write: 0.0,
100        },
101    );
102    t.insert(
103        "o1",
104        ModelPrice {
105            input: 15.00,
106            output: 60.00,
107            cache_read: 7.50,
108            cache_write: 0.0,
109        },
110    );
111    t.insert(
112        "o1-mini",
113        ModelPrice {
114            input: 3.00,
115            output: 12.00,
116            cache_read: 1.50,
117            cache_write: 0.0,
118        },
119    );
120
121    // Anthropic.
122    t.insert(
123        "claude-opus-4",
124        ModelPrice {
125            input: 15.00,
126            output: 75.00,
127            cache_read: 1.50,
128            cache_write: 18.75,
129        },
130    );
131    t.insert(
132        "claude-sonnet-4",
133        ModelPrice {
134            input: 3.00,
135            output: 15.00,
136            cache_read: 0.30,
137            cache_write: 3.75,
138        },
139    );
140    t.insert(
141        "claude-haiku-4",
142        ModelPrice {
143            input: 0.80,
144            output: 4.00,
145            cache_read: 0.08,
146            cache_write: 1.00,
147        },
148    );
149
150    // Google.
151    t.insert(
152        "gemini-2.0-flash",
153        ModelPrice {
154            input: 0.10,
155            output: 0.40,
156            cache_read: 0.025,
157            cache_write: 0.0,
158        },
159    );
160    t.insert(
161        "gemini-1.5-pro",
162        ModelPrice {
163            input: 1.25,
164            output: 5.00,
165            cache_read: 0.3125,
166            cache_write: 0.0,
167        },
168    );
169
170    t
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn empty_table_returns_none() {
179        let t = PriceTable::new();
180        assert!(t.compute("gpt-4o", TokenUsage::default()).is_none());
181    }
182
183    #[test]
184    fn exact_match_used_first() {
185        let mut t = PriceTable::new();
186        t.insert(
187            "gpt-4o-2024-08-06",
188            ModelPrice {
189                input: 1.0,
190                output: 2.0,
191                cache_read: 0.0,
192                cache_write: 0.0,
193            },
194        );
195        t.insert(
196            "gpt-4o",
197            ModelPrice {
198                input: 99.0,
199                output: 99.0,
200                cache_read: 0.0,
201                cache_write: 0.0,
202            },
203        );
204        let p = t.get("gpt-4o-2024-08-06").unwrap();
205        assert_eq!(p.input, 1.0);
206    }
207
208    #[test]
209    fn prefix_match_falls_back() {
210        let mut t = PriceTable::new();
211        t.insert(
212            "gpt-4o",
213            ModelPrice {
214                input: 2.50,
215                output: 10.00,
216                cache_read: 1.25,
217                cache_write: 0.0,
218            },
219        );
220        let p = t.get("gpt-4o-2024-08-06").unwrap();
221        assert_eq!(p.input, 2.50);
222    }
223
224    #[test]
225    fn longest_prefix_wins() {
226        let mut t = PriceTable::new();
227        t.insert(
228            "gpt",
229            ModelPrice {
230                input: 1.0,
231                output: 1.0,
232                cache_read: 0.0,
233                cache_write: 0.0,
234            },
235        );
236        t.insert(
237            "gpt-4o",
238            ModelPrice {
239                input: 2.50,
240                output: 10.00,
241                cache_read: 0.0,
242                cache_write: 0.0,
243            },
244        );
245        let p = t.get("gpt-4o-2024-08-06").unwrap();
246        assert_eq!(p.input, 2.50);
247    }
248
249    #[test]
250    fn compute_scales_per_million_tokens() {
251        let mut t = PriceTable::new();
252        t.insert(
253            "gpt-4o",
254            ModelPrice {
255                input: 2.50,
256                output: 10.00,
257                cache_read: 1.25,
258                cache_write: 0.0,
259            },
260        );
261        let usage = TokenUsage {
262            input: 1_000_000,
263            output: 500_000,
264            cache_read: 0,
265            cache_write: 0,
266        };
267        let c = t.compute("gpt-4o", usage).unwrap();
268        assert!((c.input - 2.50).abs() < 1e-9);
269        assert!((c.output - 5.00).abs() < 1e-9);
270        assert!((c.total - 7.50).abs() < 1e-9);
271    }
272
273    #[test]
274    fn defaults_contains_mainstream_models() {
275        let t = default_pricing_2026_05();
276        assert!(t.get("gpt-4o").is_some());
277        assert!(t.get("claude-sonnet-4").is_some());
278        assert!(t.get("gemini-2.0-flash").is_some());
279    }
280}