1use anyhow::Result;
6use serde::Deserialize;
7use std::sync::OnceLock;
8
9const COST_TOML: &str = include_str!("../../assets/cost.toml");
10
11#[derive(Debug, Deserialize)]
12pub struct ModelPrice {
13 pub id: String,
14 pub input_per_mtok: i64,
15 pub output_per_mtok: i64,
16 #[serde(default)]
17 pub avg_tokens_per_turn: u32,
18}
19
20#[derive(Debug, Deserialize)]
21struct CostFile {
22 models: Vec<ModelPrice>,
23}
24
25pub struct CostTable {
26 models: Vec<ModelPrice>,
27}
28
29impl CostTable {
30 pub fn load() -> Result<Self> {
32 let f: CostFile = toml::from_str(COST_TOML)?;
33 Ok(Self { models: f.models })
34 }
35
36 pub fn estimate(&self, model: Option<&str>, tokens_in: u32, tokens_out: u32) -> i64 {
41 let entry = model
42 .and_then(|m| self.models.iter().find(|p| p.id == m))
43 .or_else(|| self.models.iter().find(|p| p.id == "cursor"));
44
45 let Some(price) = entry else { return 0 };
46
47 let (tin, tout) = if tokens_in == 0 && tokens_out == 0 && price.avg_tokens_per_turn > 0 {
48 let avg = price.avg_tokens_per_turn as i64;
49 (avg * 4 / 5, avg / 5)
50 } else {
51 (tokens_in as i64, tokens_out as i64)
52 };
53
54 tin * price.input_per_mtok / 1_000_000 + tout * price.output_per_mtok / 1_000_000
55 }
56
57 pub fn find(&self, model: &str) -> Option<&ModelPrice> {
58 self.models.iter().find(|p| p.id == model)
59 }
60
61 pub fn estimate_tail_event_cost_usd_e6(
65 &self,
66 model: Option<&str>,
67 tokens_in: Option<u32>,
68 tokens_out: Option<u32>,
69 reasoning_tokens: Option<u32>,
70 ) -> Option<i64> {
71 let any_field = tokens_in.is_some() || tokens_out.is_some() || reasoning_tokens.is_some();
72 if !any_field {
73 return None;
74 }
75 let tin = tokens_in.unwrap_or(0);
76 let tout = tokens_out
77 .unwrap_or(0)
78 .saturating_add(reasoning_tokens.unwrap_or(0));
79 if tin == 0 && tout == 0 {
80 return None;
81 }
82 Some(self.estimate(model, tin, tout))
83 }
84}
85
86static BUNDLED_COST: OnceLock<CostTable> = OnceLock::new();
87
88fn bundled_cost_table() -> &'static CostTable {
89 BUNDLED_COST.get_or_init(|| CostTable::load().expect("bundled assets/cost.toml"))
90}
91
92pub fn estimate_tail_event_cost_usd_e6(
94 model: Option<&str>,
95 tokens_in: Option<u32>,
96 tokens_out: Option<u32>,
97 reasoning_tokens: Option<u32>,
98) -> Option<i64> {
99 bundled_cost_table().estimate_tail_event_cost_usd_e6(
100 model,
101 tokens_in,
102 tokens_out,
103 reasoning_tokens,
104 )
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn load_succeeds() {
113 CostTable::load().unwrap();
114 }
115
116 #[test]
117 fn known_model_cost() {
118 let table = CostTable::load().unwrap();
119 let cost = table.estimate(Some("claude-sonnet-4"), 1000, 500);
122 assert_eq!(cost, 10500);
123 }
124
125 #[test]
126 fn cursor_heuristic_nonzero() {
127 let table = CostTable::load().unwrap();
128 let cost = table.estimate(None, 0, 0);
130 assert!(cost > 0, "cursor heuristic should produce nonzero cost");
131 }
132
133 #[test]
134 fn unknown_model_falls_back_to_cursor() {
135 let table = CostTable::load().unwrap();
136 let cost = table.estimate(Some("unknown-model-xyz"), 1000, 500);
137 let cost2 = table.estimate(None, 1000, 500);
138 assert_eq!(cost, cost2);
139 }
140
141 #[test]
142 fn tail_estimate_none_without_usage_fields() {
143 let table = CostTable::load().unwrap();
144 assert!(
145 table
146 .estimate_tail_event_cost_usd_e6(None, None, None, None)
147 .is_none()
148 );
149 }
150
151 #[test]
152 fn tail_estimate_none_when_fields_present_but_all_zero() {
153 let table = CostTable::load().unwrap();
154 assert!(table
155 .estimate_tail_event_cost_usd_e6(
156 Some("claude-sonnet-4"),
157 Some(0),
158 Some(0),
159 Some(0),
160 )
161 .is_none());
162 }
163
164 #[test]
165 fn tail_estimate_adds_reasoning_to_output_side() {
166 let table = CostTable::load().unwrap();
167 let with_reasoning = table
168 .estimate_tail_event_cost_usd_e6(
169 Some("claude-sonnet-4"),
170 Some(1000),
171 Some(100),
172 Some(400),
173 )
174 .expect("cost");
175 let output_only = table
176 .estimate_tail_event_cost_usd_e6(Some("claude-sonnet-4"), Some(1000), Some(500), None)
177 .expect("cost");
178 assert_eq!(with_reasoning, output_only);
179 }
180}