Skip to main content

agent_code_lib/services/
cache_tracking.rs

1//! Prompt cache tracking and break detection.
2//!
3//! Monitors cache hit/miss patterns across API calls to identify
4//! when the prompt cache is breaking and why. Tracks cache creation
5//! vs read tokens to compute effective cache utilization.
6
7use crate::llm::message::Usage;
8
9/// Tracks cache performance across multiple API calls.
10#[derive(Debug, Default)]
11pub struct CacheTracker {
12    /// Total cache creation tokens (cache misses that create new entries).
13    pub total_cache_writes: u64,
14    /// Total cache read tokens (cache hits).
15    pub total_cache_reads: u64,
16    /// Number of API calls observed.
17    pub call_count: u64,
18    /// Number of calls that had any cache reads (hits).
19    pub hit_count: u64,
20    /// Number of calls where cache writes exceeded reads (likely break).
21    pub break_count: u64,
22    /// Last observed cache state.
23    last_write: u64,
24    last_read: u64,
25    /// Fingerprint of the last request prefix (system prompt hash + tool count).
26    /// Used to detect when prompt changes cause cache invalidation.
27    last_fingerprint: u64,
28}
29
30impl CacheTracker {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    /// Update the fingerprint of the cacheable request prefix.
36    /// Call before each API request. Returns true if the fingerprint changed
37    /// (indicating the cache will likely break).
38    pub fn update_fingerprint(&mut self, system_prompt: &str, tool_count: usize) -> bool {
39        let mut hasher = std::hash::DefaultHasher::new();
40        std::hash::Hash::hash(&system_prompt.len(), &mut hasher);
41        // Hash first and last 200 chars for speed (full hash unnecessary).
42        let prefix = &system_prompt[..system_prompt.len().min(200)];
43        std::hash::Hash::hash(prefix, &mut hasher);
44        std::hash::Hash::hash(&tool_count, &mut hasher);
45        let fp = std::hash::Hasher::finish(&hasher);
46
47        let changed = self.last_fingerprint != 0 && self.last_fingerprint != fp;
48        self.last_fingerprint = fp;
49        changed
50    }
51
52    /// Record usage from an API call and detect cache breaks.
53    pub fn record(&mut self, usage: &Usage) -> CacheEvent {
54        self.call_count += 1;
55        self.total_cache_writes += usage.cache_creation_input_tokens;
56        self.total_cache_reads += usage.cache_read_input_tokens;
57
58        let had_reads = usage.cache_read_input_tokens > 0;
59        let had_writes = usage.cache_creation_input_tokens > 0;
60
61        if had_reads {
62            self.hit_count += 1;
63        }
64
65        let event = if !had_reads && had_writes && self.call_count > 1 {
66            // Cache miss on a non-first call — likely a break.
67            self.break_count += 1;
68            CacheEvent::Break {
69                write_tokens: usage.cache_creation_input_tokens,
70                reason: if self.last_read > 0 {
71                    "Cache invalidated since last call".to_string()
72                } else {
73                    "No cache hits — content may have changed".to_string()
74                },
75            }
76        } else if had_reads && !had_writes {
77            // Pure cache hit — ideal.
78            CacheEvent::Hit {
79                read_tokens: usage.cache_read_input_tokens,
80            }
81        } else if had_reads && had_writes {
82            // Partial hit — some content cached, some new.
83            CacheEvent::Partial {
84                read_tokens: usage.cache_read_input_tokens,
85                write_tokens: usage.cache_creation_input_tokens,
86            }
87        } else {
88            // First call or no caching configured.
89            CacheEvent::Miss
90        };
91
92        self.last_write = usage.cache_creation_input_tokens;
93        self.last_read = usage.cache_read_input_tokens;
94
95        event
96    }
97
98    /// Cache hit rate as a percentage (0-100).
99    pub fn hit_rate(&self) -> f64 {
100        if self.call_count == 0 {
101            return 0.0;
102        }
103        (self.hit_count as f64 / self.call_count as f64) * 100.0
104    }
105
106    /// Estimated cost savings from cache hits.
107    /// Cache reads are ~10% the cost of cache writes.
108    pub fn estimated_savings(&self) -> f64 {
109        // Savings = (cache_reads * 0.9 * cost_per_token)
110        // Approximate: saved tokens * 90% discount
111        self.total_cache_reads as f64 * 0.9
112    }
113}
114
115/// Event produced by cache tracking for each API call.
116#[derive(Debug)]
117pub enum CacheEvent {
118    /// Full cache hit — all cached content was reused.
119    Hit { read_tokens: u64 },
120    /// Cache break — previously cached content was not reused.
121    Break { write_tokens: u64, reason: String },
122    /// Partial hit — some cached, some new.
123    Partial { read_tokens: u64, write_tokens: u64 },
124    /// No cache interaction (first call or caching disabled).
125    Miss,
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn test_new_tracker() {
134        let t = CacheTracker::new();
135        assert_eq!(t.call_count, 0);
136        assert_eq!(t.hit_rate(), 0.0);
137    }
138
139    #[test]
140    fn test_first_call_miss() {
141        let mut t = CacheTracker::new();
142        let event = t.record(&Usage {
143            cache_creation_input_tokens: 1000,
144            ..Default::default()
145        });
146        assert!(matches!(event, CacheEvent::Miss));
147        assert_eq!(t.call_count, 1);
148    }
149
150    #[test]
151    fn test_cache_hit() {
152        let mut t = CacheTracker::new();
153        // First call — miss.
154        t.record(&Usage {
155            cache_creation_input_tokens: 1000,
156            ..Default::default()
157        });
158        // Second call — hit.
159        let event = t.record(&Usage {
160            cache_read_input_tokens: 900,
161            ..Default::default()
162        });
163        assert!(matches!(event, CacheEvent::Hit { .. }));
164        assert_eq!(t.hit_count, 1);
165    }
166
167    #[test]
168    fn test_cache_break() {
169        let mut t = CacheTracker::new();
170        // First call with reads.
171        t.record(&Usage {
172            cache_read_input_tokens: 500,
173            ..Default::default()
174        });
175        // Second call — no reads, only writes = break.
176        let event = t.record(&Usage {
177            cache_creation_input_tokens: 1000,
178            ..Default::default()
179        });
180        assert!(matches!(event, CacheEvent::Break { .. }));
181        assert_eq!(t.break_count, 1);
182    }
183
184    #[test]
185    fn test_partial_hit() {
186        let mut t = CacheTracker::new();
187        t.record(&Usage::default()); // First call.
188        let event = t.record(&Usage {
189            cache_read_input_tokens: 500,
190            cache_creation_input_tokens: 200,
191            ..Default::default()
192        });
193        assert!(matches!(event, CacheEvent::Partial { .. }));
194    }
195
196    #[test]
197    fn test_hit_rate() {
198        let mut t = CacheTracker::new();
199        t.record(&Usage::default()); // Miss.
200        t.record(&Usage {
201            cache_read_input_tokens: 100,
202            ..Default::default()
203        }); // Hit.
204        assert!((t.hit_rate() - 50.0).abs() < 0.01);
205    }
206
207    #[test]
208    fn test_fingerprint_change_detection() {
209        let mut t = CacheTracker::new();
210        let changed = t.update_fingerprint("system prompt v1", 10);
211        assert!(!changed); // First call, no previous fingerprint.
212
213        let changed = t.update_fingerprint("system prompt v1", 10);
214        assert!(!changed); // Same fingerprint.
215
216        let changed = t.update_fingerprint("system prompt v2", 10);
217        assert!(changed); // Different prompt.
218    }
219
220    #[test]
221    fn test_fingerprint_tool_count_change() {
222        let mut t = CacheTracker::new();
223        t.update_fingerprint("prompt", 10);
224        let changed = t.update_fingerprint("prompt", 15);
225        assert!(changed); // Different tool count.
226    }
227}