Skip to main content

brainwires_call_policy/
cache.rs

1//! Response caching decorator.
2//!
3//! Wraps a [`Provider`] in a content-addressed cache so deterministic eval
4//! runs are byte-reproducible and local development stops burning real
5//! tokens. The cache key is `SHA-256(messages_json || tools_json ||
6//! options_json)` — any change to inputs produces a miss.
7//!
8//! Only the non-streaming [`Provider::chat`] path is cached. Streaming
9//! passes through unchanged (reconstructing a replayable stream from a
10//! recorded response is out of scope for this decorator).
11//!
12//! The in-memory [`MemoryCache`] is the default backend. A SQLite-backed
13//! `SqliteCache` lives behind the `cache` feature flag for runs that need
14//! persistence across process restarts.
15
16use std::collections::HashMap;
17use std::sync::Arc;
18
19use anyhow::Result;
20use async_trait::async_trait;
21use brainwires_core::message::{ChatResponse, Message, MessageContent, Role, StreamChunk, Usage};
22use brainwires_core::provider::{ChatOptions, Provider};
23use brainwires_core::tool::Tool;
24use futures::stream::BoxStream;
25use serde::{Deserialize, Serialize};
26use sha2::{Digest, Sha256};
27use tokio::sync::Mutex;
28
29/// Key used to address a cached response.
30#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
31pub struct CacheKey(pub String);
32
33/// Wire representation of a cached response, suitable for JSON storage.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CachedResponse {
36    /// Serialised [`Role`] of the produced message.
37    pub role: Role,
38    /// Message payload as plain text (blocks are rendered to a single string
39    /// since cache hits are only useful for deterministic evals where tool
40    /// calls produce the same output each time anyway).
41    pub text: String,
42    /// Usage counters at record time.
43    pub usage: Usage,
44    /// Original `finish_reason`, when the provider supplied one.
45    pub finish_reason: Option<String>,
46}
47
48impl CachedResponse {
49    fn from_chat(resp: &ChatResponse) -> Self {
50        let text = match &resp.message.content {
51            MessageContent::Text(t) => t.clone(),
52            MessageContent::Blocks(_) => resp
53                .message
54                .text()
55                .map(|s| s.to_string())
56                .unwrap_or_default(),
57        };
58        Self {
59            role: resp.message.role.clone(),
60            text,
61            usage: resp.usage.clone(),
62            finish_reason: resp.finish_reason.clone(),
63        }
64    }
65
66    fn into_chat(self) -> ChatResponse {
67        let msg = match self.role {
68            Role::Assistant => Message::assistant(self.text.clone()),
69            Role::System => Message::system(self.text.clone()),
70            _ => Message::user(self.text.clone()),
71        };
72        ChatResponse {
73            message: msg,
74            usage: self.usage,
75            finish_reason: self.finish_reason,
76        }
77    }
78}
79
80/// Pluggable storage backend behind [`CachedProvider`].
81#[async_trait]
82pub trait CacheBackend: Send + Sync {
83    /// Return the cached response for `key`, if any.
84    async fn get(&self, key: &CacheKey) -> Result<Option<CachedResponse>>;
85    /// Persist `resp` under `key`. Overwrites any previous value.
86    async fn put(&self, key: &CacheKey, resp: CachedResponse) -> Result<()>;
87}
88
89/// In-memory cache — the default backend. Cheap to `Arc`-clone.
90#[derive(Clone, Default)]
91pub struct MemoryCache {
92    inner: Arc<Mutex<HashMap<CacheKey, CachedResponse>>>,
93}
94
95impl MemoryCache {
96    /// Create an empty cache.
97    pub fn new() -> Self {
98        Self::default()
99    }
100
101    /// Number of cached responses.
102    pub async fn len(&self) -> usize {
103        self.inner.lock().await.len()
104    }
105
106    /// `true` if no responses are cached yet.
107    pub async fn is_empty(&self) -> bool {
108        self.len().await == 0
109    }
110}
111
112#[async_trait]
113impl CacheBackend for MemoryCache {
114    async fn get(&self, key: &CacheKey) -> Result<Option<CachedResponse>> {
115        Ok(self.inner.lock().await.get(key).cloned())
116    }
117    async fn put(&self, key: &CacheKey, resp: CachedResponse) -> Result<()> {
118        self.inner.lock().await.insert(key.clone(), resp);
119        Ok(())
120    }
121}
122
123/// Compute a stable cache key from the inputs to a `chat()` call.
124///
125/// Tools are name-sorted before hashing so reordering them doesn't break
126/// cache hits. Options and messages are serialised verbatim.
127pub fn cache_key_for(
128    messages: &[Message],
129    tools: Option<&[Tool]>,
130    options: &ChatOptions,
131) -> CacheKey {
132    let mut hasher = Sha256::new();
133    // Serialising with serde_json gives us a canonical representation.
134    let msgs = serde_json::to_vec(messages).unwrap_or_default();
135    hasher.update(&msgs);
136
137    if let Some(ts) = tools {
138        let mut names: Vec<&str> = ts.iter().map(|t| t.name.as_str()).collect();
139        names.sort();
140        for n in names {
141            hasher.update(b"\x00tool:");
142            hasher.update(n.as_bytes());
143        }
144    }
145
146    let opts = serde_json::to_vec(options).unwrap_or_default();
147    hasher.update(b"\x00opts:");
148    hasher.update(&opts);
149
150    let digest = hasher.finalize();
151    CacheKey(hex_encode(&digest))
152}
153
154fn hex_encode(bytes: &[u8]) -> String {
155    const HEX: &[u8; 16] = b"0123456789abcdef";
156    let mut out = String::with_capacity(bytes.len() * 2);
157    for &b in bytes {
158        out.push(HEX[(b >> 4) as usize] as char);
159        out.push(HEX[(b & 0x0f) as usize] as char);
160    }
161    out
162}
163
164/// A [`Provider`] decorator that deduplicates identical `chat()` calls.
165pub struct CachedProvider<P: Provider + ?Sized> {
166    inner: Arc<P>,
167    backend: Arc<dyn CacheBackend>,
168}
169
170impl<P: Provider + ?Sized> CachedProvider<P> {
171    /// Wrap `inner` with the given cache backend.
172    pub fn new(inner: Arc<P>, backend: Arc<dyn CacheBackend>) -> Self {
173        Self { inner, backend }
174    }
175
176    /// Convenience constructor using the in-memory backend.
177    pub fn with_memory_cache(inner: Arc<P>) -> (Self, MemoryCache) {
178        let cache = MemoryCache::new();
179        let me = Self::new(inner, Arc::new(cache.clone()));
180        (me, cache)
181    }
182}
183
184#[async_trait]
185impl<P: Provider + ?Sized + 'static> Provider for CachedProvider<P> {
186    fn name(&self) -> &str {
187        self.inner.name()
188    }
189
190    fn max_output_tokens(&self) -> Option<u32> {
191        self.inner.max_output_tokens()
192    }
193
194    async fn chat(
195        &self,
196        messages: &[Message],
197        tools: Option<&[Tool]>,
198        options: &ChatOptions,
199    ) -> Result<ChatResponse> {
200        let key = cache_key_for(messages, tools, options);
201        if let Some(cached) = self.backend.get(&key).await? {
202            tracing::debug!(provider = self.inner.name(), key = %key.0, "cache hit");
203            return Ok(cached.into_chat());
204        }
205        let resp = self.inner.chat(messages, tools, options).await?;
206        self.backend
207            .put(&key, CachedResponse::from_chat(&resp))
208            .await
209            .ok(); // caching failures are non-fatal
210        Ok(resp)
211    }
212
213    fn stream_chat<'a>(
214        &'a self,
215        messages: &'a [Message],
216        tools: Option<&'a [Tool]>,
217        options: &'a ChatOptions,
218    ) -> BoxStream<'a, Result<StreamChunk>> {
219        // Streaming bypasses the cache — reconstructing a replayable event
220        // stream from a single recorded response would fabricate data a
221        // caller can't distinguish from real model output.
222        self.inner.stream_chat(messages, tools, options)
223    }
224}
225
226#[cfg(feature = "cache")]
227mod sqlite_backend {
228    use super::{CacheBackend, CacheKey, CachedResponse};
229    use anyhow::{Context, Result};
230    use async_trait::async_trait;
231    use rusqlite::{Connection, OptionalExtension, params};
232    use std::path::{Path, PathBuf};
233    use std::sync::Arc;
234    use tokio::sync::Mutex;
235
236    /// Disk-backed cache. Uses a single shared connection guarded by a mutex.
237    pub struct SqliteCache {
238        conn: Arc<Mutex<Connection>>,
239        path: PathBuf,
240    }
241
242    impl SqliteCache {
243        /// Open (or create) the cache at `path`.
244        pub fn open(path: impl AsRef<Path>) -> Result<Self> {
245            let path = path.as_ref().to_path_buf();
246            let conn = Connection::open(&path)
247                .with_context(|| format!("opening cache at {}", path.display()))?;
248            conn.execute_batch(
249                "CREATE TABLE IF NOT EXISTS responses (
250                    key TEXT PRIMARY KEY,
251                    payload TEXT NOT NULL
252                );",
253            )?;
254            Ok(Self {
255                conn: Arc::new(Mutex::new(conn)),
256                path,
257            })
258        }
259
260        /// Path this cache writes to.
261        pub fn path(&self) -> &Path {
262            &self.path
263        }
264    }
265
266    #[async_trait]
267    impl CacheBackend for SqliteCache {
268        async fn get(&self, key: &CacheKey) -> Result<Option<CachedResponse>> {
269            let conn = self.conn.lock().await;
270            let raw: Option<String> = conn
271                .query_row(
272                    "SELECT payload FROM responses WHERE key = ?1",
273                    params![&key.0],
274                    |r| r.get(0),
275                )
276                .optional()?;
277            Ok(match raw {
278                Some(s) => Some(serde_json::from_str(&s)?),
279                None => None,
280            })
281        }
282        async fn put(&self, key: &CacheKey, resp: CachedResponse) -> Result<()> {
283            let payload = serde_json::to_string(&resp)?;
284            let conn = self.conn.lock().await;
285            conn.execute(
286                "INSERT INTO responses (key, payload) VALUES (?1, ?2)
287                 ON CONFLICT(key) DO UPDATE SET payload = excluded.payload",
288                params![&key.0, payload],
289            )?;
290            Ok(())
291        }
292    }
293}
294
295#[cfg(feature = "cache")]
296pub use sqlite_backend::SqliteCache;
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use crate::tests_util::EchoProvider;
302
303    #[tokio::test]
304    async fn miss_populates_cache_then_hits_match() {
305        let inner = Arc::new(EchoProvider::ok("p"));
306        let (cached, mem) = CachedProvider::with_memory_cache(inner.clone());
307
308        let msgs = vec![Message::user("hello")];
309        let opts = ChatOptions::default();
310
311        let r1 = cached.chat(&msgs, None, &opts).await.unwrap();
312        assert_eq!(inner.calls(), 1);
313        assert_eq!(mem.len().await, 1);
314
315        let r2 = cached.chat(&msgs, None, &opts).await.unwrap();
316        assert_eq!(
317            inner.calls(),
318            1,
319            "cache hit must not call the inner provider again"
320        );
321        assert_eq!(r1.message.text(), r2.message.text());
322    }
323
324    #[tokio::test]
325    async fn different_messages_miss() {
326        let inner = Arc::new(EchoProvider::ok("p"));
327        let (cached, _mem) = CachedProvider::with_memory_cache(inner.clone());
328        let opts = ChatOptions::default();
329
330        cached
331            .chat(&[Message::user("a")], None, &opts)
332            .await
333            .unwrap();
334        cached
335            .chat(&[Message::user("b")], None, &opts)
336            .await
337            .unwrap();
338        assert_eq!(inner.calls(), 2);
339    }
340
341    #[test]
342    fn key_stable_across_reorderings() {
343        let opts = ChatOptions::default();
344        let msgs = vec![Message::user("x")];
345        let tool_a = Tool {
346            name: "alpha".into(),
347            ..Default::default()
348        };
349        let tool_b = Tool {
350            name: "beta".into(),
351            ..Default::default()
352        };
353
354        let k1 = cache_key_for(&msgs, Some(&[tool_a.clone(), tool_b.clone()]), &opts);
355        let k2 = cache_key_for(&msgs, Some(&[tool_b, tool_a]), &opts);
356        assert_eq!(k1, k2, "tool order must not affect the key");
357    }
358
359    #[cfg(feature = "cache")]
360    #[tokio::test]
361    async fn sqlite_cache_persists() {
362        let tmp = tempfile::tempdir().unwrap();
363        let path = tmp.path().join("cache.db");
364
365        let inner = Arc::new(EchoProvider::ok("p"));
366        {
367            let backend = Arc::new(SqliteCache::open(&path).unwrap()) as Arc<dyn CacheBackend>;
368            let cached = CachedProvider::new(inner.clone(), backend);
369            cached
370                .chat(&[Message::user("persist")], None, &ChatOptions::default())
371                .await
372                .unwrap();
373        }
374        // Reopen — fresh EchoProvider — hit the cache.
375        let inner2 = Arc::new(EchoProvider::ok("p"));
376        let backend = Arc::new(SqliteCache::open(&path).unwrap()) as Arc<dyn CacheBackend>;
377        let cached = CachedProvider::new(inner2.clone(), backend);
378        let r = cached
379            .chat(&[Message::user("persist")], None, &ChatOptions::default())
380            .await
381            .unwrap();
382        assert_eq!(r.message.text(), Some("ok"));
383        assert_eq!(
384            inner2.calls(),
385            0,
386            "cached response must come from the sqlite store"
387        );
388    }
389}