1use std::fs;
24use std::io;
25use std::path::{Path, PathBuf};
26
27use serde::de::DeserializeOwned;
28use serde::Serialize;
29use sha2::{Digest, Sha256};
30
31#[derive(Debug, thiserror::Error)]
33pub enum CacheError {
34 #[error("ai-cache I/O error: {0}")]
36 Io(#[from] io::Error),
37 #[error("ai-cache serialize error: {0}")]
39 Serialize(serde_json::Error),
40}
41
42#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
44pub struct CacheStats {
45 pub entries: usize,
47 pub total_bytes: u64,
49}
50
51#[derive(Debug, Clone)]
53pub struct AiCache {
54 root: PathBuf,
55}
56
57impl AiCache {
58 #[must_use]
63 pub fn new(project_root: &Path) -> Self {
64 Self {
65 root: project_root.join(".bock").join("ai-cache"),
66 }
67 }
68
69 #[must_use]
71 pub fn with_root(root: PathBuf) -> Self {
72 Self { root }
73 }
74
75 #[must_use]
77 pub fn root(&self) -> &Path {
78 &self.root
79 }
80
81 #[must_use]
88 pub fn get<R: Serialize, S: DeserializeOwned>(&self, request: &R) -> Option<S> {
89 self.get_strict(request).ok().flatten()
90 }
91
92 pub fn get_strict<R: Serialize, S: DeserializeOwned>(
99 &self,
100 request: &R,
101 ) -> Result<Option<S>, CacheError> {
102 let key = compute_key(request)?;
103 let path = self.path_for_key(&key);
104 if !path.exists() {
105 return Ok(None);
106 }
107 let bytes = fs::read(&path)?;
108 let value = serde_json::from_slice(&bytes).map_err(CacheError::Serialize)?;
109 Ok(Some(value))
110 }
111
112 pub fn put<R: Serialize, S: Serialize>(
122 &self,
123 request: &R,
124 response: &S,
125 ) -> Result<(), CacheError> {
126 let key = compute_key(request)?;
127 let path = self.path_for_key(&key);
128 if let Some(parent) = path.parent() {
129 fs::create_dir_all(parent)?;
130 }
131 let bytes = serde_json::to_vec(response).map_err(CacheError::Serialize)?;
132 fs::write(&path, bytes)?;
133 Ok(())
134 }
135
136 #[must_use]
138 pub fn contains<R: Serialize>(&self, request: &R) -> bool {
139 match compute_key(request) {
140 Ok(key) => self.path_for_key(&key).exists(),
141 Err(_) => false,
142 }
143 }
144
145 pub fn clear(&self) -> io::Result<()> {
151 if self.root.exists() {
152 fs::remove_dir_all(&self.root)?;
153 }
154 Ok(())
155 }
156
157 pub fn stats(&self) -> io::Result<CacheStats> {
162 let mut stats = CacheStats::default();
163 if self.root.exists() {
164 walk(&self.root, &mut stats)?;
165 }
166 Ok(stats)
167 }
168
169 fn path_for_key(&self, key: &str) -> PathBuf {
170 let shard = &key[..2];
171 self.root.join(shard).join(format!("{key}.json"))
172 }
173}
174
175pub fn compute_key<R: Serialize>(request: &R) -> Result<String, CacheError> {
184 let value = serde_json::to_value(request).map_err(CacheError::Serialize)?;
185 let canonical = serde_json::to_vec(&value).map_err(CacheError::Serialize)?;
186 let mut hasher = Sha256::new();
187 hasher.update(&canonical);
188 Ok(hex_encode(&hasher.finalize()))
189}
190
191fn walk(dir: &Path, stats: &mut CacheStats) -> io::Result<()> {
192 for entry in fs::read_dir(dir)? {
193 let entry = entry?;
194 let file_type = entry.file_type()?;
195 if file_type.is_dir() {
196 walk(&entry.path(), stats)?;
197 } else if entry.path().extension().and_then(|e| e.to_str()) == Some("json") {
198 stats.entries += 1;
199 stats.total_bytes = stats.total_bytes.saturating_add(entry.metadata()?.len());
200 }
201 }
202 Ok(())
203}
204
205fn hex_encode(bytes: &[u8]) -> String {
206 let mut s = String::with_capacity(bytes.len() * 2);
207 for b in bytes {
208 s.push_str(&format!("{b:02x}"));
209 }
210 s
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use serde::Deserialize;
217 use std::collections::HashMap;
218
219 #[derive(Serialize, Deserialize, PartialEq, Debug)]
220 struct Req {
221 kind: String,
222 params: HashMap<String, String>,
223 }
224
225 #[derive(Serialize, Deserialize, PartialEq, Debug)]
226 struct Resp {
227 body: String,
228 }
229
230 fn req(kind: &str) -> Req {
231 let mut params = HashMap::new();
232 params.insert("a".into(), "1".into());
233 params.insert("b".into(), "2".into());
234 Req {
235 kind: kind.into(),
236 params,
237 }
238 }
239
240 #[test]
241 fn put_then_get_round_trips() {
242 let dir = tempfile::tempdir().unwrap();
243 let cache = AiCache::new(dir.path());
244 let r = req("generate");
245 let resp = Resp {
246 body: "code".into(),
247 };
248 cache.put(&r, &resp).unwrap();
249 let got: Resp = cache.get(&r).expect("hit");
250 assert_eq!(got, resp);
251 }
252
253 #[test]
254 fn miss_returns_none() {
255 let dir = tempfile::tempdir().unwrap();
256 let cache = AiCache::new(dir.path());
257 let r = req("generate");
258 let got: Option<Resp> = cache.get(&r);
259 assert!(got.is_none());
260 }
261
262 #[test]
263 fn key_is_stable_across_hashmap_iteration_order() {
264 let mut a = HashMap::new();
267 a.insert("x".to_string(), "1".to_string());
268 a.insert("y".to_string(), "2".to_string());
269 let mut b = HashMap::new();
270 b.insert("y".to_string(), "2".to_string());
271 b.insert("x".to_string(), "1".to_string());
272 let ra = Req {
273 kind: "k".into(),
274 params: a,
275 };
276 let rb = Req {
277 kind: "k".into(),
278 params: b,
279 };
280 assert_eq!(compute_key(&ra).unwrap(), compute_key(&rb).unwrap());
281 }
282
283 #[test]
284 fn key_differs_for_different_input() {
285 let r1 = req("generate");
286 let r2 = req("repair");
287 assert_ne!(compute_key(&r1).unwrap(), compute_key(&r2).unwrap());
288 }
289
290 #[test]
291 fn contains_reflects_state() {
292 let dir = tempfile::tempdir().unwrap();
293 let cache = AiCache::new(dir.path());
294 let r = req("generate");
295 assert!(!cache.contains(&r));
296 cache
297 .put(
298 &r,
299 &Resp {
300 body: "x".into(),
301 },
302 )
303 .unwrap();
304 assert!(cache.contains(&r));
305 }
306
307 #[test]
308 fn sharded_storage_layout() {
309 let dir = tempfile::tempdir().unwrap();
310 let cache = AiCache::new(dir.path());
311 let r = req("generate");
312 cache
313 .put(
314 &r,
315 &Resp {
316 body: "x".into(),
317 },
318 )
319 .unwrap();
320
321 let key = compute_key(&r).unwrap();
322 let shard_dir = cache.root().join(&key[..2]);
323 assert!(shard_dir.is_dir(), "expected shard dir at {shard_dir:?}");
324 let entry = shard_dir.join(format!("{key}.json"));
325 assert!(entry.exists(), "expected entry file at {entry:?}");
326 }
327
328 #[test]
329 fn stats_count_entries_and_bytes() {
330 let dir = tempfile::tempdir().unwrap();
331 let cache = AiCache::new(dir.path());
332 cache
333 .put(
334 &req("a"),
335 &Resp {
336 body: "one".into(),
337 },
338 )
339 .unwrap();
340 cache
341 .put(
342 &req("b"),
343 &Resp {
344 body: "two".into(),
345 },
346 )
347 .unwrap();
348 let stats = cache.stats().unwrap();
349 assert_eq!(stats.entries, 2);
350 assert!(stats.total_bytes > 0);
351 }
352
353 #[test]
354 fn clear_removes_all_entries() {
355 let dir = tempfile::tempdir().unwrap();
356 let cache = AiCache::new(dir.path());
357 cache
358 .put(
359 &req("a"),
360 &Resp {
361 body: "x".into(),
362 },
363 )
364 .unwrap();
365 cache.clear().unwrap();
366 assert_eq!(cache.stats().unwrap().entries, 0);
367 assert!(!cache.root().exists());
368 }
369
370 #[test]
371 fn clear_on_missing_root_is_ok() {
372 let dir = tempfile::tempdir().unwrap();
373 let cache = AiCache::new(dir.path());
374 cache.clear().expect("no-op on missing dir");
375 }
376}