1use std::collections::HashMap;
24use std::time::Duration;
25
26use serde::{Deserialize, Serialize};
27
28use adk_core::Event;
29
30use crate::pricing::{ModelPricing, default_pricing};
31
32#[derive(Debug, Clone, Default, Serialize, Deserialize)]
37pub struct CostMetrics {
38 pub prompt_tokens: u64,
40 pub completion_tokens: u64,
42 pub total_tokens: u64,
44 pub cost_usd: Option<f64>,
46 pub latency_ms: u64,
48}
49
50pub struct CostTracker {
71 pricing: HashMap<String, ModelPricing>,
72}
73
74impl CostTracker {
75 pub fn new() -> Self {
80 Self::with_pricing(default_pricing())
81 }
82
83 pub fn with_pricing(pricing: Vec<ModelPricing>) -> Self {
89 let pricing_map = pricing.into_iter().map(|p| (p.model_name.clone(), p)).collect();
90 Self { pricing: pricing_map }
91 }
92
93 pub fn extract_metrics(&self, events: &[Event], duration: Duration) -> CostMetrics {
114 let mut prompt_tokens: u64 = 0;
115 let mut completion_tokens: u64 = 0;
116 let mut total_tokens: u64 = 0;
117
118 for event in events {
119 if let Some(usage) = &event.llm_response.usage_metadata {
120 prompt_tokens += u64::try_from(usage.prompt_token_count.max(0)).unwrap_or(0);
122 completion_tokens +=
123 u64::try_from(usage.candidates_token_count.max(0)).unwrap_or(0);
124 total_tokens += u64::try_from(usage.total_token_count.max(0)).unwrap_or(0);
125 }
126 }
127
128 CostMetrics {
129 prompt_tokens,
130 completion_tokens,
131 total_tokens,
132 cost_usd: None,
133 latency_ms: duration.as_millis() as u64,
134 }
135 }
136
137 pub fn compute_cost(
152 &self,
153 model: &str,
154 prompt_tokens: u64,
155 completion_tokens: u64,
156 ) -> Option<f64> {
157 self.pricing.get(model).map(|p| {
158 (prompt_tokens as f64 / 1000.0) * p.input_cost_per_1k
159 + (completion_tokens as f64 / 1000.0) * p.output_cost_per_1k
160 })
161 }
162}
163
164impl Default for CostTracker {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn test_cost_tracker_new_has_default_pricing() {
176 let tracker = CostTracker::new();
177 assert!(!tracker.pricing.is_empty());
178 assert!(tracker.pricing.contains_key("gpt-4o"));
179 assert!(tracker.pricing.contains_key("gemini-2.5-flash"));
180 }
181
182 #[test]
183 fn test_cost_tracker_with_custom_pricing() {
184 let pricing = vec![ModelPricing::new("custom-model", 0.01, 0.02)];
185 let tracker = CostTracker::with_pricing(pricing);
186 assert_eq!(tracker.pricing.len(), 1);
187 assert!(tracker.pricing.contains_key("custom-model"));
188 }
189
190 #[test]
191 fn test_compute_cost_known_model() {
192 let pricing = vec![ModelPricing::new("test-model", 0.001, 0.002)];
193 let tracker = CostTracker::with_pricing(pricing);
194
195 let cost = tracker.compute_cost("test-model", 1000, 500);
196 assert!(cost.is_some());
197 let expected = (1000.0 / 1000.0) * 0.001 + (500.0 / 1000.0) * 0.002;
199 assert!((cost.unwrap() - expected).abs() < f64::EPSILON);
200 }
201
202 #[test]
203 fn test_compute_cost_unknown_model() {
204 let tracker = CostTracker::with_pricing(vec![]);
205 let cost = tracker.compute_cost("unknown", 100, 50);
206 assert!(cost.is_none());
207 }
208
209 #[test]
210 fn test_compute_cost_zero_tokens() {
211 let pricing = vec![ModelPricing::new("test-model", 0.001, 0.002)];
212 let tracker = CostTracker::with_pricing(pricing);
213
214 let cost = tracker.compute_cost("test-model", 0, 0);
215 assert_eq!(cost, Some(0.0));
216 }
217
218 #[test]
219 fn test_extract_metrics_empty_events() {
220 let tracker = CostTracker::new();
221 let metrics = tracker.extract_metrics(&[], Duration::from_millis(500));
222
223 assert_eq!(metrics.prompt_tokens, 0);
224 assert_eq!(metrics.completion_tokens, 0);
225 assert_eq!(metrics.total_tokens, 0);
226 assert_eq!(metrics.cost_usd, None);
227 assert_eq!(metrics.latency_ms, 500);
228 }
229
230 #[test]
231 fn test_extract_metrics_with_usage() {
232 let tracker =
233 CostTracker::with_pricing(vec![ModelPricing::new("test-model", 0.001, 0.002)]);
234
235 let mut event = Event::new("inv-1");
236 event.llm_response.usage_metadata = Some(adk_core::UsageMetadata {
237 prompt_token_count: 100,
238 candidates_token_count: 50,
239 total_token_count: 150,
240 ..Default::default()
241 });
242
243 let metrics = tracker.extract_metrics(&[event], Duration::from_secs(2));
244
245 assert_eq!(metrics.prompt_tokens, 100);
246 assert_eq!(metrics.completion_tokens, 50);
247 assert_eq!(metrics.total_tokens, 150);
248 assert_eq!(metrics.cost_usd, None);
250 assert_eq!(metrics.latency_ms, 2000);
251 }
252
253 #[test]
254 fn test_extract_metrics_no_usage_metadata() {
255 let tracker = CostTracker::new();
256 let event = Event::new("inv-1");
257
258 let metrics = tracker.extract_metrics(&[event], Duration::from_millis(100));
259
260 assert_eq!(metrics.prompt_tokens, 0);
261 assert_eq!(metrics.completion_tokens, 0);
262 assert_eq!(metrics.total_tokens, 0);
263 assert_eq!(metrics.cost_usd, None);
264 assert_eq!(metrics.latency_ms, 100);
265 }
266
267 #[test]
268 fn test_extract_metrics_multiple_events_accumulate() {
269 let tracker =
270 CostTracker::with_pricing(vec![ModelPricing::new("test-model", 0.001, 0.002)]);
271
272 let mut event1 = Event::new("inv-1");
273 event1.llm_response.usage_metadata = Some(adk_core::UsageMetadata {
274 prompt_token_count: 50,
275 candidates_token_count: 25,
276 total_token_count: 75,
277 ..Default::default()
278 });
279
280 let mut event2 = Event::new("inv-1");
281 event2.llm_response.usage_metadata = Some(adk_core::UsageMetadata {
282 prompt_token_count: 60,
283 candidates_token_count: 30,
284 total_token_count: 90,
285 ..Default::default()
286 });
287
288 let metrics = tracker.extract_metrics(&[event1, event2], Duration::from_millis(300));
289
290 assert_eq!(metrics.prompt_tokens, 110);
291 assert_eq!(metrics.completion_tokens, 55);
292 assert_eq!(metrics.total_tokens, 165);
293 assert_eq!(metrics.cost_usd, None);
295 assert_eq!(metrics.latency_ms, 300);
296 }
297
298 #[test]
299 fn test_default_impl() {
300 let tracker = CostTracker::default();
301 assert!(!tracker.pricing.is_empty());
302 }
303}