1use std::fs::{File, OpenOptions};
2use std::io::{BufReader, BufWriter, Read, Write};
3use std::path::{Path, PathBuf};
4
5const MAGIC: &[u8; 8] = b"MCPMEMV1";
6const MAX_RECORD_BYTES: u32 = 1 << 20;
7
8#[repr(u8)]
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum RecordKind {
11 CreateEntity = 0,
12 CreateRelation = 1,
13 AddObservations = 2,
14 DeleteEntity = 3,
15 DeleteObservations = 4,
16 DeleteRelation = 5,
17}
18
19impl RecordKind {
20 #[inline]
21 pub const fn from_u8(v: u8) -> Option<RecordKind> {
22 Some(match v {
23 0 => RecordKind::CreateEntity,
24 1 => RecordKind::CreateRelation,
25 2 => RecordKind::AddObservations,
26 3 => RecordKind::DeleteEntity,
27 4 => RecordKind::DeleteObservations,
28 5 => RecordKind::DeleteRelation,
29 _ => return None,
30 })
31 }
32}
33
34pub struct BinaryStore {
35 writer: BufWriter<File>,
36 path: PathBuf,
37}
38
39impl BinaryStore {
40 pub const fn path(&self) -> &PathBuf {
41 &self.path
42 }
43
44 pub fn new(path: &Path) -> std::io::Result<Self> {
45 let exists = path.exists();
46 let file = OpenOptions::new()
47 .create(true)
48 .append(true)
49 .read(false)
50 .open(path)?;
51
52 let mut writer = BufWriter::with_capacity(65536, file);
53
54 if !exists {
55 writer.write_all(MAGIC)?;
56 writer.flush()?;
57 }
58
59 Ok(Self {
60 writer,
61 path: path.to_path_buf(),
62 })
63 }
64
65 pub fn write_record(&mut self, kind: RecordKind, payload: &[u8]) -> std::io::Result<()> {
66 let total_len = 4 + 1 + payload.len();
67 if total_len as u32 > MAX_RECORD_BYTES {
68 return Err(std::io::Error::new(
69 std::io::ErrorKind::InvalidInput,
70 "Record too large",
71 ));
72 }
73 self.writer.write_all(&(total_len as u32).to_le_bytes())?;
74 self.writer.write_all(&[kind as u8])?;
75 self.writer.write_all(payload)?;
76 Ok(())
77 }
78
79 pub fn flush_and_sync(&mut self) -> std::io::Result<()> {
80 self.writer.flush()?;
81 self.writer.get_ref().sync_data()
82 }
83
84 pub fn replay<F>(&self, mut callback: F) -> std::io::Result<()>
85 where
86 F: FnMut(RecordKind, &[u8]),
87 {
88 let file = match OpenOptions::new().read(true).open(&self.path) {
89 Ok(f) => f,
90 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
91 Err(e) => return Err(e),
92 };
93
94 let meta = file.metadata()?;
95 if meta.len() == 0 {
96 return Ok(());
97 }
98
99 let mut reader = BufReader::with_capacity(65536, file);
100 let mut magic = [0u8; 8];
101
102 match reader.read_exact(&mut magic) {
103 Ok(()) => {
104 if &magic != MAGIC {
105 return Ok(());
106 }
107 }
108 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
109 Err(e) => return Err(e),
110 }
111
112 let mut payload_buf = Vec::with_capacity(4096);
113
114 loop {
115 let mut len_buf = [0u8; 4];
116 match reader.read_exact(&mut len_buf) {
117 Ok(()) => {}
118 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
119 Err(e) => return Err(e),
120 }
121 let total_len = u32::from_le_bytes(len_buf) as usize;
122 if total_len < 5 || total_len > MAX_RECORD_BYTES as usize {
123 return Err(std::io::Error::new(
124 std::io::ErrorKind::InvalidData,
125 format!("Invalid record length: {total_len}"),
126 ));
127 }
128 let payload_len = total_len - 5;
129
130 let mut kind_buf = [0u8; 1];
131 reader.read_exact(&mut kind_buf)?;
132 let kind_val = kind_buf[0];
133
134 payload_buf.clear();
135 payload_buf.resize(payload_len, 0);
136 if payload_len > 0 {
137 reader.read_exact(&mut payload_buf)?;
138 }
139
140 if let Some(kind) = RecordKind::from_u8(kind_val) {
141 callback(kind, &payload_buf);
142 } else {
143 tracing::warn!("Unknown record kind byte {kind_val}, skipping");
144 }
145 }
146 }
147
148 pub fn close(&mut self) -> std::io::Result<()> {
149 self.flush_and_sync()
150 }
151
152 pub fn reopen_truncated(&mut self) -> std::io::Result<()> {
155 self.writer.flush()?;
156 let file = OpenOptions::new()
157 .create(true)
158 .write(true)
159 .truncate(true)
160 .open(&self.path)?;
161 let mut writer = BufWriter::with_capacity(65536, file);
162 writer.write_all(MAGIC)?;
163 writer.flush()?;
164 self.writer = writer;
165 Ok(())
166 }
167}
168
169fn encode_str(buf: &mut Vec<u8>, s: &str) -> std::io::Result<()> {
172 let bytes = s.as_bytes();
173 let len = bytes.len();
174 if len > u16::MAX as usize {
175 return Err(std::io::Error::new(
176 std::io::ErrorKind::InvalidInput,
177 format!("string too long (max {} bytes, got {len})", u16::MAX),
178 ));
179 }
180 buf.extend_from_slice(&(len as u16).to_le_bytes());
181 buf.extend_from_slice(bytes);
182 Ok(())
183}
184
185fn decode_str<'a>(data: &'a [u8], offset: &mut usize) -> Option<&'a str> {
186 if *offset + 2 > data.len() {
187 return None;
188 }
189 let len = u16::from_le_bytes([data[*offset], data[*offset + 1]]) as usize;
190 *offset += 2;
191 if *offset + len > data.len() {
192 return None;
193 }
194 let s = std::str::from_utf8(&data[*offset..*offset + len]).ok()?;
195 *offset += len;
196 Some(s)
197}
198
199fn decode_count(data: &[u8], offset: &mut usize) -> Option<usize> {
200 if *offset + 4 > data.len() {
201 return None;
202 }
203 let count = u32::from_le_bytes([
204 data[*offset],
205 data[*offset + 1],
206 data[*offset + 2],
207 data[*offset + 3],
208 ]) as usize;
209 *offset += 4;
210 Some(count)
211}
212
213pub fn encode_create_entity(buf: &mut Vec<u8>, name: &str, entity_type: &str, observations: &[String]) -> std::io::Result<()> {
214 encode_str(buf, name)?;
215 encode_str(buf, entity_type)?;
216 buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
217 for obs in observations {
218 encode_str(buf, obs)?;
219 }
220 Ok(())
221}
222
223pub fn decode_create_entity(data: &[u8]) -> Option<(&str, &str, Vec<&str>)> {
224 let mut offset = 0;
225 let name = decode_str(data, &mut offset)?;
226 let entity_type = decode_str(data, &mut offset)?;
227 let count = decode_count(data, &mut offset)?;
228 let mut observations = Vec::with_capacity(count);
229 for _ in 0..count {
230 observations.push(decode_str(data, &mut offset)?);
231 }
232 Some((name, entity_type, observations))
233}
234
235pub fn encode_create_relation(buf: &mut Vec<u8>, from: &str, to: &str, relation_type: &str) -> std::io::Result<()> {
236 encode_str(buf, from)?;
237 encode_str(buf, to)?;
238 encode_str(buf, relation_type)
239}
240
241pub fn decode_create_relation(data: &[u8]) -> Option<(&str, &str, &str)> {
242 let mut offset = 0;
243 let from = decode_str(data, &mut offset)?;
244 let to = decode_str(data, &mut offset)?;
245 let relation_type = decode_str(data, &mut offset)?;
246 Some((from, to, relation_type))
247}
248
249pub fn encode_add_observations(buf: &mut Vec<u8>, name: &str, observations: &[String]) -> std::io::Result<()> {
250 encode_str(buf, name)?;
251 buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
252 for obs in observations {
253 encode_str(buf, obs)?;
254 }
255 Ok(())
256}
257
258pub fn decode_add_observations(data: &[u8]) -> Option<(&str, Vec<&str>)> {
259 let mut offset = 0;
260 let name = decode_str(data, &mut offset)?;
261 let count = decode_count(data, &mut offset)?;
262 let mut observations = Vec::with_capacity(count);
263 for _ in 0..count {
264 observations.push(decode_str(data, &mut offset)?);
265 }
266 Some((name, observations))
267}
268
269pub fn encode_delete_entity(buf: &mut Vec<u8>, name: &str) -> std::io::Result<()> {
270 encode_str(buf, name)
271}
272
273pub fn decode_delete_entity(data: &[u8]) -> Option<&str> {
274 let mut offset = 0;
275 decode_str(data, &mut offset)
276}
277
278pub fn encode_delete_observations(buf: &mut Vec<u8>, name: &str, observations: &[String]) -> std::io::Result<()> {
279 encode_str(buf, name)?;
280 buf.extend_from_slice(&(observations.len() as u32).to_le_bytes());
281 for obs in observations {
282 encode_str(buf, obs)?;
283 }
284 Ok(())
285}
286
287pub fn decode_delete_observations(data: &[u8]) -> Option<(&str, Vec<&str>)> {
288 decode_add_observations(data)
289}
290
291pub fn encode_delete_relation(buf: &mut Vec<u8>, from: &str, to: &str, relation_type: &str) -> std::io::Result<()> {
292 encode_str(buf, from)?;
293 encode_str(buf, to)?;
294 encode_str(buf, relation_type)
295}
296
297pub fn decode_delete_relation(data: &[u8]) -> Option<(&str, &str, &str)> {
298 decode_create_relation(data)
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use std::sync::atomic::{AtomicU64, Ordering};
305
306 static COUNTER: AtomicU64 = AtomicU64::new(0);
307
308 fn tmp_path() -> PathBuf {
309 let pid = std::process::id();
310 let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
311 std::env::temp_dir().join(format!("mcp_store_test_{pid}_{seq}.bin"))
312 }
313
314 #[test]
315 fn test_write_and_replay() {
316 let path = tmp_path();
317 let mut store = BinaryStore::new(&path).unwrap();
318
319 let mut buf = Vec::new();
320 encode_create_entity(&mut buf, "Alice", "person", &["likes coffee".into()]).unwrap();
321 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
322
323 buf.clear();
324 encode_create_entity(&mut buf, "Bob", "person", &[]).unwrap();
325 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
326
327 drop(store);
328
329 let mut replayed: Vec<(RecordKind, Vec<u8>)> = Vec::new();
330 let replay_store = BinaryStore::new(&path).unwrap();
331 replay_store
332 .replay(|kind, data| {
333 replayed.push((kind, data.to_vec()));
334 })
335 .unwrap();
336
337 assert_eq!(replayed.len(), 2);
338 assert_eq!(replayed[0].0, RecordKind::CreateEntity);
339 assert_eq!(
340 decode_create_entity(&replayed[0].1).unwrap().0,
341 "Alice"
342 );
343
344 let _ = std::fs::remove_file(&path);
345 }
346
347 #[test]
348 fn test_encode_decode_roundtrip() {
349 let mut buf = Vec::new();
350 encode_create_entity(
351 &mut buf,
352 "TestEntity",
353 "test_type",
354 &["obs1".into(), "obs2".into()],
355 )
356 .unwrap();
357 let (name, etype, obs) = decode_create_entity(&buf).unwrap();
358 assert_eq!(name, "TestEntity");
359 assert_eq!(etype, "test_type");
360 assert_eq!(obs, vec!["obs1", "obs2"]);
361 }
362
363 #[test]
364 fn test_empty_file() {
365 let path = tmp_path();
366 let store = BinaryStore::new(&path).unwrap();
367 drop(store);
368
369 let mut count = 0;
370 let replay_store = BinaryStore::new(&path).unwrap();
371 replay_store.replay(|_, _| count += 1).unwrap();
372 assert_eq!(count, 0);
373 let _ = std::fs::remove_file(&path);
374 }
375
376 #[test]
377 fn test_write_all_record_kinds() {
378 let path = tmp_path();
379 let mut store = BinaryStore::new(&path).unwrap();
380 let mut buf = Vec::new();
381
382 encode_create_entity(&mut buf, "E1", "t1", &["o1".into()]).unwrap();
384 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
385
386 buf.clear();
387 encode_create_relation(&mut buf, "E1", "E2", "knows").unwrap();
388 store.write_record(RecordKind::CreateRelation, &buf).unwrap();
389
390 buf.clear();
391 encode_add_observations(&mut buf, "E1", &["o2".into()]).unwrap();
392 store.write_record(RecordKind::AddObservations, &buf).unwrap();
393
394 buf.clear();
395 encode_delete_entity(&mut buf, "E1").unwrap();
396 store.write_record(RecordKind::DeleteEntity, &buf).unwrap();
397
398 buf.clear();
399 encode_delete_observations(&mut buf, "E1", &["o1".into()]).unwrap();
400 store.write_record(RecordKind::DeleteObservations, &buf).unwrap();
401
402 buf.clear();
403 encode_delete_relation(&mut buf, "E1", "E2", "knows").unwrap();
404 store.write_record(RecordKind::DeleteRelation, &buf).unwrap();
405
406 drop(store);
407
408 let mut kinds = Vec::new();
409 let replay_store = BinaryStore::new(&path).unwrap();
410 replay_store
411 .replay(|kind, _| {
412 kinds.push(kind);
413 })
414 .unwrap();
415
416 assert_eq!(kinds.len(), 6);
417 assert_eq!(kinds[0], RecordKind::CreateEntity);
418 assert_eq!(kinds[1], RecordKind::CreateRelation);
419 assert_eq!(kinds[2], RecordKind::AddObservations);
420 assert_eq!(kinds[3], RecordKind::DeleteEntity);
421 assert_eq!(kinds[4], RecordKind::DeleteObservations);
422 assert_eq!(kinds[5], RecordKind::DeleteRelation);
423 let _ = std::fs::remove_file(&path);
424 }
425
426 #[test]
427 fn test_reopen_truncated() {
428 let path = tmp_path();
429 let mut store = BinaryStore::new(&path).unwrap();
430 let mut buf = Vec::new();
431 encode_create_entity(&mut buf, "E1", "t1", &[]).unwrap();
432 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
433 drop(store);
434
435 let mut store2 = BinaryStore::new(&path).unwrap();
437 store2.reopen_truncated().unwrap();
438
439 let mut buf2 = Vec::new();
440 encode_create_entity(&mut buf2, "E2", "t2", &[]).unwrap();
441 store2.write_record(RecordKind::CreateEntity, &buf2).unwrap();
442 drop(store2);
443
444 let mut names = Vec::new();
445 let replay_store = BinaryStore::new(&path).unwrap();
446 replay_store
447 .replay(|_, data| {
448 if let Some((name, _, _)) = decode_create_entity(data) {
449 names.push(name.to_string());
450 }
451 })
452 .unwrap();
453
454 assert_eq!(names, vec!["E2"]);
456 let _ = std::fs::remove_file(&path);
457 }
458
459 #[test]
460 fn test_encode_decode_add_observations() {
461 let mut buf = Vec::new();
462 encode_add_observations(&mut buf, "Alice", &["obs1".into(), "obs2".into()]).unwrap();
463 let (name, obs) = decode_add_observations(&buf).unwrap();
464 assert_eq!(name, "Alice");
465 assert_eq!(obs, vec!["obs1", "obs2"]);
466 }
467
468 #[test]
469 fn test_encode_decode_delete_entity() {
470 let mut buf = Vec::new();
471 encode_delete_entity(&mut buf, "ToDelete").unwrap();
472 let name = decode_delete_entity(&buf).unwrap();
473 assert_eq!(name, "ToDelete");
474 }
475
476 #[test]
477 fn test_encode_decode_delete_observations() {
478 let mut buf = Vec::new();
479 encode_delete_observations(&mut buf, "Alice", &["o1".into()]).unwrap();
480 let (name, obs) = decode_delete_observations(&buf).unwrap();
481 assert_eq!(name, "Alice");
482 assert_eq!(obs, vec!["o1"]);
483 }
484
485 #[test]
486 fn test_encode_decode_delete_relation() {
487 let mut buf = Vec::new();
488 encode_delete_relation(&mut buf, "A", "B", "knows").unwrap();
489 let (from, to, rtype) = decode_delete_relation(&buf).unwrap();
490 assert_eq!(from, "A");
491 assert_eq!(to, "B");
492 assert_eq!(rtype, "knows");
493 }
494
495 #[test]
496 fn test_record_too_large() {
497 let path = tmp_path();
498 let mut store = BinaryStore::new(&path).unwrap();
499 let huge = vec![0u8; (1 << 20) + 1];
500 let result = store.write_record(RecordKind::CreateEntity, &huge);
501 assert!(result.is_err());
502 let _ = std::fs::remove_file(&path);
503 }
504
505 #[test]
506 fn test_multiple_writes_and_replay() {
507 let path = tmp_path();
508 let mut store = BinaryStore::new(&path).unwrap();
509 for i in 0..100 {
510 let mut buf = Vec::new();
511 encode_create_entity(&mut buf, &format!("E{i}"), "type", &[]).unwrap();
512 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
513 }
514 drop(store);
515
516 let mut count = 0;
517 let replay_store = BinaryStore::new(&path).unwrap();
518 replay_store
519 .replay(|kind, _| {
520 assert_eq!(kind, RecordKind::CreateEntity);
521 count += 1;
522 })
523 .unwrap();
524 assert_eq!(count, 100);
525 let _ = std::fs::remove_file(&path);
526 }
527
528 #[test]
529 fn test_truncated_log_handling() {
530 let path = tmp_path();
531 let mut store = BinaryStore::new(&path).unwrap();
532 let mut buf = Vec::new();
533 encode_create_entity(&mut buf, "Alice", "person", &[]).unwrap();
534 store.write_record(RecordKind::CreateEntity, &buf).unwrap();
535 drop(store);
536
537 let file = OpenOptions::new().write(true).open(&path).unwrap();
539 file.set_len(10).unwrap(); drop(file);
541
542 let replay_store = BinaryStore::new(&path).unwrap();
544 let mut count = 0;
545 replay_store.replay(|_, _| count += 1).unwrap();
546 assert_eq!(count, 0);
547 let _ = std::fs::remove_file(&path);
548 }
549}