1use std::fs::{File, OpenOptions};
2use std::io::{BufReader, BufWriter, Read, Write};
3use std::path::{Path, PathBuf};
4use std::sync::Arc;
5
6use arc_swap::ArcSwap;
7
8const MAGIC: &[u8; 8] = b"MCPMEMV1";
9const MAX_RECORD_BYTES: u32 = 1 << 20;
10
11#[repr(u8)]
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum RecordKind {
14 CreateEntity = 0,
15 CreateRelation = 1,
16 AddObservations = 2,
17 DeleteEntity = 3,
18 DeleteObservations = 4,
19 DeleteRelation = 5,
20 TxnBegin = 6,
25 TxnCommit = 7,
27}
28
29impl RecordKind {
30 #[inline]
31 pub const fn from_u8(v: u8) -> Option<RecordKind> {
32 Some(match v {
33 0 => RecordKind::CreateEntity,
34 1 => RecordKind::CreateRelation,
35 2 => RecordKind::AddObservations,
36 3 => RecordKind::DeleteEntity,
37 4 => RecordKind::DeleteObservations,
38 5 => RecordKind::DeleteRelation,
39 6 => RecordKind::TxnBegin,
40 7 => RecordKind::TxnCommit,
41 _ => return None,
42 })
43 }
44}
45
46pub struct BinaryStore {
47 writer: BufWriter<File>,
48 path: PathBuf,
49 pub(crate) sync_slot: Arc<ArcSwap<File>>,
55}
56
57impl BinaryStore {
58 pub const fn path(&self) -> &PathBuf {
59 &self.path
60 }
61
62 pub fn new(path: &Path) -> std::io::Result<Self> {
63 Self::new_with_slot(path, None)
64 }
65
66 pub fn new_with_slot(
71 path: &Path,
72 slot: Option<Arc<ArcSwap<File>>>,
73 ) -> std::io::Result<Self> {
74 let exists = path.exists();
75 let file = OpenOptions::new()
76 .create(true)
77 .append(true)
78 .read(false)
79 .open(path)?;
80
81 let handle = Arc::new(file.try_clone()?);
82 let sync_slot = match slot {
83 Some(s) => {
84 s.store(handle);
85 s
86 }
87 None => Arc::new(ArcSwap::new(handle)),
88 };
89 let mut writer = BufWriter::with_capacity(65536, file);
90
91 if !exists {
92 writer.write_all(MAGIC)?;
93 writer.flush()?;
94 }
95
96 Ok(Self {
97 writer,
98 path: path.to_path_buf(),
99 sync_slot,
100 })
101 }
102
103 pub fn write_record(&mut self, kind: RecordKind, payload: &[u8]) -> std::io::Result<()> {
104 let total_len = 4 + 1 + payload.len();
105 if total_len as u32 > MAX_RECORD_BYTES {
106 return Err(std::io::Error::new(
107 std::io::ErrorKind::InvalidInput,
108 "Record too large",
109 ));
110 }
111 self.writer.write_all(&(total_len as u32).to_le_bytes())?;
112 self.writer.write_all(&[kind as u8])?;
113 self.writer.write_all(payload)?;
114 Ok(())
115 }
116
117 pub fn flush(&mut self) -> std::io::Result<()> {
119 self.writer.flush()
120 }
121
122 pub fn sync(&mut self) -> std::io::Result<()> {
124 self.writer.get_ref().sync_data()
125 }
126
127 pub fn flush_and_sync(&mut self) -> std::io::Result<()> {
128 self.flush()?;
129 self.sync()
130 }
131
132 pub fn replay<F>(&self, mut callback: F) -> std::io::Result<()>
133 where
134 F: FnMut(RecordKind, &[u8]),
135 {
136 let file = match OpenOptions::new().read(true).open(&self.path) {
137 Ok(f) => f,
138 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
139 Err(e) => return Err(e),
140 };
141
142 let meta = file.metadata()?;
143 if meta.len() == 0 {
144 return Ok(());
145 }
146
147 let mut reader = BufReader::with_capacity(65536, file);
148 let mut magic = [0u8; 8];
149
150 match reader.read_exact(&mut magic) {
151 Ok(()) => {
152 if &magic != MAGIC {
153 return Ok(());
154 }
155 }
156 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
157 Err(e) => return Err(e),
158 }
159
160 let mut payload_buf = Vec::with_capacity(4096);
161
162 loop {
163 let mut len_buf = [0u8; 4];
164 match reader.read_exact(&mut len_buf) {
165 Ok(()) => {}
166 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
167 Err(e) => return Err(e),
168 }
169 let total_len = u32::from_le_bytes(len_buf) as usize;
170 if total_len < 5 || total_len > MAX_RECORD_BYTES as usize {
171 return Err(std::io::Error::new(
172 std::io::ErrorKind::InvalidData,
173 format!("Invalid record length: {total_len}"),
174 ));
175 }
176 let payload_len = total_len - 5;
177
178 let mut kind_buf = [0u8; 1];
183 match reader.read_exact(&mut kind_buf) {
184 Ok(()) => {}
185 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
186 Err(e) => return Err(e),
187 }
188 let kind_val = kind_buf[0];
189
190 payload_buf.clear();
191 payload_buf.resize(payload_len, 0);
192 if payload_len > 0 {
193 match reader.read_exact(&mut payload_buf) {
194 Ok(()) => {}
195 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
196 Err(e) => return Err(e),
197 }
198 }
199
200 if let Some(kind) = RecordKind::from_u8(kind_val) {
201 callback(kind, &payload_buf);
202 } else {
203 tracing::warn!("Unknown record kind byte {kind_val}, skipping");
204 }
205 }
206 }
207
208 pub fn close(&mut self) -> std::io::Result<()> {
209 self.flush_and_sync()
210 }
211
212 pub fn reopen_truncated(&mut self) -> std::io::Result<()> {
215 self.writer.flush()?;
216 let file = OpenOptions::new()
217 .create(true)
218 .write(true)
219 .truncate(true)
220 .open(&self.path)?;
221 self.sync_slot.store(Arc::new(file.try_clone()?));
224 let mut writer = BufWriter::with_capacity(65536, file);
225 writer.write_all(MAGIC)?;
226 writer.flush()?;
227 self.writer = writer;
228 Ok(())
229 }
230}
231
232fn encode_str(buf: &mut Vec<u8>, s: &str) -> std::io::Result<()> {
235 let bytes = s.as_bytes();
236 let len = bytes.len();
237 if len > u16::MAX as usize {
238 return Err(std::io::Error::new(
239 std::io::ErrorKind::InvalidInput,
240 format!("string too long (max {} bytes, got {len})", u16::MAX),
241 ));
242 }
243 buf.extend_from_slice(&(len as u16).to_le_bytes());
244 buf.extend_from_slice(bytes);
245 Ok(())
246}
247
248fn decode_str<'a>(data: &'a [u8], offset: &mut usize) -> Option<&'a str> {
249 if *offset + 2 > data.len() {
250 return None;
251 }
252 let len = u16::from_le_bytes([data[*offset], data[*offset + 1]]) as usize;
253 *offset += 2;
254 if *offset + len > data.len() {
255 return None;
256 }
257 let s = std::str::from_utf8(&data[*offset..*offset + len]).ok()?;
258 *offset += len;
259 Some(s)
260}
261
262fn decode_count(data: &[u8], offset: &mut usize) -> Option<usize> {
263 if *offset + 4 > data.len() {
264 return None;
265 }
266 let count = u32::from_le_bytes([
267 data[*offset],
268 data[*offset + 1],
269 data[*offset + 2],
270 data[*offset + 3],
271 ]) as usize;
272 *offset += 4;
273 Some(count)
274}
275
276pub fn encode_create_entity(buf: &mut Vec<u8>, name: &str, entity_type: &str, observations: &[String]) -> std::io::Result<()> {
277 encode_str(buf, name)?;
278 encode_str(buf, entity_type)?;
279 buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
280 for obs in observations {
281 encode_str(buf, obs)?;
282 }
283 Ok(())
284}
285
286pub fn decode_create_entity(data: &[u8]) -> Option<(&str, &str, Vec<&str>)> {
287 let mut offset = 0;
288 let name = decode_str(data, &mut offset)?;
289 let entity_type = decode_str(data, &mut offset)?;
290 let count = decode_count(data, &mut offset)?;
291 let mut observations = Vec::with_capacity(count);
292 for _ in 0..count {
293 observations.push(decode_str(data, &mut offset)?);
294 }
295 Some((name, entity_type, observations))
296}
297
298pub fn encode_create_relation(buf: &mut Vec<u8>, from: &str, to: &str, relation_type: &str) -> std::io::Result<()> {
299 encode_str(buf, from)?;
300 encode_str(buf, to)?;
301 encode_str(buf, relation_type)
302}
303
304pub fn decode_create_relation(data: &[u8]) -> Option<(&str, &str, &str)> {
305 let mut offset = 0;
306 let from = decode_str(data, &mut offset)?;
307 let to = decode_str(data, &mut offset)?;
308 let relation_type = decode_str(data, &mut offset)?;
309 Some((from, to, relation_type))
310}
311
312pub fn encode_add_observations(buf: &mut Vec<u8>, name: &str, observations: &[String]) -> std::io::Result<()> {
313 encode_str(buf, name)?;
314 buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
315 for obs in observations {
316 encode_str(buf, obs)?;
317 }
318 Ok(())
319}
320
321pub fn decode_add_observations(data: &[u8]) -> Option<(&str, Vec<&str>)> {
322 let mut offset = 0;
323 let name = decode_str(data, &mut offset)?;
324 let count = decode_count(data, &mut offset)?;
325 let mut observations = Vec::with_capacity(count);
326 for _ in 0..count {
327 observations.push(decode_str(data, &mut offset)?);
328 }
329 Some((name, observations))
330}
331
332pub fn encode_delete_entity(buf: &mut Vec<u8>, name: &str) -> std::io::Result<()> {
333 encode_str(buf, name)
334}
335
336pub fn decode_delete_entity(data: &[u8]) -> Option<&str> {
337 let mut offset = 0;
338 decode_str(data, &mut offset)
339}
340
341pub fn encode_delete_observations(buf: &mut Vec<u8>, name: &str, observations: &[String]) -> std::io::Result<()> {
342 encode_str(buf, name)?;
343 buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
344 for obs in observations {
345 encode_str(buf, obs)?;
346 }
347 Ok(())
348}
349
350pub fn decode_delete_observations(data: &[u8]) -> Option<(&str, Vec<&str>)> {
351 decode_add_observations(data)
352}
353
354pub fn encode_delete_relation(buf: &mut Vec<u8>, from: &str, to: &str, relation_type: &str) -> std::io::Result<()> {
355 encode_str(buf, from)?;
356 encode_str(buf, to)?;
357 encode_str(buf, relation_type)
358}
359
360pub fn decode_delete_relation(data: &[u8]) -> Option<(&str, &str, &str)> {
361 decode_create_relation(data)
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use std::sync::atomic::{AtomicU64, Ordering};
368
369 static COUNTER: AtomicU64 = AtomicU64::new(0);
370
371 fn tmp_path() -> PathBuf {
372 let pid = std::process::id();
373 let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
374 std::env::temp_dir().join(format!("mcp_store_test_{pid}_{seq}.bin"))
375 }
376
377 #[test]
378 fn test_write_and_replay() {
379 let path = tmp_path();
380 let mut store = BinaryStore::new(&path).unwrap();
381
382 let mut buf = Vec::new();
383 encode_create_entity(&mut buf, "Alice", "person", &["likes coffee".into()]).unwrap();
384 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
385
386 buf.clear();
387 encode_create_entity(&mut buf, "Bob", "person", &[]).unwrap();
388 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
389
390 drop(store);
391
392 let mut replayed: Vec<(RecordKind, Vec<u8>)> = Vec::new();
393 let replay_store = BinaryStore::new(&path).unwrap();
394 replay_store
395 .replay(|kind, data| {
396 replayed.push((kind, data.to_vec()));
397 })
398 .unwrap();
399
400 assert_eq!(replayed.len(), 2);
401 assert_eq!(replayed[0].0, RecordKind::CreateEntity);
402 assert_eq!(
403 decode_create_entity(&replayed[0].1).unwrap().0,
404 "Alice"
405 );
406
407 let _ = std::fs::remove_file(&path);
408 }
409
410 #[test]
411 fn test_encode_decode_roundtrip() {
412 let mut buf = Vec::new();
413 encode_create_entity(
414 &mut buf,
415 "TestEntity",
416 "test_type",
417 &["obs1".into(), "obs2".into()],
418 )
419 .unwrap();
420 let (name, etype, obs) = decode_create_entity(&buf).unwrap();
421 assert_eq!(name, "TestEntity");
422 assert_eq!(etype, "test_type");
423 assert_eq!(obs, vec!["obs1", "obs2"]);
424 }
425
426 #[test]
427 fn test_empty_file() {
428 let path = tmp_path();
429 let store = BinaryStore::new(&path).unwrap();
430 drop(store);
431
432 let mut count = 0;
433 let replay_store = BinaryStore::new(&path).unwrap();
434 replay_store.replay(|_, _| count += 1).unwrap();
435 assert_eq!(count, 0);
436 let _ = std::fs::remove_file(&path);
437 }
438
439 #[test]
440 fn test_write_all_record_kinds() {
441 let path = tmp_path();
442 let mut store = BinaryStore::new(&path).unwrap();
443 let mut buf = Vec::new();
444
445 encode_create_entity(&mut buf, "E1", "t1", &["o1".into()]).unwrap();
447 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
448
449 buf.clear();
450 encode_create_relation(&mut buf, "E1", "E2", "knows").unwrap();
451 store.write_record(RecordKind::CreateRelation, &buf).unwrap();
452
453 buf.clear();
454 encode_add_observations(&mut buf, "E1", &["o2".into()]).unwrap();
455 store.write_record(RecordKind::AddObservations, &buf).unwrap();
456
457 buf.clear();
458 encode_delete_entity(&mut buf, "E1").unwrap();
459 store.write_record(RecordKind::DeleteEntity, &buf).unwrap();
460
461 buf.clear();
462 encode_delete_observations(&mut buf, "E1", &["o1".into()]).unwrap();
463 store.write_record(RecordKind::DeleteObservations, &buf).unwrap();
464
465 buf.clear();
466 encode_delete_relation(&mut buf, "E1", "E2", "knows").unwrap();
467 store.write_record(RecordKind::DeleteRelation, &buf).unwrap();
468
469 drop(store);
470
471 let mut kinds = Vec::new();
472 let replay_store = BinaryStore::new(&path).unwrap();
473 replay_store
474 .replay(|kind, _| {
475 kinds.push(kind);
476 })
477 .unwrap();
478
479 assert_eq!(kinds.len(), 6);
480 assert_eq!(kinds[0], RecordKind::CreateEntity);
481 assert_eq!(kinds[1], RecordKind::CreateRelation);
482 assert_eq!(kinds[2], RecordKind::AddObservations);
483 assert_eq!(kinds[3], RecordKind::DeleteEntity);
484 assert_eq!(kinds[4], RecordKind::DeleteObservations);
485 assert_eq!(kinds[5], RecordKind::DeleteRelation);
486 let _ = std::fs::remove_file(&path);
487 }
488
489 #[test]
490 fn test_reopen_truncated() {
491 let path = tmp_path();
492 let mut store = BinaryStore::new(&path).unwrap();
493 let mut buf = Vec::new();
494 encode_create_entity(&mut buf, "E1", "t1", &[]).unwrap();
495 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
496 drop(store);
497
498 let mut store2 = BinaryStore::new(&path).unwrap();
500 store2.reopen_truncated().unwrap();
501
502 let mut buf2 = Vec::new();
503 encode_create_entity(&mut buf2, "E2", "t2", &[]).unwrap();
504 store2.write_record(RecordKind::CreateEntity, &buf2).unwrap();
505 drop(store2);
506
507 let mut names = Vec::new();
508 let replay_store = BinaryStore::new(&path).unwrap();
509 replay_store
510 .replay(|_, data| {
511 if let Some((name, _, _)) = decode_create_entity(data) {
512 names.push(name.to_string());
513 }
514 })
515 .unwrap();
516
517 assert_eq!(names, vec!["E2"]);
519 let _ = std::fs::remove_file(&path);
520 }
521
522 #[test]
523 fn test_encode_decode_add_observations() {
524 let mut buf = Vec::new();
525 encode_add_observations(&mut buf, "Alice", &["obs1".into(), "obs2".into()]).unwrap();
526 let (name, obs) = decode_add_observations(&buf).unwrap();
527 assert_eq!(name, "Alice");
528 assert_eq!(obs, vec!["obs1", "obs2"]);
529 }
530
531 #[test]
532 fn test_encode_decode_delete_entity() {
533 let mut buf = Vec::new();
534 encode_delete_entity(&mut buf, "ToDelete").unwrap();
535 let name = decode_delete_entity(&buf).unwrap();
536 assert_eq!(name, "ToDelete");
537 }
538
539 #[test]
540 fn test_encode_decode_delete_observations() {
541 let mut buf = Vec::new();
542 encode_delete_observations(&mut buf, "Alice", &["o1".into()]).unwrap();
543 let (name, obs) = decode_delete_observations(&buf).unwrap();
544 assert_eq!(name, "Alice");
545 assert_eq!(obs, vec!["o1"]);
546 }
547
548 #[test]
549 fn test_encode_decode_delete_relation() {
550 let mut buf = Vec::new();
551 encode_delete_relation(&mut buf, "A", "B", "knows").unwrap();
552 let (from, to, rtype) = decode_delete_relation(&buf).unwrap();
553 assert_eq!(from, "A");
554 assert_eq!(to, "B");
555 assert_eq!(rtype, "knows");
556 }
557
558 #[test]
559 fn test_sync_slot_follows_reopen_truncated() {
560 let path = tmp_path();
563 let mut store = BinaryStore::new(&path).unwrap();
564 let slot = Arc::clone(&store.sync_slot);
565 let before = Arc::as_ptr(&slot.load_full());
566 store.reopen_truncated().unwrap();
567 let after = Arc::as_ptr(&slot.load_full());
568 assert_ne!(before, after, "reopen must publish the new handle into the slot");
569 assert!(Arc::ptr_eq(&slot, &store.sync_slot), "slot identity must be stable");
570 let _ = std::fs::remove_file(&path);
571 }
572
573 #[test]
574 fn test_new_with_slot_reuses_shared_cell() {
575 let path = tmp_path();
578 let store1 = BinaryStore::new(&path).unwrap();
579 let slot = Arc::clone(&store1.sync_slot);
580 let before = Arc::as_ptr(&slot.load_full());
581 drop(store1);
582
583 let store2 = BinaryStore::new_with_slot(&path, Some(Arc::clone(&slot))).unwrap();
584 assert!(Arc::ptr_eq(&slot, &store2.sync_slot), "must reuse the passed slot");
585 let after = Arc::as_ptr(&slot.load_full());
586 assert_ne!(before, after, "reopened handle must be published into the slot");
587 let _ = std::fs::remove_file(&path);
588 }
589
590 #[test]
591 fn test_record_too_large() {
592 let path = tmp_path();
593 let mut store = BinaryStore::new(&path).unwrap();
594 let huge = vec![0u8; (1 << 20) + 1];
595 let result = store.write_record(RecordKind::CreateEntity, &huge);
596 assert!(result.is_err());
597 let _ = std::fs::remove_file(&path);
598 }
599
600 #[test]
601 fn test_multiple_writes_and_replay() {
602 let path = tmp_path();
603 let mut store = BinaryStore::new(&path).unwrap();
604 for i in 0..100 {
605 let mut buf = Vec::new();
606 encode_create_entity(&mut buf, &format!("E{i}"), "type", &[]).unwrap();
607 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
608 }
609 drop(store);
610
611 let mut count = 0;
612 let replay_store = BinaryStore::new(&path).unwrap();
613 replay_store
614 .replay(|kind, _| {
615 assert_eq!(kind, RecordKind::CreateEntity);
616 count += 1;
617 })
618 .unwrap();
619 assert_eq!(count, 100);
620 let _ = std::fs::remove_file(&path);
621 }
622
623 #[test]
624 fn test_truncated_log_handling() {
625 let path = tmp_path();
626 let mut store = BinaryStore::new(&path).unwrap();
627 let mut buf = Vec::new();
628 encode_create_entity(&mut buf, "Alice", "person", &[]).unwrap();
629 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
630 drop(store);
631
632 let file = OpenOptions::new().write(true).open(&path).unwrap();
634 file.set_len(10).unwrap(); drop(file);
636
637 let replay_store = BinaryStore::new(&path).unwrap();
639 let mut count = 0;
640 replay_store.replay(|_, _| count += 1).unwrap();
641 assert_eq!(count, 0);
642 let _ = std::fs::remove_file(&path);
643 }
644
645 #[test]
646 fn test_torn_record_mid_stream_recovers_prefix() {
647 let path = tmp_path();
651 let mut store = BinaryStore::new(&path).unwrap();
652 let mut buf = Vec::new();
653 encode_create_entity(&mut buf, "Alice", "person", &["likes coffee".into()]).unwrap();
654 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
655 store.flush_and_sync().unwrap();
656 let good_len = std::fs::metadata(&path).unwrap().len();
657
658 buf.clear();
661 encode_create_entity(&mut buf, "Bob", "person", &["drinks tea".into()]).unwrap();
662 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
663 store.flush_and_sync().unwrap();
664 drop(store);
665
666 let full_len = std::fs::metadata(&path).unwrap().len();
667 let torn_len = good_len + (full_len - good_len) / 2;
669 let file = OpenOptions::new().write(true).open(&path).unwrap();
670 file.set_len(torn_len).unwrap();
671 drop(file);
672
673 let replay_store = BinaryStore::new(&path).unwrap();
674 let mut names = Vec::new();
675 replay_store
676 .replay(|_, data| {
677 if let Some((name, _, _)) = decode_create_entity(data) {
678 names.push(name.to_string());
679 }
680 })
681 .expect("torn tail must not be a hard error");
682 assert_eq!(names, vec!["Alice"]);
684 let _ = std::fs::remove_file(&path);
685 }
686}