Skip to main content

mcp_memory/
store.rs

1use std::fs::{File, OpenOptions};
2use std::io::{BufReader, BufWriter, Read, Write};
3use std::path::{Path, PathBuf};
4
5const MAGIC: &[u8; 8] = b"MCPMEMV1";
6const MAX_RECORD_BYTES: u32 = 1 << 20;
7
8#[repr(u8)]
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum RecordKind {
11    CreateEntity = 0,
12    CreateRelation = 1,
13    AddObservations = 2,
14    DeleteEntity = 3,
15    DeleteObservations = 4,
16    DeleteRelation = 5,
17    /// Opens a transaction: records that follow are buffered on replay and only
18    /// applied once a matching [`RecordKind::TxnCommit`] is seen. An unclosed
19    /// transaction (no commit before EOF) is discarded — this is how
20    /// multi-record operations like `merge_entities` stay crash-atomic.
21    TxnBegin = 6,
22    /// Closes a transaction opened by [`RecordKind::TxnBegin`].
23    TxnCommit = 7,
24}
25
26impl RecordKind {
27    #[inline]
28    pub const fn from_u8(v: u8) -> Option<RecordKind> {
29        Some(match v {
30            0 => RecordKind::CreateEntity,
31            1 => RecordKind::CreateRelation,
32            2 => RecordKind::AddObservations,
33            3 => RecordKind::DeleteEntity,
34            4 => RecordKind::DeleteObservations,
35            5 => RecordKind::DeleteRelation,
36            6 => RecordKind::TxnBegin,
37            7 => RecordKind::TxnCommit,
38            _ => return None,
39        })
40    }
41}
42
43pub struct BinaryStore {
44    writer: BufWriter<File>,
45    path: PathBuf,
46}
47
48impl BinaryStore {
49    pub const fn path(&self) -> &PathBuf {
50        &self.path
51    }
52
53    pub fn new(path: &Path) -> std::io::Result<Self> {
54        let exists = path.exists();
55        let file = OpenOptions::new()
56            .create(true)
57            .append(true)
58            .read(false)
59            .open(path)?;
60
61        let mut writer = BufWriter::with_capacity(65536, file);
62
63        if !exists {
64            writer.write_all(MAGIC)?;
65            writer.flush()?;
66        }
67
68        Ok(Self {
69            writer,
70            path: path.to_path_buf(),
71        })
72    }
73
74    pub fn write_record(&mut self, kind: RecordKind, payload: &[u8]) -> std::io::Result<()> {
75        let total_len = 4 + 1 + payload.len();
76        if total_len as u32 > MAX_RECORD_BYTES {
77            return Err(std::io::Error::new(
78                std::io::ErrorKind::InvalidInput,
79                "Record too large",
80            ));
81        }
82        self.writer.write_all(&(total_len as u32).to_le_bytes())?;
83        self.writer.write_all(&[kind as u8])?;
84        self.writer.write_all(payload)?;
85        Ok(())
86    }
87
88    pub fn flush_and_sync(&mut self) -> std::io::Result<()> {
89        self.writer.flush()?;
90        self.writer.get_ref().sync_data()
91    }
92
93    pub fn replay<F>(&self, mut callback: F) -> std::io::Result<()>
94    where
95        F: FnMut(RecordKind, &[u8]),
96    {
97        let file = match OpenOptions::new().read(true).open(&self.path) {
98            Ok(f) => f,
99            Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
100            Err(e) => return Err(e),
101        };
102
103        let meta = file.metadata()?;
104        if meta.len() == 0 {
105            return Ok(());
106        }
107
108        let mut reader = BufReader::with_capacity(65536, file);
109        let mut magic = [0u8; 8];
110
111        match reader.read_exact(&mut magic) {
112            Ok(()) => {
113                if &magic != MAGIC {
114                    return Ok(());
115                }
116            }
117            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
118            Err(e) => return Err(e),
119        }
120
121        let mut payload_buf = Vec::with_capacity(4096);
122
123        loop {
124            let mut len_buf = [0u8; 4];
125            match reader.read_exact(&mut len_buf) {
126                Ok(()) => {}
127                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
128                Err(e) => return Err(e),
129            }
130            let total_len = u32::from_le_bytes(len_buf) as usize;
131            if total_len < 5 || total_len > MAX_RECORD_BYTES as usize {
132                return Err(std::io::Error::new(
133                    std::io::ErrorKind::InvalidData,
134                    format!("Invalid record length: {total_len}"),
135                ));
136            }
137            let payload_len = total_len - 5;
138
139            let mut kind_buf = [0u8; 1];
140            reader.read_exact(&mut kind_buf)?;
141            let kind_val = kind_buf[0];
142
143            payload_buf.clear();
144            payload_buf.resize(payload_len, 0);
145            if payload_len > 0 {
146                reader.read_exact(&mut payload_buf)?;
147            }
148
149            if let Some(kind) = RecordKind::from_u8(kind_val) {
150                callback(kind, &payload_buf);
151            } else {
152                tracing::warn!("Unknown record kind byte {kind_val}, skipping");
153            }
154        }
155    }
156
157    pub fn close(&mut self) -> std::io::Result<()> {
158        self.flush_and_sync()
159    }
160
161    /// Reopen the file with truncation — discards all existing records.
162    /// Used by `compact` to rewrite a fresh log from in-memory state.
163    pub fn reopen_truncated(&mut self) -> std::io::Result<()> {
164        self.writer.flush()?;
165        let file = OpenOptions::new()
166            .create(true)
167            .write(true)
168            .truncate(true)
169            .open(&self.path)?;
170        let mut writer = BufWriter::with_capacity(65536, file);
171        writer.write_all(MAGIC)?;
172        writer.flush()?;
173        self.writer = writer;
174        Ok(())
175    }
176}
177
178// --- Binary encoding helpers ---
179
180fn encode_str(buf: &mut Vec<u8>, s: &str) -> std::io::Result<()> {
181    let bytes = s.as_bytes();
182    let len = bytes.len();
183    if len > u16::MAX as usize {
184        return Err(std::io::Error::new(
185            std::io::ErrorKind::InvalidInput,
186            format!("string too long (max {} bytes, got {len})", u16::MAX),
187        ));
188    }
189    buf.extend_from_slice(&(len as u16).to_le_bytes());
190    buf.extend_from_slice(bytes);
191    Ok(())
192}
193
194fn decode_str<'a>(data: &'a [u8], offset: &mut usize) -> Option<&'a str> {
195    if *offset + 2 > data.len() {
196        return None;
197    }
198    let len = u16::from_le_bytes([data[*offset], data[*offset + 1]]) as usize;
199    *offset += 2;
200    if *offset + len > data.len() {
201        return None;
202    }
203    let s = std::str::from_utf8(&data[*offset..*offset + len]).ok()?;
204    *offset += len;
205    Some(s)
206}
207
208fn decode_count(data: &[u8], offset: &mut usize) -> Option<usize> {
209    if *offset + 4 > data.len() {
210        return None;
211    }
212    let count = u32::from_le_bytes([
213        data[*offset],
214        data[*offset + 1],
215        data[*offset + 2],
216        data[*offset + 3],
217    ]) as usize;
218    *offset += 4;
219    Some(count)
220}
221
222pub fn encode_create_entity(buf: &mut Vec<u8>, name: &str, entity_type: &str, observations: &[String]) -> std::io::Result<()> {
223    encode_str(buf, name)?;
224    encode_str(buf, entity_type)?;
225    buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
226    for obs in observations {
227        encode_str(buf, obs)?;
228    }
229    Ok(())
230}
231
232pub fn decode_create_entity(data: &[u8]) -> Option<(&str, &str, Vec<&str>)> {
233    let mut offset = 0;
234    let name = decode_str(data, &mut offset)?;
235    let entity_type = decode_str(data, &mut offset)?;
236    let count = decode_count(data, &mut offset)?;
237    let mut observations = Vec::with_capacity(count);
238    for _ in 0..count {
239        observations.push(decode_str(data, &mut offset)?);
240    }
241    Some((name, entity_type, observations))
242}
243
244pub fn encode_create_relation(buf: &mut Vec<u8>, from: &str, to: &str, relation_type: &str) -> std::io::Result<()> {
245    encode_str(buf, from)?;
246    encode_str(buf, to)?;
247    encode_str(buf, relation_type)
248}
249
250pub fn decode_create_relation(data: &[u8]) -> Option<(&str, &str, &str)> {
251    let mut offset = 0;
252    let from = decode_str(data, &mut offset)?;
253    let to = decode_str(data, &mut offset)?;
254    let relation_type = decode_str(data, &mut offset)?;
255    Some((from, to, relation_type))
256}
257
258pub fn encode_add_observations(buf: &mut Vec<u8>, name: &str, observations: &[String]) -> std::io::Result<()> {
259    encode_str(buf, name)?;
260    buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
261    for obs in observations {
262        encode_str(buf, obs)?;
263    }
264    Ok(())
265}
266
267pub fn decode_add_observations(data: &[u8]) -> Option<(&str, Vec<&str>)> {
268    let mut offset = 0;
269    let name = decode_str(data, &mut offset)?;
270    let count = decode_count(data, &mut offset)?;
271    let mut observations = Vec::with_capacity(count);
272    for _ in 0..count {
273        observations.push(decode_str(data, &mut offset)?);
274    }
275    Some((name, observations))
276}
277
278pub fn encode_delete_entity(buf: &mut Vec<u8>, name: &str) -> std::io::Result<()> {
279    encode_str(buf, name)
280}
281
282pub fn decode_delete_entity(data: &[u8]) -> Option<&str> {
283    let mut offset = 0;
284    decode_str(data, &mut offset)
285}
286
287pub fn encode_delete_observations(buf: &mut Vec<u8>, name: &str, observations: &[String]) -> std::io::Result<()> {
288    encode_str(buf, name)?;
289    buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
290    for obs in observations {
291        encode_str(buf, obs)?;
292    }
293    Ok(())
294}
295
296pub fn decode_delete_observations(data: &[u8]) -> Option<(&str, Vec<&str>)> {
297    decode_add_observations(data)
298}
299
300pub fn encode_delete_relation(buf: &mut Vec<u8>, from: &str, to: &str, relation_type: &str) -> std::io::Result<()> {
301    encode_str(buf, from)?;
302    encode_str(buf, to)?;
303    encode_str(buf, relation_type)
304}
305
306pub fn decode_delete_relation(data: &[u8]) -> Option<(&str, &str, &str)> {
307    decode_create_relation(data)
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use std::sync::atomic::{AtomicU64, Ordering};
314
315    static COUNTER: AtomicU64 = AtomicU64::new(0);
316
317    fn tmp_path() -> PathBuf {
318        let pid = std::process::id();
319        let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
320        std::env::temp_dir().join(format!("mcp_store_test_{pid}_{seq}.bin"))
321    }
322
323    #[test]
324    fn test_write_and_replay() {
325        let path = tmp_path();
326        let mut store = BinaryStore::new(&path).unwrap();
327
328        let mut buf = Vec::new();
329        encode_create_entity(&mut buf, "Alice", "person", &["likes coffee".into()]).unwrap();
330        store.write_record(RecordKind::CreateEntity, &buf).unwrap();
331
332        buf.clear();
333        encode_create_entity(&mut buf, "Bob", "person", &[]).unwrap();
334        store.write_record(RecordKind::CreateEntity, &buf).unwrap();
335
336        drop(store);
337
338        let mut replayed: Vec<(RecordKind, Vec<u8>)> = Vec::new();
339        let replay_store = BinaryStore::new(&path).unwrap();
340        replay_store
341            .replay(|kind, data| {
342                replayed.push((kind, data.to_vec()));
343            })
344            .unwrap();
345
346        assert_eq!(replayed.len(), 2);
347        assert_eq!(replayed[0].0, RecordKind::CreateEntity);
348        assert_eq!(
349            decode_create_entity(&replayed[0].1).unwrap().0,
350            "Alice"
351        );
352
353        let _ = std::fs::remove_file(&path);
354    }
355
356    #[test]
357    fn test_encode_decode_roundtrip() {
358        let mut buf = Vec::new();
359        encode_create_entity(
360            &mut buf,
361            "TestEntity",
362            "test_type",
363            &["obs1".into(), "obs2".into()],
364        )
365        .unwrap();
366        let (name, etype, obs) = decode_create_entity(&buf).unwrap();
367        assert_eq!(name, "TestEntity");
368        assert_eq!(etype, "test_type");
369        assert_eq!(obs, vec!["obs1", "obs2"]);
370    }
371
372    #[test]
373    fn test_empty_file() {
374        let path = tmp_path();
375        let store = BinaryStore::new(&path).unwrap();
376        drop(store);
377
378        let mut count = 0;
379        let replay_store = BinaryStore::new(&path).unwrap();
380        replay_store.replay(|_, _| count += 1).unwrap();
381        assert_eq!(count, 0);
382        let _ = std::fs::remove_file(&path);
383    }
384
385    #[test]
386    fn test_write_all_record_kinds() {
387        let path = tmp_path();
388        let mut store = BinaryStore::new(&path).unwrap();
389        let mut buf = Vec::new();
390
391        // Write one of each record kind.
392        encode_create_entity(&mut buf, "E1", "t1", &["o1".into()]).unwrap();
393        store.write_record(RecordKind::CreateEntity, &buf).unwrap();
394
395        buf.clear();
396        encode_create_relation(&mut buf, "E1", "E2", "knows").unwrap();
397        store.write_record(RecordKind::CreateRelation, &buf).unwrap();
398
399        buf.clear();
400        encode_add_observations(&mut buf, "E1", &["o2".into()]).unwrap();
401        store.write_record(RecordKind::AddObservations, &buf).unwrap();
402
403        buf.clear();
404        encode_delete_entity(&mut buf, "E1").unwrap();
405        store.write_record(RecordKind::DeleteEntity, &buf).unwrap();
406
407        buf.clear();
408        encode_delete_observations(&mut buf, "E1", &["o1".into()]).unwrap();
409        store.write_record(RecordKind::DeleteObservations, &buf).unwrap();
410
411        buf.clear();
412        encode_delete_relation(&mut buf, "E1", "E2", "knows").unwrap();
413        store.write_record(RecordKind::DeleteRelation, &buf).unwrap();
414
415        drop(store);
416
417        let mut kinds = Vec::new();
418        let replay_store = BinaryStore::new(&path).unwrap();
419        replay_store
420            .replay(|kind, _| {
421                kinds.push(kind);
422            })
423            .unwrap();
424
425        assert_eq!(kinds.len(), 6);
426        assert_eq!(kinds[0], RecordKind::CreateEntity);
427        assert_eq!(kinds[1], RecordKind::CreateRelation);
428        assert_eq!(kinds[2], RecordKind::AddObservations);
429        assert_eq!(kinds[3], RecordKind::DeleteEntity);
430        assert_eq!(kinds[4], RecordKind::DeleteObservations);
431        assert_eq!(kinds[5], RecordKind::DeleteRelation);
432        let _ = std::fs::remove_file(&path);
433    }
434
435    #[test]
436    fn test_reopen_truncated() {
437        let path = tmp_path();
438        let mut store = BinaryStore::new(&path).unwrap();
439        let mut buf = Vec::new();
440        encode_create_entity(&mut buf, "E1", "t1", &[]).unwrap();
441        store.write_record(RecordKind::CreateEntity, &buf).unwrap();
442        drop(store);
443
444        // Reopen with truncation.
445        let mut store2 = BinaryStore::new(&path).unwrap();
446        store2.reopen_truncated().unwrap();
447
448        let mut buf2 = Vec::new();
449        encode_create_entity(&mut buf2, "E2", "t2", &[]).unwrap();
450        store2.write_record(RecordKind::CreateEntity, &buf2).unwrap();
451        drop(store2);
452
453        let mut names = Vec::new();
454        let replay_store = BinaryStore::new(&path).unwrap();
455        replay_store
456            .replay(|_, data| {
457                if let Some((name, _, _)) = decode_create_entity(data) {
458                    names.push(name.to_string());
459                }
460            })
461            .unwrap();
462
463        // Only E2 should remain — E1 was truncated away.
464        assert_eq!(names, vec!["E2"]);
465        let _ = std::fs::remove_file(&path);
466    }
467
468    #[test]
469    fn test_encode_decode_add_observations() {
470        let mut buf = Vec::new();
471        encode_add_observations(&mut buf, "Alice", &["obs1".into(), "obs2".into()]).unwrap();
472        let (name, obs) = decode_add_observations(&buf).unwrap();
473        assert_eq!(name, "Alice");
474        assert_eq!(obs, vec!["obs1", "obs2"]);
475    }
476
477    #[test]
478    fn test_encode_decode_delete_entity() {
479        let mut buf = Vec::new();
480        encode_delete_entity(&mut buf, "ToDelete").unwrap();
481        let name = decode_delete_entity(&buf).unwrap();
482        assert_eq!(name, "ToDelete");
483    }
484
485    #[test]
486    fn test_encode_decode_delete_observations() {
487        let mut buf = Vec::new();
488        encode_delete_observations(&mut buf, "Alice", &["o1".into()]).unwrap();
489        let (name, obs) = decode_delete_observations(&buf).unwrap();
490        assert_eq!(name, "Alice");
491        assert_eq!(obs, vec!["o1"]);
492    }
493
494    #[test]
495    fn test_encode_decode_delete_relation() {
496        let mut buf = Vec::new();
497        encode_delete_relation(&mut buf, "A", "B", "knows").unwrap();
498        let (from, to, rtype) = decode_delete_relation(&buf).unwrap();
499        assert_eq!(from, "A");
500        assert_eq!(to, "B");
501        assert_eq!(rtype, "knows");
502    }
503
504    #[test]
505    fn test_record_too_large() {
506        let path = tmp_path();
507        let mut store = BinaryStore::new(&path).unwrap();
508        let huge = vec![0u8; (1 << 20) + 1];
509        let result = store.write_record(RecordKind::CreateEntity, &huge);
510        assert!(result.is_err());
511        let _ = std::fs::remove_file(&path);
512    }
513
514    #[test]
515    fn test_multiple_writes_and_replay() {
516        let path = tmp_path();
517        let mut store = BinaryStore::new(&path).unwrap();
518        for i in 0..100 {
519            let mut buf = Vec::new();
520            encode_create_entity(&mut buf, &format!("E{i}"), "type", &[]).unwrap();
521            store.write_record(RecordKind::CreateEntity, &buf).unwrap();
522        }
523        drop(store);
524
525        let mut count = 0;
526        let replay_store = BinaryStore::new(&path).unwrap();
527        replay_store
528            .replay(|kind, _| {
529                assert_eq!(kind, RecordKind::CreateEntity);
530                count += 1;
531            })
532            .unwrap();
533        assert_eq!(count, 100);
534        let _ = std::fs::remove_file(&path);
535    }
536
537    #[test]
538    fn test_truncated_log_handling() {
539        let path = tmp_path();
540        let mut store = BinaryStore::new(&path).unwrap();
541        let mut buf = Vec::new();
542        encode_create_entity(&mut buf, "Alice", "person", &[]).unwrap();
543        store.write_record(RecordKind::CreateEntity, &buf).unwrap();
544        drop(store);
545
546        // Truncate the file manually (simulate crash during write).
547        let file = OpenOptions::new().write(true).open(&path).unwrap();
548        file.set_len(10).unwrap(); // cut off after MAGIC
549        drop(file);
550
551        // Replay should handle gracefully.
552        let replay_store = BinaryStore::new(&path).unwrap();
553        let mut count = 0;
554        replay_store.replay(|_, _| count += 1).unwrap();
555        assert_eq!(count, 0);
556        let _ = std::fs::remove_file(&path);
557    }
558}