Skip to main content

hanzo_engine/
disk_kv_cache.rs

1use std::{
2    fs::{self, File, OpenOptions},
3    io::{self, BufWriter, Read, Write},
4    path::{Path, PathBuf},
5    time::{SystemTime, UNIX_EPOCH},
6};
7
8use sha1::{Digest, Sha1};
9
10const KVC_MAGIC: [u8; 3] = *b"KVC";
11const KVC_VERSION: u8 = 1;
12const KVC_HEADER_BYTES: usize = 48;
13const FILE_SUFFIX: &str = ".kv";
14const MIN_BUDGET_BYTES: u64 = 64 * 1024 * 1024;
15
16static TMP_SEQ: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
17
18#[repr(u8)]
19#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20pub enum SaveReason {
21    Unknown = 0,
22    Cold = 1,
23    Continued = 2,
24    Evict = 3,
25    Shutdown = 4,
26}
27
28impl SaveReason {
29    fn from_u8(v: u8) -> Self {
30        match v {
31            1 => Self::Cold,
32            2 => Self::Continued,
33            3 => Self::Evict,
34            4 => Self::Shutdown,
35            _ => Self::Unknown,
36        }
37    }
38}
39
40#[derive(Clone, Debug)]
41pub struct KvcHeader {
42    pub quant_bits: u8,
43    pub save_reason: SaveReason,
44    pub ext_flags: u8,
45    pub token_count: u32,
46    pub hit_count: u32,
47    pub ctx_size: u32,
48    pub created_unix: u64,
49    pub last_used_unix: u64,
50    pub payload_bytes: u64,
51}
52
53impl KvcHeader {
54    pub fn new(quant_bits: u8, save_reason: SaveReason, token_count: u32, ctx_size: u32) -> Self {
55        let now = SystemTime::now()
56            .duration_since(UNIX_EPOCH)
57            .map(|d| d.as_secs())
58            .unwrap_or(0);
59        Self {
60            quant_bits,
61            save_reason,
62            ext_flags: 0,
63            token_count,
64            hit_count: 0,
65            ctx_size,
66            created_unix: now,
67            last_used_unix: now,
68            payload_bytes: 0,
69        }
70    }
71
72    fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
73        let mut buf = [0u8; KVC_HEADER_BYTES];
74        buf[0..3].copy_from_slice(&KVC_MAGIC);
75        buf[3] = KVC_VERSION;
76        buf[4] = self.quant_bits;
77        buf[5] = self.save_reason as u8;
78        buf[6] = self.ext_flags;
79        // buf[7] reserved
80        buf[8..12].copy_from_slice(&self.token_count.to_le_bytes());
81        buf[12..16].copy_from_slice(&self.hit_count.to_le_bytes());
82        buf[16..20].copy_from_slice(&self.ctx_size.to_le_bytes());
83        // buf[20..24] reserved
84        buf[24..32].copy_from_slice(&self.created_unix.to_le_bytes());
85        buf[32..40].copy_from_slice(&self.last_used_unix.to_le_bytes());
86        buf[40..48].copy_from_slice(&self.payload_bytes.to_le_bytes());
87        w.write_all(&buf)
88    }
89
90    fn read<R: Read>(r: &mut R) -> io::Result<Self> {
91        let mut buf = [0u8; KVC_HEADER_BYTES];
92        r.read_exact(&mut buf)?;
93        if buf[0..3] != KVC_MAGIC {
94            return Err(io::Error::new(io::ErrorKind::InvalidData, "bad KVC magic"));
95        }
96        if buf[3] != KVC_VERSION {
97            return Err(io::Error::new(
98                io::ErrorKind::InvalidData,
99                format!("unsupported KVC version {}", buf[3]),
100            ));
101        }
102        Ok(Self {
103            quant_bits: buf[4],
104            save_reason: SaveReason::from_u8(buf[5]),
105            ext_flags: buf[6],
106            token_count: u32::from_le_bytes(buf[8..12].try_into().unwrap()),
107            hit_count: u32::from_le_bytes(buf[12..16].try_into().unwrap()),
108            ctx_size: u32::from_le_bytes(buf[16..20].try_into().unwrap()),
109            created_unix: u64::from_le_bytes(buf[24..32].try_into().unwrap()),
110            last_used_unix: u64::from_le_bytes(buf[32..40].try_into().unwrap()),
111            payload_bytes: u64::from_le_bytes(buf[40..48].try_into().unwrap()),
112        })
113    }
114}
115
116pub fn key_for(rendered_text: &str) -> String {
117    let mut h = Sha1::new();
118    h.update(rendered_text.as_bytes());
119    let digest = h.finalize();
120    let mut s = String::with_capacity(40);
121    for b in digest.iter() {
122        let _ = std::fmt::Write::write_fmt(&mut s, format_args!("{:02x}", b));
123    }
124    s
125}
126
127pub struct DiskKvCache {
128    dir: PathBuf,
129    budget_bytes: u64,
130}
131
132#[derive(Clone, Debug)]
133pub struct CacheHit {
134    pub header: KvcHeader,
135    pub rendered_text: Vec<u8>,
136    pub payload: Vec<u8>,
137}
138
139impl DiskKvCache {
140    pub fn new(dir: impl Into<PathBuf>, budget_mb: u64) -> io::Result<Self> {
141        let dir = dir.into();
142        fs::create_dir_all(&dir)?;
143        let budget_bytes = (budget_mb * 1024 * 1024).max(MIN_BUDGET_BYTES);
144        Ok(Self { dir, budget_bytes })
145    }
146
147    pub fn dir(&self) -> &Path {
148        &self.dir
149    }
150
151    pub fn budget_bytes(&self) -> u64 {
152        self.budget_bytes
153    }
154
155    pub fn path_for(&self, key: &str) -> PathBuf {
156        self.dir.join(format!("{key}{FILE_SUFFIX}"))
157    }
158
159    pub fn save(
160        &self,
161        key: &str,
162        mut header: KvcHeader,
163        rendered_text: &[u8],
164        payload: &[u8],
165    ) -> io::Result<()> {
166        header.payload_bytes = payload.len() as u64;
167        let path = self.path_for(key);
168        let seq = TMP_SEQ.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
169        let tmp = path.with_extension(format!("{}-{}.kv.tmp", std::process::id(), seq));
170        {
171            let f = OpenOptions::new()
172                .write(true)
173                .create(true)
174                .truncate(true)
175                .open(&tmp)?;
176            let mut w = BufWriter::new(f);
177            header.write(&mut w)?;
178            let text_len = u32::try_from(rendered_text.len())
179                .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "text too large"))?;
180            w.write_all(&text_len.to_le_bytes())?;
181            w.write_all(rendered_text)?;
182            w.write_all(payload)?;
183            w.flush()?;
184        }
185        fs::rename(tmp, path)
186    }
187
188    pub fn load(&self, key: &str) -> io::Result<Option<CacheHit>> {
189        let path = self.path_for(key);
190        match File::open(&path) {
191            Ok(mut f) => {
192                let mut header = KvcHeader::read(&mut f)?;
193                let mut text_len_buf = [0u8; 4];
194                f.read_exact(&mut text_len_buf)?;
195                let text_len = u32::from_le_bytes(text_len_buf) as u64;
196                if text_len > self.budget_bytes {
197                    return Err(io::Error::new(
198                        io::ErrorKind::InvalidData,
199                        "kvc text_len exceeds cache budget",
200                    ));
201                }
202                let text_len = usize::try_from(text_len).map_err(|_| {
203                    io::Error::new(io::ErrorKind::InvalidData, "kvc text_len too large")
204                })?;
205                let mut rendered_text = vec![0u8; text_len];
206                f.read_exact(&mut rendered_text)?;
207                if header.payload_bytes > self.budget_bytes {
208                    return Err(io::Error::new(
209                        io::ErrorKind::InvalidData,
210                        "kvc payload_bytes exceeds cache budget",
211                    ));
212                }
213                let payload_len = usize::try_from(header.payload_bytes).map_err(|_| {
214                    io::Error::new(io::ErrorKind::InvalidData, "kvc payload_bytes too large")
215                })?;
216                let mut payload = vec![0u8; payload_len];
217                f.read_exact(&mut payload)?;
218                let now = SystemTime::now()
219                    .duration_since(UNIX_EPOCH)
220                    .map(|d| d.as_secs())
221                    .unwrap_or(header.last_used_unix);
222                header.last_used_unix = now;
223                header.hit_count = header.hit_count.saturating_add(1);
224                let _ = self.touch_header(key, &header);
225                Ok(Some(CacheHit {
226                    header,
227                    rendered_text,
228                    payload,
229                }))
230            }
231            Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(None),
232            Err(e) => Err(e),
233        }
234    }
235
236    fn touch_header(&self, key: &str, header: &KvcHeader) -> io::Result<()> {
237        let path = self.path_for(key);
238        let mut f = OpenOptions::new().write(true).open(path)?;
239        header.write(&mut f)
240    }
241
242    pub fn evict_to_budget(&self) -> io::Result<usize> {
243        let mut entries: Vec<(PathBuf, u64, u64)> = Vec::new();
244        let mut total: u64 = 0;
245        for entry in fs::read_dir(&self.dir)? {
246            let entry = entry?;
247            let p = entry.path();
248            if p.extension().and_then(|e| e.to_str()) != Some("kv") {
249                continue;
250            }
251            let meta = entry.metadata()?;
252            let len = meta.len();
253            total += len;
254            let lru = match File::open(&p).and_then(|mut f| KvcHeader::read(&mut f)) {
255                Ok(h) => h.last_used_unix,
256                Err(_) => 0,
257            };
258            entries.push((p, len, lru));
259        }
260        if total <= self.budget_bytes {
261            return Ok(0);
262        }
263        entries.sort_by_key(|e| e.2);
264        let mut removed = 0usize;
265        let mut current = total;
266        for (path, size, _) in entries {
267            if current <= self.budget_bytes {
268                break;
269            }
270            if fs::remove_file(&path).is_ok() {
271                current = current.saturating_sub(size);
272                removed += 1;
273            }
274        }
275        Ok(removed)
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn key_is_stable_sha1_hex() {
285        let k = key_for("hello world");
286        assert_eq!(k.len(), 40);
287        assert_eq!(k, "2aae6c35c94fcfb415dbe95f408b9ce91ee846ed");
288    }
289
290    #[test]
291    fn header_round_trip() {
292        let h = KvcHeader::new(2, SaveReason::Cold, 12345, 32768);
293        let mut buf: Vec<u8> = Vec::new();
294        h.write(&mut buf).unwrap();
295        assert_eq!(buf.len(), KVC_HEADER_BYTES);
296        let r = KvcHeader::read(&mut buf.as_slice()).unwrap();
297        assert_eq!(r.quant_bits, 2);
298        assert_eq!(r.save_reason, SaveReason::Cold);
299        assert_eq!(r.token_count, 12345);
300        assert_eq!(r.ctx_size, 32768);
301    }
302
303    #[test]
304    fn save_load_round_trip() {
305        let tmp = tempfile::tempdir().unwrap();
306        let cache = DiskKvCache::new(tmp.path(), 64).unwrap();
307        let text = b"<|user|>hi<|assistant|>";
308        let payload = b"binary session payload bytes";
309        let key = key_for(std::str::from_utf8(text).unwrap());
310        let header = KvcHeader::new(2, SaveReason::Cold, 17, 8192);
311        cache.save(&key, header, text, payload).unwrap();
312
313        let hit = cache.load(&key).unwrap().expect("hit");
314        assert_eq!(hit.rendered_text, text);
315        assert_eq!(hit.payload, payload);
316        assert_eq!(hit.header.token_count, 17);
317        assert_eq!(hit.header.ctx_size, 8192);
318        assert_eq!(hit.header.quant_bits, 2);
319        assert_eq!(hit.header.hit_count, 1);
320    }
321
322    #[test]
323    fn load_missing_returns_none() {
324        let tmp = tempfile::tempdir().unwrap();
325        let cache = DiskKvCache::new(tmp.path(), 64).unwrap();
326        assert!(cache
327            .load("0000000000000000000000000000000000000000")
328            .unwrap()
329            .is_none());
330    }
331
332    #[test]
333    fn evict_to_budget_keeps_recent() {
334        let tmp = tempfile::tempdir().unwrap();
335        let cache = DiskKvCache::new(tmp.path(), 1).unwrap();
336        let big = vec![0u8; 256 * 1024];
337        for i in 0..400 {
338            let key = key_for(&format!("prefix{i}"));
339            let mut header = KvcHeader::new(2, SaveReason::Cold, i as u32, 8192);
340            header.last_used_unix = 1_000_000 + i as u64;
341            cache.save(&key, header, b"", &big).unwrap();
342        }
343        let removed = cache.evict_to_budget().unwrap();
344        assert!(removed > 0, "expected eviction beyond budget");
345    }
346}