1use std::fs::{File, OpenOptions};
2use std::io::{BufReader, BufWriter, Read, Write};
3use std::path::{Path, PathBuf};
4
5const MAGIC: &[u8; 8] = b"MCPMEMV1";
6const MAX_RECORD_BYTES: u32 = 1 << 20;
7
8#[repr(u8)]
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum RecordKind {
11 CreateEntity = 0,
12 CreateRelation = 1,
13 AddObservations = 2,
14 DeleteEntity = 3,
15 DeleteObservations = 4,
16 DeleteRelation = 5,
17 TxnBegin = 6,
22 TxnCommit = 7,
24}
25
26impl RecordKind {
27 #[inline]
28 pub const fn from_u8(v: u8) -> Option<RecordKind> {
29 Some(match v {
30 0 => RecordKind::CreateEntity,
31 1 => RecordKind::CreateRelation,
32 2 => RecordKind::AddObservations,
33 3 => RecordKind::DeleteEntity,
34 4 => RecordKind::DeleteObservations,
35 5 => RecordKind::DeleteRelation,
36 6 => RecordKind::TxnBegin,
37 7 => RecordKind::TxnCommit,
38 _ => return None,
39 })
40 }
41}
42
43pub struct BinaryStore {
44 writer: BufWriter<File>,
45 path: PathBuf,
46}
47
48impl BinaryStore {
49 pub const fn path(&self) -> &PathBuf {
50 &self.path
51 }
52
53 pub fn new(path: &Path) -> std::io::Result<Self> {
54 let exists = path.exists();
55 let file = OpenOptions::new()
56 .create(true)
57 .append(true)
58 .read(false)
59 .open(path)?;
60
61 let mut writer = BufWriter::with_capacity(65536, file);
62
63 if !exists {
64 writer.write_all(MAGIC)?;
65 writer.flush()?;
66 }
67
68 Ok(Self {
69 writer,
70 path: path.to_path_buf(),
71 })
72 }
73
74 pub fn write_record(&mut self, kind: RecordKind, payload: &[u8]) -> std::io::Result<()> {
75 let total_len = 4 + 1 + payload.len();
76 if total_len as u32 > MAX_RECORD_BYTES {
77 return Err(std::io::Error::new(
78 std::io::ErrorKind::InvalidInput,
79 "Record too large",
80 ));
81 }
82 self.writer.write_all(&(total_len as u32).to_le_bytes())?;
83 self.writer.write_all(&[kind as u8])?;
84 self.writer.write_all(payload)?;
85 Ok(())
86 }
87
88 pub fn flush_and_sync(&mut self) -> std::io::Result<()> {
89 self.writer.flush()?;
90 self.writer.get_ref().sync_data()
91 }
92
93 pub fn replay<F>(&self, mut callback: F) -> std::io::Result<()>
94 where
95 F: FnMut(RecordKind, &[u8]),
96 {
97 let file = match OpenOptions::new().read(true).open(&self.path) {
98 Ok(f) => f,
99 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
100 Err(e) => return Err(e),
101 };
102
103 let meta = file.metadata()?;
104 if meta.len() == 0 {
105 return Ok(());
106 }
107
108 let mut reader = BufReader::with_capacity(65536, file);
109 let mut magic = [0u8; 8];
110
111 match reader.read_exact(&mut magic) {
112 Ok(()) => {
113 if &magic != MAGIC {
114 return Ok(());
115 }
116 }
117 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
118 Err(e) => return Err(e),
119 }
120
121 let mut payload_buf = Vec::with_capacity(4096);
122
123 loop {
124 let mut len_buf = [0u8; 4];
125 match reader.read_exact(&mut len_buf) {
126 Ok(()) => {}
127 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
128 Err(e) => return Err(e),
129 }
130 let total_len = u32::from_le_bytes(len_buf) as usize;
131 if total_len < 5 || total_len > MAX_RECORD_BYTES as usize {
132 return Err(std::io::Error::new(
133 std::io::ErrorKind::InvalidData,
134 format!("Invalid record length: {total_len}"),
135 ));
136 }
137 let payload_len = total_len - 5;
138
139 let mut kind_buf = [0u8; 1];
140 reader.read_exact(&mut kind_buf)?;
141 let kind_val = kind_buf[0];
142
143 payload_buf.clear();
144 payload_buf.resize(payload_len, 0);
145 if payload_len > 0 {
146 reader.read_exact(&mut payload_buf)?;
147 }
148
149 if let Some(kind) = RecordKind::from_u8(kind_val) {
150 callback(kind, &payload_buf);
151 } else {
152 tracing::warn!("Unknown record kind byte {kind_val}, skipping");
153 }
154 }
155 }
156
157 pub fn close(&mut self) -> std::io::Result<()> {
158 self.flush_and_sync()
159 }
160
161 pub fn reopen_truncated(&mut self) -> std::io::Result<()> {
164 self.writer.flush()?;
165 let file = OpenOptions::new()
166 .create(true)
167 .write(true)
168 .truncate(true)
169 .open(&self.path)?;
170 let mut writer = BufWriter::with_capacity(65536, file);
171 writer.write_all(MAGIC)?;
172 writer.flush()?;
173 self.writer = writer;
174 Ok(())
175 }
176}
177
178fn encode_str(buf: &mut Vec<u8>, s: &str) -> std::io::Result<()> {
181 let bytes = s.as_bytes();
182 let len = bytes.len();
183 if len > u16::MAX as usize {
184 return Err(std::io::Error::new(
185 std::io::ErrorKind::InvalidInput,
186 format!("string too long (max {} bytes, got {len})", u16::MAX),
187 ));
188 }
189 buf.extend_from_slice(&(len as u16).to_le_bytes());
190 buf.extend_from_slice(bytes);
191 Ok(())
192}
193
194fn decode_str<'a>(data: &'a [u8], offset: &mut usize) -> Option<&'a str> {
195 if *offset + 2 > data.len() {
196 return None;
197 }
198 let len = u16::from_le_bytes([data[*offset], data[*offset + 1]]) as usize;
199 *offset += 2;
200 if *offset + len > data.len() {
201 return None;
202 }
203 let s = std::str::from_utf8(&data[*offset..*offset + len]).ok()?;
204 *offset += len;
205 Some(s)
206}
207
208fn decode_count(data: &[u8], offset: &mut usize) -> Option<usize> {
209 if *offset + 4 > data.len() {
210 return None;
211 }
212 let count = u32::from_le_bytes([
213 data[*offset],
214 data[*offset + 1],
215 data[*offset + 2],
216 data[*offset + 3],
217 ]) as usize;
218 *offset += 4;
219 Some(count)
220}
221
222pub fn encode_create_entity(buf: &mut Vec<u8>, name: &str, entity_type: &str, observations: &[String]) -> std::io::Result<()> {
223 encode_str(buf, name)?;
224 encode_str(buf, entity_type)?;
225 buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
226 for obs in observations {
227 encode_str(buf, obs)?;
228 }
229 Ok(())
230}
231
232pub fn decode_create_entity(data: &[u8]) -> Option<(&str, &str, Vec<&str>)> {
233 let mut offset = 0;
234 let name = decode_str(data, &mut offset)?;
235 let entity_type = decode_str(data, &mut offset)?;
236 let count = decode_count(data, &mut offset)?;
237 let mut observations = Vec::with_capacity(count);
238 for _ in 0..count {
239 observations.push(decode_str(data, &mut offset)?);
240 }
241 Some((name, entity_type, observations))
242}
243
244pub fn encode_create_relation(buf: &mut Vec<u8>, from: &str, to: &str, relation_type: &str) -> std::io::Result<()> {
245 encode_str(buf, from)?;
246 encode_str(buf, to)?;
247 encode_str(buf, relation_type)
248}
249
250pub fn decode_create_relation(data: &[u8]) -> Option<(&str, &str, &str)> {
251 let mut offset = 0;
252 let from = decode_str(data, &mut offset)?;
253 let to = decode_str(data, &mut offset)?;
254 let relation_type = decode_str(data, &mut offset)?;
255 Some((from, to, relation_type))
256}
257
258pub fn encode_add_observations(buf: &mut Vec<u8>, name: &str, observations: &[String]) -> std::io::Result<()> {
259 encode_str(buf, name)?;
260 buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
261 for obs in observations {
262 encode_str(buf, obs)?;
263 }
264 Ok(())
265}
266
267pub fn decode_add_observations(data: &[u8]) -> Option<(&str, Vec<&str>)> {
268 let mut offset = 0;
269 let name = decode_str(data, &mut offset)?;
270 let count = decode_count(data, &mut offset)?;
271 let mut observations = Vec::with_capacity(count);
272 for _ in 0..count {
273 observations.push(decode_str(data, &mut offset)?);
274 }
275 Some((name, observations))
276}
277
278pub fn encode_delete_entity(buf: &mut Vec<u8>, name: &str) -> std::io::Result<()> {
279 encode_str(buf, name)
280}
281
282pub fn decode_delete_entity(data: &[u8]) -> Option<&str> {
283 let mut offset = 0;
284 decode_str(data, &mut offset)
285}
286
287pub fn encode_delete_observations(buf: &mut Vec<u8>, name: &str, observations: &[String]) -> std::io::Result<()> {
288 encode_str(buf, name)?;
289 buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
290 for obs in observations {
291 encode_str(buf, obs)?;
292 }
293 Ok(())
294}
295
296pub fn decode_delete_observations(data: &[u8]) -> Option<(&str, Vec<&str>)> {
297 decode_add_observations(data)
298}
299
300pub fn encode_delete_relation(buf: &mut Vec<u8>, from: &str, to: &str, relation_type: &str) -> std::io::Result<()> {
301 encode_str(buf, from)?;
302 encode_str(buf, to)?;
303 encode_str(buf, relation_type)
304}
305
306pub fn decode_delete_relation(data: &[u8]) -> Option<(&str, &str, &str)> {
307 decode_create_relation(data)
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use std::sync::atomic::{AtomicU64, Ordering};
314
315 static COUNTER: AtomicU64 = AtomicU64::new(0);
316
317 fn tmp_path() -> PathBuf {
318 let pid = std::process::id();
319 let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
320 std::env::temp_dir().join(format!("mcp_store_test_{pid}_{seq}.bin"))
321 }
322
323 #[test]
324 fn test_write_and_replay() {
325 let path = tmp_path();
326 let mut store = BinaryStore::new(&path).unwrap();
327
328 let mut buf = Vec::new();
329 encode_create_entity(&mut buf, "Alice", "person", &["likes coffee".into()]).unwrap();
330 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
331
332 buf.clear();
333 encode_create_entity(&mut buf, "Bob", "person", &[]).unwrap();
334 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
335
336 drop(store);
337
338 let mut replayed: Vec<(RecordKind, Vec<u8>)> = Vec::new();
339 let replay_store = BinaryStore::new(&path).unwrap();
340 replay_store
341 .replay(|kind, data| {
342 replayed.push((kind, data.to_vec()));
343 })
344 .unwrap();
345
346 assert_eq!(replayed.len(), 2);
347 assert_eq!(replayed[0].0, RecordKind::CreateEntity);
348 assert_eq!(
349 decode_create_entity(&replayed[0].1).unwrap().0,
350 "Alice"
351 );
352
353 let _ = std::fs::remove_file(&path);
354 }
355
356 #[test]
357 fn test_encode_decode_roundtrip() {
358 let mut buf = Vec::new();
359 encode_create_entity(
360 &mut buf,
361 "TestEntity",
362 "test_type",
363 &["obs1".into(), "obs2".into()],
364 )
365 .unwrap();
366 let (name, etype, obs) = decode_create_entity(&buf).unwrap();
367 assert_eq!(name, "TestEntity");
368 assert_eq!(etype, "test_type");
369 assert_eq!(obs, vec!["obs1", "obs2"]);
370 }
371
372 #[test]
373 fn test_empty_file() {
374 let path = tmp_path();
375 let store = BinaryStore::new(&path).unwrap();
376 drop(store);
377
378 let mut count = 0;
379 let replay_store = BinaryStore::new(&path).unwrap();
380 replay_store.replay(|_, _| count += 1).unwrap();
381 assert_eq!(count, 0);
382 let _ = std::fs::remove_file(&path);
383 }
384
385 #[test]
386 fn test_write_all_record_kinds() {
387 let path = tmp_path();
388 let mut store = BinaryStore::new(&path).unwrap();
389 let mut buf = Vec::new();
390
391 encode_create_entity(&mut buf, "E1", "t1", &["o1".into()]).unwrap();
393 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
394
395 buf.clear();
396 encode_create_relation(&mut buf, "E1", "E2", "knows").unwrap();
397 store.write_record(RecordKind::CreateRelation, &buf).unwrap();
398
399 buf.clear();
400 encode_add_observations(&mut buf, "E1", &["o2".into()]).unwrap();
401 store.write_record(RecordKind::AddObservations, &buf).unwrap();
402
403 buf.clear();
404 encode_delete_entity(&mut buf, "E1").unwrap();
405 store.write_record(RecordKind::DeleteEntity, &buf).unwrap();
406
407 buf.clear();
408 encode_delete_observations(&mut buf, "E1", &["o1".into()]).unwrap();
409 store.write_record(RecordKind::DeleteObservations, &buf).unwrap();
410
411 buf.clear();
412 encode_delete_relation(&mut buf, "E1", "E2", "knows").unwrap();
413 store.write_record(RecordKind::DeleteRelation, &buf).unwrap();
414
415 drop(store);
416
417 let mut kinds = Vec::new();
418 let replay_store = BinaryStore::new(&path).unwrap();
419 replay_store
420 .replay(|kind, _| {
421 kinds.push(kind);
422 })
423 .unwrap();
424
425 assert_eq!(kinds.len(), 6);
426 assert_eq!(kinds[0], RecordKind::CreateEntity);
427 assert_eq!(kinds[1], RecordKind::CreateRelation);
428 assert_eq!(kinds[2], RecordKind::AddObservations);
429 assert_eq!(kinds[3], RecordKind::DeleteEntity);
430 assert_eq!(kinds[4], RecordKind::DeleteObservations);
431 assert_eq!(kinds[5], RecordKind::DeleteRelation);
432 let _ = std::fs::remove_file(&path);
433 }
434
435 #[test]
436 fn test_reopen_truncated() {
437 let path = tmp_path();
438 let mut store = BinaryStore::new(&path).unwrap();
439 let mut buf = Vec::new();
440 encode_create_entity(&mut buf, "E1", "t1", &[]).unwrap();
441 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
442 drop(store);
443
444 let mut store2 = BinaryStore::new(&path).unwrap();
446 store2.reopen_truncated().unwrap();
447
448 let mut buf2 = Vec::new();
449 encode_create_entity(&mut buf2, "E2", "t2", &[]).unwrap();
450 store2.write_record(RecordKind::CreateEntity, &buf2).unwrap();
451 drop(store2);
452
453 let mut names = Vec::new();
454 let replay_store = BinaryStore::new(&path).unwrap();
455 replay_store
456 .replay(|_, data| {
457 if let Some((name, _, _)) = decode_create_entity(data) {
458 names.push(name.to_string());
459 }
460 })
461 .unwrap();
462
463 assert_eq!(names, vec!["E2"]);
465 let _ = std::fs::remove_file(&path);
466 }
467
468 #[test]
469 fn test_encode_decode_add_observations() {
470 let mut buf = Vec::new();
471 encode_add_observations(&mut buf, "Alice", &["obs1".into(), "obs2".into()]).unwrap();
472 let (name, obs) = decode_add_observations(&buf).unwrap();
473 assert_eq!(name, "Alice");
474 assert_eq!(obs, vec!["obs1", "obs2"]);
475 }
476
477 #[test]
478 fn test_encode_decode_delete_entity() {
479 let mut buf = Vec::new();
480 encode_delete_entity(&mut buf, "ToDelete").unwrap();
481 let name = decode_delete_entity(&buf).unwrap();
482 assert_eq!(name, "ToDelete");
483 }
484
485 #[test]
486 fn test_encode_decode_delete_observations() {
487 let mut buf = Vec::new();
488 encode_delete_observations(&mut buf, "Alice", &["o1".into()]).unwrap();
489 let (name, obs) = decode_delete_observations(&buf).unwrap();
490 assert_eq!(name, "Alice");
491 assert_eq!(obs, vec!["o1"]);
492 }
493
494 #[test]
495 fn test_encode_decode_delete_relation() {
496 let mut buf = Vec::new();
497 encode_delete_relation(&mut buf, "A", "B", "knows").unwrap();
498 let (from, to, rtype) = decode_delete_relation(&buf).unwrap();
499 assert_eq!(from, "A");
500 assert_eq!(to, "B");
501 assert_eq!(rtype, "knows");
502 }
503
504 #[test]
505 fn test_record_too_large() {
506 let path = tmp_path();
507 let mut store = BinaryStore::new(&path).unwrap();
508 let huge = vec![0u8; (1 << 20) + 1];
509 let result = store.write_record(RecordKind::CreateEntity, &huge);
510 assert!(result.is_err());
511 let _ = std::fs::remove_file(&path);
512 }
513
514 #[test]
515 fn test_multiple_writes_and_replay() {
516 let path = tmp_path();
517 let mut store = BinaryStore::new(&path).unwrap();
518 for i in 0..100 {
519 let mut buf = Vec::new();
520 encode_create_entity(&mut buf, &format!("E{i}"), "type", &[]).unwrap();
521 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
522 }
523 drop(store);
524
525 let mut count = 0;
526 let replay_store = BinaryStore::new(&path).unwrap();
527 replay_store
528 .replay(|kind, _| {
529 assert_eq!(kind, RecordKind::CreateEntity);
530 count += 1;
531 })
532 .unwrap();
533 assert_eq!(count, 100);
534 let _ = std::fs::remove_file(&path);
535 }
536
537 #[test]
538 fn test_truncated_log_handling() {
539 let path = tmp_path();
540 let mut store = BinaryStore::new(&path).unwrap();
541 let mut buf = Vec::new();
542 encode_create_entity(&mut buf, "Alice", "person", &[]).unwrap();
543 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
544 drop(store);
545
546 let file = OpenOptions::new().write(true).open(&path).unwrap();
548 file.set_len(10).unwrap(); drop(file);
550
551 let replay_store = BinaryStore::new(&path).unwrap();
553 let mut count = 0;
554 replay_store.replay(|_, _| count += 1).unwrap();
555 assert_eq!(count, 0);
556 let _ = std::fs::remove_file(&path);
557 }
558}