1use std::path::Path;
4
5use mentedb_core::MemoryNode;
6use mentedb_core::error::{MenteError, MenteResult};
7
8use parking_lot::Mutex;
9use tracing::info;
10
11use crate::buffer::BufferPool;
12use crate::page::{PAGE_DATA_SIZE, Page, PageId, PageManager, PageType};
13use crate::wal::{Wal, WalEntryType};
14const DEFAULT_BUFFER_POOL_SIZE: usize = 1024;
16
17const WAL_AUTO_CHECKPOINT_BYTES: u64 = 8 * 1024 * 1024;
19
20pub struct StorageEngine {
38 page_manager: Mutex<PageManager>,
39 buffer_pool: BufferPool,
40 wal: Mutex<Wal>,
41}
42
43impl StorageEngine {
44 pub fn open(path: &Path) -> MenteResult<Self> {
59 std::fs::create_dir_all(path)?;
60
61 let page_manager = PageManager::open(path)?;
62 let buffer_pool = BufferPool::new(DEFAULT_BUFFER_POOL_SIZE);
63 let wal = Wal::open(path)?;
64
65 let engine = Self {
66 page_manager: Mutex::new(page_manager),
67 buffer_pool,
68 wal: Mutex::new(wal),
69 };
70
71 let recovered = engine.recover()?;
72 if recovered > 0 {
73 info!(recovered, ?path, "storage engine opened with WAL recovery");
74 } else {
75 info!(?path, "storage engine opened");
76 }
77
78 Ok(engine)
79 }
80
81 pub fn recover(&self) -> MenteResult<usize> {
86 let mut wal = self.wal.lock();
87 wal.lock_exclusive()?;
88 let entries = wal.iterate()?;
89 let mut count = 0usize;
90 let mut pm = self.page_manager.lock();
91
92 pm.reload_header()?;
94
95 for entry in &entries {
96 match entry.entry_type {
97 WalEntryType::PageWrite => {
98 let page_id = PageId(entry.page_id);
99
100 while pm.page_count() <= entry.page_id {
101 pm.allocate_page()?;
102 }
103
104 let mut page = pm.read_page(page_id)?;
105 let copy_len = entry.data.len().min(PAGE_DATA_SIZE);
106 page.data[..copy_len].copy_from_slice(&entry.data[..copy_len]);
107 if copy_len < PAGE_DATA_SIZE {
108 page.data[copy_len..].fill(0);
109 }
110 page.header.page_id = entry.page_id;
111 page.header.lsn = entry.lsn;
112 page.header.page_type = PageType::Data as u8;
113 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
114 page.header.checksum = page.compute_checksum();
115
116 pm.write_page(page_id, &page)?;
117 count += 1;
118 }
119 WalEntryType::Checkpoint | WalEntryType::Commit => {}
120 }
121 }
122
123 if count > 0 {
124 pm.sync()?;
125 let next_lsn = wal.next_lsn();
126 wal.truncate(next_lsn)?;
127 info!(count, "WAL recovery replayed entries");
128 }
129
130 wal.unlock()?;
131 Ok(count)
132 }
133
134 pub fn close(&self) -> MenteResult<()> {
145 let mut pm = self.page_manager.lock();
146 self.buffer_pool.flush_all(&mut pm)?;
147 pm.sync()?;
148 self.wal.lock().sync()?;
149 info!("storage engine closed");
150 Ok(())
151 }
152
153 pub fn allocate_page(&self) -> MenteResult<PageId> {
160 self.page_manager.lock().allocate_page()
161 }
162
163 pub fn read_page(&self, page_id: PageId) -> MenteResult<Box<Page>> {
165 self.buffer_pool
166 .fetch_page(page_id, &mut self.page_manager.lock())
167 }
168
169 pub fn write_page(&self, page_id: PageId, data: &[u8]) -> MenteResult<()> {
174 let lsn = {
175 let mut wal = self.wal.lock();
176 wal.lock_exclusive()?;
177 wal.reload_lsn()?;
178 let lsn = wal.append(WalEntryType::PageWrite, page_id.0, data)?;
179 wal.sync()?;
180 wal.unlock()?;
181 lsn
182 };
183
184 self.apply_page_write(page_id, data, lsn)
185 }
186
187 fn apply_page_write(&self, page_id: PageId, data: &[u8], lsn: u64) -> MenteResult<()> {
189 let mut pm = self.page_manager.lock();
190 let mut page = self.buffer_pool.fetch_page(page_id, &mut pm)?;
191 drop(pm);
192
193 let copy_len = data.len().min(PAGE_DATA_SIZE);
194 page.data[..copy_len].copy_from_slice(&data[..copy_len]);
195 if copy_len < PAGE_DATA_SIZE {
196 page.data[copy_len..].fill(0);
197 }
198 page.header.lsn = lsn;
199 page.header.page_type = PageType::Data as u8;
200 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
201 page.header.checksum = page.compute_checksum();
202
203 if self.buffer_pool.update_page(page_id, &page).is_err() {
204 self.page_manager.lock().write_page(page_id, &page)?;
205 }
206 self.buffer_pool.unpin_page(page_id, true).ok();
207
208 Ok(())
209 }
210
211 pub fn store_memory(&self, node: &MemoryNode) -> MenteResult<PageId> {
235 let serialized =
236 serde_json::to_vec(node).map_err(|e| MenteError::Serialization(e.to_string()))?;
237
238 if serialized.len() + 4 > PAGE_DATA_SIZE {
239 return Err(MenteError::CapacityExceeded(format!(
240 "memory node serialized to {} bytes (max {})",
241 serialized.len(),
242 PAGE_DATA_SIZE - 4,
243 )));
244 }
245
246 let mut buf = Vec::with_capacity(4 + serialized.len());
247 buf.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
248 buf.extend_from_slice(&serialized);
249
250 let (page_id, lsn) = {
252 let mut wal = self.wal.lock();
253 let mut pm = self.page_manager.lock();
254
255 wal.lock_exclusive()?;
257 pm.reload_header()?;
258 wal.reload_lsn()?;
259
260 let page_id = pm.allocate_page()?;
262
263 let lsn = wal.append(WalEntryType::PageWrite, page_id.0, &buf)?;
266 wal.sync()?;
267
268 let mut page = Page::zeroed();
270 page.header.page_id = page_id.0;
271 let copy_len = buf.len().min(PAGE_DATA_SIZE);
272 page.data[..copy_len].copy_from_slice(&buf[..copy_len]);
273 page.header.lsn = lsn;
274 page.header.page_type = PageType::Data as u8;
275 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
276 page.header.checksum = page.compute_checksum();
277 pm.write_page(page_id, &page)?;
278
279 wal.unlock()?;
281
282 (page_id, lsn)
283 };
284
285 let _ = lsn; if self.wal.lock().file_size() > WAL_AUTO_CHECKPOINT_BYTES
291 && let Err(e) = self.checkpoint()
292 {
293 tracing::warn!("auto-checkpoint failed: {e}");
294 }
295
296 info!(
297 page_id = page_id.0,
298 bytes = serialized.len(),
299 "stored memory node"
300 );
301 Ok(page_id)
302 }
303
304 pub fn store_memory_batch(&self, nodes: &[MemoryNode]) -> MenteResult<Vec<PageId>> {
310 let mut bufs = Vec::with_capacity(nodes.len());
312 for node in nodes {
313 let serialized =
314 serde_json::to_vec(node).map_err(|e| MenteError::Serialization(e.to_string()))?;
315 if serialized.len() + 4 > PAGE_DATA_SIZE {
316 return Err(MenteError::CapacityExceeded(format!(
317 "memory node serialized to {} bytes (max {})",
318 serialized.len(),
319 PAGE_DATA_SIZE - 4,
320 )));
321 }
322 let mut buf = Vec::with_capacity(4 + serialized.len());
323 buf.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
324 buf.extend_from_slice(&serialized);
325 bufs.push(buf);
326 }
327
328 let page_ids = {
330 let mut wal = self.wal.lock();
331 let mut pm = self.page_manager.lock();
332
333 wal.lock_exclusive()?;
334 pm.reload_header()?;
335 wal.reload_lsn()?;
336
337 let mut ids = Vec::with_capacity(bufs.len());
338 for buf in &bufs {
339 let page_id = pm.allocate_page()?;
340 let lsn = wal.append(WalEntryType::PageWrite, page_id.0, buf)?;
341
342 let mut page = Page::zeroed();
343 page.header.page_id = page_id.0;
344 let copy_len = buf.len().min(PAGE_DATA_SIZE);
345 page.data[..copy_len].copy_from_slice(&buf[..copy_len]);
346 page.header.lsn = lsn;
347 page.header.page_type = PageType::Data as u8;
348 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
349 page.header.checksum = page.compute_checksum();
350 pm.write_page(page_id, &page)?;
351
352 ids.push(page_id);
353 }
354
355 wal.sync()?;
358 wal.unlock()?;
359
360 ids
361 };
362
363 if self.wal.lock().file_size() > WAL_AUTO_CHECKPOINT_BYTES
365 && let Err(e) = self.checkpoint()
366 {
367 tracing::warn!("auto-checkpoint failed: {e}");
368 }
369
370 info!(count = page_ids.len(), "stored memory batch");
371 Ok(page_ids)
372 }
373
374 pub fn load_memory(&self, page_id: PageId) -> MenteResult<MemoryNode> {
386 let page = self.read_page(page_id)?;
387 self.buffer_pool.unpin_page(page_id, false).ok();
388
389 let len = u32::from_le_bytes(page.data[..4].try_into().unwrap()) as usize;
390 if len == 0 || len + 4 > PAGE_DATA_SIZE {
391 return Err(MenteError::Storage(format!(
392 "invalid memory node length prefix: {len}"
393 )));
394 }
395
396 serde_json::from_slice(&page.data[4..4 + len])
397 .map_err(|e| MenteError::Serialization(e.to_string()))
398 }
399
400 pub fn checkpoint(&self) -> MenteResult<()> {
414 let mut wal = self.wal.lock();
415 let mut pm = self.page_manager.lock();
416
417 wal.lock_exclusive()?;
418 wal.reload_lsn()?;
419
420 self.buffer_pool.flush_all(&mut pm)?;
421 pm.sync()?;
422
423 let lsn = wal.append(WalEntryType::Checkpoint, 0, &[])?;
424 wal.sync()?;
425 wal.truncate(lsn)?;
426 wal.unlock()?;
427
428 info!(lsn, "checkpoint complete");
429 Ok(())
430 }
431
432 pub fn scan_all_memories(&self) -> Vec<(mentedb_core::types::MemoryId, PageId)> {
449 let mut pm = self.page_manager.lock();
450 let _ = pm.reload_header();
452 let count = pm.page_count();
453 drop(pm);
454
455 let mut results = Vec::new();
456 for i in 1..count {
457 let page_id = PageId(i);
458 if let Ok(node) = self.load_memory(page_id) {
459 results.push((node.id, page_id));
460 }
461 }
462 results
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469 use mentedb_core::memory::MemoryType;
470 use mentedb_core::types::AgentId;
471
472 fn setup() -> (tempfile::TempDir, StorageEngine) {
473 let dir = tempfile::tempdir().unwrap();
474 let engine = StorageEngine::open(dir.path()).unwrap();
475 (dir, engine)
476 }
477
478 #[test]
479 fn test_allocate_write_read() {
480 let (_dir, engine) = setup();
481
482 let pid = engine.allocate_page().unwrap();
483 engine.write_page(pid, b"hello storage engine").unwrap();
484
485 let page = engine.read_page(pid).unwrap();
486 assert_eq!(&page.data[..20], b"hello storage engine");
487 engine.buffer_pool.unpin_page(pid, false).ok();
488 }
489
490 #[test]
491 fn test_store_and_load_memory() {
492 let (_dir, engine) = setup();
493
494 let node = MemoryNode::new(
495 AgentId::new(),
496 MemoryType::Episodic,
497 "The user prefers Rust over Go".to_string(),
498 vec![0.1, 0.2, 0.3, 0.4],
499 );
500
501 let page_id = engine.store_memory(&node).unwrap();
502 let loaded = engine.load_memory(page_id).unwrap();
503
504 assert_eq!(node.id, loaded.id);
505 assert_eq!(node.content, loaded.content);
506 assert_eq!(node.embedding, loaded.embedding);
507 assert_eq!(node.memory_type, loaded.memory_type);
508 }
509
510 #[test]
511 fn test_checkpoint() {
512 let (_dir, engine) = setup();
513
514 let node = MemoryNode::new(
515 AgentId::new(),
516 MemoryType::Semantic,
517 "checkpoint test".to_string(),
518 vec![1.0, 2.0],
519 );
520
521 let pid = engine.store_memory(&node).unwrap();
522 engine.checkpoint().unwrap();
523
524 let loaded = engine.load_memory(pid).unwrap();
525 assert_eq!(loaded.content, "checkpoint test");
526 }
527
528 #[test]
529 fn test_close_and_reopen() {
530 let dir = tempfile::tempdir().unwrap();
531 let pid;
532 {
533 let engine = StorageEngine::open(dir.path()).unwrap();
534 let node = MemoryNode::new(
535 AgentId::new(),
536 MemoryType::Procedural,
537 "persist across close".to_string(),
538 vec![0.5],
539 );
540 pid = engine.store_memory(&node).unwrap();
541 engine.close().unwrap();
542 }
543 {
544 let engine = StorageEngine::open(dir.path()).unwrap();
545 let loaded = engine.load_memory(pid).unwrap();
546 assert_eq!(loaded.content, "persist across close");
547 }
548 }
549
550 #[test]
551 fn test_crash_recovery() {
552 let dir = tempfile::tempdir().unwrap();
553 let mut ids = Vec::new();
554 let mut contents = Vec::new();
555 {
556 let engine = StorageEngine::open(dir.path()).unwrap();
557 for i in 0..3 {
558 let content = format!("crash-recovery-{i}");
559 let node = MemoryNode::new(
560 AgentId::new(),
561 MemoryType::Episodic,
562 content.clone(),
563 vec![i as f32],
564 );
565 let pid = engine.store_memory(&node).unwrap();
566 ids.push(pid);
567 contents.push(content);
568 }
569 engine.wal.lock().sync().unwrap();
571 }
572 {
573 let engine = StorageEngine::open(dir.path()).unwrap();
574 for (pid, expected) in ids.iter().zip(contents.iter()) {
575 let loaded = engine.load_memory(*pid).unwrap();
576 assert_eq!(&loaded.content, expected);
577 }
578 }
579 }
580
581 #[test]
582 fn test_recovery_idempotent() {
583 let dir = tempfile::tempdir().unwrap();
584 let pid;
585 let content = "idempotent-check".to_string();
586 {
587 let engine = StorageEngine::open(dir.path()).unwrap();
588 let node = MemoryNode::new(
589 AgentId::new(),
590 MemoryType::Semantic,
591 content.clone(),
592 vec![1.0, 2.0],
593 );
594 pid = engine.store_memory(&node).unwrap();
595 engine.checkpoint().unwrap();
596 engine.close().unwrap();
597 }
598 {
599 let engine = StorageEngine::open(dir.path()).unwrap();
600 let loaded = engine.load_memory(pid).unwrap();
601 assert_eq!(loaded.content, content);
602 }
603 }
604
605 #[test]
606 fn test_partial_write_recovery() {
607 let dir = tempfile::tempdir().unwrap();
608 let mut ids = Vec::new();
609 let mut contents = Vec::new();
610 {
611 let engine = StorageEngine::open(dir.path()).unwrap();
612 for i in 0..3 {
613 let content = format!("checkpointed-{i}");
614 let node = MemoryNode::new(
615 AgentId::new(),
616 MemoryType::Semantic,
617 content.clone(),
618 vec![i as f32],
619 );
620 let pid = engine.store_memory(&node).unwrap();
621 ids.push(pid);
622 contents.push(content);
623 }
624 engine.checkpoint().unwrap();
625
626 for i in 3..5 {
627 let content = format!("unckeckpointed-{i}");
628 let node = MemoryNode::new(
629 AgentId::new(),
630 MemoryType::Episodic,
631 content.clone(),
632 vec![i as f32],
633 );
634 let pid = engine.store_memory(&node).unwrap();
635 ids.push(pid);
636 contents.push(content);
637 }
638 engine.wal.lock().sync().unwrap();
640 }
641 {
642 let engine = StorageEngine::open(dir.path()).unwrap();
643 for (pid, expected) in ids.iter().zip(contents.iter()) {
644 let loaded = engine.load_memory(*pid).unwrap();
645 assert_eq!(&loaded.content, expected);
646 }
647 }
648 }
649
650 #[test]
651 fn test_concurrent_open_no_lock_conflict() {
652 let dir = tempfile::tempdir().unwrap();
653
654 let engine1 = StorageEngine::open(dir.path()).unwrap();
657 let engine2 = StorageEngine::open(dir.path()).unwrap();
658
659 let node1 = MemoryNode::new(
661 AgentId::new(),
662 MemoryType::Episodic,
663 "from engine 1".to_string(),
664 vec![1.0],
665 );
666 let node2 = MemoryNode::new(
667 AgentId::new(),
668 MemoryType::Episodic,
669 "from engine 2".to_string(),
670 vec![2.0],
671 );
672
673 let pid1 = engine1.store_memory(&node1).unwrap();
674 let pid2 = engine2.store_memory(&node2).unwrap();
675
676 let loaded1 = engine1.load_memory(pid1).unwrap();
678 assert_eq!(loaded1.content, "from engine 1");
679
680 let loaded2 = engine2.load_memory(pid2).unwrap();
681 assert_eq!(loaded2.content, "from engine 2");
682 }
683
684 #[test]
685 fn test_concurrent_writes_from_threads() {
686 use std::sync::Arc;
687 let dir = tempfile::tempdir().unwrap();
688 let engine = Arc::new(StorageEngine::open(dir.path()).unwrap());
689
690 let handles: Vec<_> = (0..10)
691 .map(|i| {
692 let eng = Arc::clone(&engine);
693 std::thread::spawn(move || {
694 let node = MemoryNode::new(
695 AgentId::new(),
696 MemoryType::Episodic,
697 format!("thread-{i}"),
698 vec![i as f32],
699 );
700 eng.store_memory(&node).unwrap()
701 })
702 })
703 .collect();
704
705 let pids: Vec<PageId> = handles.into_iter().map(|h| h.join().unwrap()).collect();
706
707 for (i, pid) in pids.iter().enumerate() {
709 let loaded = engine.load_memory(*pid).unwrap();
710 assert_eq!(loaded.content, format!("thread-{i}"));
711 }
712 }
713}