Skip to main content

mentedb_storage/
engine.rs

1//! Storage Engine: facade that ties the page manager, WAL, and buffer pool together.
2
3use std::fs::File;
4use std::path::Path;
5
6use mentedb_core::MemoryNode;
7use mentedb_core::error::{MenteError, MenteResult};
8
9use fs2::FileExt;
10use parking_lot::Mutex;
11use tracing::info;
12
13use crate::buffer::BufferPool;
14use crate::page::{PAGE_DATA_SIZE, Page, PageId, PageManager, PageType};
15use crate::wal::{Wal, WalEntryType};
16/// Default number of page frames in the buffer pool.
17const DEFAULT_BUFFER_POOL_SIZE: usize = 1024;
18
19/// The unified storage engine for MenteDB.
20///
21/// Coordinates page allocation, caching, and write-ahead logging to provide
22/// crash-safe, page-oriented storage for memory nodes.
23///
24/// All internal state is protected by fine-grained locks so every public method
25/// takes `&self`, enabling concurrent reads from multiple threads.
26pub struct StorageEngine {
27    page_manager: Mutex<PageManager>,
28    buffer_pool: BufferPool,
29    wal: Mutex<Wal>,
30    /// Exclusive lock file — held for the lifetime of the engine to prevent concurrent access.
31    _lock_file: File,
32}
33
34impl StorageEngine {
35    /// Open (or create) a storage engine rooted at `path`.
36    ///
37    /// `path` must be a directory; it will be created if it does not exist.
38    /// After opening, any uncommitted WAL entries are replayed for crash recovery.
39    pub fn open(path: &Path) -> MenteResult<Self> {
40        std::fs::create_dir_all(path)?;
41
42        // Acquire exclusive lock to prevent concurrent access corruption
43        let lock_path = path.join("mentedb.lock");
44        let lock_file = File::create(&lock_path)
45            .map_err(|e| MenteError::Storage(format!("failed to create lock file: {e}")))?;
46        lock_file.try_lock_exclusive().map_err(|_| {
47            MenteError::Storage(
48                "Database is locked by another process. Only one instance can access the database at a time.".to_string()
49            )
50        })?;
51
52        let page_manager = PageManager::open(path)?;
53        let buffer_pool = BufferPool::new(DEFAULT_BUFFER_POOL_SIZE);
54        let wal = Wal::open(path)?;
55
56        let engine = Self {
57            page_manager: Mutex::new(page_manager),
58            buffer_pool,
59            wal: Mutex::new(wal),
60            _lock_file: lock_file,
61        };
62
63        let recovered = engine.recover()?;
64        if recovered > 0 {
65            info!(recovered, ?path, "storage engine opened with WAL recovery");
66        } else {
67            info!(?path, "storage engine opened");
68        }
69
70        Ok(engine)
71    }
72
73    /// Replay WAL entries to recover writes that were not checkpointed.
74    ///
75    /// For each `PageWrite` entry the serialized data is written back to its page.
76    /// After replay the WAL is truncated. Returns the number of entries replayed.
77    pub fn recover(&self) -> MenteResult<usize> {
78        let mut wal = self.wal.lock();
79        let entries = wal.iterate()?;
80        let mut count = 0usize;
81        let mut pm = self.page_manager.lock();
82
83        for entry in &entries {
84            match entry.entry_type {
85                WalEntryType::PageWrite => {
86                    let page_id = PageId(entry.page_id);
87
88                    while pm.page_count() <= entry.page_id {
89                        pm.allocate_page()?;
90                    }
91
92                    let mut page = pm.read_page(page_id)?;
93                    let copy_len = entry.data.len().min(PAGE_DATA_SIZE);
94                    page.data[..copy_len].copy_from_slice(&entry.data[..copy_len]);
95                    if copy_len < PAGE_DATA_SIZE {
96                        page.data[copy_len..].fill(0);
97                    }
98                    page.header.page_id = entry.page_id;
99                    page.header.lsn = entry.lsn;
100                    page.header.page_type = PageType::Data as u8;
101                    page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
102                    page.header.checksum = page.compute_checksum();
103
104                    pm.write_page(page_id, &page)?;
105                    count += 1;
106                }
107                WalEntryType::Checkpoint | WalEntryType::Commit => {}
108            }
109        }
110
111        if count > 0 {
112            pm.sync()?;
113            let next_lsn = wal.next_lsn();
114            wal.truncate(next_lsn)?;
115            info!(count, "WAL recovery replayed entries");
116        }
117
118        Ok(count)
119    }
120
121    /// Gracefully shut down: flush dirty pages, sync files.
122    pub fn close(&self) -> MenteResult<()> {
123        let mut pm = self.page_manager.lock();
124        self.buffer_pool.flush_all(&mut pm)?;
125        pm.sync()?;
126        self.wal.lock().sync()?;
127        info!("storage engine closed");
128        Ok(())
129    }
130
131    // ---- low-level page operations ----
132
133    /// Allocate a fresh page.
134    pub fn allocate_page(&self) -> MenteResult<PageId> {
135        self.page_manager.lock().allocate_page()
136    }
137
138    /// Read a page through the buffer pool.
139    pub fn read_page(&self, page_id: PageId) -> MenteResult<Box<Page>> {
140        self.buffer_pool
141            .fetch_page(page_id, &mut self.page_manager.lock())
142    }
143
144    /// Write data into a page with WAL protection.
145    pub fn write_page(&self, page_id: PageId, data: &[u8]) -> MenteResult<()> {
146        let lsn = self
147            .wal
148            .lock()
149            .append(WalEntryType::PageWrite, page_id.0, data)?;
150
151        let mut pm = self.page_manager.lock();
152        let mut page = self.buffer_pool.fetch_page(page_id, &mut pm)?;
153        drop(pm); // release page_manager lock while we work on the page
154
155        let copy_len = data.len().min(PAGE_DATA_SIZE);
156        page.data[..copy_len].copy_from_slice(&data[..copy_len]);
157        if copy_len < PAGE_DATA_SIZE {
158            page.data[copy_len..].fill(0);
159        }
160        page.header.lsn = lsn;
161        page.header.page_type = PageType::Data as u8;
162        page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
163        page.header.checksum = page.compute_checksum();
164
165        if self.buffer_pool.update_page(page_id, &page).is_err() {
166            self.page_manager.lock().write_page(page_id, &page)?;
167        }
168        self.buffer_pool.unpin_page(page_id, true).ok();
169
170        Ok(())
171    }
172
173    // ---- high-level memory operations ----
174
175    /// Serialize and store a [`MemoryNode`] into a single page.
176    ///
177    /// Returns the [`PageId`] where the node was stored.
178    pub fn store_memory(&self, node: &MemoryNode) -> MenteResult<PageId> {
179        let serialized =
180            serde_json::to_vec(node).map_err(|e| MenteError::Serialization(e.to_string()))?;
181
182        if serialized.len() + 4 > PAGE_DATA_SIZE {
183            return Err(MenteError::CapacityExceeded(format!(
184                "memory node serialized to {} bytes (max {})",
185                serialized.len(),
186                PAGE_DATA_SIZE - 4,
187            )));
188        }
189
190        let page_id = self.allocate_page()?;
191
192        let mut buf = Vec::with_capacity(4 + serialized.len());
193        buf.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
194        buf.extend_from_slice(&serialized);
195
196        self.write_page(page_id, &buf)?;
197
198        info!(
199            page_id = page_id.0,
200            bytes = serialized.len(),
201            "stored memory node"
202        );
203        Ok(page_id)
204    }
205
206    /// Load and deserialize a [`MemoryNode`] from the given page.
207    pub fn load_memory(&self, page_id: PageId) -> MenteResult<MemoryNode> {
208        let page = self.read_page(page_id)?;
209        self.buffer_pool.unpin_page(page_id, false).ok();
210
211        let len = u32::from_le_bytes(page.data[..4].try_into().unwrap()) as usize;
212        if len == 0 || len + 4 > PAGE_DATA_SIZE {
213            return Err(MenteError::Storage(format!(
214                "invalid memory node length prefix: {len}"
215            )));
216        }
217
218        serde_json::from_slice(&page.data[4..4 + len])
219            .map_err(|e| MenteError::Serialization(e.to_string()))
220    }
221
222    // ---- durability ----
223
224    /// Checkpoint: flush all dirty pages, sync to disk, and truncate the WAL.
225    pub fn checkpoint(&self) -> MenteResult<()> {
226        let mut pm = self.page_manager.lock();
227        self.buffer_pool.flush_all(&mut pm)?;
228        pm.sync()?;
229
230        let mut wal = self.wal.lock();
231        let lsn = wal.append(WalEntryType::Checkpoint, 0, &[])?;
232        wal.sync()?;
233        wal.truncate(lsn)?;
234
235        info!(lsn, "checkpoint complete");
236        Ok(())
237    }
238
239    /// Scan all pages and return (MemoryId, PageId) pairs for every valid memory node.
240    ///
241    /// Used to rebuild the page map on startup.
242    pub fn scan_all_memories(&self) -> Vec<(mentedb_core::types::MemoryId, PageId)> {
243        let count = self.page_manager.lock().page_count();
244        let mut results = Vec::new();
245        for i in 1..count {
246            let page_id = PageId(i);
247            if let Ok(node) = self.load_memory(page_id) {
248                results.push((node.id, page_id));
249            }
250        }
251        results
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use mentedb_core::memory::MemoryType;
259    use mentedb_core::types::AgentId;
260
261    fn setup() -> (tempfile::TempDir, StorageEngine) {
262        let dir = tempfile::tempdir().unwrap();
263        let engine = StorageEngine::open(dir.path()).unwrap();
264        (dir, engine)
265    }
266
267    #[test]
268    fn test_allocate_write_read() {
269        let (_dir, engine) = setup();
270
271        let pid = engine.allocate_page().unwrap();
272        engine.write_page(pid, b"hello storage engine").unwrap();
273
274        let page = engine.read_page(pid).unwrap();
275        assert_eq!(&page.data[..20], b"hello storage engine");
276        engine.buffer_pool.unpin_page(pid, false).ok();
277    }
278
279    #[test]
280    fn test_store_and_load_memory() {
281        let (_dir, engine) = setup();
282
283        let node = MemoryNode::new(
284            AgentId::new(),
285            MemoryType::Episodic,
286            "The user prefers Rust over Go".to_string(),
287            vec![0.1, 0.2, 0.3, 0.4],
288        );
289
290        let page_id = engine.store_memory(&node).unwrap();
291        let loaded = engine.load_memory(page_id).unwrap();
292
293        assert_eq!(node.id, loaded.id);
294        assert_eq!(node.content, loaded.content);
295        assert_eq!(node.embedding, loaded.embedding);
296        assert_eq!(node.memory_type, loaded.memory_type);
297    }
298
299    #[test]
300    fn test_checkpoint() {
301        let (_dir, engine) = setup();
302
303        let node = MemoryNode::new(
304            AgentId::new(),
305            MemoryType::Semantic,
306            "checkpoint test".to_string(),
307            vec![1.0, 2.0],
308        );
309
310        let pid = engine.store_memory(&node).unwrap();
311        engine.checkpoint().unwrap();
312
313        let loaded = engine.load_memory(pid).unwrap();
314        assert_eq!(loaded.content, "checkpoint test");
315    }
316
317    #[test]
318    fn test_close_and_reopen() {
319        let dir = tempfile::tempdir().unwrap();
320        let pid;
321        {
322            let engine = StorageEngine::open(dir.path()).unwrap();
323            let node = MemoryNode::new(
324                AgentId::new(),
325                MemoryType::Procedural,
326                "persist across close".to_string(),
327                vec![0.5],
328            );
329            pid = engine.store_memory(&node).unwrap();
330            engine.close().unwrap();
331        }
332        {
333            let engine = StorageEngine::open(dir.path()).unwrap();
334            let loaded = engine.load_memory(pid).unwrap();
335            assert_eq!(loaded.content, "persist across close");
336        }
337    }
338
339    #[test]
340    fn test_crash_recovery() {
341        let dir = tempfile::tempdir().unwrap();
342        let mut ids = Vec::new();
343        let mut contents = Vec::new();
344        {
345            let engine = StorageEngine::open(dir.path()).unwrap();
346            for i in 0..3 {
347                let content = format!("crash-recovery-{i}");
348                let node = MemoryNode::new(
349                    AgentId::new(),
350                    MemoryType::Episodic,
351                    content.clone(),
352                    vec![i as f32],
353                );
354                let pid = engine.store_memory(&node).unwrap();
355                ids.push(pid);
356                contents.push(content);
357            }
358            // Simulate crash: sync the WAL but do NOT call close/checkpoint.
359            engine.wal.lock().sync().unwrap();
360        }
361        {
362            let engine = StorageEngine::open(dir.path()).unwrap();
363            for (pid, expected) in ids.iter().zip(contents.iter()) {
364                let loaded = engine.load_memory(*pid).unwrap();
365                assert_eq!(&loaded.content, expected);
366            }
367        }
368    }
369
370    #[test]
371    fn test_recovery_idempotent() {
372        let dir = tempfile::tempdir().unwrap();
373        let pid;
374        let content = "idempotent-check".to_string();
375        {
376            let engine = StorageEngine::open(dir.path()).unwrap();
377            let node = MemoryNode::new(
378                AgentId::new(),
379                MemoryType::Semantic,
380                content.clone(),
381                vec![1.0, 2.0],
382            );
383            pid = engine.store_memory(&node).unwrap();
384            engine.checkpoint().unwrap();
385            engine.close().unwrap();
386        }
387        {
388            let engine = StorageEngine::open(dir.path()).unwrap();
389            let loaded = engine.load_memory(pid).unwrap();
390            assert_eq!(loaded.content, content);
391        }
392    }
393
394    #[test]
395    fn test_partial_write_recovery() {
396        let dir = tempfile::tempdir().unwrap();
397        let mut ids = Vec::new();
398        let mut contents = Vec::new();
399        {
400            let engine = StorageEngine::open(dir.path()).unwrap();
401            for i in 0..3 {
402                let content = format!("checkpointed-{i}");
403                let node = MemoryNode::new(
404                    AgentId::new(),
405                    MemoryType::Semantic,
406                    content.clone(),
407                    vec![i as f32],
408                );
409                let pid = engine.store_memory(&node).unwrap();
410                ids.push(pid);
411                contents.push(content);
412            }
413            engine.checkpoint().unwrap();
414
415            for i in 3..5 {
416                let content = format!("unckeckpointed-{i}");
417                let node = MemoryNode::new(
418                    AgentId::new(),
419                    MemoryType::Episodic,
420                    content.clone(),
421                    vec![i as f32],
422                );
423                let pid = engine.store_memory(&node).unwrap();
424                ids.push(pid);
425                contents.push(content);
426            }
427            // Simulate crash — sync WAL but don't close.
428            engine.wal.lock().sync().unwrap();
429        }
430        {
431            let engine = StorageEngine::open(dir.path()).unwrap();
432            for (pid, expected) in ids.iter().zip(contents.iter()) {
433                let loaded = engine.load_memory(*pid).unwrap();
434                assert_eq!(&loaded.content, expected);
435            }
436        }
437    }
438}