Skip to main content

kaizen/core/
cost.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Model cost estimation from bundled price table.
3//! All values in cost_usd_e6 (1 USD = 1_000_000 units).
4
5use anyhow::Result;
6use serde::Deserialize;
7
8const COST_TOML: &str = include_str!("../../assets/cost.toml");
9
10#[derive(Debug, Deserialize)]
11pub struct ModelPrice {
12    pub id: String,
13    pub input_per_mtok: i64,
14    pub output_per_mtok: i64,
15    #[serde(default)]
16    pub avg_tokens_per_turn: u32,
17}
18
19#[derive(Debug, Deserialize)]
20struct CostFile {
21    models: Vec<ModelPrice>,
22}
23
24pub struct CostTable {
25    models: Vec<ModelPrice>,
26}
27
28impl CostTable {
29    /// Load from bundled `assets/cost.toml`.
30    pub fn load() -> Result<Self> {
31        let f: CostFile = toml::from_str(COST_TOML)?;
32        Ok(Self { models: f.models })
33    }
34
35    /// Estimate cost in cost_usd_e6 units.
36    /// If model is None or not found, falls back to "cursor" heuristic entry.
37    /// If tokens_in == 0 and tokens_out == 0 (Cursor: no native tokens),
38    /// uses avg_tokens_per_turn from the matched entry.
39    pub fn estimate(&self, model: Option<&str>, tokens_in: u32, tokens_out: u32) -> i64 {
40        let entry = model
41            .and_then(|m| self.models.iter().find(|p| p.id == m))
42            .or_else(|| self.models.iter().find(|p| p.id == "cursor"));
43
44        let Some(price) = entry else { return 0 };
45
46        let (tin, tout) = if tokens_in == 0 && tokens_out == 0 && price.avg_tokens_per_turn > 0 {
47            let avg = price.avg_tokens_per_turn as i64;
48            (avg * 4 / 5, avg / 5)
49        } else {
50            (tokens_in as i64, tokens_out as i64)
51        };
52
53        tin * price.input_per_mtok / 1_000_000 + tout * price.output_per_mtok / 1_000_000
54    }
55
56    pub fn find(&self, model: &str) -> Option<&ModelPrice> {
57        self.models.iter().find(|p| p.id == model)
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64
65    #[test]
66    fn load_succeeds() {
67        CostTable::load().unwrap();
68    }
69
70    #[test]
71    fn known_model_cost() {
72        let table = CostTable::load().unwrap();
73        // claude-sonnet-4: $3/MTok in, $15/MTok out
74        // 1000 in + 500 out → 3000 + 7500 = 10500 cost_usd_e6
75        let cost = table.estimate(Some("claude-sonnet-4"), 1000, 500);
76        assert_eq!(cost, 10500);
77    }
78
79    #[test]
80    fn cursor_heuristic_nonzero() {
81        let table = CostTable::load().unwrap();
82        // model=None, no tokens → cursor heuristic
83        let cost = table.estimate(None, 0, 0);
84        assert!(cost > 0, "cursor heuristic should produce nonzero cost");
85    }
86
87    #[test]
88    fn unknown_model_falls_back_to_cursor() {
89        let table = CostTable::load().unwrap();
90        let cost = table.estimate(Some("unknown-model-xyz"), 1000, 500);
91        let cost2 = table.estimate(None, 1000, 500);
92        assert_eq!(cost, cost2);
93    }
94}