1use std::collections::HashMap;
17use std::sync::Arc;
18
19use anyhow::Result;
20use async_trait::async_trait;
21use brainwires_core::message::{ChatResponse, Message, MessageContent, Role, StreamChunk, Usage};
22use brainwires_core::provider::{ChatOptions, Provider};
23use brainwires_core::tool::Tool;
24use futures::stream::BoxStream;
25use serde::{Deserialize, Serialize};
26use sha2::{Digest, Sha256};
27use tokio::sync::Mutex;
28
29#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
31pub struct CacheKey(pub String);
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CachedResponse {
36 pub role: Role,
38 pub text: String,
42 pub usage: Usage,
44 pub finish_reason: Option<String>,
46}
47
48impl CachedResponse {
49 fn from_chat(resp: &ChatResponse) -> Self {
50 let text = match &resp.message.content {
51 MessageContent::Text(t) => t.clone(),
52 MessageContent::Blocks(_) => resp
53 .message
54 .text()
55 .map(|s| s.to_string())
56 .unwrap_or_default(),
57 };
58 Self {
59 role: resp.message.role.clone(),
60 text,
61 usage: resp.usage.clone(),
62 finish_reason: resp.finish_reason.clone(),
63 }
64 }
65
66 fn into_chat(self) -> ChatResponse {
67 let msg = match self.role {
68 Role::Assistant => Message::assistant(self.text.clone()),
69 Role::System => Message::system(self.text.clone()),
70 _ => Message::user(self.text.clone()),
71 };
72 ChatResponse {
73 message: msg,
74 usage: self.usage,
75 finish_reason: self.finish_reason,
76 }
77 }
78}
79
80#[async_trait]
82pub trait CacheBackend: Send + Sync {
83 async fn get(&self, key: &CacheKey) -> Result<Option<CachedResponse>>;
85 async fn put(&self, key: &CacheKey, resp: CachedResponse) -> Result<()>;
87}
88
89#[derive(Clone, Default)]
91pub struct MemoryCache {
92 inner: Arc<Mutex<HashMap<CacheKey, CachedResponse>>>,
93}
94
95impl MemoryCache {
96 pub fn new() -> Self {
98 Self::default()
99 }
100
101 pub async fn len(&self) -> usize {
103 self.inner.lock().await.len()
104 }
105
106 pub async fn is_empty(&self) -> bool {
108 self.len().await == 0
109 }
110}
111
112#[async_trait]
113impl CacheBackend for MemoryCache {
114 async fn get(&self, key: &CacheKey) -> Result<Option<CachedResponse>> {
115 Ok(self.inner.lock().await.get(key).cloned())
116 }
117 async fn put(&self, key: &CacheKey, resp: CachedResponse) -> Result<()> {
118 self.inner.lock().await.insert(key.clone(), resp);
119 Ok(())
120 }
121}
122
123pub fn cache_key_for(
128 messages: &[Message],
129 tools: Option<&[Tool]>,
130 options: &ChatOptions,
131) -> CacheKey {
132 let mut hasher = Sha256::new();
133 let msgs = serde_json::to_vec(messages).unwrap_or_default();
135 hasher.update(&msgs);
136
137 if let Some(ts) = tools {
138 let mut names: Vec<&str> = ts.iter().map(|t| t.name.as_str()).collect();
139 names.sort();
140 for n in names {
141 hasher.update(b"\x00tool:");
142 hasher.update(n.as_bytes());
143 }
144 }
145
146 let opts = serde_json::to_vec(options).unwrap_or_default();
147 hasher.update(b"\x00opts:");
148 hasher.update(&opts);
149
150 let digest = hasher.finalize();
151 CacheKey(hex_encode(&digest))
152}
153
154fn hex_encode(bytes: &[u8]) -> String {
155 const HEX: &[u8; 16] = b"0123456789abcdef";
156 let mut out = String::with_capacity(bytes.len() * 2);
157 for &b in bytes {
158 out.push(HEX[(b >> 4) as usize] as char);
159 out.push(HEX[(b & 0x0f) as usize] as char);
160 }
161 out
162}
163
164pub struct CachedProvider<P: Provider + ?Sized> {
166 inner: Arc<P>,
167 backend: Arc<dyn CacheBackend>,
168}
169
170impl<P: Provider + ?Sized> CachedProvider<P> {
171 pub fn new(inner: Arc<P>, backend: Arc<dyn CacheBackend>) -> Self {
173 Self { inner, backend }
174 }
175
176 pub fn with_memory_cache(inner: Arc<P>) -> (Self, MemoryCache) {
178 let cache = MemoryCache::new();
179 let me = Self::new(inner, Arc::new(cache.clone()));
180 (me, cache)
181 }
182}
183
184#[async_trait]
185impl<P: Provider + ?Sized + 'static> Provider for CachedProvider<P> {
186 fn name(&self) -> &str {
187 self.inner.name()
188 }
189
190 fn max_output_tokens(&self) -> Option<u32> {
191 self.inner.max_output_tokens()
192 }
193
194 async fn chat(
195 &self,
196 messages: &[Message],
197 tools: Option<&[Tool]>,
198 options: &ChatOptions,
199 ) -> Result<ChatResponse> {
200 let key = cache_key_for(messages, tools, options);
201 if let Some(cached) = self.backend.get(&key).await? {
202 tracing::debug!(provider = self.inner.name(), key = %key.0, "cache hit");
203 return Ok(cached.into_chat());
204 }
205 let resp = self.inner.chat(messages, tools, options).await?;
206 self.backend
207 .put(&key, CachedResponse::from_chat(&resp))
208 .await
209 .ok(); Ok(resp)
211 }
212
213 fn stream_chat<'a>(
214 &'a self,
215 messages: &'a [Message],
216 tools: Option<&'a [Tool]>,
217 options: &'a ChatOptions,
218 ) -> BoxStream<'a, Result<StreamChunk>> {
219 self.inner.stream_chat(messages, tools, options)
223 }
224}
225
226#[cfg(feature = "cache")]
227mod sqlite_backend {
228 use super::{CacheBackend, CacheKey, CachedResponse};
229 use anyhow::{Context, Result};
230 use async_trait::async_trait;
231 use rusqlite::{Connection, OptionalExtension, params};
232 use std::path::{Path, PathBuf};
233 use std::sync::Arc;
234 use tokio::sync::Mutex;
235
236 pub struct SqliteCache {
238 conn: Arc<Mutex<Connection>>,
239 path: PathBuf,
240 }
241
242 impl SqliteCache {
243 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
245 let path = path.as_ref().to_path_buf();
246 let conn = Connection::open(&path)
247 .with_context(|| format!("opening cache at {}", path.display()))?;
248 conn.execute_batch(
249 "CREATE TABLE IF NOT EXISTS responses (
250 key TEXT PRIMARY KEY,
251 payload TEXT NOT NULL
252 );",
253 )?;
254 Ok(Self {
255 conn: Arc::new(Mutex::new(conn)),
256 path,
257 })
258 }
259
260 pub fn path(&self) -> &Path {
262 &self.path
263 }
264 }
265
266 #[async_trait]
267 impl CacheBackend for SqliteCache {
268 async fn get(&self, key: &CacheKey) -> Result<Option<CachedResponse>> {
269 let conn = self.conn.lock().await;
270 let raw: Option<String> = conn
271 .query_row(
272 "SELECT payload FROM responses WHERE key = ?1",
273 params![&key.0],
274 |r| r.get(0),
275 )
276 .optional()?;
277 Ok(match raw {
278 Some(s) => Some(serde_json::from_str(&s)?),
279 None => None,
280 })
281 }
282 async fn put(&self, key: &CacheKey, resp: CachedResponse) -> Result<()> {
283 let payload = serde_json::to_string(&resp)?;
284 let conn = self.conn.lock().await;
285 conn.execute(
286 "INSERT INTO responses (key, payload) VALUES (?1, ?2)
287 ON CONFLICT(key) DO UPDATE SET payload = excluded.payload",
288 params![&key.0, payload],
289 )?;
290 Ok(())
291 }
292 }
293}
294
295#[cfg(feature = "cache")]
296pub use sqlite_backend::SqliteCache;
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use crate::tests_util::EchoProvider;
302
303 #[tokio::test]
304 async fn miss_populates_cache_then_hits_match() {
305 let inner = Arc::new(EchoProvider::ok("p"));
306 let (cached, mem) = CachedProvider::with_memory_cache(inner.clone());
307
308 let msgs = vec![Message::user("hello")];
309 let opts = ChatOptions::default();
310
311 let r1 = cached.chat(&msgs, None, &opts).await.unwrap();
312 assert_eq!(inner.calls(), 1);
313 assert_eq!(mem.len().await, 1);
314
315 let r2 = cached.chat(&msgs, None, &opts).await.unwrap();
316 assert_eq!(
317 inner.calls(),
318 1,
319 "cache hit must not call the inner provider again"
320 );
321 assert_eq!(r1.message.text(), r2.message.text());
322 }
323
324 #[tokio::test]
325 async fn different_messages_miss() {
326 let inner = Arc::new(EchoProvider::ok("p"));
327 let (cached, _mem) = CachedProvider::with_memory_cache(inner.clone());
328 let opts = ChatOptions::default();
329
330 cached
331 .chat(&[Message::user("a")], None, &opts)
332 .await
333 .unwrap();
334 cached
335 .chat(&[Message::user("b")], None, &opts)
336 .await
337 .unwrap();
338 assert_eq!(inner.calls(), 2);
339 }
340
341 #[test]
342 fn key_stable_across_reorderings() {
343 let opts = ChatOptions::default();
344 let msgs = vec![Message::user("x")];
345 let tool_a = Tool {
346 name: "alpha".into(),
347 ..Default::default()
348 };
349 let tool_b = Tool {
350 name: "beta".into(),
351 ..Default::default()
352 };
353
354 let k1 = cache_key_for(&msgs, Some(&[tool_a.clone(), tool_b.clone()]), &opts);
355 let k2 = cache_key_for(&msgs, Some(&[tool_b, tool_a]), &opts);
356 assert_eq!(k1, k2, "tool order must not affect the key");
357 }
358
359 #[cfg(feature = "cache")]
360 #[tokio::test]
361 async fn sqlite_cache_persists() {
362 let tmp = tempfile::tempdir().unwrap();
363 let path = tmp.path().join("cache.db");
364
365 let inner = Arc::new(EchoProvider::ok("p"));
366 {
367 let backend = Arc::new(SqliteCache::open(&path).unwrap()) as Arc<dyn CacheBackend>;
368 let cached = CachedProvider::new(inner.clone(), backend);
369 cached
370 .chat(&[Message::user("persist")], None, &ChatOptions::default())
371 .await
372 .unwrap();
373 }
374 let inner2 = Arc::new(EchoProvider::ok("p"));
376 let backend = Arc::new(SqliteCache::open(&path).unwrap()) as Arc<dyn CacheBackend>;
377 let cached = CachedProvider::new(inner2.clone(), backend);
378 let r = cached
379 .chat(&[Message::user("persist")], None, &ChatOptions::default())
380 .await
381 .unwrap();
382 assert_eq!(r.message.text(), Some("ok"));
383 assert_eq!(
384 inner2.calls(),
385 0,
386 "cached response must come from the sqlite store"
387 );
388 }
389}