1use std::{
14 path::{Path, PathBuf},
15 sync::{Arc, OnceLock},
16 time::{Duration, SystemTime, UNIX_EPOCH},
17};
18
19use parking_lot::Mutex;
20use rusqlite::{Connection, OptionalExtension, params};
21
22use crate::{
23 config::CommitConfig,
24 error::{CommitGenError, Result},
25};
26
27const SCHEMA_VERSION: i32 = 1;
30
31const PRUNE_DIVISOR: u64 = 64;
34
35static GLOBAL: OnceLock<Option<Arc<LlmCache>>> = OnceLock::new();
39
40pub fn init(config: &CommitConfig) {
43 let _ = GLOBAL.set(build_from_config(config));
44}
45
46pub fn global() -> Option<Arc<LlmCache>> {
48 GLOBAL.get().and_then(Option::clone)
49}
50
51fn build_from_config(config: &CommitConfig) -> Option<Arc<LlmCache>> {
52 if !config.cache_enabled {
53 return None;
54 }
55 let dir = resolve_cache_dir(config)?;
56 let path = dir.join("responses.sqlite");
57 let ttl = Duration::from_secs(u64::from(config.cache_ttl_days).saturating_mul(86_400));
58 match LlmCache::open(&path, ttl) {
59 Ok(cache) => Some(Arc::new(cache)),
60 Err(err) => {
61 crate::style::warn(&format!(
62 "LLM response cache disabled (failed to open {}): {err}",
63 path.display()
64 ));
65 None
66 },
67 }
68}
69
70fn resolve_cache_dir(config: &CommitConfig) -> Option<PathBuf> {
71 if let Some(dir) = config.cache_dir.as_deref()
72 && !dir.is_empty()
73 {
74 return Some(PathBuf::from(dir));
75 }
76 if let Ok(xdg) = std::env::var("XDG_CACHE_HOME")
77 && !xdg.is_empty()
78 {
79 return Some(PathBuf::from(xdg).join("llm-git"));
80 }
81 if let Ok(home) = std::env::var("HOME") {
82 return Some(PathBuf::from(home).join(".cache").join("llm-git"));
83 }
84 if let Ok(home) = std::env::var("USERPROFILE") {
85 return Some(PathBuf::from(home).join(".cache").join("llm-git"));
86 }
87 None
88}
89
90pub struct LlmCache {
92 conn: Mutex<Connection>,
93 ttl_secs: u64,
94}
95
96impl LlmCache {
97 pub fn open(path: &Path, ttl: Duration) -> Result<Self> {
100 if let Some(parent) = path.parent() {
101 std::fs::create_dir_all(parent).map_err(|err| {
102 CommitGenError::Other(format!("create cache dir {}: {err}", parent.display()))
103 })?;
104 }
105 let conn = Connection::open(path)
106 .map_err(|err| CommitGenError::Other(format!("open llm cache db: {err}")))?;
107 conn
108 .pragma_update(None, "journal_mode", "WAL")
109 .map_err(|err| CommitGenError::Other(format!("pragma WAL: {err}")))?;
110 conn
111 .pragma_update(None, "synchronous", "NORMAL")
112 .map_err(|err| CommitGenError::Other(format!("pragma synchronous: {err}")))?;
113 conn
114 .execute_batch(
115 "CREATE TABLE IF NOT EXISTS responses (
116 key TEXT PRIMARY KEY,
117 schema_version INTEGER NOT NULL,
118 model TEXT NOT NULL,
119 operation TEXT NOT NULL,
120 response TEXT NOT NULL,
121 created_at INTEGER NOT NULL,
122 accessed_at INTEGER NOT NULL
123 );
124 CREATE INDEX IF NOT EXISTS idx_responses_created_at
125 ON responses(created_at);",
126 )
127 .map_err(|err| CommitGenError::Other(format!("create cache schema: {err}")))?;
128 Ok(Self { conn: Mutex::new(conn), ttl_secs: ttl.as_secs() })
129 }
130
131 pub fn get(&self, key: &str) -> Option<String> {
134 let conn = self.conn.lock();
135 let now = now_unix();
136 let row: Option<(String, i64)> = conn
137 .query_row(
138 "SELECT response, created_at FROM responses
139 WHERE key = ?1 AND schema_version = ?2",
140 params![key, SCHEMA_VERSION],
141 |row| Ok((row.get(0)?, row.get(1)?)),
142 )
143 .optional()
144 .ok()
145 .flatten();
146 let (response, created_at) = row?;
147 if self.ttl_secs > 0 {
148 let cutoff = now.saturating_sub(self.ttl_secs);
149 if (created_at as u64) < cutoff {
150 let _ = conn.execute("DELETE FROM responses WHERE key = ?1", params![key]);
151 return None;
152 }
153 }
154 let _ = conn
155 .execute("UPDATE responses SET accessed_at = ?1 WHERE key = ?2", params![now as i64, key]);
156 Some(response)
157 }
158
159 pub fn put(&self, key: &str, model: &str, operation: &str, response: &str) {
162 let conn = self.conn.lock();
163 let now = now_unix();
164 let _ = conn.execute(
165 "INSERT OR REPLACE INTO responses
166 (key, schema_version, model, operation, response, created_at, accessed_at)
167 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?6)",
168 params![key, SCHEMA_VERSION, model, operation, response, now as i64],
169 );
170 if self.ttl_secs > 0 && now.is_multiple_of(PRUNE_DIVISOR) {
171 let cutoff = now.saturating_sub(self.ttl_secs);
172 let _ =
173 conn.execute("DELETE FROM responses WHERE created_at < ?1", params![cutoff as i64]);
174 }
175 }
176}
177
178fn now_unix() -> u64 {
179 SystemTime::now()
180 .duration_since(UNIX_EPOCH)
181 .map(|d| d.as_secs())
182 .unwrap_or(0)
183}
184
185pub struct CacheMaterial<'a> {
188 pub operation: &'a str,
189 pub model: &'a str,
190 pub tool_name: &'a str,
191 pub tool_description: &'a str,
192 pub system_prompt: &'a str,
193 pub user_prompt: &'a str,
194 pub schema: &'a serde_json::Value,
195 pub temperature: f32,
196 pub max_tokens: u32,
197 pub api_mode: &'a str,
198}
199
200pub fn compute_key(material: &CacheMaterial<'_>) -> String {
203 let mut hasher = blake3::Hasher::new();
204 hasher.update(b"llm-cache/v1\n");
205 write_field(&mut hasher, "operation", material.operation);
206 write_field(&mut hasher, "model", material.model);
207 write_field(&mut hasher, "api_mode", material.api_mode);
208 write_field(&mut hasher, "tool_name", material.tool_name);
209 write_field(&mut hasher, "tool_description", material.tool_description);
210 write_field(&mut hasher, "system", material.system_prompt);
211 write_field(&mut hasher, "user", material.user_prompt);
212 let schema_canonical = serde_json::to_string(material.schema).unwrap_or_else(|_| String::new());
215 write_field(&mut hasher, "schema", &schema_canonical);
216 hasher.update(b"temperature\x00");
217 hasher.update(&material.temperature.to_bits().to_le_bytes());
218 hasher.update(b"\nmax_tokens\x00");
219 hasher.update(&material.max_tokens.to_le_bytes());
220 hasher.update(b"\n");
221 hasher.finalize().to_hex().to_string()
222}
223
224fn write_field(hasher: &mut blake3::Hasher, name: &str, value: &str) {
225 hasher.update(name.as_bytes());
226 hasher.update(b"\x00");
227 hasher.update(value.as_bytes());
228 hasher.update(b"\n");
229}
230
231#[cfg(test)]
232mod tests {
233 use std::sync::Arc;
234
235 use serde_json::json;
236 use tempfile::tempdir;
237
238 use super::*;
239
240 fn material<'a>() -> CacheMaterial<'a> {
241 static SCHEMA: std::sync::LazyLock<serde_json::Value> =
243 std::sync::LazyLock::new(|| json!({"foo": "bar"}));
244 CacheMaterial {
245 operation: "test",
246 model: "test-model",
247 tool_name: "tool",
248 tool_description: "desc",
249 system_prompt: "system",
250 user_prompt: "user",
251 schema: &SCHEMA,
252 temperature: 0.0,
253 max_tokens: 100,
254 api_mode: "ChatCompletions",
255 }
256 }
257
258 #[test]
259 fn key_is_stable_and_collision_resistant() {
260 let m = material();
261 let k1 = compute_key(&m);
262 let k2 = compute_key(&m);
263 assert_eq!(k1, k2);
264
265 let mut other = material();
266 other.user_prompt = "different";
267 assert_ne!(k1, compute_key(&other));
268 }
269
270 #[test]
271 fn roundtrip_get_put() {
272 let dir = tempdir().unwrap();
273 let cache =
274 Arc::new(LlmCache::open(&dir.path().join("c.sqlite"), Duration::from_secs(60)).unwrap());
275 assert!(cache.get("k").is_none());
276 cache.put("k", "model", "op", "{\"x\":1}");
277 assert_eq!(cache.get("k").as_deref(), Some("{\"x\":1}"));
278 cache.put("k", "model", "op", "{\"x\":2}");
279 assert_eq!(cache.get("k").as_deref(), Some("{\"x\":2}"));
280 }
281
282 #[test]
283 fn ttl_zero_disables_expiry() {
284 let dir = tempdir().unwrap();
285 let cache = LlmCache::open(&dir.path().join("c.sqlite"), Duration::from_secs(0)).unwrap();
286 cache.put("k", "model", "op", "v");
287 assert_eq!(cache.get("k").as_deref(), Some("v"));
288 }
289}