Skip to main content

mentedb_storage/
engine.rs

1//! Storage Engine: facade that ties the page manager, WAL, and buffer pool together.
2
3use std::fs::File;
4use std::path::Path;
5
6use mentedb_core::MemoryNode;
7use mentedb_core::error::{MenteError, MenteResult};
8
9use fs2::FileExt;
10use tracing::info;
11
12use crate::buffer::BufferPool;
13use crate::page::{PAGE_DATA_SIZE, Page, PageId, PageManager, PageType};
14use crate::wal::{Wal, WalEntryType};
15/// Default number of page frames in the buffer pool.
16const DEFAULT_BUFFER_POOL_SIZE: usize = 1024;
17
18/// The unified storage engine for MenteDB.
19///
20/// Coordinates page allocation, caching, and write-ahead logging to provide
21/// crash-safe, page-oriented storage for memory nodes.
22pub struct StorageEngine {
23    page_manager: PageManager,
24    buffer_pool: BufferPool,
25    wal: Wal,
26    /// Exclusive lock file — held for the lifetime of the engine to prevent concurrent access.
27    _lock_file: File,
28}
29
30impl StorageEngine {
31    /// Open (or create) a storage engine rooted at `path`.
32    ///
33    /// `path` must be a directory; it will be created if it does not exist.
34    /// After opening, any uncommitted WAL entries are replayed for crash recovery.
35    pub fn open(path: &Path) -> MenteResult<Self> {
36        std::fs::create_dir_all(path)?;
37
38        // Acquire exclusive lock to prevent concurrent access corruption
39        let lock_path = path.join("mentedb.lock");
40        let lock_file = File::create(&lock_path)
41            .map_err(|e| MenteError::Storage(format!("failed to create lock file: {e}")))?;
42        lock_file.try_lock_exclusive().map_err(|_| {
43            MenteError::Storage(
44                "Database is locked by another process. Only one instance can access the database at a time.".to_string()
45            )
46        })?;
47
48        let page_manager = PageManager::open(path)?;
49        let buffer_pool = BufferPool::new(DEFAULT_BUFFER_POOL_SIZE);
50        let wal = Wal::open(path)?;
51
52        let mut engine = Self {
53            page_manager,
54            buffer_pool,
55            wal,
56            _lock_file: lock_file,
57        };
58
59        let recovered = engine.recover()?;
60        if recovered > 0 {
61            info!(recovered, ?path, "storage engine opened with WAL recovery");
62        } else {
63            info!(?path, "storage engine opened");
64        }
65
66        Ok(engine)
67    }
68
69    /// Replay WAL entries to recover writes that were not checkpointed.
70    ///
71    /// For each `PageWrite` entry the serialized data is written back to its page.
72    /// After replay the WAL is truncated. Returns the number of entries replayed.
73    pub fn recover(&mut self) -> MenteResult<usize> {
74        let entries = self.wal.iterate()?;
75        let mut count = 0usize;
76
77        for entry in &entries {
78            match entry.entry_type {
79                WalEntryType::PageWrite => {
80                    let page_id = PageId(entry.page_id);
81
82                    // Ensure the page file is large enough for this page id.
83                    while self.page_manager.page_count() <= entry.page_id {
84                        self.page_manager.allocate_page()?;
85                    }
86
87                    let mut page = self.page_manager.read_page(page_id)?;
88                    let copy_len = entry.data.len().min(PAGE_DATA_SIZE);
89                    page.data[..copy_len].copy_from_slice(&entry.data[..copy_len]);
90                    if copy_len < PAGE_DATA_SIZE {
91                        page.data[copy_len..].fill(0);
92                    }
93                    page.header.page_id = entry.page_id;
94                    page.header.lsn = entry.lsn;
95                    page.header.page_type = PageType::Data as u8;
96                    page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
97                    page.header.checksum = page.compute_checksum();
98
99                    self.page_manager.write_page(page_id, &page)?;
100                    count += 1;
101                }
102                WalEntryType::Checkpoint | WalEntryType::Commit => {
103                    // No page data to replay for these entry types.
104                }
105            }
106        }
107
108        if count > 0 {
109            self.page_manager.sync()?;
110            // Truncate the entire WAL — all entries have been applied.
111            let next_lsn = self.wal.next_lsn();
112            self.wal.truncate(next_lsn)?;
113            info!(count, "WAL recovery replayed entries");
114        }
115
116        Ok(count)
117    }
118
119    /// Gracefully shut down: flush dirty pages, sync files.
120    pub fn close(&mut self) -> MenteResult<()> {
121        self.buffer_pool.flush_all(&mut self.page_manager)?;
122        self.page_manager.sync()?;
123        self.wal.sync()?;
124        info!("storage engine closed");
125        Ok(())
126    }
127
128    // ---- low-level page operations ----
129
130    /// Allocate a fresh page.
131    pub fn allocate_page(&mut self) -> MenteResult<PageId> {
132        self.page_manager.allocate_page()
133    }
134
135    /// Read a page through the buffer pool.
136    pub fn read_page(&mut self, page_id: PageId) -> MenteResult<Box<Page>> {
137        self.buffer_pool.fetch_page(page_id, &mut self.page_manager)
138    }
139
140    /// Write data into a page with WAL protection.
141    pub fn write_page(&mut self, page_id: PageId, data: &[u8]) -> MenteResult<()> {
142        // WAL-first: log before modifying the page.
143        let lsn = self.wal.append(WalEntryType::PageWrite, page_id.0, data)?;
144
145        // Load the page into the buffer pool (or get cached copy).
146        let mut page = self
147            .buffer_pool
148            .fetch_page(page_id, &mut self.page_manager)?;
149
150        let copy_len = data.len().min(PAGE_DATA_SIZE);
151        page.data[..copy_len].copy_from_slice(&data[..copy_len]);
152        // Zero out remaining space if data is shorter than existing content.
153        if copy_len < PAGE_DATA_SIZE {
154            page.data[copy_len..].fill(0);
155        }
156        page.header.lsn = lsn;
157        page.header.page_type = PageType::Data as u8;
158        page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
159        page.header.checksum = page.compute_checksum();
160
161        // Push modified page back into the buffer pool.
162        if self.buffer_pool.update_page(page_id, &page).is_err() {
163            // Not cached (shouldn't happen after fetch, but be safe).
164            self.page_manager.write_page(page_id, &page)?;
165        }
166        self.buffer_pool.unpin_page(page_id, true).ok();
167
168        Ok(())
169    }
170
171    // ---- high-level memory operations ----
172
173    /// Serialize and store a [`MemoryNode`] into a single page.
174    ///
175    /// Returns the [`PageId`] where the node was stored.
176    pub fn store_memory(&mut self, node: &MemoryNode) -> MenteResult<PageId> {
177        let serialized =
178            serde_json::to_vec(node).map_err(|e| MenteError::Serialization(e.to_string()))?;
179
180        // 4 bytes for the length prefix.
181        if serialized.len() + 4 > PAGE_DATA_SIZE {
182            return Err(MenteError::CapacityExceeded(format!(
183                "memory node serialized to {} bytes (max {})",
184                serialized.len(),
185                PAGE_DATA_SIZE - 4,
186            )));
187        }
188
189        let page_id = self.allocate_page()?;
190
191        // Layout: [length: u32 LE][JSON bytes]
192        let mut buf = Vec::with_capacity(4 + serialized.len());
193        buf.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
194        buf.extend_from_slice(&serialized);
195
196        self.write_page(page_id, &buf)?;
197
198        info!(
199            page_id = page_id.0,
200            bytes = serialized.len(),
201            "stored memory node"
202        );
203        Ok(page_id)
204    }
205
206    /// Load and deserialize a [`MemoryNode`] from the given page.
207    pub fn load_memory(&mut self, page_id: PageId) -> MenteResult<MemoryNode> {
208        let page = self.read_page(page_id)?;
209        // Unpin immediately — we copy the data we need.
210        self.buffer_pool.unpin_page(page_id, false).ok();
211
212        let len = u32::from_le_bytes(page.data[..4].try_into().unwrap()) as usize;
213        if len == 0 || len + 4 > PAGE_DATA_SIZE {
214            return Err(MenteError::Storage(format!(
215                "invalid memory node length prefix: {len}"
216            )));
217        }
218
219        serde_json::from_slice(&page.data[4..4 + len])
220            .map_err(|e| MenteError::Serialization(e.to_string()))
221    }
222
223    // ---- durability ----
224
225    /// Checkpoint: flush all dirty pages, sync to disk, and truncate the WAL.
226    pub fn checkpoint(&mut self) -> MenteResult<()> {
227        self.buffer_pool.flush_all(&mut self.page_manager)?;
228        self.page_manager.sync()?;
229
230        let lsn = self.wal.append(WalEntryType::Checkpoint, 0, &[])?;
231        self.wal.sync()?;
232        self.wal.truncate(lsn)?;
233
234        info!(lsn, "checkpoint complete");
235        Ok(())
236    }
237
238    /// Scan all pages and return (MemoryId, PageId) pairs for every valid memory node.
239    ///
240    /// Used to rebuild the page map on startup.
241    pub fn scan_all_memories(&mut self) -> Vec<(mentedb_core::types::MemoryId, PageId)> {
242        let count = self.page_manager.page_count();
243        let mut results = Vec::new();
244        // Page 0 is the header page, start from 1
245        for i in 1..count {
246            let page_id = PageId(i);
247            if let Ok(node) = self.load_memory(page_id) {
248                results.push((node.id, page_id));
249            }
250        }
251        results
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use mentedb_core::memory::MemoryType;
259    use mentedb_core::types::AgentId;
260
261    fn setup() -> (tempfile::TempDir, StorageEngine) {
262        let dir = tempfile::tempdir().unwrap();
263        let engine = StorageEngine::open(dir.path()).unwrap();
264        (dir, engine)
265    }
266
267    #[test]
268    fn test_allocate_write_read() {
269        let (_dir, mut engine) = setup();
270
271        let pid = engine.allocate_page().unwrap();
272        engine.write_page(pid, b"hello storage engine").unwrap();
273
274        let page = engine.read_page(pid).unwrap();
275        assert_eq!(&page.data[..20], b"hello storage engine");
276        engine.buffer_pool.unpin_page(pid, false).ok();
277    }
278
279    #[test]
280    fn test_store_and_load_memory() {
281        let (_dir, mut engine) = setup();
282
283        let node = MemoryNode::new(
284            AgentId::new(),
285            MemoryType::Episodic,
286            "The user prefers Rust over Go".to_string(),
287            vec![0.1, 0.2, 0.3, 0.4],
288        );
289
290        let page_id = engine.store_memory(&node).unwrap();
291        let loaded = engine.load_memory(page_id).unwrap();
292
293        assert_eq!(node.id, loaded.id);
294        assert_eq!(node.content, loaded.content);
295        assert_eq!(node.embedding, loaded.embedding);
296        assert_eq!(node.memory_type, loaded.memory_type);
297    }
298
299    #[test]
300    fn test_checkpoint() {
301        let (_dir, mut engine) = setup();
302
303        let node = MemoryNode::new(
304            AgentId::new(),
305            MemoryType::Semantic,
306            "checkpoint test".to_string(),
307            vec![1.0, 2.0],
308        );
309
310        let pid = engine.store_memory(&node).unwrap();
311        engine.checkpoint().unwrap();
312
313        // Data should still be readable after checkpoint.
314        let loaded = engine.load_memory(pid).unwrap();
315        assert_eq!(loaded.content, "checkpoint test");
316    }
317
318    #[test]
319    fn test_close_and_reopen() {
320        let dir = tempfile::tempdir().unwrap();
321        let pid;
322        {
323            let mut engine = StorageEngine::open(dir.path()).unwrap();
324            let node = MemoryNode::new(
325                AgentId::new(),
326                MemoryType::Procedural,
327                "persist across close".to_string(),
328                vec![0.5],
329            );
330            pid = engine.store_memory(&node).unwrap();
331            engine.close().unwrap();
332        }
333        {
334            let mut engine = StorageEngine::open(dir.path()).unwrap();
335            let loaded = engine.load_memory(pid).unwrap();
336            assert_eq!(loaded.content, "persist across close");
337        }
338    }
339
340    #[test]
341    fn test_crash_recovery() {
342        let dir = tempfile::tempdir().unwrap();
343        let mut ids = Vec::new();
344        let mut contents = Vec::new();
345        {
346            let mut engine = StorageEngine::open(dir.path()).unwrap();
347            for i in 0..3 {
348                let content = format!("crash-recovery-{i}");
349                let node = MemoryNode::new(
350                    AgentId::new(),
351                    MemoryType::Episodic,
352                    content.clone(),
353                    vec![i as f32],
354                );
355                let pid = engine.store_memory(&node).unwrap();
356                ids.push(pid);
357                contents.push(content);
358            }
359            // Simulate crash: sync the WAL but do NOT call close/checkpoint.
360            engine.wal.sync().unwrap();
361            // Drop without close — dirty pages may not be flushed.
362        }
363        {
364            // Reopen — WAL replay should recover the writes.
365            let mut engine = StorageEngine::open(dir.path()).unwrap();
366            for (pid, expected) in ids.iter().zip(contents.iter()) {
367                let loaded = engine.load_memory(*pid).unwrap();
368                assert_eq!(&loaded.content, expected);
369            }
370        }
371    }
372
373    #[test]
374    fn test_recovery_idempotent() {
375        let dir = tempfile::tempdir().unwrap();
376        let pid;
377        let content = "idempotent-check".to_string();
378        {
379            let mut engine = StorageEngine::open(dir.path()).unwrap();
380            let node = MemoryNode::new(
381                AgentId::new(),
382                MemoryType::Semantic,
383                content.clone(),
384                vec![1.0, 2.0],
385            );
386            pid = engine.store_memory(&node).unwrap();
387            // Proper shutdown — checkpoint flushes pages and truncates WAL.
388            engine.checkpoint().unwrap();
389            engine.close().unwrap();
390        }
391        {
392            // Reopen after clean shutdown — WAL should be empty, no duplicate data.
393            let mut engine = StorageEngine::open(dir.path()).unwrap();
394            let loaded = engine.load_memory(pid).unwrap();
395            assert_eq!(loaded.content, content);
396        }
397    }
398
399    #[test]
400    fn test_partial_write_recovery() {
401        let dir = tempfile::tempdir().unwrap();
402        let mut ids = Vec::new();
403        let mut contents = Vec::new();
404        {
405            let mut engine = StorageEngine::open(dir.path()).unwrap();
406            // Store 3 memories then checkpoint.
407            for i in 0..3 {
408                let content = format!("checkpointed-{i}");
409                let node = MemoryNode::new(
410                    AgentId::new(),
411                    MemoryType::Semantic,
412                    content.clone(),
413                    vec![i as f32],
414                );
415                let pid = engine.store_memory(&node).unwrap();
416                ids.push(pid);
417                contents.push(content);
418            }
419            engine.checkpoint().unwrap();
420
421            // Store 2 more without checkpoint (will only be in WAL).
422            for i in 3..5 {
423                let content = format!("unckeckpointed-{i}");
424                let node = MemoryNode::new(
425                    AgentId::new(),
426                    MemoryType::Episodic,
427                    content.clone(),
428                    vec![i as f32],
429                );
430                let pid = engine.store_memory(&node).unwrap();
431                ids.push(pid);
432                contents.push(content);
433            }
434            // Simulate crash — sync WAL but don't close.
435            engine.wal.sync().unwrap();
436        }
437        {
438            let mut engine = StorageEngine::open(dir.path()).unwrap();
439            // All 5 memories should be present: 3 from checkpoint, 2 from WAL replay.
440            for (pid, expected) in ids.iter().zip(contents.iter()) {
441                let loaded = engine.load_memory(*pid).unwrap();
442                assert_eq!(&loaded.content, expected);
443            }
444        }
445    }
446}