Skip to main content

mentedb_storage/
engine.rs

1//! Storage Engine: facade that ties the page manager, WAL, and buffer pool together.
2
3use std::fs::File;
4use std::path::Path;
5
6use mentedb_core::MemoryNode;
7use mentedb_core::error::{MenteError, MenteResult};
8
9use fs2::FileExt;
10use parking_lot::Mutex;
11use tracing::info;
12
13use crate::buffer::BufferPool;
14use crate::page::{PAGE_DATA_SIZE, Page, PageId, PageManager, PageType};
15use crate::wal::{Wal, WalEntryType};
16/// Default number of page frames in the buffer pool.
17const DEFAULT_BUFFER_POOL_SIZE: usize = 1024;
18
19/// The unified storage engine for MenteDB.
20///
21/// Coordinates page allocation, caching, and write-ahead logging to provide
22/// crash-safe, page-oriented storage for memory nodes.
23///
24/// All internal state is protected by fine-grained locks so every public method
25/// takes `&self`, enabling concurrent reads from multiple threads.
26pub struct StorageEngine {
27    page_manager: Mutex<PageManager>,
28    buffer_pool: BufferPool,
29    wal: Mutex<Wal>,
30    /// Exclusive lock file — held for the lifetime of the engine to prevent concurrent access.
31    _lock_file: File,
32}
33
34impl StorageEngine {
35    /// Open (or create) a storage engine rooted at `path`.
36    ///
37    /// `path` must be a directory; it will be created if it does not exist.
38    /// After opening, any uncommitted WAL entries are replayed for crash recovery.
39    pub fn open(path: &Path) -> MenteResult<Self> {
40        std::fs::create_dir_all(path)?;
41
42        // Acquire exclusive lock to prevent concurrent access corruption
43        let lock_path = path.join("mentedb.lock");
44        let lock_file = File::create(&lock_path)
45            .map_err(|e| MenteError::Storage(format!("failed to create lock file: {e}")))?;
46        lock_file.try_lock_exclusive().map_err(|_| {
47            MenteError::Storage(
48                "Database is locked by another process. Only one instance can access the database at a time.".to_string()
49            )
50        })?;
51
52        let page_manager = PageManager::open(path)?;
53        let buffer_pool = BufferPool::new(DEFAULT_BUFFER_POOL_SIZE);
54        let wal = Wal::open(path)?;
55
56        let engine = Self {
57            page_manager: Mutex::new(page_manager),
58            buffer_pool,
59            wal: Mutex::new(wal),
60            _lock_file: lock_file,
61        };
62
63        let recovered = engine.recover()?;
64        if recovered > 0 {
65            info!(recovered, ?path, "storage engine opened with WAL recovery");
66        } else {
67            info!(?path, "storage engine opened");
68        }
69
70        Ok(engine)
71    }
72
73    /// Replay WAL entries to recover writes that were not checkpointed.
74    ///
75    /// For each `PageWrite` entry the serialized data is written back to its page.
76    /// After replay the WAL is truncated. Returns the number of entries replayed.
77    pub fn recover(&self) -> MenteResult<usize> {
78        let mut wal = self.wal.lock();
79        let entries = wal.iterate()?;
80        let mut count = 0usize;
81        let mut pm = self.page_manager.lock();
82
83        for entry in &entries {
84            match entry.entry_type {
85                WalEntryType::PageWrite => {
86                    let page_id = PageId(entry.page_id);
87
88                    while pm.page_count() <= entry.page_id {
89                        pm.allocate_page()?;
90                    }
91
92                    let mut page = pm.read_page(page_id)?;
93                    let copy_len = entry.data.len().min(PAGE_DATA_SIZE);
94                    page.data[..copy_len].copy_from_slice(&entry.data[..copy_len]);
95                    if copy_len < PAGE_DATA_SIZE {
96                        page.data[copy_len..].fill(0);
97                    }
98                    page.header.page_id = entry.page_id;
99                    page.header.lsn = entry.lsn;
100                    page.header.page_type = PageType::Data as u8;
101                    page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
102                    page.header.checksum = page.compute_checksum();
103
104                    pm.write_page(page_id, &page)?;
105                    count += 1;
106                }
107                WalEntryType::Checkpoint | WalEntryType::Commit => {}
108            }
109        }
110
111        if count > 0 {
112            pm.sync()?;
113            let next_lsn = wal.next_lsn();
114            wal.truncate(next_lsn)?;
115            info!(count, "WAL recovery replayed entries");
116        }
117
118        Ok(count)
119    }
120
121    /// Gracefully shut down: flush dirty pages, sync files.
122    pub fn close(&self) -> MenteResult<()> {
123        let mut pm = self.page_manager.lock();
124        self.buffer_pool.flush_all(&mut pm)?;
125        pm.sync()?;
126        self.wal.lock().sync()?;
127        info!("storage engine closed");
128        Ok(())
129    }
130
131    // ---- low-level page operations ----
132
133    /// Allocate a fresh page.
134    pub fn allocate_page(&self) -> MenteResult<PageId> {
135        self.page_manager.lock().allocate_page()
136    }
137
138    /// Read a page through the buffer pool.
139    pub fn read_page(&self, page_id: PageId) -> MenteResult<Box<Page>> {
140        self.buffer_pool
141            .fetch_page(page_id, &mut self.page_manager.lock())
142    }
143
144    /// Write data into a page with WAL protection.
145    pub fn write_page(&self, page_id: PageId, data: &[u8]) -> MenteResult<()> {
146        let lsn = {
147            let mut wal = self.wal.lock();
148            let lsn = wal.append(WalEntryType::PageWrite, page_id.0, data)?;
149            wal.sync()?;
150            lsn
151        };
152
153        let mut pm = self.page_manager.lock();
154        let mut page = self.buffer_pool.fetch_page(page_id, &mut pm)?;
155        drop(pm); // release page_manager lock while we work on the page
156
157        let copy_len = data.len().min(PAGE_DATA_SIZE);
158        page.data[..copy_len].copy_from_slice(&data[..copy_len]);
159        if copy_len < PAGE_DATA_SIZE {
160            page.data[copy_len..].fill(0);
161        }
162        page.header.lsn = lsn;
163        page.header.page_type = PageType::Data as u8;
164        page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
165        page.header.checksum = page.compute_checksum();
166
167        if self.buffer_pool.update_page(page_id, &page).is_err() {
168            self.page_manager.lock().write_page(page_id, &page)?;
169        }
170        self.buffer_pool.unpin_page(page_id, true).ok();
171
172        Ok(())
173    }
174
175    // ---- high-level memory operations ----
176
177    /// Serialize and store a [`MemoryNode`] into a single page.
178    ///
179    /// Returns the [`PageId`] where the node was stored.
180    pub fn store_memory(&self, node: &MemoryNode) -> MenteResult<PageId> {
181        let serialized =
182            serde_json::to_vec(node).map_err(|e| MenteError::Serialization(e.to_string()))?;
183
184        if serialized.len() + 4 > PAGE_DATA_SIZE {
185            return Err(MenteError::CapacityExceeded(format!(
186                "memory node serialized to {} bytes (max {})",
187                serialized.len(),
188                PAGE_DATA_SIZE - 4,
189            )));
190        }
191
192        let page_id = self.allocate_page()?;
193
194        let mut buf = Vec::with_capacity(4 + serialized.len());
195        buf.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
196        buf.extend_from_slice(&serialized);
197
198        self.write_page(page_id, &buf)?;
199
200        info!(
201            page_id = page_id.0,
202            bytes = serialized.len(),
203            "stored memory node"
204        );
205        Ok(page_id)
206    }
207
208    /// Load and deserialize a [`MemoryNode`] from the given page.
209    pub fn load_memory(&self, page_id: PageId) -> MenteResult<MemoryNode> {
210        let page = self.read_page(page_id)?;
211        self.buffer_pool.unpin_page(page_id, false).ok();
212
213        let len = u32::from_le_bytes(page.data[..4].try_into().unwrap()) as usize;
214        if len == 0 || len + 4 > PAGE_DATA_SIZE {
215            return Err(MenteError::Storage(format!(
216                "invalid memory node length prefix: {len}"
217            )));
218        }
219
220        serde_json::from_slice(&page.data[4..4 + len])
221            .map_err(|e| MenteError::Serialization(e.to_string()))
222    }
223
224    // ---- durability ----
225
226    /// Checkpoint: flush all dirty pages, sync to disk, and truncate the WAL.
227    pub fn checkpoint(&self) -> MenteResult<()> {
228        let mut pm = self.page_manager.lock();
229        self.buffer_pool.flush_all(&mut pm)?;
230        pm.sync()?;
231
232        let mut wal = self.wal.lock();
233        let lsn = wal.append(WalEntryType::Checkpoint, 0, &[])?;
234        wal.sync()?;
235        wal.truncate(lsn)?;
236
237        info!(lsn, "checkpoint complete");
238        Ok(())
239    }
240
241    /// Scan all pages and return (MemoryId, PageId) pairs for every valid memory node.
242    ///
243    /// Used to rebuild the page map on startup.
244    pub fn scan_all_memories(&self) -> Vec<(mentedb_core::types::MemoryId, PageId)> {
245        let count = self.page_manager.lock().page_count();
246        let mut results = Vec::new();
247        for i in 1..count {
248            let page_id = PageId(i);
249            if let Ok(node) = self.load_memory(page_id) {
250                results.push((node.id, page_id));
251            }
252        }
253        results
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use mentedb_core::memory::MemoryType;
261    use mentedb_core::types::AgentId;
262
263    fn setup() -> (tempfile::TempDir, StorageEngine) {
264        let dir = tempfile::tempdir().unwrap();
265        let engine = StorageEngine::open(dir.path()).unwrap();
266        (dir, engine)
267    }
268
269    #[test]
270    fn test_allocate_write_read() {
271        let (_dir, engine) = setup();
272
273        let pid = engine.allocate_page().unwrap();
274        engine.write_page(pid, b"hello storage engine").unwrap();
275
276        let page = engine.read_page(pid).unwrap();
277        assert_eq!(&page.data[..20], b"hello storage engine");
278        engine.buffer_pool.unpin_page(pid, false).ok();
279    }
280
281    #[test]
282    fn test_store_and_load_memory() {
283        let (_dir, engine) = setup();
284
285        let node = MemoryNode::new(
286            AgentId::new(),
287            MemoryType::Episodic,
288            "The user prefers Rust over Go".to_string(),
289            vec![0.1, 0.2, 0.3, 0.4],
290        );
291
292        let page_id = engine.store_memory(&node).unwrap();
293        let loaded = engine.load_memory(page_id).unwrap();
294
295        assert_eq!(node.id, loaded.id);
296        assert_eq!(node.content, loaded.content);
297        assert_eq!(node.embedding, loaded.embedding);
298        assert_eq!(node.memory_type, loaded.memory_type);
299    }
300
301    #[test]
302    fn test_checkpoint() {
303        let (_dir, engine) = setup();
304
305        let node = MemoryNode::new(
306            AgentId::new(),
307            MemoryType::Semantic,
308            "checkpoint test".to_string(),
309            vec![1.0, 2.0],
310        );
311
312        let pid = engine.store_memory(&node).unwrap();
313        engine.checkpoint().unwrap();
314
315        let loaded = engine.load_memory(pid).unwrap();
316        assert_eq!(loaded.content, "checkpoint test");
317    }
318
319    #[test]
320    fn test_close_and_reopen() {
321        let dir = tempfile::tempdir().unwrap();
322        let pid;
323        {
324            let engine = StorageEngine::open(dir.path()).unwrap();
325            let node = MemoryNode::new(
326                AgentId::new(),
327                MemoryType::Procedural,
328                "persist across close".to_string(),
329                vec![0.5],
330            );
331            pid = engine.store_memory(&node).unwrap();
332            engine.close().unwrap();
333        }
334        {
335            let engine = StorageEngine::open(dir.path()).unwrap();
336            let loaded = engine.load_memory(pid).unwrap();
337            assert_eq!(loaded.content, "persist across close");
338        }
339    }
340
341    #[test]
342    fn test_crash_recovery() {
343        let dir = tempfile::tempdir().unwrap();
344        let mut ids = Vec::new();
345        let mut contents = Vec::new();
346        {
347            let engine = StorageEngine::open(dir.path()).unwrap();
348            for i in 0..3 {
349                let content = format!("crash-recovery-{i}");
350                let node = MemoryNode::new(
351                    AgentId::new(),
352                    MemoryType::Episodic,
353                    content.clone(),
354                    vec![i as f32],
355                );
356                let pid = engine.store_memory(&node).unwrap();
357                ids.push(pid);
358                contents.push(content);
359            }
360            // Simulate crash: sync the WAL but do NOT call close/checkpoint.
361            engine.wal.lock().sync().unwrap();
362        }
363        {
364            let engine = StorageEngine::open(dir.path()).unwrap();
365            for (pid, expected) in ids.iter().zip(contents.iter()) {
366                let loaded = engine.load_memory(*pid).unwrap();
367                assert_eq!(&loaded.content, expected);
368            }
369        }
370    }
371
372    #[test]
373    fn test_recovery_idempotent() {
374        let dir = tempfile::tempdir().unwrap();
375        let pid;
376        let content = "idempotent-check".to_string();
377        {
378            let engine = StorageEngine::open(dir.path()).unwrap();
379            let node = MemoryNode::new(
380                AgentId::new(),
381                MemoryType::Semantic,
382                content.clone(),
383                vec![1.0, 2.0],
384            );
385            pid = engine.store_memory(&node).unwrap();
386            engine.checkpoint().unwrap();
387            engine.close().unwrap();
388        }
389        {
390            let engine = StorageEngine::open(dir.path()).unwrap();
391            let loaded = engine.load_memory(pid).unwrap();
392            assert_eq!(loaded.content, content);
393        }
394    }
395
396    #[test]
397    fn test_partial_write_recovery() {
398        let dir = tempfile::tempdir().unwrap();
399        let mut ids = Vec::new();
400        let mut contents = Vec::new();
401        {
402            let engine = StorageEngine::open(dir.path()).unwrap();
403            for i in 0..3 {
404                let content = format!("checkpointed-{i}");
405                let node = MemoryNode::new(
406                    AgentId::new(),
407                    MemoryType::Semantic,
408                    content.clone(),
409                    vec![i as f32],
410                );
411                let pid = engine.store_memory(&node).unwrap();
412                ids.push(pid);
413                contents.push(content);
414            }
415            engine.checkpoint().unwrap();
416
417            for i in 3..5 {
418                let content = format!("unckeckpointed-{i}");
419                let node = MemoryNode::new(
420                    AgentId::new(),
421                    MemoryType::Episodic,
422                    content.clone(),
423                    vec![i as f32],
424                );
425                let pid = engine.store_memory(&node).unwrap();
426                ids.push(pid);
427                contents.push(content);
428            }
429            // Simulate crash — sync WAL but don't close.
430            engine.wal.lock().sync().unwrap();
431        }
432        {
433            let engine = StorageEngine::open(dir.path()).unwrap();
434            for (pid, expected) in ids.iter().zip(contents.iter()) {
435                let loaded = engine.load_memory(*pid).unwrap();
436                assert_eq!(&loaded.content, expected);
437            }
438        }
439    }
440}