Skip to main content

punch_kernel/
metering.rs

1//! Cost tracking and quota enforcement engine.
2//!
3//! The [`MeteringEngine`] calculates costs based on model pricing tables
4//! and persists usage data through the memory substrate. It supports
5//! per-fighter and aggregate spend queries across configurable time periods.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use chrono::{Duration, Utc};
11use serde::{Deserialize, Serialize};
12use tracing::{debug, instrument};
13
14use punch_memory::MemorySubstrate;
15use punch_types::{FighterId, PunchResult};
16
17// ---------------------------------------------------------------------------
18// Types
19// ---------------------------------------------------------------------------
20
21/// Pricing for a specific model (cost per million tokens).
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ModelPrice {
24    /// Cost per million input tokens (USD).
25    pub input_per_million: f64,
26    /// Cost per million output tokens (USD).
27    pub output_per_million: f64,
28}
29
30/// Time period for spend queries.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum SpendPeriod {
34    Hour,
35    Day,
36    Month,
37}
38
39impl SpendPeriod {
40    /// Convert to a chrono [`Duration`].
41    fn to_duration(self) -> Duration {
42        match self {
43            Self::Hour => Duration::hours(1),
44            Self::Day => Duration::days(1),
45            Self::Month => Duration::days(30),
46        }
47    }
48}
49
50impl std::fmt::Display for SpendPeriod {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        match self {
53            Self::Hour => write!(f, "hour"),
54            Self::Day => write!(f, "day"),
55            Self::Month => write!(f, "month"),
56        }
57    }
58}
59
60// ---------------------------------------------------------------------------
61// MeteringEngine
62// ---------------------------------------------------------------------------
63
64/// Engine for tracking LLM costs and enforcing spend quotas.
65pub struct MeteringEngine {
66    /// Shared memory substrate for persisting usage data.
67    memory: Arc<MemorySubstrate>,
68    /// Pricing table keyed by model name (or prefix for wildcard matches).
69    model_prices: HashMap<String, ModelPrice>,
70}
71
72impl MeteringEngine {
73    /// Create a new metering engine with embedded default pricing.
74    pub fn new(memory: Arc<MemorySubstrate>) -> Self {
75        let model_prices = Self::default_price_table();
76        Self {
77            memory,
78            model_prices,
79        }
80    }
81
82    /// Create a new metering engine with custom pricing.
83    pub fn with_prices(
84        memory: Arc<MemorySubstrate>,
85        model_prices: HashMap<String, ModelPrice>,
86    ) -> Self {
87        Self {
88            memory,
89            model_prices,
90        }
91    }
92
93    /// Build the default embedded price table.
94    fn default_price_table() -> HashMap<String, ModelPrice> {
95        let mut prices = HashMap::new();
96
97        prices.insert(
98            "claude-opus".to_string(),
99            ModelPrice {
100                input_per_million: 15.0,
101                output_per_million: 75.0,
102            },
103        );
104
105        prices.insert(
106            "claude-sonnet".to_string(),
107            ModelPrice {
108                input_per_million: 3.0,
109                output_per_million: 15.0,
110            },
111        );
112
113        prices.insert(
114            "claude-haiku".to_string(),
115            ModelPrice {
116                input_per_million: 0.25,
117                output_per_million: 1.25,
118            },
119        );
120
121        prices.insert(
122            "gpt-4o".to_string(),
123            ModelPrice {
124                input_per_million: 2.50,
125                output_per_million: 10.0,
126            },
127        );
128
129        prices.insert(
130            "gpt-4o-mini".to_string(),
131            ModelPrice {
132                input_per_million: 0.15,
133                output_per_million: 0.60,
134            },
135        );
136
137        // Ollama (local) models are free.
138        prices.insert(
139            "ollama/".to_string(),
140            ModelPrice {
141                input_per_million: 0.0,
142                output_per_million: 0.0,
143            },
144        );
145
146        prices
147    }
148
149    /// Look up the price for a model, using prefix matching and a default fallback.
150    fn get_price(&self, model: &str) -> &ModelPrice {
151        // Exact match first.
152        if let Some(price) = self.model_prices.get(model) {
153            return price;
154        }
155
156        // Prefix match (e.g. "claude-sonnet" matches "claude-sonnet-4-20250514").
157        for (key, price) in &self.model_prices {
158            if model.starts_with(key) {
159                return price;
160            }
161        }
162
163        // Default fallback pricing.
164        // We use a static leak-free approach with a const reference.
165        static DEFAULT_PRICE: ModelPrice = ModelPrice {
166            input_per_million: 1.0,
167            output_per_million: 3.0,
168        };
169        &DEFAULT_PRICE
170    }
171
172    /// Calculate the cost for a given model and token counts.
173    pub fn estimate_cost(&self, model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
174        let price = self.get_price(model);
175        let input_cost = (input_tokens as f64 / 1_000_000.0) * price.input_per_million;
176        let output_cost = (output_tokens as f64 / 1_000_000.0) * price.output_per_million;
177        input_cost + output_cost
178    }
179
180    /// Record usage for a fighter, calculating cost automatically.
181    #[instrument(skip(self), fields(%fighter_id, %model, input_tokens, output_tokens))]
182    pub async fn record_usage(
183        &self,
184        fighter_id: &FighterId,
185        model: &str,
186        input_tokens: u64,
187        output_tokens: u64,
188    ) -> PunchResult<f64> {
189        let cost = self.estimate_cost(model, input_tokens, output_tokens);
190
191        self.memory
192            .record_usage(fighter_id, model, input_tokens, output_tokens, cost)
193            .await?;
194
195        debug!(cost_usd = cost, "usage recorded with cost");
196        Ok(cost)
197    }
198
199    /// Get total spend for a specific fighter over a time period.
200    pub async fn get_spend(&self, fighter_id: &FighterId, period: SpendPeriod) -> PunchResult<f64> {
201        let since = Utc::now() - period.to_duration();
202        let summary = self.memory.get_usage_summary(fighter_id, since).await?;
203        Ok(summary.total_cost_usd)
204    }
205
206    /// Get total spend across all fighters over a time period.
207    pub async fn get_total_spend(&self, period: SpendPeriod) -> PunchResult<f64> {
208        let since = Utc::now() - period.to_duration();
209        let summary = self.memory.get_total_usage_summary(since).await?;
210        Ok(summary.total_cost_usd)
211    }
212
213    /// Get the usage summary for a fighter over a time period.
214    pub async fn get_fighter_summary(
215        &self,
216        fighter_id: &FighterId,
217        period: SpendPeriod,
218    ) -> PunchResult<punch_memory::UsageSummary> {
219        let since = Utc::now() - period.to_duration();
220        self.memory.get_usage_summary(fighter_id, since).await
221    }
222
223    /// Get the total usage summary across all fighters over a time period.
224    pub async fn get_total_summary(
225        &self,
226        period: SpendPeriod,
227    ) -> PunchResult<punch_memory::UsageSummary> {
228        let since = Utc::now() - period.to_duration();
229        self.memory.get_total_usage_summary(since).await
230    }
231
232    /// Get per-model usage breakdown for a fighter over a time period.
233    pub async fn get_model_breakdown(
234        &self,
235        fighter_id: &FighterId,
236        period: SpendPeriod,
237    ) -> PunchResult<Vec<punch_memory::ModelUsageBreakdown>> {
238        let since = Utc::now() - period.to_duration();
239        self.memory.get_model_breakdown(fighter_id, since).await
240    }
241
242    /// Get per-model usage breakdown across all fighters over a time period.
243    pub async fn get_total_model_breakdown(
244        &self,
245        period: SpendPeriod,
246    ) -> PunchResult<Vec<punch_memory::ModelUsageBreakdown>> {
247        let since = Utc::now() - period.to_duration();
248        self.memory.get_total_model_breakdown(since).await
249    }
250
251    /// Get per-fighter usage breakdown over a time period.
252    pub async fn get_fighter_breakdown(
253        &self,
254        period: SpendPeriod,
255    ) -> PunchResult<Vec<punch_memory::FighterUsageBreakdown>> {
256        let since = Utc::now() - period.to_duration();
257        self.memory.get_fighter_breakdown(since).await
258    }
259}
260
261// ---------------------------------------------------------------------------
262// Tests
263// ---------------------------------------------------------------------------
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[test]
270    fn estimate_cost_claude_sonnet() {
271        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
272        let engine = MeteringEngine::new(memory);
273
274        // claude-sonnet: $3/M in, $15/M out
275        let cost = engine.estimate_cost("claude-sonnet-4-20250514", 1_000_000, 1_000_000);
276        assert!((cost - 18.0).abs() < 1e-9);
277    }
278
279    #[test]
280    fn estimate_cost_gpt4o_mini() {
281        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
282        let engine = MeteringEngine::new(memory);
283
284        // gpt-4o-mini: $0.15/M in, $0.60/M out
285        let cost = engine.estimate_cost("gpt-4o-mini", 1_000_000, 1_000_000);
286        assert!((cost - 0.75).abs() < 1e-9);
287    }
288
289    #[test]
290    fn estimate_cost_ollama_free() {
291        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
292        let engine = MeteringEngine::new(memory);
293
294        let cost = engine.estimate_cost("ollama/llama3", 1_000_000, 1_000_000);
295        assert!((cost - 0.0).abs() < 1e-9);
296    }
297
298    #[test]
299    fn estimate_cost_unknown_model_uses_fallback() {
300        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
301        let engine = MeteringEngine::new(memory);
302
303        // Default fallback: $1/M in, $3/M out
304        let cost = engine.estimate_cost("some-unknown-model", 1_000_000, 1_000_000);
305        assert!((cost - 4.0).abs() < 1e-9);
306    }
307
308    #[test]
309    fn estimate_cost_small_usage() {
310        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
311        let engine = MeteringEngine::new(memory);
312
313        // 1000 input tokens, 500 output tokens with claude-sonnet
314        let cost = engine.estimate_cost("claude-sonnet-4-20250514", 1000, 500);
315        let expected = (1000.0 / 1_000_000.0) * 3.0 + (500.0 / 1_000_000.0) * 15.0;
316        assert!((cost - expected).abs() < 1e-12);
317    }
318
319    #[tokio::test]
320    async fn record_and_query_usage() {
321        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
322        let engine = MeteringEngine::new(Arc::clone(&memory));
323
324        let fighter_id = FighterId::new();
325
326        // Save fighter first (FK constraint).
327        use punch_types::{FighterManifest, FighterStatus, ModelConfig, Provider, WeightClass};
328        let manifest = FighterManifest {
329            name: "metering-test".into(),
330            description: "test".into(),
331            model: ModelConfig {
332                provider: Provider::Anthropic,
333                model: "claude-sonnet-4-20250514".into(),
334                api_key_env: None,
335                base_url: None,
336                max_tokens: Some(4096),
337                temperature: Some(0.7),
338            },
339            system_prompt: "test".into(),
340            capabilities: Vec::new(),
341            weight_class: WeightClass::Featherweight,
342            tenant_id: None,
343        };
344        memory
345            .save_fighter(&fighter_id, &manifest, FighterStatus::Idle)
346            .await
347            .unwrap();
348
349        let cost = engine
350            .record_usage(&fighter_id, "claude-sonnet-4-20250514", 5000, 2000)
351            .await
352            .unwrap();
353
354        // claude-sonnet: $3/M in, $15/M out
355        let expected = (5000.0 / 1_000_000.0) * 3.0 + (2000.0 / 1_000_000.0) * 15.0;
356        assert!((cost - expected).abs() < 1e-12);
357
358        // Query the spend.
359        let spend = engine
360            .get_spend(&fighter_id, SpendPeriod::Hour)
361            .await
362            .unwrap();
363        assert!((spend - expected).abs() < 1e-9);
364    }
365
366    #[test]
367    fn spend_period_display() {
368        assert_eq!(SpendPeriod::Hour.to_string(), "hour");
369        assert_eq!(SpendPeriod::Day.to_string(), "day");
370        assert_eq!(SpendPeriod::Month.to_string(), "month");
371    }
372
373    #[test]
374    fn estimate_cost_zero_tokens() {
375        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
376        let engine = MeteringEngine::new(memory);
377        let cost = engine.estimate_cost("claude-sonnet-4-20250514", 0, 0);
378        assert!((cost - 0.0).abs() < 1e-12);
379    }
380
381    #[test]
382    fn estimate_cost_claude_opus() {
383        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
384        let engine = MeteringEngine::new(memory);
385        // claude-opus: $15/M in, $75/M out
386        let cost = engine.estimate_cost("claude-opus-4-20250514", 1_000_000, 1_000_000);
387        assert!((cost - 90.0).abs() < 1e-9);
388    }
389
390    #[test]
391    fn estimate_cost_claude_haiku() {
392        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
393        let engine = MeteringEngine::new(memory);
394        // claude-haiku: $0.25/M in, $1.25/M out
395        let cost = engine.estimate_cost("claude-haiku-3.5", 1_000_000, 1_000_000);
396        assert!((cost - 1.5).abs() < 1e-9);
397    }
398
399    #[test]
400    fn estimate_cost_gpt4o() {
401        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
402        let engine = MeteringEngine::new(memory);
403        // gpt-4o: $2.50/M in, $10/M out
404        let cost = engine.estimate_cost("gpt-4o", 1_000_000, 1_000_000);
405        assert!((cost - 12.5).abs() < 1e-9);
406    }
407
408    #[test]
409    fn with_custom_prices() {
410        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
411        let mut prices = HashMap::new();
412        prices.insert(
413            "custom-model".to_string(),
414            ModelPrice {
415                input_per_million: 5.0,
416                output_per_million: 10.0,
417            },
418        );
419        let engine = MeteringEngine::with_prices(memory, prices);
420        let cost = engine.estimate_cost("custom-model", 1_000_000, 1_000_000);
421        assert!((cost - 15.0).abs() < 1e-9);
422    }
423
424    #[test]
425    fn custom_prices_missing_model_uses_default_fallback() {
426        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
427        let prices = HashMap::new();
428        let engine = MeteringEngine::with_prices(memory, prices);
429        // Default: $1/M in, $3/M out
430        let cost = engine.estimate_cost("anything", 1_000_000, 1_000_000);
431        assert!((cost - 4.0).abs() < 1e-9);
432    }
433}