Skip to main content

llm_git/
llm_cache.rs

1//! Content-addressed cache for parsed LLM responses.
2//!
3//! Every successful one-shot LLM call writes its parsed payload here keyed on
4//! the canonical request material (operation, model, prompts, schema,
5//! temperature, …). Subsequent calls with byte-identical inputs short-circuit
6//! the network round-trip and replay the parsed value, which is the cheapest
7//! possible recovery when the caller (eg. `lgit --compose`) is rerun after a
8//! transient failure or unrelated edit.
9//!
10//! Backed by `SQLite` for atomic upserts and TTL-based eviction. The cache is
11//! best-effort: any failure to read/write is logged and skipped — never fatal.
12
13use std::{
14   path::{Path, PathBuf},
15   sync::{Arc, OnceLock},
16   time::{Duration, SystemTime, UNIX_EPOCH},
17};
18
19use parking_lot::Mutex;
20use rusqlite::{Connection, OptionalExtension, params};
21
22use crate::{
23   config::CommitConfig,
24   error::{CommitGenError, Result},
25};
26
27/// Bumped whenever the on-disk row format or hashing scheme changes. Existing
28/// rows with a different schema version are treated as misses.
29const SCHEMA_VERSION: i32 = 1;
30
31/// Approximate inverse probability of running a TTL prune on each successful
32/// `put` call. Keeps the cache bounded without scheduling background work.
33const PRUNE_DIVISOR: u64 = 64;
34
35/// Holds the process-wide cache. Initialized from runtime config in `main`,
36/// hence `OnceLock` rather than `LazyLock` (the value depends on user config
37/// loaded at startup, not on a static initializer).
38static GLOBAL: OnceLock<Option<Arc<LlmCache>>> = OnceLock::new();
39
40/// Initialize the global LLM response cache from `config`. Idempotent: only
41/// the first call wins.
42pub fn init(config: &CommitConfig) {
43   let _ = GLOBAL.set(build_from_config(config));
44}
45
46/// Get the active cache handle, if any. Cheap clone of an `Arc`.
47pub fn global() -> Option<Arc<LlmCache>> {
48   GLOBAL.get().and_then(Option::clone)
49}
50
51fn build_from_config(config: &CommitConfig) -> Option<Arc<LlmCache>> {
52   if !config.cache_enabled {
53      return None;
54   }
55   let dir = resolve_cache_dir(config)?;
56   let path = dir.join("responses.sqlite");
57   let ttl = Duration::from_secs(u64::from(config.cache_ttl_days).saturating_mul(86_400));
58   match LlmCache::open(&path, ttl) {
59      Ok(cache) => Some(Arc::new(cache)),
60      Err(err) => {
61         crate::style::warn(&format!(
62            "LLM response cache disabled (failed to open {}): {err}",
63            path.display()
64         ));
65         None
66      },
67   }
68}
69
70fn resolve_cache_dir(config: &CommitConfig) -> Option<PathBuf> {
71   if let Some(dir) = config.cache_dir.as_deref()
72      && !dir.is_empty()
73   {
74      return Some(PathBuf::from(dir));
75   }
76   if let Ok(xdg) = std::env::var("XDG_CACHE_HOME")
77      && !xdg.is_empty()
78   {
79      return Some(PathBuf::from(xdg).join("llm-git"));
80   }
81   if let Ok(home) = std::env::var("HOME") {
82      return Some(PathBuf::from(home).join(".cache").join("llm-git"));
83   }
84   if let Ok(home) = std::env::var("USERPROFILE") {
85      return Some(PathBuf::from(home).join(".cache").join("llm-git"));
86   }
87   None
88}
89
90/// SQLite-backed cache of LLM responses. Cheap to clone via `Arc`.
91pub struct LlmCache {
92   conn:     Mutex<Connection>,
93   ttl_secs: u64,
94}
95
96impl LlmCache {
97   /// Open (or create) the cache at `path` with the given TTL. A TTL of zero
98   /// disables expiration.
99   pub fn open(path: &Path, ttl: Duration) -> Result<Self> {
100      if let Some(parent) = path.parent() {
101         std::fs::create_dir_all(parent).map_err(|err| {
102            CommitGenError::Other(format!("create cache dir {}: {err}", parent.display()))
103         })?;
104      }
105      let conn = Connection::open(path)
106         .map_err(|err| CommitGenError::Other(format!("open llm cache db: {err}")))?;
107      conn
108         .pragma_update(None, "journal_mode", "WAL")
109         .map_err(|err| CommitGenError::Other(format!("pragma WAL: {err}")))?;
110      conn
111         .pragma_update(None, "synchronous", "NORMAL")
112         .map_err(|err| CommitGenError::Other(format!("pragma synchronous: {err}")))?;
113      conn
114         .execute_batch(
115            "CREATE TABLE IF NOT EXISTS responses (
116                key            TEXT    PRIMARY KEY,
117                schema_version INTEGER NOT NULL,
118                model          TEXT    NOT NULL,
119                operation      TEXT    NOT NULL,
120                response       TEXT    NOT NULL,
121                created_at     INTEGER NOT NULL,
122                accessed_at    INTEGER NOT NULL
123             );
124             CREATE INDEX IF NOT EXISTS idx_responses_created_at
125                ON responses(created_at);",
126         )
127         .map_err(|err| CommitGenError::Other(format!("create cache schema: {err}")))?;
128      Ok(Self { conn: Mutex::new(conn), ttl_secs: ttl.as_secs() })
129   }
130
131   /// Look up the stored payload string for `key`. Returns `None` on miss,
132   /// expired entry, or any underlying error (cache failures are silent).
133   pub fn get(&self, key: &str) -> Option<String> {
134      let conn = self.conn.lock();
135      let now = now_unix();
136      let row: Option<(String, i64)> = conn
137         .query_row(
138            "SELECT response, created_at FROM responses
139             WHERE key = ?1 AND schema_version = ?2",
140            params![key, SCHEMA_VERSION],
141            |row| Ok((row.get(0)?, row.get(1)?)),
142         )
143         .optional()
144         .ok()
145         .flatten();
146      let (response, created_at) = row?;
147      if self.ttl_secs > 0 {
148         let cutoff = now.saturating_sub(self.ttl_secs);
149         if (created_at as u64) < cutoff {
150            let _ = conn.execute("DELETE FROM responses WHERE key = ?1", params![key]);
151            return None;
152         }
153      }
154      let _ = conn
155         .execute("UPDATE responses SET accessed_at = ?1 WHERE key = ?2", params![now as i64, key]);
156      Some(response)
157   }
158
159   /// Insert (or replace) a cached payload. Failures are silently swallowed —
160   /// the cache must never break the actual operation.
161   pub fn put(&self, key: &str, model: &str, operation: &str, response: &str) {
162      let conn = self.conn.lock();
163      let now = now_unix();
164      let _ = conn.execute(
165         "INSERT OR REPLACE INTO responses
166          (key, schema_version, model, operation, response, created_at, accessed_at)
167          VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?6)",
168         params![key, SCHEMA_VERSION, model, operation, response, now as i64],
169      );
170      if self.ttl_secs > 0 && now.is_multiple_of(PRUNE_DIVISOR) {
171         let cutoff = now.saturating_sub(self.ttl_secs);
172         let _ =
173            conn.execute("DELETE FROM responses WHERE created_at < ?1", params![cutoff as i64]);
174      }
175   }
176}
177
178fn now_unix() -> u64 {
179   SystemTime::now()
180      .duration_since(UNIX_EPOCH)
181      .map(|d| d.as_secs())
182      .unwrap_or(0)
183}
184
185/// Material that uniquely identifies a one-shot LLM call. Hashed into the
186/// cache key.
187pub struct CacheMaterial<'a> {
188   pub operation:        &'a str,
189   pub model:            &'a str,
190   pub tool_name:        &'a str,
191   pub tool_description: &'a str,
192   pub system_prompt:    &'a str,
193   pub user_prompt:      &'a str,
194   pub schema:           &'a serde_json::Value,
195   pub temperature:      f32,
196   pub max_tokens:       u32,
197   pub api_mode:         &'a str,
198}
199
200/// Compute a content-addressed cache key over `material`. Stable across runs
201/// for byte-identical inputs.
202pub fn compute_key(material: &CacheMaterial<'_>) -> String {
203   let mut hasher = blake3::Hasher::new();
204   hasher.update(b"llm-cache/v1\n");
205   write_field(&mut hasher, "operation", material.operation);
206   write_field(&mut hasher, "model", material.model);
207   write_field(&mut hasher, "api_mode", material.api_mode);
208   write_field(&mut hasher, "tool_name", material.tool_name);
209   write_field(&mut hasher, "tool_description", material.tool_description);
210   write_field(&mut hasher, "system", material.system_prompt);
211   write_field(&mut hasher, "user", material.user_prompt);
212   // serde_json::Value uses BTreeMap by default → keys serialize in stable
213   // order without preserve_order, giving a canonical schema string.
214   let schema_canonical = serde_json::to_string(material.schema).unwrap_or_else(|_| String::new());
215   write_field(&mut hasher, "schema", &schema_canonical);
216   hasher.update(b"temperature\x00");
217   hasher.update(&material.temperature.to_bits().to_le_bytes());
218   hasher.update(b"\nmax_tokens\x00");
219   hasher.update(&material.max_tokens.to_le_bytes());
220   hasher.update(b"\n");
221   hasher.finalize().to_hex().to_string()
222}
223
224fn write_field(hasher: &mut blake3::Hasher, name: &str, value: &str) {
225   hasher.update(name.as_bytes());
226   hasher.update(b"\x00");
227   hasher.update(value.as_bytes());
228   hasher.update(b"\n");
229}
230
231#[cfg(test)]
232mod tests {
233   use std::sync::Arc;
234
235   use serde_json::json;
236   use tempfile::tempdir;
237
238   use super::*;
239
240   fn material<'a>() -> CacheMaterial<'a> {
241      // Stable static-ish references for tests.
242      static SCHEMA: std::sync::LazyLock<serde_json::Value> =
243         std::sync::LazyLock::new(|| json!({"foo": "bar"}));
244      CacheMaterial {
245         operation:        "test",
246         model:            "test-model",
247         tool_name:        "tool",
248         tool_description: "desc",
249         system_prompt:    "system",
250         user_prompt:      "user",
251         schema:           &SCHEMA,
252         temperature:      0.0,
253         max_tokens:       100,
254         api_mode:         "ChatCompletions",
255      }
256   }
257
258   #[test]
259   fn key_is_stable_and_collision_resistant() {
260      let m = material();
261      let k1 = compute_key(&m);
262      let k2 = compute_key(&m);
263      assert_eq!(k1, k2);
264
265      let mut other = material();
266      other.user_prompt = "different";
267      assert_ne!(k1, compute_key(&other));
268   }
269
270   #[test]
271   fn roundtrip_get_put() {
272      let dir = tempdir().unwrap();
273      let cache =
274         Arc::new(LlmCache::open(&dir.path().join("c.sqlite"), Duration::from_secs(60)).unwrap());
275      assert!(cache.get("k").is_none());
276      cache.put("k", "model", "op", "{\"x\":1}");
277      assert_eq!(cache.get("k").as_deref(), Some("{\"x\":1}"));
278      cache.put("k", "model", "op", "{\"x\":2}");
279      assert_eq!(cache.get("k").as_deref(), Some("{\"x\":2}"));
280   }
281
282   #[test]
283   fn ttl_zero_disables_expiry() {
284      let dir = tempdir().unwrap();
285      let cache = LlmCache::open(&dir.path().join("c.sqlite"), Duration::from_secs(0)).unwrap();
286      cache.put("k", "model", "op", "v");
287      assert_eq!(cache.get("k").as_deref(), Some("v"));
288   }
289}