1use anyhow::Result;
6use serde::Deserialize;
7use std::sync::OnceLock;
8
9const COST_TOML: &str = include_str!("../../assets/cost.toml");
10
11#[derive(Debug, Deserialize)]
12pub struct ModelPrice {
13 pub id: String,
14 pub input_per_mtok: i64,
15 pub output_per_mtok: i64,
16 #[serde(default)]
17 pub cache_read_per_mtok: Option<i64>,
18 #[serde(default)]
19 pub cache_create_per_mtok: Option<i64>,
20 #[serde(default)]
21 pub avg_tokens_per_turn: u32,
22}
23
24#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
25pub struct TokenUsage {
26 pub input: u32,
27 pub output: u32,
28 pub reasoning: u32,
29 pub cache_read: u32,
30 pub cache_create: u32,
31}
32
33#[derive(Debug, Deserialize)]
34struct CostFile {
35 models: Vec<ModelPrice>,
36}
37
38pub struct CostTable {
39 models: Vec<ModelPrice>,
40}
41
42impl CostTable {
43 pub fn load() -> Result<Self> {
45 let f: CostFile = toml::from_str(COST_TOML)?;
46 Ok(Self { models: f.models })
47 }
48
49 pub fn estimate(&self, model: Option<&str>, tokens_in: u32, tokens_out: u32) -> i64 {
54 let entry = model
55 .and_then(|m| self.models.iter().find(|p| p.id == m))
56 .or_else(|| self.models.iter().find(|p| p.id == "cursor"));
57
58 let Some(price) = entry else { return 0 };
59
60 let (tin, tout) = if tokens_in == 0 && tokens_out == 0 && price.avg_tokens_per_turn > 0 {
61 let avg = price.avg_tokens_per_turn as i64;
62 (avg * 4 / 5, avg / 5)
63 } else {
64 (tokens_in as i64, tokens_out as i64)
65 };
66
67 tin * price.input_per_mtok / 1_000_000 + tout * price.output_per_mtok / 1_000_000
68 }
69
70 pub fn estimate_usage_cost_usd_e6(&self, model: Option<&str>, usage: TokenUsage) -> i64 {
71 let Some(price) = self.price_for(model) else {
72 return 0;
73 };
74 usage_cost(price, usage)
75 }
76
77 pub fn find(&self, model: &str) -> Option<&ModelPrice> {
78 self.models.iter().find(|p| p.id == model)
79 }
80
81 pub fn estimate_tail_event_cost_usd_e6(
85 &self,
86 model: Option<&str>,
87 tokens_in: Option<u32>,
88 tokens_out: Option<u32>,
89 reasoning_tokens: Option<u32>,
90 ) -> Option<i64> {
91 let any_field = tokens_in.is_some() || tokens_out.is_some() || reasoning_tokens.is_some();
92 if !any_field {
93 return None;
94 }
95 let tin = tokens_in.unwrap_or(0);
96 let tout = tokens_out
97 .unwrap_or(0)
98 .saturating_add(reasoning_tokens.unwrap_or(0));
99 if tin == 0 && tout == 0 {
100 return None;
101 }
102 Some(self.estimate(model, tin, tout))
103 }
104
105 pub fn estimate_tail_usage_cost_usd_e6(
106 &self,
107 model: Option<&str>,
108 usage: TokenUsage,
109 ) -> Option<i64> {
110 usage
111 .any()
112 .then(|| self.estimate_usage_cost_usd_e6(model, usage))
113 }
114
115 fn price_for(&self, model: Option<&str>) -> Option<&ModelPrice> {
116 model
117 .and_then(|m| self.models.iter().find(|p| p.id == m))
118 .or_else(|| self.models.iter().find(|p| p.id == "cursor"))
119 }
120}
121
122impl TokenUsage {
123 pub fn from_tail(input: Option<u32>, output: Option<u32>, reasoning: Option<u32>) -> Self {
124 Self {
125 input: input.unwrap_or(0),
126 output: output.unwrap_or(0),
127 reasoning: reasoning.unwrap_or(0),
128 cache_read: 0,
129 cache_create: 0,
130 }
131 }
132
133 fn any(self) -> bool {
134 self.input > 0
135 || self.output > 0
136 || self.reasoning > 0
137 || self.cache_read > 0
138 || self.cache_create > 0
139 }
140}
141
142fn usage_cost(price: &ModelPrice, usage: TokenUsage) -> i64 {
143 let out = usage.output.saturating_add(usage.reasoning) as i64;
144 let cache_read = price.cache_read_per_mtok.unwrap_or(price.input_per_mtok);
145 let cache_create = price.cache_create_per_mtok.unwrap_or(price.input_per_mtok);
146 usage.input as i64 * price.input_per_mtok / 1_000_000
147 + out * price.output_per_mtok / 1_000_000
148 + usage.cache_read as i64 * cache_read / 1_000_000
149 + usage.cache_create as i64 * cache_create / 1_000_000
150}
151
152static BUNDLED_COST: OnceLock<CostTable> = OnceLock::new();
153
154fn bundled_cost_table() -> &'static CostTable {
155 BUNDLED_COST.get_or_init(|| CostTable::load().expect("bundled assets/cost.toml"))
156}
157
158pub fn estimate_tail_event_cost_usd_e6(
160 model: Option<&str>,
161 tokens_in: Option<u32>,
162 tokens_out: Option<u32>,
163 reasoning_tokens: Option<u32>,
164) -> Option<i64> {
165 bundled_cost_table().estimate_tail_event_cost_usd_e6(
166 model,
167 tokens_in,
168 tokens_out,
169 reasoning_tokens,
170 )
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn load_succeeds() {
179 CostTable::load().unwrap();
180 }
181
182 #[test]
183 fn known_model_cost() {
184 let table = CostTable::load().unwrap();
185 let cost = table.estimate(Some("claude-sonnet-4"), 1000, 500);
188 assert_eq!(cost, 10500);
189 }
190
191 #[test]
192 fn cursor_heuristic_nonzero() {
193 let table = CostTable::load().unwrap();
194 let cost = table.estimate(None, 0, 0);
196 assert!(cost > 0, "cursor heuristic should produce nonzero cost");
197 }
198
199 #[test]
200 fn unknown_model_falls_back_to_cursor() {
201 let table = CostTable::load().unwrap();
202 let cost = table.estimate(Some("unknown-model-xyz"), 1000, 500);
203 let cost2 = table.estimate(None, 1000, 500);
204 assert_eq!(cost, cost2);
205 }
206
207 #[test]
208 fn tail_estimate_none_without_usage_fields() {
209 let table = CostTable::load().unwrap();
210 assert!(
211 table
212 .estimate_tail_event_cost_usd_e6(None, None, None, None)
213 .is_none()
214 );
215 }
216
217 #[test]
218 fn tail_estimate_none_when_fields_present_but_all_zero() {
219 let table = CostTable::load().unwrap();
220 assert!(table
221 .estimate_tail_event_cost_usd_e6(
222 Some("claude-sonnet-4"),
223 Some(0),
224 Some(0),
225 Some(0),
226 )
227 .is_none());
228 }
229
230 #[test]
231 fn tail_estimate_adds_reasoning_to_output_side() {
232 let table = CostTable::load().unwrap();
233 let with_reasoning = table
234 .estimate_tail_event_cost_usd_e6(
235 Some("claude-sonnet-4"),
236 Some(1000),
237 Some(100),
238 Some(400),
239 )
240 .expect("cost");
241 let output_only = table
242 .estimate_tail_event_cost_usd_e6(Some("claude-sonnet-4"), Some(1000), Some(500), None)
243 .expect("cost");
244 assert_eq!(with_reasoning, output_only);
245 }
246
247 #[test]
248 fn usage_estimate_prices_cache_tokens_separately() {
249 let table = CostTable::load().unwrap();
250 let cached = table.estimate_usage_cost_usd_e6(
251 Some("claude-sonnet-4"),
252 TokenUsage {
253 input: 1000,
254 output: 500,
255 reasoning: 0,
256 cache_read: 1000,
257 cache_create: 1000,
258 },
259 );
260 let plain = table.estimate(Some("claude-sonnet-4"), 3000, 500);
261 assert!(cached < plain);
262 }
263
264 #[test]
265 fn tail_estimate_includes_cache_usage() {
266 let table = CostTable::load().unwrap();
267 let cached = table
268 .estimate_tail_usage_cost_usd_e6(
269 Some("claude-sonnet-4"),
270 TokenUsage {
271 input: 0,
272 output: 0,
273 reasoning: 0,
274 cache_read: 1000,
275 cache_create: 0,
276 },
277 )
278 .expect("cache read cost");
279 assert!(cached > 0);
280 }
281}