1use std::{
16 path::{Path, PathBuf},
17 sync::{Arc, OnceLock},
18 time::{Duration, SystemTime, UNIX_EPOCH},
19};
20
21use parking_lot::Mutex;
22use rusqlite::{Connection, OptionalExtension, params};
23
24use crate::{
25 config::CommitConfig,
26 error::{CommitGenError, Result},
27};
28
29const SCHEMA_VERSION: i32 = 2;
32
33const PRUNE_DIVISOR: u64 = 64;
36
37static GLOBAL: OnceLock<Option<Arc<LlmCache>>> = OnceLock::new();
41
42pub fn init(config: &CommitConfig) {
45 let _ = GLOBAL.set(build_from_config(config));
46}
47
48pub fn global() -> Option<Arc<LlmCache>> {
50 GLOBAL.get().and_then(Option::clone)
51}
52
53fn build_from_config(config: &CommitConfig) -> Option<Arc<LlmCache>> {
54 if !config.cache_enabled {
55 return None;
56 }
57 let dir = resolve_cache_dir(config)?;
58 let path = dir.join("responses.sqlite");
59 let ttl = Duration::from_secs(u64::from(config.cache_ttl_days).saturating_mul(86_400));
60 match LlmCache::open(&path, ttl) {
61 Ok(cache) => Some(Arc::new(cache)),
62 Err(err) => {
63 crate::style::warn(&format!(
64 "LLM response cache disabled (failed to open {}): {err}",
65 path.display()
66 ));
67 None
68 },
69 }
70}
71
72fn resolve_cache_dir(config: &CommitConfig) -> Option<PathBuf> {
73 if let Some(dir) = config.cache_dir.as_deref()
74 && !dir.is_empty()
75 {
76 return Some(PathBuf::from(dir));
77 }
78 if let Ok(xdg) = std::env::var("XDG_CACHE_HOME")
79 && !xdg.is_empty()
80 {
81 return Some(PathBuf::from(xdg).join("llm-git"));
82 }
83 if let Ok(home) = std::env::var("HOME") {
84 return Some(PathBuf::from(home).join(".cache").join("llm-git"));
85 }
86 if let Ok(home) = std::env::var("USERPROFILE") {
87 return Some(PathBuf::from(home).join(".cache").join("llm-git"));
88 }
89 None
90}
91
92pub struct LlmCache {
94 conn: Mutex<Connection>,
95 ttl_secs: u64,
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct CachedLlmResponse {
100 pub request: String,
101 pub response: String,
102}
103
104impl LlmCache {
105 pub fn open(path: &Path, ttl: Duration) -> Result<Self> {
108 if let Some(parent) = path.parent() {
109 std::fs::create_dir_all(parent).map_err(|err| {
110 CommitGenError::Other(format!("create cache dir {}: {err}", parent.display()))
111 })?;
112 }
113 let conn = Connection::open(path)
114 .map_err(|err| CommitGenError::Other(format!("open llm cache db: {err}")))?;
115 conn
116 .pragma_update(None, "journal_mode", "WAL")
117 .map_err(|err| CommitGenError::Other(format!("pragma WAL: {err}")))?;
118 conn
119 .pragma_update(None, "synchronous", "NORMAL")
120 .map_err(|err| CommitGenError::Other(format!("pragma synchronous: {err}")))?;
121 conn
122 .execute_batch(
123 "CREATE TABLE IF NOT EXISTS responses (
124 key TEXT PRIMARY KEY,
125 schema_version INTEGER NOT NULL,
126 model TEXT NOT NULL,
127 operation TEXT NOT NULL,
128 request TEXT NOT NULL,
129 response TEXT NOT NULL,
130 created_at INTEGER NOT NULL,
131 accessed_at INTEGER NOT NULL
132 );
133 CREATE INDEX IF NOT EXISTS idx_responses_created_at
134 ON responses(created_at);",
135 )
136 .map_err(|err| CommitGenError::Other(format!("create cache schema: {err}")))?;
137 conn
138 .execute(
139 "ALTER TABLE responses ADD COLUMN request TEXT NOT NULL DEFAULT ''",
140 [],
141 )
142 .or_else(|err| {
143 if matches!(err, rusqlite::Error::SqliteFailure(_, Some(ref message)) if message.contains("duplicate column name"))
144 {
145 Ok(0)
146 } else {
147 Err(err)
148 }
149 })
150 .map_err(|err| CommitGenError::Other(format!("migrate cache schema: {err}")))?;
151 Ok(Self { conn: Mutex::new(conn), ttl_secs: ttl.as_secs() })
152 }
153
154 pub fn get_entry(&self, key: &str) -> Option<CachedLlmResponse> {
158 let conn = self.conn.lock();
159 let now = now_unix();
160 let row: Option<(String, String, i64)> = conn
161 .query_row(
162 "SELECT request, response, created_at FROM responses
163 WHERE key = ?1 AND schema_version = ?2",
164 params![key, SCHEMA_VERSION],
165 |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
166 )
167 .optional()
168 .ok()
169 .flatten();
170 let (request, response, created_at) = row?;
171 if self.ttl_secs > 0 {
172 let cutoff = now.saturating_sub(self.ttl_secs);
173 if (created_at as u64) < cutoff {
174 let _ = conn.execute("DELETE FROM responses WHERE key = ?1", params![key]);
175 return None;
176 }
177 }
178 let _ = conn
179 .execute("UPDATE responses SET accessed_at = ?1 WHERE key = ?2", params![now as i64, key]);
180 Some(CachedLlmResponse { request, response })
181 }
182
183 pub fn get(&self, key: &str) -> Option<String> {
186 self.get_entry(key).map(|entry| entry.response)
187 }
188
189 pub fn put(&self, key: &str, model: &str, operation: &str, request: &str, response: &str) {
192 let conn = self.conn.lock();
193 let now = now_unix();
194 let _ = conn.execute(
195 "INSERT OR REPLACE INTO responses
196 (key, schema_version, model, operation, request, response, created_at, accessed_at)
197 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?7)",
198 params![key, SCHEMA_VERSION, model, operation, request, response, now as i64],
199 );
200 if self.ttl_secs > 0 && now.is_multiple_of(PRUNE_DIVISOR) {
201 let cutoff = now.saturating_sub(self.ttl_secs);
202 let _ =
203 conn.execute("DELETE FROM responses WHERE created_at < ?1", params![cutoff as i64]);
204 }
205 }
206}
207
208fn now_unix() -> u64 {
209 SystemTime::now()
210 .duration_since(UNIX_EPOCH)
211 .map_or(0, |d| d.as_secs())
212}
213
214pub struct CacheMaterial<'a> {
217 pub operation: &'a str,
218 pub model: &'a str,
219 pub tool_name: &'a str,
220 pub tool_description: &'a str,
221 pub system_prompt: &'a str,
222 pub user_prompt: &'a str,
223 pub schema: &'a serde_json::Value,
224 pub temperature: f32,
225 pub max_tokens: u32,
226 pub api_mode: &'a str,
227}
228
229pub fn compute_key(material: &CacheMaterial<'_>) -> String {
232 let mut hasher = blake3::Hasher::new();
233 hasher.update(b"llm-cache/v1\n");
234 write_field(&mut hasher, "operation", material.operation);
235 write_field(&mut hasher, "model", material.model);
236 write_field(&mut hasher, "api_mode", material.api_mode);
237 write_field(&mut hasher, "tool_name", material.tool_name);
238 write_field(&mut hasher, "tool_description", material.tool_description);
239 write_field(&mut hasher, "system", material.system_prompt);
240 write_field(&mut hasher, "user", material.user_prompt);
241 let schema_canonical = serde_json::to_string(material.schema).unwrap_or_else(|_| String::new());
244 write_field(&mut hasher, "schema", &schema_canonical);
245 hasher.update(b"temperature\x00");
246 hasher.update(&material.temperature.to_bits().to_le_bytes());
247 hasher.update(b"\nmax_tokens\x00");
248 hasher.update(&material.max_tokens.to_le_bytes());
249 hasher.update(b"\n");
250 hasher.finalize().to_hex().to_string()
251}
252
253fn write_field(hasher: &mut blake3::Hasher, name: &str, value: &str) {
254 hasher.update(name.as_bytes());
255 hasher.update(b"\x00");
256 hasher.update(value.as_bytes());
257 hasher.update(b"\n");
258}
259
260#[cfg(test)]
261mod tests {
262 use std::sync::Arc;
263
264 use serde_json::json;
265 use tempfile::tempdir;
266
267 use super::*;
268
269 fn material<'a>() -> CacheMaterial<'a> {
270 static SCHEMA: std::sync::LazyLock<serde_json::Value> =
272 std::sync::LazyLock::new(|| json!({"foo": "bar"}));
273 CacheMaterial {
274 operation: "test",
275 model: "test-model",
276 tool_name: "tool",
277 tool_description: "desc",
278 system_prompt: "system",
279 user_prompt: "user",
280 schema: &SCHEMA,
281 temperature: 0.0,
282 max_tokens: 100,
283 api_mode: "ChatCompletions",
284 }
285 }
286
287 #[test]
288 fn key_is_stable_and_collision_resistant() {
289 let m = material();
290 let k1 = compute_key(&m);
291 let k2 = compute_key(&m);
292 assert_eq!(k1, k2);
293
294 let mut other = material();
295 other.user_prompt = "different";
296 assert_ne!(k1, compute_key(&other));
297 }
298
299 #[test]
300 fn roundtrip_get_put() {
301 let dir = tempdir().unwrap();
302 let cache =
303 Arc::new(LlmCache::open(&dir.path().join("c.sqlite"), Duration::from_secs(60)).unwrap());
304 assert!(cache.get("k").is_none());
305 cache.put("k", "model", "op", "{\"request\":1}", "{\"x\":1}");
306 assert_eq!(cache.get("k").as_deref(), Some("{\"x\":1}"));
307 assert_eq!(
308 cache.get_entry("k"),
309 Some(CachedLlmResponse {
310 request: "{\"request\":1}".to_string(),
311 response: "{\"x\":1}".to_string(),
312 })
313 );
314 cache.put("k", "model", "op", "{\"request\":2}", "{\"x\":2}");
315 assert_eq!(cache.get("k").as_deref(), Some("{\"x\":2}"));
316 assert_eq!(
317 cache.get_entry("k").map(|entry| entry.request),
318 Some("{\"request\":2}".to_string())
319 );
320 }
321
322 #[test]
323 fn open_migrates_old_schema_before_storing_requests() {
324 let dir = tempdir().unwrap();
325 let path = dir.path().join("c.sqlite");
326 {
327 let conn = Connection::open(&path).unwrap();
328 conn
329 .execute_batch(
330 "CREATE TABLE responses (
331 key TEXT PRIMARY KEY,
332 schema_version INTEGER NOT NULL,
333 model TEXT NOT NULL,
334 operation TEXT NOT NULL,
335 response TEXT NOT NULL,
336 created_at INTEGER NOT NULL,
337 accessed_at INTEGER NOT NULL
338 );",
339 )
340 .unwrap();
341 }
342
343 let cache = LlmCache::open(&path, Duration::from_secs(60)).unwrap();
344 cache.put("k", "model", "op", "{\"request\":true}", "{\"response\":true}");
345
346 assert_eq!(
347 cache.get_entry("k"),
348 Some(CachedLlmResponse {
349 request: "{\"request\":true}".to_string(),
350 response: "{\"response\":true}".to_string(),
351 })
352 );
353 }
354 #[test]
355 fn ttl_zero_disables_expiry() {
356 let dir = tempdir().unwrap();
357 let cache = LlmCache::open(&dir.path().join("c.sqlite"), Duration::from_secs(0)).unwrap();
358 cache.put("k", "model", "op", "request", "v");
359 assert_eq!(cache.get("k").as_deref(), Some("v"));
360 }
361}