1use anyhow::Result;
6use serde::Deserialize;
7
8const COST_TOML: &str = include_str!("../../assets/cost.toml");
9
10#[derive(Debug, Deserialize)]
11pub struct ModelPrice {
12 pub id: String,
13 pub input_per_mtok: i64,
14 pub output_per_mtok: i64,
15 #[serde(default)]
16 pub avg_tokens_per_turn: u32,
17}
18
19#[derive(Debug, Deserialize)]
20struct CostFile {
21 models: Vec<ModelPrice>,
22}
23
24pub struct CostTable {
25 models: Vec<ModelPrice>,
26}
27
28impl CostTable {
29 pub fn load() -> Result<Self> {
31 let f: CostFile = toml::from_str(COST_TOML)?;
32 Ok(Self { models: f.models })
33 }
34
35 pub fn estimate(&self, model: Option<&str>, tokens_in: u32, tokens_out: u32) -> i64 {
40 let entry = model
41 .and_then(|m| self.models.iter().find(|p| p.id == m))
42 .or_else(|| self.models.iter().find(|p| p.id == "cursor"));
43
44 let Some(price) = entry else { return 0 };
45
46 let (tin, tout) = if tokens_in == 0 && tokens_out == 0 && price.avg_tokens_per_turn > 0 {
47 let avg = price.avg_tokens_per_turn as i64;
48 (avg * 4 / 5, avg / 5)
49 } else {
50 (tokens_in as i64, tokens_out as i64)
51 };
52
53 tin * price.input_per_mtok / 1_000_000 + tout * price.output_per_mtok / 1_000_000
54 }
55
56 pub fn find(&self, model: &str) -> Option<&ModelPrice> {
57 self.models.iter().find(|p| p.id == model)
58 }
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64
65 #[test]
66 fn load_succeeds() {
67 CostTable::load().unwrap();
68 }
69
70 #[test]
71 fn known_model_cost() {
72 let table = CostTable::load().unwrap();
73 let cost = table.estimate(Some("claude-sonnet-4"), 1000, 500);
76 assert_eq!(cost, 10500);
77 }
78
79 #[test]
80 fn cursor_heuristic_nonzero() {
81 let table = CostTable::load().unwrap();
82 let cost = table.estimate(None, 0, 0);
84 assert!(cost > 0, "cursor heuristic should produce nonzero cost");
85 }
86
87 #[test]
88 fn unknown_model_falls_back_to_cursor() {
89 let table = CostTable::load().unwrap();
90 let cost = table.estimate(Some("unknown-model-xyz"), 1000, 500);
91 let cost2 = table.estimate(None, 1000, 500);
92 assert_eq!(cost, cost2);
93 }
94}