Skip to main content

adk_eval/
cost_tracker.rs

1//! Cost and latency tracking for evaluation runs.
2//!
3//! This module provides [`CostTracker`] which extracts token usage from agent
4//! event streams and computes estimated dollar costs using configurable
5//! per-model pricing tables.
6//!
7//! # Example
8//!
9//! ```rust
10//! use adk_eval::cost_tracker::{CostTracker, CostMetrics};
11//!
12//! let tracker = CostTracker::new();
13//!
14//! // Compute cost for a known model
15//! let cost = tracker.compute_cost("gpt-4o", 1000, 500);
16//! assert!(cost.is_some());
17//!
18//! // Unknown models return None
19//! let cost = tracker.compute_cost("unknown-model", 100, 50);
20//! assert!(cost.is_none());
21//! ```
22
23use std::collections::HashMap;
24use std::time::Duration;
25
26use serde::{Deserialize, Serialize};
27
28use adk_core::Event;
29
30use crate::pricing::{ModelPricing, default_pricing};
31
32/// Cost and latency metrics for a single evaluation turn.
33///
34/// Captures token usage, estimated cost, and wall-clock latency for
35/// a set of events produced during agent execution.
36#[derive(Debug, Clone, Default, Serialize, Deserialize)]
37pub struct CostMetrics {
38    /// Number of prompt/input tokens used.
39    pub prompt_tokens: u64,
40    /// Number of completion/output tokens generated.
41    pub completion_tokens: u64,
42    /// Total token count (prompt + completion).
43    pub total_tokens: u64,
44    /// Estimated cost in USD (None if pricing unavailable for the model).
45    pub cost_usd: Option<f64>,
46    /// Wall-clock latency in milliseconds.
47    pub latency_ms: u64,
48}
49
50/// Tracks cost and latency metrics from agent event streams.
51///
52/// Uses per-model pricing tables to compute estimated USD costs from
53/// token counts extracted from [`Event`] streams.
54///
55/// # Example
56///
57/// ```rust
58/// use adk_eval::cost_tracker::CostTracker;
59/// use adk_eval::pricing::ModelPricing;
60///
61/// // Use default pricing
62/// let tracker = CostTracker::new();
63///
64/// // Or provide custom pricing
65/// let custom_pricing = vec![
66///     ModelPricing::new("my-model", 0.001, 0.002),
67/// ];
68/// let tracker = CostTracker::with_pricing(custom_pricing);
69/// ```
70pub struct CostTracker {
71    pricing: HashMap<String, ModelPricing>,
72}
73
74impl CostTracker {
75    /// Creates a new `CostTracker` with default pricing for common models.
76    ///
77    /// Default pricing includes Google Gemini, OpenAI GPT, and Anthropic
78    /// Claude model families.
79    pub fn new() -> Self {
80        Self::with_pricing(default_pricing())
81    }
82
83    /// Creates a new `CostTracker` with the specified pricing table.
84    ///
85    /// # Arguments
86    ///
87    /// * `pricing` - A list of [`ModelPricing`] entries to use for cost computation.
88    pub fn with_pricing(pricing: Vec<ModelPricing>) -> Self {
89        let pricing_map = pricing.into_iter().map(|p| (p.model_name.clone(), p)).collect();
90        Self { pricing: pricing_map }
91    }
92
93    /// Extract cost metrics from an event stream.
94    ///
95    /// Iterates over events looking for [`UsageMetadata`](adk_core::UsageMetadata)
96    /// on LLM responses. Token counts are summed across all events that contain
97    /// usage metadata. If no usage metadata is found, token counts default to zero.
98    ///
99    /// The `duration` parameter is converted to milliseconds for the `latency_ms` field.
100    ///
101    /// Note: The `cost_usd` field will be `None` because the model name is not
102    /// available on the Event struct. Use [`compute_cost`](Self::compute_cost)
103    /// separately when the model name is known.
104    ///
105    /// # Arguments
106    ///
107    /// * `events` - Slice of events from an agent execution.
108    /// * `duration` - Wall-clock duration of the execution.
109    ///
110    /// # Returns
111    ///
112    /// A [`CostMetrics`] struct with aggregated token counts and latency.
113    pub fn extract_metrics(&self, events: &[Event], duration: Duration) -> CostMetrics {
114        let mut prompt_tokens: u64 = 0;
115        let mut completion_tokens: u64 = 0;
116        let mut total_tokens: u64 = 0;
117
118        for event in events {
119            if let Some(usage) = &event.llm_response.usage_metadata {
120                // Accumulate token counts, treating negative values as zero
121                prompt_tokens += u64::try_from(usage.prompt_token_count.max(0)).unwrap_or(0);
122                completion_tokens +=
123                    u64::try_from(usage.candidates_token_count.max(0)).unwrap_or(0);
124                total_tokens += u64::try_from(usage.total_token_count.max(0)).unwrap_or(0);
125            }
126        }
127
128        CostMetrics {
129            prompt_tokens,
130            completion_tokens,
131            total_tokens,
132            cost_usd: None,
133            latency_ms: duration.as_millis() as u64,
134        }
135    }
136
137    /// Compute cost from token counts and model name.
138    ///
139    /// Uses the formula:
140    /// ```text
141    /// (prompt_tokens / 1000.0) * input_cost_per_1k + (completion_tokens / 1000.0) * output_cost_per_1k
142    /// ```
143    ///
144    /// Returns `None` if the model is not found in the pricing table.
145    ///
146    /// # Arguments
147    ///
148    /// * `model` - Model identifier to look up pricing for.
149    /// * `prompt_tokens` - Number of input tokens.
150    /// * `completion_tokens` - Number of output tokens.
151    pub fn compute_cost(
152        &self,
153        model: &str,
154        prompt_tokens: u64,
155        completion_tokens: u64,
156    ) -> Option<f64> {
157        self.pricing.get(model).map(|p| {
158            (prompt_tokens as f64 / 1000.0) * p.input_cost_per_1k
159                + (completion_tokens as f64 / 1000.0) * p.output_cost_per_1k
160        })
161    }
162}
163
164impl Default for CostTracker {
165    fn default() -> Self {
166        Self::new()
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn test_cost_tracker_new_has_default_pricing() {
176        let tracker = CostTracker::new();
177        assert!(!tracker.pricing.is_empty());
178        assert!(tracker.pricing.contains_key("gpt-4o"));
179        assert!(tracker.pricing.contains_key("gemini-2.5-flash"));
180    }
181
182    #[test]
183    fn test_cost_tracker_with_custom_pricing() {
184        let pricing = vec![ModelPricing::new("custom-model", 0.01, 0.02)];
185        let tracker = CostTracker::with_pricing(pricing);
186        assert_eq!(tracker.pricing.len(), 1);
187        assert!(tracker.pricing.contains_key("custom-model"));
188    }
189
190    #[test]
191    fn test_compute_cost_known_model() {
192        let pricing = vec![ModelPricing::new("test-model", 0.001, 0.002)];
193        let tracker = CostTracker::with_pricing(pricing);
194
195        let cost = tracker.compute_cost("test-model", 1000, 500);
196        assert!(cost.is_some());
197        // (1000/1000) * 0.001 + (500/1000) * 0.002 = 0.001 + 0.001 = 0.002
198        let expected = (1000.0 / 1000.0) * 0.001 + (500.0 / 1000.0) * 0.002;
199        assert!((cost.unwrap() - expected).abs() < f64::EPSILON);
200    }
201
202    #[test]
203    fn test_compute_cost_unknown_model() {
204        let tracker = CostTracker::with_pricing(vec![]);
205        let cost = tracker.compute_cost("unknown", 100, 50);
206        assert!(cost.is_none());
207    }
208
209    #[test]
210    fn test_compute_cost_zero_tokens() {
211        let pricing = vec![ModelPricing::new("test-model", 0.001, 0.002)];
212        let tracker = CostTracker::with_pricing(pricing);
213
214        let cost = tracker.compute_cost("test-model", 0, 0);
215        assert_eq!(cost, Some(0.0));
216    }
217
218    #[test]
219    fn test_extract_metrics_empty_events() {
220        let tracker = CostTracker::new();
221        let metrics = tracker.extract_metrics(&[], Duration::from_millis(500));
222
223        assert_eq!(metrics.prompt_tokens, 0);
224        assert_eq!(metrics.completion_tokens, 0);
225        assert_eq!(metrics.total_tokens, 0);
226        assert_eq!(metrics.cost_usd, None);
227        assert_eq!(metrics.latency_ms, 500);
228    }
229
230    #[test]
231    fn test_extract_metrics_with_usage() {
232        let tracker =
233            CostTracker::with_pricing(vec![ModelPricing::new("test-model", 0.001, 0.002)]);
234
235        let mut event = Event::new("inv-1");
236        event.llm_response.usage_metadata = Some(adk_core::UsageMetadata {
237            prompt_token_count: 100,
238            candidates_token_count: 50,
239            total_token_count: 150,
240            ..Default::default()
241        });
242
243        let metrics = tracker.extract_metrics(&[event], Duration::from_secs(2));
244
245        assert_eq!(metrics.prompt_tokens, 100);
246        assert_eq!(metrics.completion_tokens, 50);
247        assert_eq!(metrics.total_tokens, 150);
248        // cost_usd is None because model can't be determined from events
249        assert_eq!(metrics.cost_usd, None);
250        assert_eq!(metrics.latency_ms, 2000);
251    }
252
253    #[test]
254    fn test_extract_metrics_no_usage_metadata() {
255        let tracker = CostTracker::new();
256        let event = Event::new("inv-1");
257
258        let metrics = tracker.extract_metrics(&[event], Duration::from_millis(100));
259
260        assert_eq!(metrics.prompt_tokens, 0);
261        assert_eq!(metrics.completion_tokens, 0);
262        assert_eq!(metrics.total_tokens, 0);
263        assert_eq!(metrics.cost_usd, None);
264        assert_eq!(metrics.latency_ms, 100);
265    }
266
267    #[test]
268    fn test_extract_metrics_multiple_events_accumulate() {
269        let tracker =
270            CostTracker::with_pricing(vec![ModelPricing::new("test-model", 0.001, 0.002)]);
271
272        let mut event1 = Event::new("inv-1");
273        event1.llm_response.usage_metadata = Some(adk_core::UsageMetadata {
274            prompt_token_count: 50,
275            candidates_token_count: 25,
276            total_token_count: 75,
277            ..Default::default()
278        });
279
280        let mut event2 = Event::new("inv-1");
281        event2.llm_response.usage_metadata = Some(adk_core::UsageMetadata {
282            prompt_token_count: 60,
283            candidates_token_count: 30,
284            total_token_count: 90,
285            ..Default::default()
286        });
287
288        let metrics = tracker.extract_metrics(&[event1, event2], Duration::from_millis(300));
289
290        assert_eq!(metrics.prompt_tokens, 110);
291        assert_eq!(metrics.completion_tokens, 55);
292        assert_eq!(metrics.total_tokens, 165);
293        // cost_usd is None because model can't be determined from events
294        assert_eq!(metrics.cost_usd, None);
295        assert_eq!(metrics.latency_ms, 300);
296    }
297
298    #[test]
299    fn test_default_impl() {
300        let tracker = CostTracker::default();
301        assert!(!tracker.pricing.is_empty());
302    }
303}