Skip to main content

kaizen/core/
cost.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Model cost estimation from bundled price table.
3//! All values in cost_usd_e6 (1 USD = 1_000_000 units).
4
5use anyhow::Result;
6use serde::Deserialize;
7use std::sync::OnceLock;
8
9const COST_TOML: &str = include_str!("../../assets/cost.toml");
10
11#[derive(Debug, Deserialize)]
12pub struct ModelPrice {
13    pub id: String,
14    pub input_per_mtok: i64,
15    pub output_per_mtok: i64,
16    #[serde(default)]
17    pub cache_read_per_mtok: Option<i64>,
18    #[serde(default)]
19    pub cache_create_per_mtok: Option<i64>,
20    #[serde(default)]
21    pub avg_tokens_per_turn: u32,
22}
23
24#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
25pub struct TokenUsage {
26    pub input: u32,
27    pub output: u32,
28    pub reasoning: u32,
29    pub cache_read: u32,
30    pub cache_create: u32,
31}
32
33#[derive(Debug, Deserialize)]
34struct CostFile {
35    models: Vec<ModelPrice>,
36}
37
38pub struct CostTable {
39    models: Vec<ModelPrice>,
40}
41
42impl CostTable {
43    /// Load from bundled `assets/cost.toml`.
44    pub fn load() -> Result<Self> {
45        let f: CostFile = toml::from_str(COST_TOML)?;
46        Ok(Self { models: f.models })
47    }
48
49    /// Estimate cost in cost_usd_e6 units.
50    /// If model is None or not found, falls back to "cursor" heuristic entry.
51    /// If tokens_in == 0 and tokens_out == 0 (Cursor: no native tokens),
52    /// uses avg_tokens_per_turn from the matched entry.
53    pub fn estimate(&self, model: Option<&str>, tokens_in: u32, tokens_out: u32) -> i64 {
54        let entry = model
55            .and_then(|m| self.models.iter().find(|p| p.id == m))
56            .or_else(|| self.models.iter().find(|p| p.id == "cursor"));
57
58        let Some(price) = entry else { return 0 };
59
60        let (tin, tout) = if tokens_in == 0 && tokens_out == 0 && price.avg_tokens_per_turn > 0 {
61            let avg = price.avg_tokens_per_turn as i64;
62            (avg * 4 / 5, avg / 5)
63        } else {
64            (tokens_in as i64, tokens_out as i64)
65        };
66
67        tin * price.input_per_mtok / 1_000_000 + tout * price.output_per_mtok / 1_000_000
68    }
69
70    pub fn estimate_usage_cost_usd_e6(&self, model: Option<&str>, usage: TokenUsage) -> i64 {
71        let Some(price) = self.price_for(model) else {
72            return 0;
73        };
74        usage_cost(price, usage)
75    }
76
77    pub fn find(&self, model: &str) -> Option<&ModelPrice> {
78        self.models.iter().find(|p| p.id == model)
79    }
80
81    /// Transcript tail rows: charge only when at least one usage field is set **and**
82    /// prompt + output (including reasoning) are not all zero. Omits proxy-style
83    /// `avg_tokens_per_turn` fallback so thousands of tool lines do not each get a heuristic charge.
84    pub fn estimate_tail_event_cost_usd_e6(
85        &self,
86        model: Option<&str>,
87        tokens_in: Option<u32>,
88        tokens_out: Option<u32>,
89        reasoning_tokens: Option<u32>,
90    ) -> Option<i64> {
91        let any_field = tokens_in.is_some() || tokens_out.is_some() || reasoning_tokens.is_some();
92        if !any_field {
93            return None;
94        }
95        let tin = tokens_in.unwrap_or(0);
96        let tout = tokens_out
97            .unwrap_or(0)
98            .saturating_add(reasoning_tokens.unwrap_or(0));
99        if tin == 0 && tout == 0 {
100            return None;
101        }
102        Some(self.estimate(model, tin, tout))
103    }
104
105    pub fn estimate_tail_usage_cost_usd_e6(
106        &self,
107        model: Option<&str>,
108        usage: TokenUsage,
109    ) -> Option<i64> {
110        usage
111            .any()
112            .then(|| self.estimate_usage_cost_usd_e6(model, usage))
113    }
114
115    fn price_for(&self, model: Option<&str>) -> Option<&ModelPrice> {
116        model
117            .and_then(|m| self.models.iter().find(|p| p.id == m))
118            .or_else(|| self.models.iter().find(|p| p.id == "cursor"))
119    }
120}
121
122impl TokenUsage {
123    pub fn from_tail(input: Option<u32>, output: Option<u32>, reasoning: Option<u32>) -> Self {
124        Self {
125            input: input.unwrap_or(0),
126            output: output.unwrap_or(0),
127            reasoning: reasoning.unwrap_or(0),
128            cache_read: 0,
129            cache_create: 0,
130        }
131    }
132
133    fn any(self) -> bool {
134        self.input > 0
135            || self.output > 0
136            || self.reasoning > 0
137            || self.cache_read > 0
138            || self.cache_create > 0
139    }
140}
141
142fn usage_cost(price: &ModelPrice, usage: TokenUsage) -> i64 {
143    let out = usage.output.saturating_add(usage.reasoning) as i64;
144    let cache_read = price.cache_read_per_mtok.unwrap_or(price.input_per_mtok);
145    let cache_create = price.cache_create_per_mtok.unwrap_or(price.input_per_mtok);
146    usage.input as i64 * price.input_per_mtok / 1_000_000
147        + out * price.output_per_mtok / 1_000_000
148        + usage.cache_read as i64 * cache_read / 1_000_000
149        + usage.cache_create as i64 * cache_create / 1_000_000
150}
151
152static BUNDLED_COST: OnceLock<CostTable> = OnceLock::new();
153
154fn bundled_cost_table() -> &'static CostTable {
155    BUNDLED_COST.get_or_init(|| CostTable::load().expect("bundled assets/cost.toml"))
156}
157
158/// [`CostTable::estimate_tail_event_cost_usd_e6`] on the bundled table.
159pub fn estimate_tail_event_cost_usd_e6(
160    model: Option<&str>,
161    tokens_in: Option<u32>,
162    tokens_out: Option<u32>,
163    reasoning_tokens: Option<u32>,
164) -> Option<i64> {
165    bundled_cost_table().estimate_tail_event_cost_usd_e6(
166        model,
167        tokens_in,
168        tokens_out,
169        reasoning_tokens,
170    )
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn load_succeeds() {
179        CostTable::load().unwrap();
180    }
181
182    #[test]
183    fn known_model_cost() {
184        let table = CostTable::load().unwrap();
185        // claude-sonnet-4: $3/MTok in, $15/MTok out
186        // 1000 in + 500 out → 3000 + 7500 = 10500 cost_usd_e6
187        let cost = table.estimate(Some("claude-sonnet-4"), 1000, 500);
188        assert_eq!(cost, 10500);
189    }
190
191    #[test]
192    fn cursor_heuristic_nonzero() {
193        let table = CostTable::load().unwrap();
194        // model=None, no tokens → cursor heuristic
195        let cost = table.estimate(None, 0, 0);
196        assert!(cost > 0, "cursor heuristic should produce nonzero cost");
197    }
198
199    #[test]
200    fn unknown_model_falls_back_to_cursor() {
201        let table = CostTable::load().unwrap();
202        let cost = table.estimate(Some("unknown-model-xyz"), 1000, 500);
203        let cost2 = table.estimate(None, 1000, 500);
204        assert_eq!(cost, cost2);
205    }
206
207    #[test]
208    fn tail_estimate_none_without_usage_fields() {
209        let table = CostTable::load().unwrap();
210        assert!(
211            table
212                .estimate_tail_event_cost_usd_e6(None, None, None, None)
213                .is_none()
214        );
215    }
216
217    #[test]
218    fn tail_estimate_none_when_fields_present_but_all_zero() {
219        let table = CostTable::load().unwrap();
220        assert!(table
221            .estimate_tail_event_cost_usd_e6(
222                Some("claude-sonnet-4"),
223                Some(0),
224                Some(0),
225                Some(0),
226            )
227            .is_none());
228    }
229
230    #[test]
231    fn tail_estimate_adds_reasoning_to_output_side() {
232        let table = CostTable::load().unwrap();
233        let with_reasoning = table
234            .estimate_tail_event_cost_usd_e6(
235                Some("claude-sonnet-4"),
236                Some(1000),
237                Some(100),
238                Some(400),
239            )
240            .expect("cost");
241        let output_only = table
242            .estimate_tail_event_cost_usd_e6(Some("claude-sonnet-4"), Some(1000), Some(500), None)
243            .expect("cost");
244        assert_eq!(with_reasoning, output_only);
245    }
246
247    #[test]
248    fn usage_estimate_prices_cache_tokens_separately() {
249        let table = CostTable::load().unwrap();
250        let cached = table.estimate_usage_cost_usd_e6(
251            Some("claude-sonnet-4"),
252            TokenUsage {
253                input: 1000,
254                output: 500,
255                reasoning: 0,
256                cache_read: 1000,
257                cache_create: 1000,
258            },
259        );
260        let plain = table.estimate(Some("claude-sonnet-4"), 3000, 500);
261        assert!(cached < plain);
262    }
263
264    #[test]
265    fn tail_estimate_includes_cache_usage() {
266        let table = CostTable::load().unwrap();
267        let cached = table
268            .estimate_tail_usage_cost_usd_e6(
269                Some("claude-sonnet-4"),
270                TokenUsage {
271                    input: 0,
272                    output: 0,
273                    reasoning: 0,
274                    cache_read: 1000,
275                    cache_create: 0,
276                },
277            )
278            .expect("cache read cost");
279        assert!(cached > 0);
280    }
281}