Skip to main content

mcp_memory/
store.rs

1use std::fs::{File, OpenOptions};
2use std::io::{BufReader, BufWriter, Read, Write};
3use std::path::{Path, PathBuf};
4use std::sync::Arc;
5
6const MAGIC: &[u8; 8] = b"MCPMEMV1";
7const MAX_RECORD_BYTES: u32 = 1 << 20;
8
9#[repr(u8)]
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum RecordKind {
12    CreateEntity = 0,
13    CreateRelation = 1,
14    AddObservations = 2,
15    DeleteEntity = 3,
16    DeleteObservations = 4,
17    DeleteRelation = 5,
18    /// Opens a transaction: records that follow are buffered on replay and only
19    /// applied once a matching [`RecordKind::TxnCommit`] is seen. An unclosed
20    /// transaction (no commit before EOF) is discarded — this is how
21    /// multi-record operations like `merge_entities` stay crash-atomic.
22    TxnBegin = 6,
23    /// Closes a transaction opened by [`RecordKind::TxnBegin`].
24    TxnCommit = 7,
25}
26
27impl RecordKind {
28    #[inline]
29    pub const fn from_u8(v: u8) -> Option<RecordKind> {
30        Some(match v {
31            0 => RecordKind::CreateEntity,
32            1 => RecordKind::CreateRelation,
33            2 => RecordKind::AddObservations,
34            3 => RecordKind::DeleteEntity,
35            4 => RecordKind::DeleteObservations,
36            5 => RecordKind::DeleteRelation,
37            6 => RecordKind::TxnBegin,
38            7 => RecordKind::TxnCommit,
39            _ => return None,
40        })
41    }
42}
43
44pub struct BinaryStore {
45    writer: BufWriter<File>,
46    path: PathBuf,
47    /// A separate file handle on the same underlying file, used by the
48    /// background sync thread to call `fsync` without holding any lock.
49    /// Cloned via `File::try_clone()` during construction.
50    pub(crate) sync_file: Arc<File>,
51}
52
53impl BinaryStore {
54    pub const fn path(&self) -> &PathBuf {
55        &self.path
56    }
57
58    pub fn new(path: &Path) -> std::io::Result<Self> {
59        let exists = path.exists();
60        let file = OpenOptions::new()
61            .create(true)
62            .append(true)
63            .read(false)
64            .open(path)?;
65
66        let sync_file = Arc::new(file.try_clone()?);
67        let mut writer = BufWriter::with_capacity(65536, file);
68
69        if !exists {
70            writer.write_all(MAGIC)?;
71            writer.flush()?;
72        }
73
74        Ok(Self {
75            writer,
76            path: path.to_path_buf(),
77            sync_file,
78        })
79    }
80
81    pub fn write_record(&mut self, kind: RecordKind, payload: &[u8]) -> std::io::Result<()> {
82        let total_len = 4 + 1 + payload.len();
83        if total_len as u32 > MAX_RECORD_BYTES {
84            return Err(std::io::Error::new(
85                std::io::ErrorKind::InvalidInput,
86                "Record too large",
87            ));
88        }
89        self.writer.write_all(&(total_len as u32).to_le_bytes())?;
90        self.writer.write_all(&[kind as u8])?;
91        self.writer.write_all(payload)?;
92        Ok(())
93    }
94
95    /// Flush the `BufWriter` to the kernel buffer (no `fsync`).
96    pub fn flush(&mut self) -> std::io::Result<()> {
97        self.writer.flush()
98    }
99
100    /// `fsync` the underlying file (kernel buffer → disk).
101    pub fn sync(&mut self) -> std::io::Result<()> {
102        self.writer.get_ref().sync_data()
103    }
104
105    pub fn flush_and_sync(&mut self) -> std::io::Result<()> {
106        self.flush()?;
107        self.sync()
108    }
109
110    pub fn replay<F>(&self, mut callback: F) -> std::io::Result<()>
111    where
112        F: FnMut(RecordKind, &[u8]),
113    {
114        let file = match OpenOptions::new().read(true).open(&self.path) {
115            Ok(f) => f,
116            Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
117            Err(e) => return Err(e),
118        };
119
120        let meta = file.metadata()?;
121        if meta.len() == 0 {
122            return Ok(());
123        }
124
125        let mut reader = BufReader::with_capacity(65536, file);
126        let mut magic = [0u8; 8];
127
128        match reader.read_exact(&mut magic) {
129            Ok(()) => {
130                if &magic != MAGIC {
131                    return Ok(());
132                }
133            }
134            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
135            Err(e) => return Err(e),
136        }
137
138        let mut payload_buf = Vec::with_capacity(4096);
139
140        loop {
141            let mut len_buf = [0u8; 4];
142            match reader.read_exact(&mut len_buf) {
143                Ok(()) => {}
144                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
145                Err(e) => return Err(e),
146            }
147            let total_len = u32::from_le_bytes(len_buf) as usize;
148            if total_len < 5 || total_len > MAX_RECORD_BYTES as usize {
149                return Err(std::io::Error::new(
150                    std::io::ErrorKind::InvalidData,
151                    format!("Invalid record length: {total_len}"),
152                ));
153            }
154            let payload_len = total_len - 5;
155
156            let mut kind_buf = [0u8; 1];
157            reader.read_exact(&mut kind_buf)?;
158            let kind_val = kind_buf[0];
159
160            payload_buf.clear();
161            payload_buf.resize(payload_len, 0);
162            if payload_len > 0 {
163                reader.read_exact(&mut payload_buf)?;
164            }
165
166            if let Some(kind) = RecordKind::from_u8(kind_val) {
167                callback(kind, &payload_buf);
168            } else {
169                tracing::warn!("Unknown record kind byte {kind_val}, skipping");
170            }
171        }
172    }
173
174    pub fn close(&mut self) -> std::io::Result<()> {
175        self.flush_and_sync()
176    }
177
178    /// Reopen the file with truncation — discards all existing records.
179    /// Used by `compact` to rewrite a fresh log from in-memory state.
180    pub fn reopen_truncated(&mut self) -> std::io::Result<()> {
181        self.writer.flush()?;
182        let file = OpenOptions::new()
183            .create(true)
184            .write(true)
185            .truncate(true)
186            .open(&self.path)?;
187        self.sync_file = Arc::new(file.try_clone()?);
188        let mut writer = BufWriter::with_capacity(65536, file);
189        writer.write_all(MAGIC)?;
190        writer.flush()?;
191        self.writer = writer;
192        Ok(())
193    }
194}
195
196// --- Binary encoding helpers ---
197
198fn encode_str(buf: &mut Vec<u8>, s: &str) -> std::io::Result<()> {
199    let bytes = s.as_bytes();
200    let len = bytes.len();
201    if len > u16::MAX as usize {
202        return Err(std::io::Error::new(
203            std::io::ErrorKind::InvalidInput,
204            format!("string too long (max {} bytes, got {len})", u16::MAX),
205        ));
206    }
207    buf.extend_from_slice(&(len as u16).to_le_bytes());
208    buf.extend_from_slice(bytes);
209    Ok(())
210}
211
212fn decode_str<'a>(data: &'a [u8], offset: &mut usize) -> Option<&'a str> {
213    if *offset + 2 > data.len() {
214        return None;
215    }
216    let len = u16::from_le_bytes([data[*offset], data[*offset + 1]]) as usize;
217    *offset += 2;
218    if *offset + len > data.len() {
219        return None;
220    }
221    let s = std::str::from_utf8(&data[*offset..*offset + len]).ok()?;
222    *offset += len;
223    Some(s)
224}
225
226fn decode_count(data: &[u8], offset: &mut usize) -> Option<usize> {
227    if *offset + 4 > data.len() {
228        return None;
229    }
230    let count = u32::from_le_bytes([
231        data[*offset],
232        data[*offset + 1],
233        data[*offset + 2],
234        data[*offset + 3],
235    ]) as usize;
236    *offset += 4;
237    Some(count)
238}
239
240pub fn encode_create_entity(buf: &mut Vec<u8>, name: &str, entity_type: &str, observations: &[String]) -> std::io::Result<()> {
241    encode_str(buf, name)?;
242    encode_str(buf, entity_type)?;
243    buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
244    for obs in observations {
245        encode_str(buf, obs)?;
246    }
247    Ok(())
248}
249
250pub fn decode_create_entity(data: &[u8]) -> Option<(&str, &str, Vec<&str>)> {
251    let mut offset = 0;
252    let name = decode_str(data, &mut offset)?;
253    let entity_type = decode_str(data, &mut offset)?;
254    let count = decode_count(data, &mut offset)?;
255    let mut observations = Vec::with_capacity(count);
256    for _ in 0..count {
257        observations.push(decode_str(data, &mut offset)?);
258    }
259    Some((name, entity_type, observations))
260}
261
262pub fn encode_create_relation(buf: &mut Vec<u8>, from: &str, to: &str, relation_type: &str) -> std::io::Result<()> {
263    encode_str(buf, from)?;
264    encode_str(buf, to)?;
265    encode_str(buf, relation_type)
266}
267
268pub fn decode_create_relation(data: &[u8]) -> Option<(&str, &str, &str)> {
269    let mut offset = 0;
270    let from = decode_str(data, &mut offset)?;
271    let to = decode_str(data, &mut offset)?;
272    let relation_type = decode_str(data, &mut offset)?;
273    Some((from, to, relation_type))
274}
275
276pub fn encode_add_observations(buf: &mut Vec<u8>, name: &str, observations: &[String]) -> std::io::Result<()> {
277    encode_str(buf, name)?;
278    buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
279    for obs in observations {
280        encode_str(buf, obs)?;
281    }
282    Ok(())
283}
284
285pub fn decode_add_observations(data: &[u8]) -> Option<(&str, Vec<&str>)> {
286    let mut offset = 0;
287    let name = decode_str(data, &mut offset)?;
288    let count = decode_count(data, &mut offset)?;
289    let mut observations = Vec::with_capacity(count);
290    for _ in 0..count {
291        observations.push(decode_str(data, &mut offset)?);
292    }
293    Some((name, observations))
294}
295
296pub fn encode_delete_entity(buf: &mut Vec<u8>, name: &str) -> std::io::Result<()> {
297    encode_str(buf, name)
298}
299
300pub fn decode_delete_entity(data: &[u8]) -> Option<&str> {
301    let mut offset = 0;
302    decode_str(data, &mut offset)
303}
304
305pub fn encode_delete_observations(buf: &mut Vec<u8>, name: &str, observations: &[String]) -> std::io::Result<()> {
306    encode_str(buf, name)?;
307    buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
308    for obs in observations {
309        encode_str(buf, obs)?;
310    }
311    Ok(())
312}
313
314pub fn decode_delete_observations(data: &[u8]) -> Option<(&str, Vec<&str>)> {
315    decode_add_observations(data)
316}
317
318pub fn encode_delete_relation(buf: &mut Vec<u8>, from: &str, to: &str, relation_type: &str) -> std::io::Result<()> {
319    encode_str(buf, from)?;
320    encode_str(buf, to)?;
321    encode_str(buf, relation_type)
322}
323
324pub fn decode_delete_relation(data: &[u8]) -> Option<(&str, &str, &str)> {
325    decode_create_relation(data)
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use std::sync::atomic::{AtomicU64, Ordering};
332
333    static COUNTER: AtomicU64 = AtomicU64::new(0);
334
335    fn tmp_path() -> PathBuf {
336        let pid = std::process::id();
337        let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
338        std::env::temp_dir().join(format!("mcp_store_test_{pid}_{seq}.bin"))
339    }
340
341    #[test]
342    fn test_write_and_replay() {
343        let path = tmp_path();
344        let mut store = BinaryStore::new(&path).unwrap();
345
346        let mut buf = Vec::new();
347        encode_create_entity(&mut buf, "Alice", "person", &["likes coffee".into()]).unwrap();
348        store.write_record(RecordKind::CreateEntity, &buf).unwrap();
349
350        buf.clear();
351        encode_create_entity(&mut buf, "Bob", "person", &[]).unwrap();
352        store.write_record(RecordKind::CreateEntity, &buf).unwrap();
353
354        drop(store);
355
356        let mut replayed: Vec<(RecordKind, Vec<u8>)> = Vec::new();
357        let replay_store = BinaryStore::new(&path).unwrap();
358        replay_store
359            .replay(|kind, data| {
360                replayed.push((kind, data.to_vec()));
361            })
362            .unwrap();
363
364        assert_eq!(replayed.len(), 2);
365        assert_eq!(replayed[0].0, RecordKind::CreateEntity);
366        assert_eq!(
367            decode_create_entity(&replayed[0].1).unwrap().0,
368            "Alice"
369        );
370
371        let _ = std::fs::remove_file(&path);
372    }
373
374    #[test]
375    fn test_encode_decode_roundtrip() {
376        let mut buf = Vec::new();
377        encode_create_entity(
378            &mut buf,
379            "TestEntity",
380            "test_type",
381            &["obs1".into(), "obs2".into()],
382        )
383        .unwrap();
384        let (name, etype, obs) = decode_create_entity(&buf).unwrap();
385        assert_eq!(name, "TestEntity");
386        assert_eq!(etype, "test_type");
387        assert_eq!(obs, vec!["obs1", "obs2"]);
388    }
389
390    #[test]
391    fn test_empty_file() {
392        let path = tmp_path();
393        let store = BinaryStore::new(&path).unwrap();
394        drop(store);
395
396        let mut count = 0;
397        let replay_store = BinaryStore::new(&path).unwrap();
398        replay_store.replay(|_, _| count += 1).unwrap();
399        assert_eq!(count, 0);
400        let _ = std::fs::remove_file(&path);
401    }
402
403    #[test]
404    fn test_write_all_record_kinds() {
405        let path = tmp_path();
406        let mut store = BinaryStore::new(&path).unwrap();
407        let mut buf = Vec::new();
408
409        // Write one of each record kind.
410        encode_create_entity(&mut buf, "E1", "t1", &["o1".into()]).unwrap();
411        store.write_record(RecordKind::CreateEntity, &buf).unwrap();
412
413        buf.clear();
414        encode_create_relation(&mut buf, "E1", "E2", "knows").unwrap();
415        store.write_record(RecordKind::CreateRelation, &buf).unwrap();
416
417        buf.clear();
418        encode_add_observations(&mut buf, "E1", &["o2".into()]).unwrap();
419        store.write_record(RecordKind::AddObservations, &buf).unwrap();
420
421        buf.clear();
422        encode_delete_entity(&mut buf, "E1").unwrap();
423        store.write_record(RecordKind::DeleteEntity, &buf).unwrap();
424
425        buf.clear();
426        encode_delete_observations(&mut buf, "E1", &["o1".into()]).unwrap();
427        store.write_record(RecordKind::DeleteObservations, &buf).unwrap();
428
429        buf.clear();
430        encode_delete_relation(&mut buf, "E1", "E2", "knows").unwrap();
431        store.write_record(RecordKind::DeleteRelation, &buf).unwrap();
432
433        drop(store);
434
435        let mut kinds = Vec::new();
436        let replay_store = BinaryStore::new(&path).unwrap();
437        replay_store
438            .replay(|kind, _| {
439                kinds.push(kind);
440            })
441            .unwrap();
442
443        assert_eq!(kinds.len(), 6);
444        assert_eq!(kinds[0], RecordKind::CreateEntity);
445        assert_eq!(kinds[1], RecordKind::CreateRelation);
446        assert_eq!(kinds[2], RecordKind::AddObservations);
447        assert_eq!(kinds[3], RecordKind::DeleteEntity);
448        assert_eq!(kinds[4], RecordKind::DeleteObservations);
449        assert_eq!(kinds[5], RecordKind::DeleteRelation);
450        let _ = std::fs::remove_file(&path);
451    }
452
453    #[test]
454    fn test_reopen_truncated() {
455        let path = tmp_path();
456        let mut store = BinaryStore::new(&path).unwrap();
457        let mut buf = Vec::new();
458        encode_create_entity(&mut buf, "E1", "t1", &[]).unwrap();
459        store.write_record(RecordKind::CreateEntity, &buf).unwrap();
460        drop(store);
461
462        // Reopen with truncation.
463        let mut store2 = BinaryStore::new(&path).unwrap();
464        store2.reopen_truncated().unwrap();
465
466        let mut buf2 = Vec::new();
467        encode_create_entity(&mut buf2, "E2", "t2", &[]).unwrap();
468        store2.write_record(RecordKind::CreateEntity, &buf2).unwrap();
469        drop(store2);
470
471        let mut names = Vec::new();
472        let replay_store = BinaryStore::new(&path).unwrap();
473        replay_store
474            .replay(|_, data| {
475                if let Some((name, _, _)) = decode_create_entity(data) {
476                    names.push(name.to_string());
477                }
478            })
479            .unwrap();
480
481        // Only E2 should remain — E1 was truncated away.
482        assert_eq!(names, vec!["E2"]);
483        let _ = std::fs::remove_file(&path);
484    }
485
486    #[test]
487    fn test_encode_decode_add_observations() {
488        let mut buf = Vec::new();
489        encode_add_observations(&mut buf, "Alice", &["obs1".into(), "obs2".into()]).unwrap();
490        let (name, obs) = decode_add_observations(&buf).unwrap();
491        assert_eq!(name, "Alice");
492        assert_eq!(obs, vec!["obs1", "obs2"]);
493    }
494
495    #[test]
496    fn test_encode_decode_delete_entity() {
497        let mut buf = Vec::new();
498        encode_delete_entity(&mut buf, "ToDelete").unwrap();
499        let name = decode_delete_entity(&buf).unwrap();
500        assert_eq!(name, "ToDelete");
501    }
502
503    #[test]
504    fn test_encode_decode_delete_observations() {
505        let mut buf = Vec::new();
506        encode_delete_observations(&mut buf, "Alice", &["o1".into()]).unwrap();
507        let (name, obs) = decode_delete_observations(&buf).unwrap();
508        assert_eq!(name, "Alice");
509        assert_eq!(obs, vec!["o1"]);
510    }
511
512    #[test]
513    fn test_encode_decode_delete_relation() {
514        let mut buf = Vec::new();
515        encode_delete_relation(&mut buf, "A", "B", "knows").unwrap();
516        let (from, to, rtype) = decode_delete_relation(&buf).unwrap();
517        assert_eq!(from, "A");
518        assert_eq!(to, "B");
519        assert_eq!(rtype, "knows");
520    }
521
522    #[test]
523    fn test_record_too_large() {
524        let path = tmp_path();
525        let mut store = BinaryStore::new(&path).unwrap();
526        let huge = vec![0u8; (1 << 20) + 1];
527        let result = store.write_record(RecordKind::CreateEntity, &huge);
528        assert!(result.is_err());
529        let _ = std::fs::remove_file(&path);
530    }
531
532    #[test]
533    fn test_multiple_writes_and_replay() {
534        let path = tmp_path();
535        let mut store = BinaryStore::new(&path).unwrap();
536        for i in 0..100 {
537            let mut buf = Vec::new();
538            encode_create_entity(&mut buf, &format!("E{i}"), "type", &[]).unwrap();
539            store.write_record(RecordKind::CreateEntity, &buf).unwrap();
540        }
541        drop(store);
542
543        let mut count = 0;
544        let replay_store = BinaryStore::new(&path).unwrap();
545        replay_store
546            .replay(|kind, _| {
547                assert_eq!(kind, RecordKind::CreateEntity);
548                count += 1;
549            })
550            .unwrap();
551        assert_eq!(count, 100);
552        let _ = std::fs::remove_file(&path);
553    }
554
555    #[test]
556    fn test_truncated_log_handling() {
557        let path = tmp_path();
558        let mut store = BinaryStore::new(&path).unwrap();
559        let mut buf = Vec::new();
560        encode_create_entity(&mut buf, "Alice", "person", &[]).unwrap();
561        store.write_record(RecordKind::CreateEntity, &buf).unwrap();
562        drop(store);
563
564        // Truncate the file manually (simulate crash during write).
565        let file = OpenOptions::new().write(true).open(&path).unwrap();
566        file.set_len(10).unwrap(); // cut off after MAGIC
567        drop(file);
568
569        // Replay should handle gracefully.
570        let replay_store = BinaryStore::new(&path).unwrap();
571        let mut count = 0;
572        replay_store.replay(|_, _| count += 1).unwrap();
573        assert_eq!(count, 0);
574        let _ = std::fs::remove_file(&path);
575    }
576}