1use std::collections::HashMap;
8use std::sync::Arc;
9
10use chrono::{Duration, Utc};
11use serde::{Deserialize, Serialize};
12use tracing::{debug, instrument};
13
14use punch_memory::MemorySubstrate;
15use punch_types::{FighterId, PunchResult};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ModelPrice {
24 pub input_per_million: f64,
26 pub output_per_million: f64,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum SpendPeriod {
34 Hour,
35 Day,
36 Month,
37}
38
39impl SpendPeriod {
40 fn to_duration(self) -> Duration {
42 match self {
43 Self::Hour => Duration::hours(1),
44 Self::Day => Duration::days(1),
45 Self::Month => Duration::days(30),
46 }
47 }
48}
49
50impl std::fmt::Display for SpendPeriod {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 Self::Hour => write!(f, "hour"),
54 Self::Day => write!(f, "day"),
55 Self::Month => write!(f, "month"),
56 }
57 }
58}
59
60pub struct MeteringEngine {
66 memory: Arc<MemorySubstrate>,
68 model_prices: HashMap<String, ModelPrice>,
70}
71
72impl MeteringEngine {
73 pub fn new(memory: Arc<MemorySubstrate>) -> Self {
75 let model_prices = Self::default_price_table();
76 Self {
77 memory,
78 model_prices,
79 }
80 }
81
82 pub fn with_prices(
84 memory: Arc<MemorySubstrate>,
85 model_prices: HashMap<String, ModelPrice>,
86 ) -> Self {
87 Self {
88 memory,
89 model_prices,
90 }
91 }
92
93 fn default_price_table() -> HashMap<String, ModelPrice> {
95 let mut prices = HashMap::new();
96
97 prices.insert(
98 "claude-opus".to_string(),
99 ModelPrice {
100 input_per_million: 15.0,
101 output_per_million: 75.0,
102 },
103 );
104
105 prices.insert(
106 "claude-sonnet".to_string(),
107 ModelPrice {
108 input_per_million: 3.0,
109 output_per_million: 15.0,
110 },
111 );
112
113 prices.insert(
114 "claude-haiku".to_string(),
115 ModelPrice {
116 input_per_million: 0.25,
117 output_per_million: 1.25,
118 },
119 );
120
121 prices.insert(
122 "gpt-4o".to_string(),
123 ModelPrice {
124 input_per_million: 2.50,
125 output_per_million: 10.0,
126 },
127 );
128
129 prices.insert(
130 "gpt-4o-mini".to_string(),
131 ModelPrice {
132 input_per_million: 0.15,
133 output_per_million: 0.60,
134 },
135 );
136
137 prices.insert(
139 "ollama/".to_string(),
140 ModelPrice {
141 input_per_million: 0.0,
142 output_per_million: 0.0,
143 },
144 );
145
146 prices
147 }
148
149 fn get_price(&self, model: &str) -> &ModelPrice {
151 if let Some(price) = self.model_prices.get(model) {
153 return price;
154 }
155
156 for (key, price) in &self.model_prices {
158 if model.starts_with(key) {
159 return price;
160 }
161 }
162
163 static DEFAULT_PRICE: ModelPrice = ModelPrice {
166 input_per_million: 1.0,
167 output_per_million: 3.0,
168 };
169 &DEFAULT_PRICE
170 }
171
172 pub fn estimate_cost(&self, model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
174 let price = self.get_price(model);
175 let input_cost = (input_tokens as f64 / 1_000_000.0) * price.input_per_million;
176 let output_cost = (output_tokens as f64 / 1_000_000.0) * price.output_per_million;
177 input_cost + output_cost
178 }
179
180 #[instrument(skip(self), fields(%fighter_id, %model, input_tokens, output_tokens))]
182 pub async fn record_usage(
183 &self,
184 fighter_id: &FighterId,
185 model: &str,
186 input_tokens: u64,
187 output_tokens: u64,
188 ) -> PunchResult<f64> {
189 let cost = self.estimate_cost(model, input_tokens, output_tokens);
190
191 self.memory
192 .record_usage(fighter_id, model, input_tokens, output_tokens, cost)
193 .await?;
194
195 debug!(cost_usd = cost, "usage recorded with cost");
196 Ok(cost)
197 }
198
199 pub async fn get_spend(&self, fighter_id: &FighterId, period: SpendPeriod) -> PunchResult<f64> {
201 let since = Utc::now() - period.to_duration();
202 let summary = self.memory.get_usage_summary(fighter_id, since).await?;
203 Ok(summary.total_cost_usd)
204 }
205
206 pub async fn get_total_spend(&self, period: SpendPeriod) -> PunchResult<f64> {
208 let since = Utc::now() - period.to_duration();
209 let summary = self.memory.get_total_usage_summary(since).await?;
210 Ok(summary.total_cost_usd)
211 }
212}
213
214#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn estimate_cost_claude_sonnet() {
224 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
225 let engine = MeteringEngine::new(memory);
226
227 let cost = engine.estimate_cost("claude-sonnet-4-20250514", 1_000_000, 1_000_000);
229 assert!((cost - 18.0).abs() < 1e-9);
230 }
231
232 #[test]
233 fn estimate_cost_gpt4o_mini() {
234 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
235 let engine = MeteringEngine::new(memory);
236
237 let cost = engine.estimate_cost("gpt-4o-mini", 1_000_000, 1_000_000);
239 assert!((cost - 0.75).abs() < 1e-9);
240 }
241
242 #[test]
243 fn estimate_cost_ollama_free() {
244 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
245 let engine = MeteringEngine::new(memory);
246
247 let cost = engine.estimate_cost("ollama/llama3", 1_000_000, 1_000_000);
248 assert!((cost - 0.0).abs() < 1e-9);
249 }
250
251 #[test]
252 fn estimate_cost_unknown_model_uses_fallback() {
253 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
254 let engine = MeteringEngine::new(memory);
255
256 let cost = engine.estimate_cost("some-unknown-model", 1_000_000, 1_000_000);
258 assert!((cost - 4.0).abs() < 1e-9);
259 }
260
261 #[test]
262 fn estimate_cost_small_usage() {
263 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
264 let engine = MeteringEngine::new(memory);
265
266 let cost = engine.estimate_cost("claude-sonnet-4-20250514", 1000, 500);
268 let expected = (1000.0 / 1_000_000.0) * 3.0 + (500.0 / 1_000_000.0) * 15.0;
269 assert!((cost - expected).abs() < 1e-12);
270 }
271
272 #[tokio::test]
273 async fn record_and_query_usage() {
274 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
275 let engine = MeteringEngine::new(Arc::clone(&memory));
276
277 let fighter_id = FighterId::new();
278
279 use punch_types::{FighterManifest, FighterStatus, ModelConfig, Provider, WeightClass};
281 let manifest = FighterManifest {
282 name: "metering-test".into(),
283 description: "test".into(),
284 model: ModelConfig {
285 provider: Provider::Anthropic,
286 model: "claude-sonnet-4-20250514".into(),
287 api_key_env: None,
288 base_url: None,
289 max_tokens: Some(4096),
290 temperature: Some(0.7),
291 },
292 system_prompt: "test".into(),
293 capabilities: Vec::new(),
294 weight_class: WeightClass::Featherweight,
295 tenant_id: None,
296 };
297 memory
298 .save_fighter(&fighter_id, &manifest, FighterStatus::Idle)
299 .await
300 .unwrap();
301
302 let cost = engine
303 .record_usage(&fighter_id, "claude-sonnet-4-20250514", 5000, 2000)
304 .await
305 .unwrap();
306
307 let expected = (5000.0 / 1_000_000.0) * 3.0 + (2000.0 / 1_000_000.0) * 15.0;
309 assert!((cost - expected).abs() < 1e-12);
310
311 let spend = engine
313 .get_spend(&fighter_id, SpendPeriod::Hour)
314 .await
315 .unwrap();
316 assert!((spend - expected).abs() < 1e-9);
317 }
318
319 #[test]
320 fn spend_period_display() {
321 assert_eq!(SpendPeriod::Hour.to_string(), "hour");
322 assert_eq!(SpendPeriod::Day.to_string(), "day");
323 assert_eq!(SpendPeriod::Month.to_string(), "month");
324 }
325
326 #[test]
327 fn estimate_cost_zero_tokens() {
328 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
329 let engine = MeteringEngine::new(memory);
330 let cost = engine.estimate_cost("claude-sonnet-4-20250514", 0, 0);
331 assert!((cost - 0.0).abs() < 1e-12);
332 }
333
334 #[test]
335 fn estimate_cost_claude_opus() {
336 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
337 let engine = MeteringEngine::new(memory);
338 let cost = engine.estimate_cost("claude-opus-4-20250514", 1_000_000, 1_000_000);
340 assert!((cost - 90.0).abs() < 1e-9);
341 }
342
343 #[test]
344 fn estimate_cost_claude_haiku() {
345 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
346 let engine = MeteringEngine::new(memory);
347 let cost = engine.estimate_cost("claude-haiku-3.5", 1_000_000, 1_000_000);
349 assert!((cost - 1.5).abs() < 1e-9);
350 }
351
352 #[test]
353 fn estimate_cost_gpt4o() {
354 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
355 let engine = MeteringEngine::new(memory);
356 let cost = engine.estimate_cost("gpt-4o", 1_000_000, 1_000_000);
358 assert!((cost - 12.5).abs() < 1e-9);
359 }
360
361 #[test]
362 fn with_custom_prices() {
363 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
364 let mut prices = HashMap::new();
365 prices.insert(
366 "custom-model".to_string(),
367 ModelPrice {
368 input_per_million: 5.0,
369 output_per_million: 10.0,
370 },
371 );
372 let engine = MeteringEngine::with_prices(memory, prices);
373 let cost = engine.estimate_cost("custom-model", 1_000_000, 1_000_000);
374 assert!((cost - 15.0).abs() < 1e-9);
375 }
376
377 #[test]
378 fn custom_prices_missing_model_uses_default_fallback() {
379 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
380 let prices = HashMap::new();
381 let engine = MeteringEngine::with_prices(memory, prices);
382 let cost = engine.estimate_cost("anything", 1_000_000, 1_000_000);
384 assert!((cost - 4.0).abs() < 1e-9);
385 }
386}