Skip to main content

kaizen/core/
cost.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Model cost estimation from bundled price table.
3//! All values in cost_usd_e6 (1 USD = 1_000_000 units).
4
5use anyhow::Result;
6use serde::Deserialize;
7use std::sync::OnceLock;
8
9const COST_TOML: &str = include_str!("../../assets/cost.toml");
10
11#[derive(Debug, Deserialize)]
12pub struct ModelPrice {
13    pub id: String,
14    pub input_per_mtok: i64,
15    pub output_per_mtok: i64,
16    #[serde(default)]
17    pub avg_tokens_per_turn: u32,
18}
19
20#[derive(Debug, Deserialize)]
21struct CostFile {
22    models: Vec<ModelPrice>,
23}
24
25pub struct CostTable {
26    models: Vec<ModelPrice>,
27}
28
29impl CostTable {
30    /// Load from bundled `assets/cost.toml`.
31    pub fn load() -> Result<Self> {
32        let f: CostFile = toml::from_str(COST_TOML)?;
33        Ok(Self { models: f.models })
34    }
35
36    /// Estimate cost in cost_usd_e6 units.
37    /// If model is None or not found, falls back to "cursor" heuristic entry.
38    /// If tokens_in == 0 and tokens_out == 0 (Cursor: no native tokens),
39    /// uses avg_tokens_per_turn from the matched entry.
40    pub fn estimate(&self, model: Option<&str>, tokens_in: u32, tokens_out: u32) -> i64 {
41        let entry = model
42            .and_then(|m| self.models.iter().find(|p| p.id == m))
43            .or_else(|| self.models.iter().find(|p| p.id == "cursor"));
44
45        let Some(price) = entry else { return 0 };
46
47        let (tin, tout) = if tokens_in == 0 && tokens_out == 0 && price.avg_tokens_per_turn > 0 {
48            let avg = price.avg_tokens_per_turn as i64;
49            (avg * 4 / 5, avg / 5)
50        } else {
51            (tokens_in as i64, tokens_out as i64)
52        };
53
54        tin * price.input_per_mtok / 1_000_000 + tout * price.output_per_mtok / 1_000_000
55    }
56
57    pub fn find(&self, model: &str) -> Option<&ModelPrice> {
58        self.models.iter().find(|p| p.id == model)
59    }
60
61    /// Transcript tail rows: charge only when at least one usage field is set **and**
62    /// prompt + output (including reasoning) are not all zero. Omits proxy-style
63    /// `avg_tokens_per_turn` fallback so thousands of tool lines do not each get a heuristic charge.
64    pub fn estimate_tail_event_cost_usd_e6(
65        &self,
66        model: Option<&str>,
67        tokens_in: Option<u32>,
68        tokens_out: Option<u32>,
69        reasoning_tokens: Option<u32>,
70    ) -> Option<i64> {
71        let any_field = tokens_in.is_some() || tokens_out.is_some() || reasoning_tokens.is_some();
72        if !any_field {
73            return None;
74        }
75        let tin = tokens_in.unwrap_or(0);
76        let tout = tokens_out
77            .unwrap_or(0)
78            .saturating_add(reasoning_tokens.unwrap_or(0));
79        if tin == 0 && tout == 0 {
80            return None;
81        }
82        Some(self.estimate(model, tin, tout))
83    }
84}
85
86static BUNDLED_COST: OnceLock<CostTable> = OnceLock::new();
87
88fn bundled_cost_table() -> &'static CostTable {
89    BUNDLED_COST.get_or_init(|| CostTable::load().expect("bundled assets/cost.toml"))
90}
91
92/// [`CostTable::estimate_tail_event_cost_usd_e6`] on the bundled table.
93pub fn estimate_tail_event_cost_usd_e6(
94    model: Option<&str>,
95    tokens_in: Option<u32>,
96    tokens_out: Option<u32>,
97    reasoning_tokens: Option<u32>,
98) -> Option<i64> {
99    bundled_cost_table().estimate_tail_event_cost_usd_e6(
100        model,
101        tokens_in,
102        tokens_out,
103        reasoning_tokens,
104    )
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn load_succeeds() {
113        CostTable::load().unwrap();
114    }
115
116    #[test]
117    fn known_model_cost() {
118        let table = CostTable::load().unwrap();
119        // claude-sonnet-4: $3/MTok in, $15/MTok out
120        // 1000 in + 500 out → 3000 + 7500 = 10500 cost_usd_e6
121        let cost = table.estimate(Some("claude-sonnet-4"), 1000, 500);
122        assert_eq!(cost, 10500);
123    }
124
125    #[test]
126    fn cursor_heuristic_nonzero() {
127        let table = CostTable::load().unwrap();
128        // model=None, no tokens → cursor heuristic
129        let cost = table.estimate(None, 0, 0);
130        assert!(cost > 0, "cursor heuristic should produce nonzero cost");
131    }
132
133    #[test]
134    fn unknown_model_falls_back_to_cursor() {
135        let table = CostTable::load().unwrap();
136        let cost = table.estimate(Some("unknown-model-xyz"), 1000, 500);
137        let cost2 = table.estimate(None, 1000, 500);
138        assert_eq!(cost, cost2);
139    }
140
141    #[test]
142    fn tail_estimate_none_without_usage_fields() {
143        let table = CostTable::load().unwrap();
144        assert!(
145            table
146                .estimate_tail_event_cost_usd_e6(None, None, None, None)
147                .is_none()
148        );
149    }
150
151    #[test]
152    fn tail_estimate_none_when_fields_present_but_all_zero() {
153        let table = CostTable::load().unwrap();
154        assert!(table
155            .estimate_tail_event_cost_usd_e6(
156                Some("claude-sonnet-4"),
157                Some(0),
158                Some(0),
159                Some(0),
160            )
161            .is_none());
162    }
163
164    #[test]
165    fn tail_estimate_adds_reasoning_to_output_side() {
166        let table = CostTable::load().unwrap();
167        let with_reasoning = table
168            .estimate_tail_event_cost_usd_e6(
169                Some("claude-sonnet-4"),
170                Some(1000),
171                Some(100),
172                Some(400),
173            )
174            .expect("cost");
175        let output_only = table
176            .estimate_tail_event_cost_usd_e6(Some("claude-sonnet-4"), Some(1000), Some(500), None)
177            .expect("cost");
178        assert_eq!(with_reasoning, output_only);
179    }
180}