Skip to main content

mur_core/
cost.rs

1//! Cost tracking and estimation for AI model API calls.
2//!
3//! Tracks actual spending and provides pre-execution cost estimates
4//! so users know what a workflow run will cost before committing.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9
10use crate::model::ModelRole;
11
12/// Cost per million tokens for a model (input and output).
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ModelPricing {
15    /// Cost per million input tokens (USD).
16    pub input_per_million: f64,
17    /// Cost per million output tokens (USD).
18    pub output_per_million: f64,
19}
20
21/// A single recorded cost event.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CostEvent {
24    pub model_role: ModelRole,
25    pub model_name: String,
26    pub input_tokens: u64,
27    pub output_tokens: u64,
28    pub cost: f64,
29    pub timestamp: chrono::DateTime<chrono::Utc>,
30}
31
32/// Tracks API costs across model invocations.
33#[derive(Debug, Clone)]
34pub struct CostTracker {
35    /// Known pricing for models.
36    pricing: HashMap<String, ModelPricing>,
37    /// Recorded cost events (shared across threads).
38    events: Arc<Mutex<Vec<CostEvent>>>,
39}
40
41impl CostTracker {
42    /// Create a new cost tracker with default model pricing.
43    pub fn new() -> Self {
44        let mut pricing = HashMap::new();
45
46        // Anthropic models
47        pricing.insert(
48            "claude-opus-4-5".into(),
49            ModelPricing {
50                input_per_million: 15.0,
51                output_per_million: 75.0,
52            },
53        );
54        pricing.insert(
55            "claude-sonnet-4".into(),
56            ModelPricing {
57                input_per_million: 3.0,
58                output_per_million: 15.0,
59            },
60        );
61
62        // OpenRouter / Google
63        pricing.insert(
64            "gemini-2.5-flash".into(),
65            ModelPricing {
66                input_per_million: 0.15,
67                output_per_million: 0.60,
68            },
69        );
70
71        // Local models (free)
72        pricing.insert(
73            "ollama-local".into(),
74            ModelPricing {
75                input_per_million: 0.0,
76                output_per_million: 0.0,
77            },
78        );
79
80        Self {
81            pricing,
82            events: Arc::new(Mutex::new(Vec::new())),
83        }
84    }
85
86    /// Estimate the cost for a given model and token count.
87    pub fn estimate_cost(
88        &self,
89        model_name: &str,
90        input_tokens: u64,
91        output_tokens: u64,
92    ) -> f64 {
93        if let Some(pricing) = self.pricing.get(model_name) {
94            let input_cost = (input_tokens as f64 / 1_000_000.0) * pricing.input_per_million;
95            let output_cost = (output_tokens as f64 / 1_000_000.0) * pricing.output_per_million;
96            input_cost + output_cost
97        } else {
98            // Unknown model — return a conservative estimate
99            let input_cost = (input_tokens as f64 / 1_000_000.0) * 3.0;
100            let output_cost = (output_tokens as f64 / 1_000_000.0) * 15.0;
101            input_cost + output_cost
102        }
103    }
104
105    /// Record an actual cost event.
106    pub fn record(
107        &self,
108        model_role: ModelRole,
109        model_name: &str,
110        input_tokens: u64,
111        output_tokens: u64,
112    ) {
113        let cost = self.estimate_cost(model_name, input_tokens, output_tokens);
114        let event = CostEvent {
115            model_role,
116            model_name: model_name.to_string(),
117            input_tokens,
118            output_tokens,
119            cost,
120            timestamp: chrono::Utc::now(),
121        };
122        if let Ok(mut events) = self.events.lock() {
123            events.push(event);
124        }
125    }
126
127    /// Get total cost across all recorded events.
128    pub fn total_cost(&self) -> f64 {
129        self.events
130            .lock()
131            .map(|events| events.iter().map(|e| e.cost).sum())
132            .unwrap_or(0.0)
133    }
134
135    /// Get cost breakdown by model role.
136    pub fn cost_by_role(&self) -> HashMap<ModelRole, f64> {
137        let mut breakdown = HashMap::new();
138        if let Ok(events) = self.events.lock() {
139            for event in events.iter() {
140                *breakdown.entry(event.model_role.clone()).or_insert(0.0) += event.cost;
141            }
142        }
143        breakdown
144    }
145
146    /// Get all recorded events.
147    pub fn events(&self) -> Vec<CostEvent> {
148        self.events
149            .lock()
150            .map(|events| events.clone())
151            .unwrap_or_default()
152    }
153
154    /// Reset the tracker (clear all events).
155    pub fn reset(&self) {
156        if let Ok(mut events) = self.events.lock() {
157            events.clear();
158        }
159    }
160}
161
162impl Default for CostTracker {
163    fn default() -> Self {
164        Self::new()
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_estimate_cost_known_model() {
174        let tracker = CostTracker::new();
175        // claude-opus-4-5: $15/M input, $75/M output
176        let cost = tracker.estimate_cost("claude-opus-4-5", 1_000_000, 1_000_000);
177        assert!((cost - 90.0).abs() < 0.01); // $15 + $75 = $90
178    }
179
180    #[test]
181    fn test_estimate_cost_cheap_model() {
182        let tracker = CostTracker::new();
183        // gemini-2.5-flash: $0.15/M input, $0.60/M output
184        let cost = tracker.estimate_cost("gemini-2.5-flash", 1_000_000, 1_000_000);
185        assert!((cost - 0.75).abs() < 0.01); // $0.15 + $0.60 = $0.75
186    }
187
188    #[test]
189    fn test_estimate_cost_local_model() {
190        let tracker = CostTracker::new();
191        let cost = tracker.estimate_cost("ollama-local", 1_000_000, 1_000_000);
192        assert!((cost - 0.0).abs() < 0.01);
193    }
194
195    #[test]
196    fn test_estimate_cost_unknown_model() {
197        let tracker = CostTracker::new();
198        let cost = tracker.estimate_cost("unknown-model", 1_000_000, 1_000_000);
199        // Should use conservative default: $3/M input + $15/M output
200        assert!((cost - 18.0).abs() < 0.01);
201    }
202
203    #[test]
204    fn test_record_and_total() {
205        let tracker = CostTracker::new();
206        tracker.record(ModelRole::Thinking, "claude-opus-4-5", 500_000, 100_000);
207        tracker.record(ModelRole::Task, "gemini-2.5-flash", 200_000, 50_000);
208
209        let total = tracker.total_cost();
210        assert!(total > 0.0);
211
212        let events = tracker.events();
213        assert_eq!(events.len(), 2);
214    }
215
216    #[test]
217    fn test_cost_by_role() {
218        let tracker = CostTracker::new();
219        tracker.record(ModelRole::Thinking, "claude-opus-4-5", 1_000_000, 0);
220        tracker.record(ModelRole::Task, "gemini-2.5-flash", 1_000_000, 0);
221
222        let breakdown = tracker.cost_by_role();
223        assert!(breakdown.contains_key(&ModelRole::Thinking));
224        assert!(breakdown.contains_key(&ModelRole::Task));
225        assert!(breakdown[&ModelRole::Thinking] > breakdown[&ModelRole::Task]);
226    }
227
228    #[test]
229    fn test_reset() {
230        let tracker = CostTracker::new();
231        tracker.record(ModelRole::Thinking, "claude-opus-4-5", 1_000_000, 0);
232        assert!(tracker.total_cost() > 0.0);
233
234        tracker.reset();
235        assert!((tracker.total_cost() - 0.0).abs() < 0.001);
236        assert!(tracker.events().is_empty());
237    }
238}