1use std::path::Path;
4
5use mentedb_core::MemoryNode;
6use mentedb_core::error::{MenteError, MenteResult};
7
8use parking_lot::Mutex;
9use tracing::info;
10
11use crate::buffer::BufferPool;
12use crate::page::{PAGE_DATA_SIZE, Page, PageId, PageManager, PageType};
13use crate::wal::{Wal, WalEntryType};
14const DEFAULT_BUFFER_POOL_SIZE: usize = 1024;
16
17pub struct StorageEngine {
35 page_manager: Mutex<PageManager>,
36 buffer_pool: BufferPool,
37 wal: Mutex<Wal>,
38}
39
40impl StorageEngine {
41 pub fn open(path: &Path) -> MenteResult<Self> {
56 std::fs::create_dir_all(path)?;
57
58 let page_manager = PageManager::open(path)?;
59 let buffer_pool = BufferPool::new(DEFAULT_BUFFER_POOL_SIZE);
60 let wal = Wal::open(path)?;
61
62 let engine = Self {
63 page_manager: Mutex::new(page_manager),
64 buffer_pool,
65 wal: Mutex::new(wal),
66 };
67
68 let recovered = engine.recover()?;
69 if recovered > 0 {
70 info!(recovered, ?path, "storage engine opened with WAL recovery");
71 } else {
72 info!(?path, "storage engine opened");
73 }
74
75 Ok(engine)
76 }
77
78 pub fn recover(&self) -> MenteResult<usize> {
83 let mut wal = self.wal.lock();
84 wal.lock_exclusive()?;
85 let entries = wal.iterate()?;
86 let mut count = 0usize;
87 let mut pm = self.page_manager.lock();
88
89 pm.reload_header()?;
91
92 for entry in &entries {
93 match entry.entry_type {
94 WalEntryType::PageWrite => {
95 let page_id = PageId(entry.page_id);
96
97 while pm.page_count() <= entry.page_id {
98 pm.allocate_page()?;
99 }
100
101 let mut page = pm.read_page(page_id)?;
102 let copy_len = entry.data.len().min(PAGE_DATA_SIZE);
103 page.data[..copy_len].copy_from_slice(&entry.data[..copy_len]);
104 if copy_len < PAGE_DATA_SIZE {
105 page.data[copy_len..].fill(0);
106 }
107 page.header.page_id = entry.page_id;
108 page.header.lsn = entry.lsn;
109 page.header.page_type = PageType::Data as u8;
110 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
111 page.header.checksum = page.compute_checksum();
112
113 pm.write_page(page_id, &page)?;
114 count += 1;
115 }
116 WalEntryType::Checkpoint | WalEntryType::Commit => {}
117 }
118 }
119
120 if count > 0 {
121 pm.sync()?;
122 let next_lsn = wal.next_lsn();
123 wal.truncate(next_lsn)?;
124 info!(count, "WAL recovery replayed entries");
125 }
126
127 wal.unlock()?;
128 Ok(count)
129 }
130
131 pub fn close(&self) -> MenteResult<()> {
142 let mut pm = self.page_manager.lock();
143 self.buffer_pool.flush_all(&mut pm)?;
144 pm.sync()?;
145 self.wal.lock().sync()?;
146 info!("storage engine closed");
147 Ok(())
148 }
149
150 pub fn allocate_page(&self) -> MenteResult<PageId> {
157 self.page_manager.lock().allocate_page()
158 }
159
160 pub fn read_page(&self, page_id: PageId) -> MenteResult<Box<Page>> {
162 self.buffer_pool
163 .fetch_page(page_id, &mut self.page_manager.lock())
164 }
165
166 pub fn write_page(&self, page_id: PageId, data: &[u8]) -> MenteResult<()> {
171 let lsn = {
172 let mut wal = self.wal.lock();
173 wal.lock_exclusive()?;
174 wal.reload_lsn()?;
175 let lsn = wal.append(WalEntryType::PageWrite, page_id.0, data)?;
176 wal.sync()?;
177 wal.unlock()?;
178 lsn
179 };
180
181 self.apply_page_write(page_id, data, lsn)
182 }
183
184 fn apply_page_write(&self, page_id: PageId, data: &[u8], lsn: u64) -> MenteResult<()> {
186 let mut pm = self.page_manager.lock();
187 let mut page = self.buffer_pool.fetch_page(page_id, &mut pm)?;
188 drop(pm);
189
190 let copy_len = data.len().min(PAGE_DATA_SIZE);
191 page.data[..copy_len].copy_from_slice(&data[..copy_len]);
192 if copy_len < PAGE_DATA_SIZE {
193 page.data[copy_len..].fill(0);
194 }
195 page.header.lsn = lsn;
196 page.header.page_type = PageType::Data as u8;
197 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
198 page.header.checksum = page.compute_checksum();
199
200 if self.buffer_pool.update_page(page_id, &page).is_err() {
201 self.page_manager.lock().write_page(page_id, &page)?;
202 }
203 self.buffer_pool.unpin_page(page_id, true).ok();
204
205 Ok(())
206 }
207
208 pub fn store_memory(&self, node: &MemoryNode) -> MenteResult<PageId> {
232 let serialized =
233 serde_json::to_vec(node).map_err(|e| MenteError::Serialization(e.to_string()))?;
234
235 if serialized.len() + 4 > PAGE_DATA_SIZE {
236 return Err(MenteError::CapacityExceeded(format!(
237 "memory node serialized to {} bytes (max {})",
238 serialized.len(),
239 PAGE_DATA_SIZE - 4,
240 )));
241 }
242
243 let mut buf = Vec::with_capacity(4 + serialized.len());
244 buf.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
245 buf.extend_from_slice(&serialized);
246
247 let (page_id, lsn) = {
249 let mut wal = self.wal.lock();
250 let mut pm = self.page_manager.lock();
251
252 wal.lock_exclusive()?;
254 pm.reload_header()?;
255 wal.reload_lsn()?;
256
257 let page_id = pm.allocate_page()?;
259
260 let lsn = wal.append(WalEntryType::PageWrite, page_id.0, &buf)?;
262 wal.sync()?;
263
264 let mut page = Page::zeroed();
266 page.header.page_id = page_id.0;
267 let copy_len = buf.len().min(PAGE_DATA_SIZE);
268 page.data[..copy_len].copy_from_slice(&buf[..copy_len]);
269 page.header.lsn = lsn;
270 page.header.page_type = PageType::Data as u8;
271 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
272 page.header.checksum = page.compute_checksum();
273 pm.write_page(page_id, &page)?;
274 pm.sync()?;
275
276 wal.unlock()?;
278
279 (page_id, lsn)
280 };
281
282 let _ = lsn; info!(
286 page_id = page_id.0,
287 bytes = serialized.len(),
288 "stored memory node"
289 );
290 Ok(page_id)
291 }
292
293 pub fn load_memory(&self, page_id: PageId) -> MenteResult<MemoryNode> {
305 let page = self.read_page(page_id)?;
306 self.buffer_pool.unpin_page(page_id, false).ok();
307
308 let len = u32::from_le_bytes(page.data[..4].try_into().unwrap()) as usize;
309 if len == 0 || len + 4 > PAGE_DATA_SIZE {
310 return Err(MenteError::Storage(format!(
311 "invalid memory node length prefix: {len}"
312 )));
313 }
314
315 serde_json::from_slice(&page.data[4..4 + len])
316 .map_err(|e| MenteError::Serialization(e.to_string()))
317 }
318
319 pub fn checkpoint(&self) -> MenteResult<()> {
333 let mut wal = self.wal.lock();
334 let mut pm = self.page_manager.lock();
335
336 wal.lock_exclusive()?;
337 wal.reload_lsn()?;
338
339 self.buffer_pool.flush_all(&mut pm)?;
340 pm.sync()?;
341
342 let lsn = wal.append(WalEntryType::Checkpoint, 0, &[])?;
343 wal.sync()?;
344 wal.truncate(lsn)?;
345 wal.unlock()?;
346
347 info!(lsn, "checkpoint complete");
348 Ok(())
349 }
350
351 pub fn scan_all_memories(&self) -> Vec<(mentedb_core::types::MemoryId, PageId)> {
368 let mut pm = self.page_manager.lock();
369 let _ = pm.reload_header();
371 let count = pm.page_count();
372 drop(pm);
373
374 let mut results = Vec::new();
375 for i in 1..count {
376 let page_id = PageId(i);
377 if let Ok(node) = self.load_memory(page_id) {
378 results.push((node.id, page_id));
379 }
380 }
381 results
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use mentedb_core::memory::MemoryType;
389 use mentedb_core::types::AgentId;
390
391 fn setup() -> (tempfile::TempDir, StorageEngine) {
392 let dir = tempfile::tempdir().unwrap();
393 let engine = StorageEngine::open(dir.path()).unwrap();
394 (dir, engine)
395 }
396
397 #[test]
398 fn test_allocate_write_read() {
399 let (_dir, engine) = setup();
400
401 let pid = engine.allocate_page().unwrap();
402 engine.write_page(pid, b"hello storage engine").unwrap();
403
404 let page = engine.read_page(pid).unwrap();
405 assert_eq!(&page.data[..20], b"hello storage engine");
406 engine.buffer_pool.unpin_page(pid, false).ok();
407 }
408
409 #[test]
410 fn test_store_and_load_memory() {
411 let (_dir, engine) = setup();
412
413 let node = MemoryNode::new(
414 AgentId::new(),
415 MemoryType::Episodic,
416 "The user prefers Rust over Go".to_string(),
417 vec![0.1, 0.2, 0.3, 0.4],
418 );
419
420 let page_id = engine.store_memory(&node).unwrap();
421 let loaded = engine.load_memory(page_id).unwrap();
422
423 assert_eq!(node.id, loaded.id);
424 assert_eq!(node.content, loaded.content);
425 assert_eq!(node.embedding, loaded.embedding);
426 assert_eq!(node.memory_type, loaded.memory_type);
427 }
428
429 #[test]
430 fn test_checkpoint() {
431 let (_dir, engine) = setup();
432
433 let node = MemoryNode::new(
434 AgentId::new(),
435 MemoryType::Semantic,
436 "checkpoint test".to_string(),
437 vec![1.0, 2.0],
438 );
439
440 let pid = engine.store_memory(&node).unwrap();
441 engine.checkpoint().unwrap();
442
443 let loaded = engine.load_memory(pid).unwrap();
444 assert_eq!(loaded.content, "checkpoint test");
445 }
446
447 #[test]
448 fn test_close_and_reopen() {
449 let dir = tempfile::tempdir().unwrap();
450 let pid;
451 {
452 let engine = StorageEngine::open(dir.path()).unwrap();
453 let node = MemoryNode::new(
454 AgentId::new(),
455 MemoryType::Procedural,
456 "persist across close".to_string(),
457 vec![0.5],
458 );
459 pid = engine.store_memory(&node).unwrap();
460 engine.close().unwrap();
461 }
462 {
463 let engine = StorageEngine::open(dir.path()).unwrap();
464 let loaded = engine.load_memory(pid).unwrap();
465 assert_eq!(loaded.content, "persist across close");
466 }
467 }
468
469 #[test]
470 fn test_crash_recovery() {
471 let dir = tempfile::tempdir().unwrap();
472 let mut ids = Vec::new();
473 let mut contents = Vec::new();
474 {
475 let engine = StorageEngine::open(dir.path()).unwrap();
476 for i in 0..3 {
477 let content = format!("crash-recovery-{i}");
478 let node = MemoryNode::new(
479 AgentId::new(),
480 MemoryType::Episodic,
481 content.clone(),
482 vec![i as f32],
483 );
484 let pid = engine.store_memory(&node).unwrap();
485 ids.push(pid);
486 contents.push(content);
487 }
488 engine.wal.lock().sync().unwrap();
490 }
491 {
492 let engine = StorageEngine::open(dir.path()).unwrap();
493 for (pid, expected) in ids.iter().zip(contents.iter()) {
494 let loaded = engine.load_memory(*pid).unwrap();
495 assert_eq!(&loaded.content, expected);
496 }
497 }
498 }
499
500 #[test]
501 fn test_recovery_idempotent() {
502 let dir = tempfile::tempdir().unwrap();
503 let pid;
504 let content = "idempotent-check".to_string();
505 {
506 let engine = StorageEngine::open(dir.path()).unwrap();
507 let node = MemoryNode::new(
508 AgentId::new(),
509 MemoryType::Semantic,
510 content.clone(),
511 vec![1.0, 2.0],
512 );
513 pid = engine.store_memory(&node).unwrap();
514 engine.checkpoint().unwrap();
515 engine.close().unwrap();
516 }
517 {
518 let engine = StorageEngine::open(dir.path()).unwrap();
519 let loaded = engine.load_memory(pid).unwrap();
520 assert_eq!(loaded.content, content);
521 }
522 }
523
524 #[test]
525 fn test_partial_write_recovery() {
526 let dir = tempfile::tempdir().unwrap();
527 let mut ids = Vec::new();
528 let mut contents = Vec::new();
529 {
530 let engine = StorageEngine::open(dir.path()).unwrap();
531 for i in 0..3 {
532 let content = format!("checkpointed-{i}");
533 let node = MemoryNode::new(
534 AgentId::new(),
535 MemoryType::Semantic,
536 content.clone(),
537 vec![i as f32],
538 );
539 let pid = engine.store_memory(&node).unwrap();
540 ids.push(pid);
541 contents.push(content);
542 }
543 engine.checkpoint().unwrap();
544
545 for i in 3..5 {
546 let content = format!("unckeckpointed-{i}");
547 let node = MemoryNode::new(
548 AgentId::new(),
549 MemoryType::Episodic,
550 content.clone(),
551 vec![i as f32],
552 );
553 let pid = engine.store_memory(&node).unwrap();
554 ids.push(pid);
555 contents.push(content);
556 }
557 engine.wal.lock().sync().unwrap();
559 }
560 {
561 let engine = StorageEngine::open(dir.path()).unwrap();
562 for (pid, expected) in ids.iter().zip(contents.iter()) {
563 let loaded = engine.load_memory(*pid).unwrap();
564 assert_eq!(&loaded.content, expected);
565 }
566 }
567 }
568
569 #[test]
570 fn test_concurrent_open_no_lock_conflict() {
571 let dir = tempfile::tempdir().unwrap();
572
573 let engine1 = StorageEngine::open(dir.path()).unwrap();
576 let engine2 = StorageEngine::open(dir.path()).unwrap();
577
578 let node1 = MemoryNode::new(
580 AgentId::new(),
581 MemoryType::Episodic,
582 "from engine 1".to_string(),
583 vec![1.0],
584 );
585 let node2 = MemoryNode::new(
586 AgentId::new(),
587 MemoryType::Episodic,
588 "from engine 2".to_string(),
589 vec![2.0],
590 );
591
592 let pid1 = engine1.store_memory(&node1).unwrap();
593 let pid2 = engine2.store_memory(&node2).unwrap();
594
595 let loaded1 = engine1.load_memory(pid1).unwrap();
597 assert_eq!(loaded1.content, "from engine 1");
598
599 let loaded2 = engine2.load_memory(pid2).unwrap();
600 assert_eq!(loaded2.content, "from engine 2");
601 }
602
603 #[test]
604 fn test_concurrent_writes_from_threads() {
605 use std::sync::Arc;
606 let dir = tempfile::tempdir().unwrap();
607 let engine = Arc::new(StorageEngine::open(dir.path()).unwrap());
608
609 let handles: Vec<_> = (0..10)
610 .map(|i| {
611 let eng = Arc::clone(&engine);
612 std::thread::spawn(move || {
613 let node = MemoryNode::new(
614 AgentId::new(),
615 MemoryType::Episodic,
616 format!("thread-{i}"),
617 vec![i as f32],
618 );
619 eng.store_memory(&node).unwrap()
620 })
621 })
622 .collect();
623
624 let pids: Vec<PageId> = handles.into_iter().map(|h| h.join().unwrap()).collect();
625
626 for (i, pid) in pids.iter().enumerate() {
628 let loaded = engine.load_memory(*pid).unwrap();
629 assert_eq!(loaded.content, format!("thread-{i}"));
630 }
631 }
632}