Skip to main content

mcp_memory/
store.rs

1use std::fs::{File, OpenOptions};
2use std::io::{BufReader, BufWriter, Read, Write};
3use std::path::{Path, PathBuf};
4
5const MAGIC: &[u8; 8] = b"MCPMEMV1";
6const MAX_RECORD_BYTES: u32 = 1 << 20;
7
8#[repr(u8)]
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum RecordKind {
11    CreateEntity = 0,
12    CreateRelation = 1,
13    AddObservations = 2,
14    DeleteEntity = 3,
15    DeleteObservations = 4,
16    DeleteRelation = 5,
17}
18
19impl RecordKind {
20    #[inline]
21    pub const fn from_u8(v: u8) -> Option<RecordKind> {
22        Some(match v {
23            0 => RecordKind::CreateEntity,
24            1 => RecordKind::CreateRelation,
25            2 => RecordKind::AddObservations,
26            3 => RecordKind::DeleteEntity,
27            4 => RecordKind::DeleteObservations,
28            5 => RecordKind::DeleteRelation,
29            _ => return None,
30        })
31    }
32}
33
34pub struct BinaryStore {
35    writer: BufWriter<File>,
36    path: PathBuf,
37}
38
39impl BinaryStore {
40    pub const fn path(&self) -> &PathBuf {
41        &self.path
42    }
43
44    pub fn new(path: &Path) -> std::io::Result<Self> {
45        let exists = path.exists();
46        let file = OpenOptions::new()
47            .create(true)
48            .append(true)
49            .read(false)
50            .open(path)?;
51
52        let mut writer = BufWriter::with_capacity(65536, file);
53
54        if !exists {
55            writer.write_all(MAGIC)?;
56            writer.flush()?;
57        }
58
59        Ok(Self {
60            writer,
61            path: path.to_path_buf(),
62        })
63    }
64
65    pub fn write_record(&mut self, kind: RecordKind, payload: &[u8]) -> std::io::Result<()> {
66        let total_len = 4 + 1 + payload.len();
67        if total_len as u32 > MAX_RECORD_BYTES {
68            return Err(std::io::Error::new(
69                std::io::ErrorKind::InvalidInput,
70                "Record too large",
71            ));
72        }
73        self.writer.write_all(&(total_len as u32).to_le_bytes())?;
74        self.writer.write_all(&[kind as u8])?;
75        self.writer.write_all(payload)?;
76        Ok(())
77    }
78
79    pub fn flush_and_sync(&mut self) -> std::io::Result<()> {
80        self.writer.flush()?;
81        self.writer.get_ref().sync_data()
82    }
83
84    pub fn replay<F>(&self, mut callback: F) -> std::io::Result<()>
85    where
86        F: FnMut(RecordKind, &[u8]),
87    {
88        let file = match OpenOptions::new().read(true).open(&self.path) {
89            Ok(f) => f,
90            Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
91            Err(e) => return Err(e),
92        };
93
94        let meta = file.metadata()?;
95        if meta.len() == 0 {
96            return Ok(());
97        }
98
99        let mut reader = BufReader::with_capacity(65536, file);
100        let mut magic = [0u8; 8];
101
102        match reader.read_exact(&mut magic) {
103            Ok(()) => {
104                if &magic != MAGIC {
105                    return Ok(());
106                }
107            }
108            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
109            Err(e) => return Err(e),
110        }
111
112        let mut payload_buf = Vec::with_capacity(4096);
113
114        loop {
115            let mut len_buf = [0u8; 4];
116            match reader.read_exact(&mut len_buf) {
117                Ok(()) => {}
118                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
119                Err(e) => return Err(e),
120            }
121            let total_len = u32::from_le_bytes(len_buf) as usize;
122            if total_len < 5 || total_len > MAX_RECORD_BYTES as usize {
123                return Err(std::io::Error::new(
124                    std::io::ErrorKind::InvalidData,
125                    format!("Invalid record length: {total_len}"),
126                ));
127            }
128            let payload_len = total_len - 5;
129
130            let mut kind_buf = [0u8; 1];
131            reader.read_exact(&mut kind_buf)?;
132            let kind_val = kind_buf[0];
133
134            payload_buf.clear();
135            payload_buf.resize(payload_len, 0);
136            if payload_len > 0 {
137                reader.read_exact(&mut payload_buf)?;
138            }
139
140            if let Some(kind) = RecordKind::from_u8(kind_val) {
141                callback(kind, &payload_buf);
142            } else {
143                tracing::warn!("Unknown record kind byte {kind_val}, skipping");
144            }
145        }
146    }
147
148    pub fn close(&mut self) -> std::io::Result<()> {
149        self.flush_and_sync()
150    }
151
152    /// Reopen the file with truncation — discards all existing records.
153    /// Used by `compact` to rewrite a fresh log from in-memory state.
154    pub fn reopen_truncated(&mut self) -> std::io::Result<()> {
155        self.writer.flush()?;
156        let file = OpenOptions::new()
157            .create(true)
158            .write(true)
159            .truncate(true)
160            .open(&self.path)?;
161        let mut writer = BufWriter::with_capacity(65536, file);
162        writer.write_all(MAGIC)?;
163        writer.flush()?;
164        self.writer = writer;
165        Ok(())
166    }
167}
168
169// --- Binary encoding helpers ---
170
171fn encode_str(buf: &mut Vec<u8>, s: &str) -> std::io::Result<()> {
172    let bytes = s.as_bytes();
173    let len = bytes.len();
174    if len > u16::MAX as usize {
175        return Err(std::io::Error::new(
176            std::io::ErrorKind::InvalidInput,
177            format!("string too long (max {} bytes, got {len})", u16::MAX),
178        ));
179    }
180    buf.extend_from_slice(&(len as u16).to_le_bytes());
181    buf.extend_from_slice(bytes);
182    Ok(())
183}
184
185fn decode_str<'a>(data: &'a [u8], offset: &mut usize) -> Option<&'a str> {
186    if *offset + 2 > data.len() {
187        return None;
188    }
189    let len = u16::from_le_bytes([data[*offset], data[*offset + 1]]) as usize;
190    *offset += 2;
191    if *offset + len > data.len() {
192        return None;
193    }
194    let s = std::str::from_utf8(&data[*offset..*offset + len]).ok()?;
195    *offset += len;
196    Some(s)
197}
198
199fn decode_count(data: &[u8], offset: &mut usize) -> Option<usize> {
200    if *offset + 4 > data.len() {
201        return None;
202    }
203    let count = u32::from_le_bytes([
204        data[*offset],
205        data[*offset + 1],
206        data[*offset + 2],
207        data[*offset + 3],
208    ]) as usize;
209    *offset += 4;
210    Some(count)
211}
212
213pub fn encode_create_entity(buf: &mut Vec<u8>, name: &str, entity_type: &str, observations: &[String]) -> std::io::Result<()> {
214    encode_str(buf, name)?;
215    encode_str(buf, entity_type)?;
216    buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
217    for obs in observations {
218        encode_str(buf, obs)?;
219    }
220    Ok(())
221}
222
223pub fn decode_create_entity(data: &[u8]) -> Option<(&str, &str, Vec<&str>)> {
224    let mut offset = 0;
225    let name = decode_str(data, &mut offset)?;
226    let entity_type = decode_str(data, &mut offset)?;
227    let count = decode_count(data, &mut offset)?;
228    let mut observations = Vec::with_capacity(count);
229    for _ in 0..count {
230        observations.push(decode_str(data, &mut offset)?);
231    }
232    Some((name, entity_type, observations))
233}
234
235pub fn encode_create_relation(buf: &mut Vec<u8>, from: &str, to: &str, relation_type: &str) -> std::io::Result<()> {
236    encode_str(buf, from)?;
237    encode_str(buf, to)?;
238    encode_str(buf, relation_type)
239}
240
241pub fn decode_create_relation(data: &[u8]) -> Option<(&str, &str, &str)> {
242    let mut offset = 0;
243    let from = decode_str(data, &mut offset)?;
244    let to = decode_str(data, &mut offset)?;
245    let relation_type = decode_str(data, &mut offset)?;
246    Some((from, to, relation_type))
247}
248
249pub fn encode_add_observations(buf: &mut Vec<u8>, name: &str, observations: &[String]) -> std::io::Result<()> {
250    encode_str(buf, name)?;
251    buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
252    for obs in observations {
253        encode_str(buf, obs)?;
254    }
255    Ok(())
256}
257
258pub fn decode_add_observations(data: &[u8]) -> Option<(&str, Vec<&str>)> {
259    let mut offset = 0;
260    let name = decode_str(data, &mut offset)?;
261    let count = decode_count(data, &mut offset)?;
262    let mut observations = Vec::with_capacity(count);
263    for _ in 0..count {
264        observations.push(decode_str(data, &mut offset)?);
265    }
266    Some((name, observations))
267}
268
269pub fn encode_delete_entity(buf: &mut Vec<u8>, name: &str) -> std::io::Result<()> {
270    encode_str(buf, name)
271}
272
273pub fn decode_delete_entity(data: &[u8]) -> Option<&str> {
274    let mut offset = 0;
275    decode_str(data, &mut offset)
276}
277
278pub fn encode_delete_observations(buf: &mut Vec<u8>, name: &str, observations: &[String]) -> std::io::Result<()> {
279    encode_str(buf, name)?;
280    buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
281    for obs in observations {
282        encode_str(buf, obs)?;
283    }
284    Ok(())
285}
286
287pub fn decode_delete_observations(data: &[u8]) -> Option<(&str, Vec<&str>)> {
288    decode_add_observations(data)
289}
290
291pub fn encode_delete_relation(buf: &mut Vec<u8>, from: &str, to: &str, relation_type: &str) -> std::io::Result<()> {
292    encode_str(buf, from)?;
293    encode_str(buf, to)?;
294    encode_str(buf, relation_type)
295}
296
297pub fn decode_delete_relation(data: &[u8]) -> Option<(&str, &str, &str)> {
298    decode_create_relation(data)
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use std::sync::atomic::{AtomicU64, Ordering};
305
306    static COUNTER: AtomicU64 = AtomicU64::new(0);
307
308    fn tmp_path() -> PathBuf {
309        let pid = std::process::id();
310        let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
311        std::env::temp_dir().join(format!("mcp_store_test_{pid}_{seq}.bin"))
312    }
313
314    #[test]
315    fn test_write_and_replay() {
316        let path = tmp_path();
317        let mut store = BinaryStore::new(&path).unwrap();
318
319        let mut buf = Vec::new();
320        encode_create_entity(&mut buf, "Alice", "person", &["likes coffee".into()]).unwrap();
321        store.write_record(RecordKind::CreateEntity, &buf).unwrap();
322
323        buf.clear();
324        encode_create_entity(&mut buf, "Bob", "person", &[]).unwrap();
325        store.write_record(RecordKind::CreateEntity, &buf).unwrap();
326
327        drop(store);
328
329        let mut replayed: Vec<(RecordKind, Vec<u8>)> = Vec::new();
330        let replay_store = BinaryStore::new(&path).unwrap();
331        replay_store
332            .replay(|kind, data| {
333                replayed.push((kind, data.to_vec()));
334            })
335            .unwrap();
336
337        assert_eq!(replayed.len(), 2);
338        assert_eq!(replayed[0].0, RecordKind::CreateEntity);
339        assert_eq!(
340            decode_create_entity(&replayed[0].1).unwrap().0,
341            "Alice"
342        );
343
344        let _ = std::fs::remove_file(&path);
345    }
346
347    #[test]
348    fn test_encode_decode_roundtrip() {
349        let mut buf = Vec::new();
350        encode_create_entity(
351            &mut buf,
352            "TestEntity",
353            "test_type",
354            &["obs1".into(), "obs2".into()],
355        )
356        .unwrap();
357        let (name, etype, obs) = decode_create_entity(&buf).unwrap();
358        assert_eq!(name, "TestEntity");
359        assert_eq!(etype, "test_type");
360        assert_eq!(obs, vec!["obs1", "obs2"]);
361    }
362
363    #[test]
364    fn test_empty_file() {
365        let path = tmp_path();
366        let store = BinaryStore::new(&path).unwrap();
367        drop(store);
368
369        let mut count = 0;
370        let replay_store = BinaryStore::new(&path).unwrap();
371        replay_store.replay(|_, _| count += 1).unwrap();
372        assert_eq!(count, 0);
373        let _ = std::fs::remove_file(&path);
374    }
375
376    #[test]
377    fn test_write_all_record_kinds() {
378        let path = tmp_path();
379        let mut store = BinaryStore::new(&path).unwrap();
380        let mut buf = Vec::new();
381
382        // Write one of each record kind.
383        encode_create_entity(&mut buf, "E1", "t1", &["o1".into()]).unwrap();
384        store.write_record(RecordKind::CreateEntity, &buf).unwrap();
385
386        buf.clear();
387        encode_create_relation(&mut buf, "E1", "E2", "knows").unwrap();
388        store.write_record(RecordKind::CreateRelation, &buf).unwrap();
389
390        buf.clear();
391        encode_add_observations(&mut buf, "E1", &["o2".into()]).unwrap();
392        store.write_record(RecordKind::AddObservations, &buf).unwrap();
393
394        buf.clear();
395        encode_delete_entity(&mut buf, "E1").unwrap();
396        store.write_record(RecordKind::DeleteEntity, &buf).unwrap();
397
398        buf.clear();
399        encode_delete_observations(&mut buf, "E1", &["o1".into()]).unwrap();
400        store.write_record(RecordKind::DeleteObservations, &buf).unwrap();
401
402        buf.clear();
403        encode_delete_relation(&mut buf, "E1", "E2", "knows").unwrap();
404        store.write_record(RecordKind::DeleteRelation, &buf).unwrap();
405
406        drop(store);
407
408        let mut kinds = Vec::new();
409        let replay_store = BinaryStore::new(&path).unwrap();
410        replay_store
411            .replay(|kind, _| {
412                kinds.push(kind);
413            })
414            .unwrap();
415
416        assert_eq!(kinds.len(), 6);
417        assert_eq!(kinds[0], RecordKind::CreateEntity);
418        assert_eq!(kinds[1], RecordKind::CreateRelation);
419        assert_eq!(kinds[2], RecordKind::AddObservations);
420        assert_eq!(kinds[3], RecordKind::DeleteEntity);
421        assert_eq!(kinds[4], RecordKind::DeleteObservations);
422        assert_eq!(kinds[5], RecordKind::DeleteRelation);
423        let _ = std::fs::remove_file(&path);
424    }
425
426    #[test]
427    fn test_reopen_truncated() {
428        let path = tmp_path();
429        let mut store = BinaryStore::new(&path).unwrap();
430        let mut buf = Vec::new();
431        encode_create_entity(&mut buf, "E1", "t1", &[]).unwrap();
432        store.write_record(RecordKind::CreateEntity, &buf).unwrap();
433        drop(store);
434
435        // Reopen with truncation.
436        let mut store2 = BinaryStore::new(&path).unwrap();
437        store2.reopen_truncated().unwrap();
438
439        let mut buf2 = Vec::new();
440        encode_create_entity(&mut buf2, "E2", "t2", &[]).unwrap();
441        store2.write_record(RecordKind::CreateEntity, &buf2).unwrap();
442        drop(store2);
443
444        let mut names = Vec::new();
445        let replay_store = BinaryStore::new(&path).unwrap();
446        replay_store
447            .replay(|_, data| {
448                if let Some((name, _, _)) = decode_create_entity(data) {
449                    names.push(name.to_string());
450                }
451            })
452            .unwrap();
453
454        // Only E2 should remain — E1 was truncated away.
455        assert_eq!(names, vec!["E2"]);
456        let _ = std::fs::remove_file(&path);
457    }
458
459    #[test]
460    fn test_encode_decode_add_observations() {
461        let mut buf = Vec::new();
462        encode_add_observations(&mut buf, "Alice", &["obs1".into(), "obs2".into()]).unwrap();
463        let (name, obs) = decode_add_observations(&buf).unwrap();
464        assert_eq!(name, "Alice");
465        assert_eq!(obs, vec!["obs1", "obs2"]);
466    }
467
468    #[test]
469    fn test_encode_decode_delete_entity() {
470        let mut buf = Vec::new();
471        encode_delete_entity(&mut buf, "ToDelete").unwrap();
472        let name = decode_delete_entity(&buf).unwrap();
473        assert_eq!(name, "ToDelete");
474    }
475
476    #[test]
477    fn test_encode_decode_delete_observations() {
478        let mut buf = Vec::new();
479        encode_delete_observations(&mut buf, "Alice", &["o1".into()]).unwrap();
480        let (name, obs) = decode_delete_observations(&buf).unwrap();
481        assert_eq!(name, "Alice");
482        assert_eq!(obs, vec!["o1"]);
483    }
484
485    #[test]
486    fn test_encode_decode_delete_relation() {
487        let mut buf = Vec::new();
488        encode_delete_relation(&mut buf, "A", "B", "knows").unwrap();
489        let (from, to, rtype) = decode_delete_relation(&buf).unwrap();
490        assert_eq!(from, "A");
491        assert_eq!(to, "B");
492        assert_eq!(rtype, "knows");
493    }
494
495    #[test]
496    fn test_record_too_large() {
497        let path = tmp_path();
498        let mut store = BinaryStore::new(&path).unwrap();
499        let huge = vec![0u8; (1 << 20) + 1];
500        let result = store.write_record(RecordKind::CreateEntity, &huge);
501        assert!(result.is_err());
502        let _ = std::fs::remove_file(&path);
503    }
504
505    #[test]
506    fn test_multiple_writes_and_replay() {
507        let path = tmp_path();
508        let mut store = BinaryStore::new(&path).unwrap();
509        for i in 0..100 {
510            let mut buf = Vec::new();
511            encode_create_entity(&mut buf, &format!("E{i}"), "type", &[]).unwrap();
512            store.write_record(RecordKind::CreateEntity, &buf).unwrap();
513        }
514        drop(store);
515
516        let mut count = 0;
517        let replay_store = BinaryStore::new(&path).unwrap();
518        replay_store
519            .replay(|kind, _| {
520                assert_eq!(kind, RecordKind::CreateEntity);
521                count += 1;
522            })
523            .unwrap();
524        assert_eq!(count, 100);
525        let _ = std::fs::remove_file(&path);
526    }
527
528    #[test]
529    fn test_truncated_log_handling() {
530        let path = tmp_path();
531        let mut store = BinaryStore::new(&path).unwrap();
532        let mut buf = Vec::new();
533        encode_create_entity(&mut buf, "Alice", "person", &[]).unwrap();
534        store.write_record(RecordKind::CreateEntity, &buf).unwrap();
535        drop(store);
536
537        // Truncate the file manually (simulate crash during write).
538        let file = OpenOptions::new().write(true).open(&path).unwrap();
539        file.set_len(10).unwrap(); // cut off after MAGIC
540        drop(file);
541
542        // Replay should handle gracefully.
543        let replay_store = BinaryStore::new(&path).unwrap();
544        let mut count = 0;
545        replay_store.replay(|_, _| count += 1).unwrap();
546        assert_eq!(count, 0);
547        let _ = std::fs::remove_file(&path);
548    }
549}