Skip to main content

adk_runner/
cache.rs

1use adk_core::{ContextCacheConfig, Event};
2use serde::{Deserialize, Serialize};
3
4/// Internal cache lifecycle manager.
5///
6/// Tracks the active cache name, invocation count, and determines
7/// when caching should be attempted or refreshed based on
8/// [`ContextCacheConfig`] settings.
9pub(crate) struct CacheManager {
10    config: ContextCacheConfig,
11    active_cache_name: Option<String>,
12    invocation_count: u32,
13}
14
15impl CacheManager {
16    pub(crate) fn new(config: ContextCacheConfig) -> Self {
17        Self { config, active_cache_name: None, invocation_count: 0 }
18    }
19
20    /// Check if caching should be attempted based on config.
21    ///
22    /// Returns `false` when `min_tokens` or `ttl_seconds` is zero,
23    /// effectively disabling the cache lifecycle.
24    ///
25    /// Note: The `min_tokens` threshold is enforced server-side by the
26    /// provider (e.g., Gemini rejects cache creation for small contexts).
27    /// A zero value here acts as a kill-switch for the entire lifecycle.
28    pub(crate) fn is_enabled(&self) -> bool {
29        self.config.min_tokens > 0 && self.config.ttl_seconds > 0
30    }
31
32    /// Return the active cache name, if any.
33    pub(crate) fn active_cache_name(&self) -> Option<&str> {
34        self.active_cache_name.as_deref()
35    }
36
37    /// Check if the cache needs refresh based on invocation count.
38    ///
39    /// Returns `true` when the number of recorded invocations has
40    /// reached or exceeded `cache_intervals`.
41    pub(crate) fn needs_refresh(&self) -> bool {
42        self.invocation_count >= self.config.cache_intervals
43    }
44
45    /// Record an invocation and return the current cache name (if any).
46    pub(crate) fn record_invocation(&mut self) -> Option<&str> {
47        self.invocation_count += 1;
48        self.active_cache_name.as_deref()
49    }
50
51    /// Set the active cache name after creation, resetting the
52    /// invocation counter.
53    pub(crate) fn set_active_cache(&mut self, name: String) {
54        self.active_cache_name = Some(name);
55        self.invocation_count = 0;
56    }
57
58    /// Clear the active cache (after deletion or on error),
59    /// resetting the invocation counter.
60    ///
61    /// Returns the previously active cache name, if any.
62    pub(crate) fn clear_active_cache(&mut self) -> Option<String> {
63        self.invocation_count = 0;
64        self.active_cache_name.take()
65    }
66}
67
68/// Metrics computed from session event history.
69///
70/// All ratio fields are percentages in the range `[0.0, 100.0]`.
71/// When there are no events with usage metadata, all fields are zero.
72///
73/// # Example
74///
75/// ```rust,ignore
76/// use adk_runner::CachePerformanceAnalyzer;
77///
78/// let events = session.events();
79/// let metrics = CachePerformanceAnalyzer::analyze(&events);
80/// println!("Cache hit ratio: {:.1}%", metrics.cache_hit_ratio);
81/// ```
82#[derive(Debug, Clone, Default, Serialize, Deserialize)]
83pub struct CacheMetrics {
84    /// Total requests with `UsageMetadata`.
85    pub total_requests: u32,
86    /// Requests where `cache_read_input_token_count > 0`.
87    pub requests_with_cache_hits: u32,
88    /// Sum of all `prompt_token_count` values.
89    pub total_prompt_tokens: i64,
90    /// Sum of all `cache_read_input_token_count` values.
91    pub total_cache_read_tokens: i64,
92    /// Sum of all `cache_creation_input_token_count` values.
93    pub total_cache_creation_tokens: i64,
94    /// `total_cache_read_tokens / total_prompt_tokens * 100`.
95    pub cache_hit_ratio: f64,
96    /// `requests_with_cache_hits / total_requests * 100`.
97    pub cache_utilization_ratio: f64,
98    /// `total_cache_read_tokens / total_requests`.
99    pub avg_cached_tokens_per_request: f64,
100}
101
102/// Utility for computing cache effectiveness metrics from session events.
103///
104/// This is a stateless analyzer — call [`CachePerformanceAnalyzer::analyze`]
105/// with any slice of events to get a [`CacheMetrics`] snapshot.
106///
107/// # Example
108///
109/// ```rust,ignore
110/// use adk_runner::CachePerformanceAnalyzer;
111///
112/// let metrics = CachePerformanceAnalyzer::analyze(&events);
113/// println!("Hit ratio: {:.1}%, Utilization: {:.1}%",
114///     metrics.cache_hit_ratio, metrics.cache_utilization_ratio);
115/// ```
116pub struct CachePerformanceAnalyzer;
117
118impl CachePerformanceAnalyzer {
119    /// Analyze cache performance from a slice of events.
120    ///
121    /// Iterates over all events, extracts `usage_metadata` from LLM responses,
122    /// and computes aggregate cache metrics. Events without `usage_metadata`
123    /// are skipped. An empty slice returns zeroed metrics.
124    pub fn analyze(events: &[Event]) -> CacheMetrics {
125        let mut metrics = CacheMetrics::default();
126
127        for event in events {
128            let Some(ref usage) = event.llm_response.usage_metadata else {
129                continue;
130            };
131
132            metrics.total_requests += 1;
133            metrics.total_prompt_tokens += i64::from(usage.prompt_token_count);
134
135            let cache_read = usage.cache_read_input_token_count.unwrap_or(0);
136            metrics.total_cache_read_tokens += i64::from(cache_read);
137
138            if cache_read > 0 {
139                metrics.requests_with_cache_hits += 1;
140            }
141
142            let cache_creation = usage.cache_creation_input_token_count.unwrap_or(0);
143            metrics.total_cache_creation_tokens += i64::from(cache_creation);
144        }
145
146        if metrics.total_prompt_tokens > 0 {
147            metrics.cache_hit_ratio =
148                metrics.total_cache_read_tokens as f64 / metrics.total_prompt_tokens as f64 * 100.0;
149        }
150        if metrics.total_requests > 0 {
151            metrics.cache_utilization_ratio =
152                metrics.requests_with_cache_hits as f64 / metrics.total_requests as f64 * 100.0;
153            metrics.avg_cached_tokens_per_request =
154                metrics.total_cache_read_tokens as f64 / metrics.total_requests as f64;
155        }
156
157        metrics
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    fn default_config() -> ContextCacheConfig {
166        ContextCacheConfig { min_tokens: 4096, ttl_seconds: 600, cache_intervals: 3 }
167    }
168
169    #[test]
170    fn test_new_manager_has_no_active_cache() {
171        let cm = CacheManager::new(default_config());
172        assert!(cm.active_cache_name.is_none());
173        assert_eq!(cm.invocation_count, 0);
174    }
175
176    #[test]
177    fn test_is_enabled_with_valid_config() {
178        let cm = CacheManager::new(default_config());
179        assert!(cm.is_enabled());
180    }
181
182    #[test]
183    fn test_is_enabled_false_when_min_tokens_zero() {
184        let config = ContextCacheConfig { min_tokens: 0, ttl_seconds: 600, cache_intervals: 3 };
185        let cm = CacheManager::new(config);
186        assert!(!cm.is_enabled());
187    }
188
189    #[test]
190    fn test_is_enabled_false_when_ttl_zero() {
191        let config = ContextCacheConfig { min_tokens: 4096, ttl_seconds: 0, cache_intervals: 3 };
192        let cm = CacheManager::new(config);
193        assert!(!cm.is_enabled());
194    }
195
196    #[test]
197    fn test_is_enabled_false_when_both_zero() {
198        let config = ContextCacheConfig { min_tokens: 0, ttl_seconds: 0, cache_intervals: 3 };
199        let cm = CacheManager::new(config);
200        assert!(!cm.is_enabled());
201    }
202
203    #[test]
204    fn test_needs_refresh_false_initially() {
205        let cm = CacheManager::new(default_config());
206        assert!(!cm.needs_refresh());
207    }
208
209    #[test]
210    fn test_needs_refresh_true_after_n_invocations() {
211        let mut cm = CacheManager::new(default_config());
212        // cache_intervals = 3, so after 3 invocations needs_refresh should be true
213        cm.record_invocation();
214        assert!(!cm.needs_refresh());
215        cm.record_invocation();
216        assert!(!cm.needs_refresh());
217        cm.record_invocation();
218        assert!(cm.needs_refresh());
219    }
220
221    #[test]
222    fn test_record_invocation_returns_none_without_active_cache() {
223        let mut cm = CacheManager::new(default_config());
224        assert!(cm.record_invocation().is_none());
225    }
226
227    #[test]
228    fn test_record_invocation_returns_cache_name() {
229        let mut cm = CacheManager::new(default_config());
230        cm.set_active_cache("cachedContents/abc123".to_string());
231        let name = cm.record_invocation();
232        assert_eq!(name, Some("cachedContents/abc123"));
233    }
234
235    #[test]
236    fn test_set_active_cache_resets_invocation_count() {
237        let mut cm = CacheManager::new(default_config());
238        cm.record_invocation();
239        cm.record_invocation();
240        assert_eq!(cm.invocation_count, 2);
241
242        cm.set_active_cache("cachedContents/new".to_string());
243        assert_eq!(cm.invocation_count, 0);
244        assert_eq!(cm.active_cache_name.as_deref(), Some("cachedContents/new"));
245    }
246
247    #[test]
248    fn test_clear_active_cache_returns_old_name() {
249        let mut cm = CacheManager::new(default_config());
250        cm.set_active_cache("cachedContents/old".to_string());
251        cm.record_invocation();
252
253        let old = cm.clear_active_cache();
254        assert_eq!(old.as_deref(), Some("cachedContents/old"));
255        assert!(cm.active_cache_name.is_none());
256        assert_eq!(cm.invocation_count, 0);
257    }
258
259    #[test]
260    fn test_clear_active_cache_returns_none_when_empty() {
261        let mut cm = CacheManager::new(default_config());
262        let old = cm.clear_active_cache();
263        assert!(old.is_none());
264    }
265
266    #[test]
267    fn test_full_lifecycle() {
268        let mut cm = CacheManager::new(ContextCacheConfig {
269            min_tokens: 1024,
270            ttl_seconds: 300,
271            cache_intervals: 2,
272        });
273
274        assert!(cm.is_enabled());
275        assert!(!cm.needs_refresh());
276
277        // No cache yet
278        assert!(cm.record_invocation().is_none());
279
280        // Set a cache
281        cm.set_active_cache("cachedContents/v1".to_string());
282        assert_eq!(cm.invocation_count, 0);
283
284        // First invocation returns cache name
285        assert_eq!(cm.record_invocation(), Some("cachedContents/v1"));
286        assert!(!cm.needs_refresh());
287
288        // Second invocation triggers refresh
289        assert_eq!(cm.record_invocation(), Some("cachedContents/v1"));
290        assert!(cm.needs_refresh());
291
292        // Refresh: clear old, set new
293        let old = cm.clear_active_cache();
294        assert_eq!(old.as_deref(), Some("cachedContents/v1"));
295        cm.set_active_cache("cachedContents/v2".to_string());
296        assert!(!cm.needs_refresh());
297        assert_eq!(cm.record_invocation(), Some("cachedContents/v2"));
298    }
299
300    // --- CachePerformanceAnalyzer tests ---
301
302    use adk_core::{LlmResponse, UsageMetadata};
303
304    fn event_with_usage(
305        prompt: i32,
306        candidates: i32,
307        cache_read: Option<i32>,
308        cache_creation: Option<i32>,
309    ) -> Event {
310        let mut event = Event::new("test-invocation");
311        event.llm_response = LlmResponse {
312            usage_metadata: Some(UsageMetadata {
313                prompt_token_count: prompt,
314                candidates_token_count: candidates,
315                total_token_count: prompt + candidates,
316                cache_read_input_token_count: cache_read,
317                cache_creation_input_token_count: cache_creation,
318                ..Default::default()
319            }),
320            ..Default::default()
321        };
322        event
323    }
324
325    fn event_without_usage() -> Event {
326        Event::new("test-invocation")
327    }
328
329    #[test]
330    fn test_analyze_empty_events() {
331        let metrics = CachePerformanceAnalyzer::analyze(&[]);
332        assert_eq!(metrics.total_requests, 0);
333        assert_eq!(metrics.requests_with_cache_hits, 0);
334        assert_eq!(metrics.total_prompt_tokens, 0);
335        assert_eq!(metrics.total_cache_read_tokens, 0);
336        assert_eq!(metrics.total_cache_creation_tokens, 0);
337        assert_eq!(metrics.cache_hit_ratio, 0.0);
338        assert_eq!(metrics.cache_utilization_ratio, 0.0);
339        assert_eq!(metrics.avg_cached_tokens_per_request, 0.0);
340    }
341
342    #[test]
343    fn test_analyze_events_without_usage_metadata() {
344        let events = vec![event_without_usage(), event_without_usage()];
345        let metrics = CachePerformanceAnalyzer::analyze(&events);
346        assert_eq!(metrics.total_requests, 0);
347        assert_eq!(metrics.cache_hit_ratio, 0.0);
348    }
349
350    #[test]
351    fn test_analyze_single_event_no_cache() {
352        let events = vec![event_with_usage(1000, 200, None, None)];
353        let metrics = CachePerformanceAnalyzer::analyze(&events);
354        assert_eq!(metrics.total_requests, 1);
355        assert_eq!(metrics.requests_with_cache_hits, 0);
356        assert_eq!(metrics.total_prompt_tokens, 1000);
357        assert_eq!(metrics.total_cache_read_tokens, 0);
358        assert_eq!(metrics.total_cache_creation_tokens, 0);
359        assert_eq!(metrics.cache_hit_ratio, 0.0);
360        assert_eq!(metrics.cache_utilization_ratio, 0.0);
361        assert_eq!(metrics.avg_cached_tokens_per_request, 0.0);
362    }
363
364    #[test]
365    fn test_analyze_single_event_with_cache_hit() {
366        let events = vec![event_with_usage(1000, 200, Some(500), None)];
367        let metrics = CachePerformanceAnalyzer::analyze(&events);
368        assert_eq!(metrics.total_requests, 1);
369        assert_eq!(metrics.requests_with_cache_hits, 1);
370        assert_eq!(metrics.total_prompt_tokens, 1000);
371        assert_eq!(metrics.total_cache_read_tokens, 500);
372        assert_eq!(metrics.cache_hit_ratio, 50.0);
373        assert_eq!(metrics.cache_utilization_ratio, 100.0);
374        assert_eq!(metrics.avg_cached_tokens_per_request, 500.0);
375    }
376
377    #[test]
378    fn test_analyze_mixed_events() {
379        let events = vec![
380            event_with_usage(1000, 200, Some(800), Some(200)),
381            event_with_usage(1000, 300, None, None),
382            event_with_usage(1000, 100, Some(600), None),
383            event_without_usage(), // skipped
384        ];
385        let metrics = CachePerformanceAnalyzer::analyze(&events);
386        assert_eq!(metrics.total_requests, 3);
387        assert_eq!(metrics.requests_with_cache_hits, 2);
388        assert_eq!(metrics.total_prompt_tokens, 3000);
389        assert_eq!(metrics.total_cache_read_tokens, 1400);
390        assert_eq!(metrics.total_cache_creation_tokens, 200);
391        // cache_hit_ratio = 1400 / 3000 * 100 ≈ 46.67
392        assert!((metrics.cache_hit_ratio - 46.666_666_666_666_664).abs() < 1e-10);
393        // cache_utilization_ratio = 2 / 3 * 100 ≈ 66.67
394        assert!((metrics.cache_utilization_ratio - 66.666_666_666_666_66).abs() < 1e-10);
395        // avg_cached_tokens_per_request = 1400 / 3 ≈ 466.67
396        assert!((metrics.avg_cached_tokens_per_request - 466.666_666_666_666_7).abs() < 1e-10);
397    }
398
399    #[test]
400    fn test_analyze_all_cache_hits() {
401        let events = vec![
402            event_with_usage(500, 100, Some(500), None),
403            event_with_usage(500, 100, Some(500), None),
404        ];
405        let metrics = CachePerformanceAnalyzer::analyze(&events);
406        assert_eq!(metrics.total_requests, 2);
407        assert_eq!(metrics.requests_with_cache_hits, 2);
408        assert_eq!(metrics.cache_hit_ratio, 100.0);
409        assert_eq!(metrics.cache_utilization_ratio, 100.0);
410        assert_eq!(metrics.avg_cached_tokens_per_request, 500.0);
411    }
412
413    #[test]
414    fn test_analyze_zero_prompt_tokens() {
415        // Edge case: usage_metadata present but prompt_token_count is 0
416        let events = vec![event_with_usage(0, 100, None, None)];
417        let metrics = CachePerformanceAnalyzer::analyze(&events);
418        assert_eq!(metrics.total_requests, 1);
419        assert_eq!(metrics.total_prompt_tokens, 0);
420        // cache_hit_ratio stays 0.0 (no division by zero)
421        assert_eq!(metrics.cache_hit_ratio, 0.0);
422        assert_eq!(metrics.cache_utilization_ratio, 0.0);
423    }
424
425    #[test]
426    fn test_analyze_cache_creation_only() {
427        let events = vec![event_with_usage(2000, 500, None, Some(1500))];
428        let metrics = CachePerformanceAnalyzer::analyze(&events);
429        assert_eq!(metrics.total_requests, 1);
430        assert_eq!(metrics.requests_with_cache_hits, 0);
431        assert_eq!(metrics.total_cache_creation_tokens, 1500);
432        assert_eq!(metrics.cache_hit_ratio, 0.0);
433        assert_eq!(metrics.cache_utilization_ratio, 0.0);
434    }
435}