1use parking_lot::Mutex;
4
5use crate::llm::types::{CompletionResponse, Message};
6use crate::util::fnv1a_hash;
7
8pub struct ResponseCache {
21 entries: Mutex<Vec<(u64, CompletionResponse)>>,
22 capacity: usize,
23}
24
25impl std::fmt::Debug for ResponseCache {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.debug_struct("ResponseCache")
28 .field("capacity", &self.capacity)
29 .field("len", &self.entries.lock().len())
30 .finish()
31 }
32}
33
34impl ResponseCache {
35 pub fn new(capacity: usize) -> Self {
37 Self {
38 entries: Mutex::new(Vec::with_capacity(capacity)),
39 capacity,
40 }
41 }
42
43 pub fn get(&self, key: u64) -> Option<CompletionResponse> {
45 let mut entries = self.entries.lock();
46 if let Some(pos) = entries.iter().position(|(k, _)| *k == key) {
47 let entry = entries.remove(pos);
48 let response = entry.1.clone();
49 entries.insert(0, entry);
50 Some(response)
51 } else {
52 None
53 }
54 }
55
56 pub fn put(&self, key: u64, response: CompletionResponse) {
59 let mut entries = self.entries.lock();
60 if let Some(pos) = entries.iter().position(|(k, _)| *k == key) {
62 entries.remove(pos);
63 }
64 if entries.len() >= self.capacity {
66 entries.pop();
67 }
68 entries.insert(0, (key, response));
69 }
70
71 pub fn compute_key(system_prompt: &str, messages: &[Message], tool_names: &[&str]) -> u64 {
79 Self::compute_key_scoped(system_prompt, messages, tool_names, None)
80 }
81
82 pub fn compute_key_scoped(
91 system_prompt: &str,
92 messages: &[Message],
93 tool_names: &[&str],
94 namespace: Option<&str>,
95 ) -> u64 {
96 let mut sorted_names: Vec<&str> = tool_names.to_vec();
97 sorted_names.sort();
98
99 let messages_json = serde_json::to_string(messages).expect("messages serialize infallibly");
100
101 let mut data = Vec::new();
102 if let Some(ns) = namespace {
103 data.extend_from_slice(b"ns=");
104 data.extend_from_slice(ns.as_bytes());
105 data.push(0);
106 }
107 data.extend_from_slice(system_prompt.as_bytes());
108 data.push(0); data.extend_from_slice(messages_json.as_bytes());
110 data.push(0); for name in &sorted_names {
112 data.extend_from_slice(name.as_bytes());
113 data.push(0);
114 }
115
116 fnv1a_hash(&data)
117 }
118
119 pub fn clear(&self) {
121 self.entries.lock().clear();
122 }
123
124 pub fn len(&self) -> usize {
126 self.entries.lock().len()
127 }
128
129 pub fn is_empty(&self) -> bool {
131 self.entries.lock().is_empty()
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use crate::llm::types::{ContentBlock, Message, StopReason, TokenUsage};
139
140 fn make_response(text: &str) -> CompletionResponse {
141 CompletionResponse {
142 content: vec![ContentBlock::Text { text: text.into() }],
143 stop_reason: StopReason::EndTurn,
144 usage: TokenUsage::default(),
145 model: None,
146 }
147 }
148
149 #[test]
150 fn cache_stores_and_retrieves() {
151 let cache = ResponseCache::new(10);
152 let response = make_response("hello");
153 let key = 42;
154 cache.put(key, response.clone());
155
156 let cached = cache.get(key);
157 assert!(cached.is_some());
158 assert_eq!(cached.unwrap().text(), "hello");
159 }
160
161 #[test]
162 fn cache_miss_returns_none() {
163 let cache = ResponseCache::new(10);
164 assert!(cache.get(999).is_none());
165 }
166
167 #[test]
168 fn lru_eviction() {
169 let cache = ResponseCache::new(3);
170 cache.put(1, make_response("one"));
171 cache.put(2, make_response("two"));
172 cache.put(3, make_response("three"));
173
174 cache.put(4, make_response("four"));
176 assert_eq!(cache.len(), 3);
177 assert!(cache.get(1).is_none());
178 assert!(cache.get(2).is_some());
179 assert!(cache.get(3).is_some());
180 assert!(cache.get(4).is_some());
181 }
182
183 #[test]
184 fn lru_access_refreshes_order() {
185 let cache = ResponseCache::new(3);
186 cache.put(1, make_response("one"));
187 cache.put(2, make_response("two"));
188 cache.put(3, make_response("three"));
189
190 let _ = cache.get(1);
192
193 cache.put(4, make_response("four"));
195 assert!(cache.get(1).is_some());
196 assert!(cache.get(2).is_none());
197 assert!(cache.get(3).is_some());
198 assert!(cache.get(4).is_some());
199 }
200
201 #[test]
202 fn compute_key_deterministic() {
203 let msgs = vec![Message::user("hello")];
204 let tools = vec!["search", "read"];
205
206 let key1 = ResponseCache::compute_key("system", &msgs, &tools);
207 let key2 = ResponseCache::compute_key("system", &msgs, &tools);
208 assert_eq!(key1, key2);
209 }
210
211 #[test]
212 fn compute_key_different_for_different_inputs() {
213 let msgs_a = vec![Message::user("hello")];
214 let msgs_b = vec![Message::user("world")];
215 let tools = vec!["search"];
216
217 let key_a = ResponseCache::compute_key("system", &msgs_a, &tools);
218 let key_b = ResponseCache::compute_key("system", &msgs_b, &tools);
219 assert_ne!(key_a, key_b);
220
221 let key_c = ResponseCache::compute_key("other", &msgs_a, &tools);
223 assert_ne!(key_a, key_c);
224
225 let key_d = ResponseCache::compute_key("system", &msgs_a, &["write"]);
227 assert_ne!(key_a, key_d);
228 }
229
230 #[test]
231 fn compute_key_tool_order_independent() {
232 let msgs = vec![Message::user("hi")];
233 let key1 = ResponseCache::compute_key("sys", &msgs, &["a", "b", "c"]);
234 let key2 = ResponseCache::compute_key("sys", &msgs, &["c", "a", "b"]);
235 assert_eq!(key1, key2);
236 }
237
238 #[test]
239 fn clear_empties_cache() {
240 let cache = ResponseCache::new(10);
241 cache.put(1, make_response("one"));
242 cache.put(2, make_response("two"));
243 assert_eq!(cache.len(), 2);
244
245 cache.clear();
246 assert!(cache.is_empty());
247 assert_eq!(cache.len(), 0);
248 }
249
250 #[test]
251 fn put_overwrites_existing_key() {
252 let cache = ResponseCache::new(10);
253 cache.put(1, make_response("first"));
254 cache.put(1, make_response("second"));
255
256 assert_eq!(cache.len(), 1);
257 let cached = cache.get(1).unwrap();
258 assert_eq!(cached.text(), "second");
259 }
260
261 #[test]
262 fn thread_safety() {
263 use std::sync::Arc;
264
265 let cache = Arc::new(ResponseCache::new(100));
266 let mut handles = vec![];
267
268 for i in 0..10 {
269 let cache = cache.clone();
270 handles.push(std::thread::spawn(move || {
271 for j in 0..100 {
272 let key = (i * 100 + j) as u64;
273 cache.put(key, make_response(&format!("resp-{key}")));
274 let _ = cache.get(key);
275 }
276 }));
277 }
278
279 for h in handles {
280 h.join().expect("thread panicked");
281 }
282
283 assert!(cache.len() <= 100);
284 }
285
286 #[test]
291 fn compute_key_scoped_differs_per_tenant() {
292 let msgs = vec![Message::user("hello")];
293 let key_a = ResponseCache::compute_key_scoped("sys", &msgs, &["a"], Some("tenant-a:user1"));
294 let key_b = ResponseCache::compute_key_scoped("sys", &msgs, &["a"], Some("tenant-b:user1"));
295 let key_unscoped = ResponseCache::compute_key("sys", &msgs, &["a"]);
296 assert_ne!(
297 key_a, key_b,
298 "different tenants must produce different keys"
299 );
300 assert_ne!(
301 key_a, key_unscoped,
302 "scoped key must differ from unscoped key"
303 );
304 }
305
306 #[test]
309 fn compute_key_scoped_stable_for_same_tenant() {
310 let msgs = vec![Message::user("hello")];
311 let key1 = ResponseCache::compute_key_scoped("sys", &msgs, &["a"], Some("tenant-a:user1"));
312 let key2 = ResponseCache::compute_key_scoped("sys", &msgs, &["a"], Some("tenant-a:user1"));
313 assert_eq!(key1, key2);
314 }
315}