1use std::collections::HashMap;
8use std::sync::Arc;
9
10use chrono::{Duration, Utc};
11use serde::{Deserialize, Serialize};
12use tracing::{debug, instrument};
13
14use punch_memory::MemorySubstrate;
15use punch_types::{FighterId, PunchResult};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ModelPrice {
24 pub input_per_million: f64,
26 pub output_per_million: f64,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum SpendPeriod {
34 Hour,
35 Day,
36 Month,
37}
38
39impl SpendPeriod {
40 fn to_duration(self) -> Duration {
42 match self {
43 Self::Hour => Duration::hours(1),
44 Self::Day => Duration::days(1),
45 Self::Month => Duration::days(30),
46 }
47 }
48}
49
50impl std::fmt::Display for SpendPeriod {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 Self::Hour => write!(f, "hour"),
54 Self::Day => write!(f, "day"),
55 Self::Month => write!(f, "month"),
56 }
57 }
58}
59
60pub struct MeteringEngine {
66 memory: Arc<MemorySubstrate>,
68 model_prices: HashMap<String, ModelPrice>,
70}
71
72impl MeteringEngine {
73 pub fn new(memory: Arc<MemorySubstrate>) -> Self {
75 let model_prices = Self::default_price_table();
76 Self {
77 memory,
78 model_prices,
79 }
80 }
81
82 pub fn with_prices(
84 memory: Arc<MemorySubstrate>,
85 model_prices: HashMap<String, ModelPrice>,
86 ) -> Self {
87 Self {
88 memory,
89 model_prices,
90 }
91 }
92
93 fn default_price_table() -> HashMap<String, ModelPrice> {
95 let mut prices = HashMap::new();
96
97 prices.insert(
98 "claude-opus".to_string(),
99 ModelPrice {
100 input_per_million: 15.0,
101 output_per_million: 75.0,
102 },
103 );
104
105 prices.insert(
106 "claude-sonnet".to_string(),
107 ModelPrice {
108 input_per_million: 3.0,
109 output_per_million: 15.0,
110 },
111 );
112
113 prices.insert(
114 "claude-haiku".to_string(),
115 ModelPrice {
116 input_per_million: 0.25,
117 output_per_million: 1.25,
118 },
119 );
120
121 prices.insert(
122 "gpt-4o".to_string(),
123 ModelPrice {
124 input_per_million: 2.50,
125 output_per_million: 10.0,
126 },
127 );
128
129 prices.insert(
130 "gpt-4o-mini".to_string(),
131 ModelPrice {
132 input_per_million: 0.15,
133 output_per_million: 0.60,
134 },
135 );
136
137 prices.insert(
139 "ollama/".to_string(),
140 ModelPrice {
141 input_per_million: 0.0,
142 output_per_million: 0.0,
143 },
144 );
145
146 prices
147 }
148
149 fn get_price(&self, model: &str) -> &ModelPrice {
151 if let Some(price) = self.model_prices.get(model) {
153 return price;
154 }
155
156 for (key, price) in &self.model_prices {
158 if model.starts_with(key) {
159 return price;
160 }
161 }
162
163 static DEFAULT_PRICE: ModelPrice = ModelPrice {
166 input_per_million: 1.0,
167 output_per_million: 3.0,
168 };
169 &DEFAULT_PRICE
170 }
171
172 pub fn estimate_cost(&self, model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
174 let price = self.get_price(model);
175 let input_cost = (input_tokens as f64 / 1_000_000.0) * price.input_per_million;
176 let output_cost = (output_tokens as f64 / 1_000_000.0) * price.output_per_million;
177 input_cost + output_cost
178 }
179
180 #[instrument(skip(self), fields(%fighter_id, %model, input_tokens, output_tokens))]
182 pub async fn record_usage(
183 &self,
184 fighter_id: &FighterId,
185 model: &str,
186 input_tokens: u64,
187 output_tokens: u64,
188 ) -> PunchResult<f64> {
189 let cost = self.estimate_cost(model, input_tokens, output_tokens);
190
191 self.memory
192 .record_usage(fighter_id, model, input_tokens, output_tokens, cost)
193 .await?;
194
195 debug!(cost_usd = cost, "usage recorded with cost");
196 Ok(cost)
197 }
198
199 pub async fn get_spend(&self, fighter_id: &FighterId, period: SpendPeriod) -> PunchResult<f64> {
201 let since = Utc::now() - period.to_duration();
202 let summary = self.memory.get_usage_summary(fighter_id, since).await?;
203 Ok(summary.total_cost_usd)
204 }
205
206 pub async fn get_total_spend(&self, period: SpendPeriod) -> PunchResult<f64> {
208 let since = Utc::now() - period.to_duration();
209 let summary = self.memory.get_total_usage_summary(since).await?;
210 Ok(summary.total_cost_usd)
211 }
212
213 pub async fn get_fighter_summary(
215 &self,
216 fighter_id: &FighterId,
217 period: SpendPeriod,
218 ) -> PunchResult<punch_memory::UsageSummary> {
219 let since = Utc::now() - period.to_duration();
220 self.memory.get_usage_summary(fighter_id, since).await
221 }
222
223 pub async fn get_total_summary(
225 &self,
226 period: SpendPeriod,
227 ) -> PunchResult<punch_memory::UsageSummary> {
228 let since = Utc::now() - period.to_duration();
229 self.memory.get_total_usage_summary(since).await
230 }
231
232 pub async fn get_model_breakdown(
234 &self,
235 fighter_id: &FighterId,
236 period: SpendPeriod,
237 ) -> PunchResult<Vec<punch_memory::ModelUsageBreakdown>> {
238 let since = Utc::now() - period.to_duration();
239 self.memory.get_model_breakdown(fighter_id, since).await
240 }
241
242 pub async fn get_total_model_breakdown(
244 &self,
245 period: SpendPeriod,
246 ) -> PunchResult<Vec<punch_memory::ModelUsageBreakdown>> {
247 let since = Utc::now() - period.to_duration();
248 self.memory.get_total_model_breakdown(since).await
249 }
250
251 pub async fn get_fighter_breakdown(
253 &self,
254 period: SpendPeriod,
255 ) -> PunchResult<Vec<punch_memory::FighterUsageBreakdown>> {
256 let since = Utc::now() - period.to_duration();
257 self.memory.get_fighter_breakdown(since).await
258 }
259}
260
261#[cfg(test)]
266mod tests {
267 use super::*;
268
269 #[test]
270 fn estimate_cost_claude_sonnet() {
271 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
272 let engine = MeteringEngine::new(memory);
273
274 let cost = engine.estimate_cost("claude-sonnet-4-20250514", 1_000_000, 1_000_000);
276 assert!((cost - 18.0).abs() < 1e-9);
277 }
278
279 #[test]
280 fn estimate_cost_gpt4o_mini() {
281 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
282 let engine = MeteringEngine::new(memory);
283
284 let cost = engine.estimate_cost("gpt-4o-mini", 1_000_000, 1_000_000);
286 assert!((cost - 0.75).abs() < 1e-9);
287 }
288
289 #[test]
290 fn estimate_cost_ollama_free() {
291 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
292 let engine = MeteringEngine::new(memory);
293
294 let cost = engine.estimate_cost("ollama/llama3", 1_000_000, 1_000_000);
295 assert!((cost - 0.0).abs() < 1e-9);
296 }
297
298 #[test]
299 fn estimate_cost_unknown_model_uses_fallback() {
300 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
301 let engine = MeteringEngine::new(memory);
302
303 let cost = engine.estimate_cost("some-unknown-model", 1_000_000, 1_000_000);
305 assert!((cost - 4.0).abs() < 1e-9);
306 }
307
308 #[test]
309 fn estimate_cost_small_usage() {
310 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
311 let engine = MeteringEngine::new(memory);
312
313 let cost = engine.estimate_cost("claude-sonnet-4-20250514", 1000, 500);
315 let expected = (1000.0 / 1_000_000.0) * 3.0 + (500.0 / 1_000_000.0) * 15.0;
316 assert!((cost - expected).abs() < 1e-12);
317 }
318
319 #[tokio::test]
320 async fn record_and_query_usage() {
321 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
322 let engine = MeteringEngine::new(Arc::clone(&memory));
323
324 let fighter_id = FighterId::new();
325
326 use punch_types::{FighterManifest, FighterStatus, ModelConfig, Provider, WeightClass};
328 let manifest = FighterManifest {
329 name: "metering-test".into(),
330 description: "test".into(),
331 model: ModelConfig {
332 provider: Provider::Anthropic,
333 model: "claude-sonnet-4-20250514".into(),
334 api_key_env: None,
335 base_url: None,
336 max_tokens: Some(4096),
337 temperature: Some(0.7),
338 },
339 system_prompt: "test".into(),
340 capabilities: Vec::new(),
341 weight_class: WeightClass::Featherweight,
342 tenant_id: None,
343 };
344 memory
345 .save_fighter(&fighter_id, &manifest, FighterStatus::Idle)
346 .await
347 .unwrap();
348
349 let cost = engine
350 .record_usage(&fighter_id, "claude-sonnet-4-20250514", 5000, 2000)
351 .await
352 .unwrap();
353
354 let expected = (5000.0 / 1_000_000.0) * 3.0 + (2000.0 / 1_000_000.0) * 15.0;
356 assert!((cost - expected).abs() < 1e-12);
357
358 let spend = engine
360 .get_spend(&fighter_id, SpendPeriod::Hour)
361 .await
362 .unwrap();
363 assert!((spend - expected).abs() < 1e-9);
364 }
365
366 #[test]
367 fn spend_period_display() {
368 assert_eq!(SpendPeriod::Hour.to_string(), "hour");
369 assert_eq!(SpendPeriod::Day.to_string(), "day");
370 assert_eq!(SpendPeriod::Month.to_string(), "month");
371 }
372
373 #[test]
374 fn estimate_cost_zero_tokens() {
375 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
376 let engine = MeteringEngine::new(memory);
377 let cost = engine.estimate_cost("claude-sonnet-4-20250514", 0, 0);
378 assert!((cost - 0.0).abs() < 1e-12);
379 }
380
381 #[test]
382 fn estimate_cost_claude_opus() {
383 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
384 let engine = MeteringEngine::new(memory);
385 let cost = engine.estimate_cost("claude-opus-4-20250514", 1_000_000, 1_000_000);
387 assert!((cost - 90.0).abs() < 1e-9);
388 }
389
390 #[test]
391 fn estimate_cost_claude_haiku() {
392 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
393 let engine = MeteringEngine::new(memory);
394 let cost = engine.estimate_cost("claude-haiku-3.5", 1_000_000, 1_000_000);
396 assert!((cost - 1.5).abs() < 1e-9);
397 }
398
399 #[test]
400 fn estimate_cost_gpt4o() {
401 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
402 let engine = MeteringEngine::new(memory);
403 let cost = engine.estimate_cost("gpt-4o", 1_000_000, 1_000_000);
405 assert!((cost - 12.5).abs() < 1e-9);
406 }
407
408 #[test]
409 fn with_custom_prices() {
410 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
411 let mut prices = HashMap::new();
412 prices.insert(
413 "custom-model".to_string(),
414 ModelPrice {
415 input_per_million: 5.0,
416 output_per_million: 10.0,
417 },
418 );
419 let engine = MeteringEngine::with_prices(memory, prices);
420 let cost = engine.estimate_cost("custom-model", 1_000_000, 1_000_000);
421 assert!((cost - 15.0).abs() < 1e-9);
422 }
423
424 #[test]
425 fn custom_prices_missing_model_uses_default_fallback() {
426 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
427 let prices = HashMap::new();
428 let engine = MeteringEngine::with_prices(memory, prices);
429 let cost = engine.estimate_cost("anything", 1_000_000, 1_000_000);
431 assert!((cost - 4.0).abs() < 1e-9);
432 }
433}