Skip to main content

punch_kernel/
metering.rs

1//! Cost tracking and quota enforcement engine.
2//!
3//! The [`MeteringEngine`] calculates costs based on model pricing tables
4//! and persists usage data through the memory substrate. It supports
5//! per-fighter and aggregate spend queries across configurable time periods.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use chrono::{Duration, Utc};
11use serde::{Deserialize, Serialize};
12use tracing::{debug, instrument};
13
14use punch_memory::MemorySubstrate;
15use punch_types::{FighterId, PunchResult};
16
17// ---------------------------------------------------------------------------
18// Types
19// ---------------------------------------------------------------------------
20
21/// Pricing for a specific model (cost per million tokens).
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ModelPrice {
24    /// Cost per million input tokens (USD).
25    pub input_per_million: f64,
26    /// Cost per million output tokens (USD).
27    pub output_per_million: f64,
28}
29
30/// Time period for spend queries.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum SpendPeriod {
34    Hour,
35    Day,
36    Month,
37}
38
39impl SpendPeriod {
40    /// Convert to a chrono [`Duration`].
41    fn to_duration(self) -> Duration {
42        match self {
43            Self::Hour => Duration::hours(1),
44            Self::Day => Duration::days(1),
45            Self::Month => Duration::days(30),
46        }
47    }
48}
49
50impl std::fmt::Display for SpendPeriod {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        match self {
53            Self::Hour => write!(f, "hour"),
54            Self::Day => write!(f, "day"),
55            Self::Month => write!(f, "month"),
56        }
57    }
58}
59
60// ---------------------------------------------------------------------------
61// MeteringEngine
62// ---------------------------------------------------------------------------
63
64/// Engine for tracking LLM costs and enforcing spend quotas.
65pub struct MeteringEngine {
66    /// Shared memory substrate for persisting usage data.
67    memory: Arc<MemorySubstrate>,
68    /// Pricing table keyed by model name (or prefix for wildcard matches).
69    model_prices: HashMap<String, ModelPrice>,
70}
71
72impl MeteringEngine {
73    /// Create a new metering engine with embedded default pricing.
74    pub fn new(memory: Arc<MemorySubstrate>) -> Self {
75        let model_prices = Self::default_price_table();
76        Self {
77            memory,
78            model_prices,
79        }
80    }
81
82    /// Create a new metering engine with custom pricing.
83    pub fn with_prices(
84        memory: Arc<MemorySubstrate>,
85        model_prices: HashMap<String, ModelPrice>,
86    ) -> Self {
87        Self {
88            memory,
89            model_prices,
90        }
91    }
92
93    /// Build the default embedded price table.
94    fn default_price_table() -> HashMap<String, ModelPrice> {
95        let mut prices = HashMap::new();
96
97        prices.insert(
98            "claude-opus".to_string(),
99            ModelPrice {
100                input_per_million: 15.0,
101                output_per_million: 75.0,
102            },
103        );
104
105        prices.insert(
106            "claude-sonnet".to_string(),
107            ModelPrice {
108                input_per_million: 3.0,
109                output_per_million: 15.0,
110            },
111        );
112
113        prices.insert(
114            "claude-haiku".to_string(),
115            ModelPrice {
116                input_per_million: 0.25,
117                output_per_million: 1.25,
118            },
119        );
120
121        prices.insert(
122            "gpt-4o".to_string(),
123            ModelPrice {
124                input_per_million: 2.50,
125                output_per_million: 10.0,
126            },
127        );
128
129        prices.insert(
130            "gpt-4o-mini".to_string(),
131            ModelPrice {
132                input_per_million: 0.15,
133                output_per_million: 0.60,
134            },
135        );
136
137        // Ollama (local) models are free.
138        prices.insert(
139            "ollama/".to_string(),
140            ModelPrice {
141                input_per_million: 0.0,
142                output_per_million: 0.0,
143            },
144        );
145
146        prices
147    }
148
149    /// Look up the price for a model, using prefix matching and a default fallback.
150    fn get_price(&self, model: &str) -> &ModelPrice {
151        // Exact match first.
152        if let Some(price) = self.model_prices.get(model) {
153            return price;
154        }
155
156        // Prefix match (e.g. "claude-sonnet" matches "claude-sonnet-4-20250514").
157        for (key, price) in &self.model_prices {
158            if model.starts_with(key) {
159                return price;
160            }
161        }
162
163        // Default fallback pricing.
164        // We use a static leak-free approach with a const reference.
165        static DEFAULT_PRICE: ModelPrice = ModelPrice {
166            input_per_million: 1.0,
167            output_per_million: 3.0,
168        };
169        &DEFAULT_PRICE
170    }
171
172    /// Calculate the cost for a given model and token counts.
173    pub fn estimate_cost(&self, model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
174        let price = self.get_price(model);
175        let input_cost = (input_tokens as f64 / 1_000_000.0) * price.input_per_million;
176        let output_cost = (output_tokens as f64 / 1_000_000.0) * price.output_per_million;
177        input_cost + output_cost
178    }
179
180    /// Record usage for a fighter, calculating cost automatically.
181    #[instrument(skip(self), fields(%fighter_id, %model, input_tokens, output_tokens))]
182    pub async fn record_usage(
183        &self,
184        fighter_id: &FighterId,
185        model: &str,
186        input_tokens: u64,
187        output_tokens: u64,
188    ) -> PunchResult<f64> {
189        let cost = self.estimate_cost(model, input_tokens, output_tokens);
190
191        self.memory
192            .record_usage(fighter_id, model, input_tokens, output_tokens, cost)
193            .await?;
194
195        debug!(cost_usd = cost, "usage recorded with cost");
196        Ok(cost)
197    }
198
199    /// Get total spend for a specific fighter over a time period.
200    pub async fn get_spend(&self, fighter_id: &FighterId, period: SpendPeriod) -> PunchResult<f64> {
201        let since = Utc::now() - period.to_duration();
202        let summary = self.memory.get_usage_summary(fighter_id, since).await?;
203        Ok(summary.total_cost_usd)
204    }
205
206    /// Get total spend across all fighters over a time period.
207    pub async fn get_total_spend(&self, period: SpendPeriod) -> PunchResult<f64> {
208        let since = Utc::now() - period.to_duration();
209        let summary = self.memory.get_total_usage_summary(since).await?;
210        Ok(summary.total_cost_usd)
211    }
212}
213
214// ---------------------------------------------------------------------------
215// Tests
216// ---------------------------------------------------------------------------
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn estimate_cost_claude_sonnet() {
224        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
225        let engine = MeteringEngine::new(memory);
226
227        // claude-sonnet: $3/M in, $15/M out
228        let cost = engine.estimate_cost("claude-sonnet-4-20250514", 1_000_000, 1_000_000);
229        assert!((cost - 18.0).abs() < 1e-9);
230    }
231
232    #[test]
233    fn estimate_cost_gpt4o_mini() {
234        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
235        let engine = MeteringEngine::new(memory);
236
237        // gpt-4o-mini: $0.15/M in, $0.60/M out
238        let cost = engine.estimate_cost("gpt-4o-mini", 1_000_000, 1_000_000);
239        assert!((cost - 0.75).abs() < 1e-9);
240    }
241
242    #[test]
243    fn estimate_cost_ollama_free() {
244        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
245        let engine = MeteringEngine::new(memory);
246
247        let cost = engine.estimate_cost("ollama/llama3", 1_000_000, 1_000_000);
248        assert!((cost - 0.0).abs() < 1e-9);
249    }
250
251    #[test]
252    fn estimate_cost_unknown_model_uses_fallback() {
253        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
254        let engine = MeteringEngine::new(memory);
255
256        // Default fallback: $1/M in, $3/M out
257        let cost = engine.estimate_cost("some-unknown-model", 1_000_000, 1_000_000);
258        assert!((cost - 4.0).abs() < 1e-9);
259    }
260
261    #[test]
262    fn estimate_cost_small_usage() {
263        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
264        let engine = MeteringEngine::new(memory);
265
266        // 1000 input tokens, 500 output tokens with claude-sonnet
267        let cost = engine.estimate_cost("claude-sonnet-4-20250514", 1000, 500);
268        let expected = (1000.0 / 1_000_000.0) * 3.0 + (500.0 / 1_000_000.0) * 15.0;
269        assert!((cost - expected).abs() < 1e-12);
270    }
271
272    #[tokio::test]
273    async fn record_and_query_usage() {
274        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
275        let engine = MeteringEngine::new(Arc::clone(&memory));
276
277        let fighter_id = FighterId::new();
278
279        // Save fighter first (FK constraint).
280        use punch_types::{FighterManifest, FighterStatus, ModelConfig, Provider, WeightClass};
281        let manifest = FighterManifest {
282            name: "metering-test".into(),
283            description: "test".into(),
284            model: ModelConfig {
285                provider: Provider::Anthropic,
286                model: "claude-sonnet-4-20250514".into(),
287                api_key_env: None,
288                base_url: None,
289                max_tokens: Some(4096),
290                temperature: Some(0.7),
291            },
292            system_prompt: "test".into(),
293            capabilities: Vec::new(),
294            weight_class: WeightClass::Featherweight,
295            tenant_id: None,
296        };
297        memory
298            .save_fighter(&fighter_id, &manifest, FighterStatus::Idle)
299            .await
300            .unwrap();
301
302        let cost = engine
303            .record_usage(&fighter_id, "claude-sonnet-4-20250514", 5000, 2000)
304            .await
305            .unwrap();
306
307        // claude-sonnet: $3/M in, $15/M out
308        let expected = (5000.0 / 1_000_000.0) * 3.0 + (2000.0 / 1_000_000.0) * 15.0;
309        assert!((cost - expected).abs() < 1e-12);
310
311        // Query the spend.
312        let spend = engine
313            .get_spend(&fighter_id, SpendPeriod::Hour)
314            .await
315            .unwrap();
316        assert!((spend - expected).abs() < 1e-9);
317    }
318
319    #[test]
320    fn spend_period_display() {
321        assert_eq!(SpendPeriod::Hour.to_string(), "hour");
322        assert_eq!(SpendPeriod::Day.to_string(), "day");
323        assert_eq!(SpendPeriod::Month.to_string(), "month");
324    }
325
326    #[test]
327    fn estimate_cost_zero_tokens() {
328        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
329        let engine = MeteringEngine::new(memory);
330        let cost = engine.estimate_cost("claude-sonnet-4-20250514", 0, 0);
331        assert!((cost - 0.0).abs() < 1e-12);
332    }
333
334    #[test]
335    fn estimate_cost_claude_opus() {
336        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
337        let engine = MeteringEngine::new(memory);
338        // claude-opus: $15/M in, $75/M out
339        let cost = engine.estimate_cost("claude-opus-4-20250514", 1_000_000, 1_000_000);
340        assert!((cost - 90.0).abs() < 1e-9);
341    }
342
343    #[test]
344    fn estimate_cost_claude_haiku() {
345        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
346        let engine = MeteringEngine::new(memory);
347        // claude-haiku: $0.25/M in, $1.25/M out
348        let cost = engine.estimate_cost("claude-haiku-3.5", 1_000_000, 1_000_000);
349        assert!((cost - 1.5).abs() < 1e-9);
350    }
351
352    #[test]
353    fn estimate_cost_gpt4o() {
354        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
355        let engine = MeteringEngine::new(memory);
356        // gpt-4o: $2.50/M in, $10/M out
357        let cost = engine.estimate_cost("gpt-4o", 1_000_000, 1_000_000);
358        assert!((cost - 12.5).abs() < 1e-9);
359    }
360
361    #[test]
362    fn with_custom_prices() {
363        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
364        let mut prices = HashMap::new();
365        prices.insert(
366            "custom-model".to_string(),
367            ModelPrice {
368                input_per_million: 5.0,
369                output_per_million: 10.0,
370            },
371        );
372        let engine = MeteringEngine::with_prices(memory, prices);
373        let cost = engine.estimate_cost("custom-model", 1_000_000, 1_000_000);
374        assert!((cost - 15.0).abs() < 1e-9);
375    }
376
377    #[test]
378    fn custom_prices_missing_model_uses_default_fallback() {
379        let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
380        let prices = HashMap::new();
381        let engine = MeteringEngine::with_prices(memory, prices);
382        // Default: $1/M in, $3/M out
383        let cost = engine.estimate_cost("anything", 1_000_000, 1_000_000);
384        assert!((cost - 4.0).abs() < 1e-9);
385    }
386}