Skip to main content

mentedb_storage/
engine.rs

1//! Storage Engine: facade that ties the page manager, WAL, and buffer pool together.
2
3use std::path::Path;
4
5use mentedb_core::MemoryNode;
6use mentedb_core::error::{MenteError, MenteResult};
7
8use tracing::info;
9
10use crate::buffer::BufferPool;
11use crate::page::{PAGE_DATA_SIZE, Page, PageId, PageManager, PageType};
12use crate::wal::{Wal, WalEntryType};
13/// Default number of page frames in the buffer pool.
14const DEFAULT_BUFFER_POOL_SIZE: usize = 1024;
15
16/// The unified storage engine for MenteDB.
17///
18/// Coordinates page allocation, caching, and write-ahead logging to provide
19/// crash-safe, page-oriented storage for memory nodes.
20pub struct StorageEngine {
21    page_manager: PageManager,
22    buffer_pool: BufferPool,
23    wal: Wal,
24}
25
26impl StorageEngine {
27    /// Open (or create) a storage engine rooted at `path`.
28    ///
29    /// `path` must be a directory; it will be created if it does not exist.
30    /// After opening, any uncommitted WAL entries are replayed for crash recovery.
31    pub fn open(path: &Path) -> MenteResult<Self> {
32        std::fs::create_dir_all(path)?;
33
34        let page_manager = PageManager::open(path)?;
35        let buffer_pool = BufferPool::new(DEFAULT_BUFFER_POOL_SIZE);
36        let wal = Wal::open(path)?;
37
38        let mut engine = Self {
39            page_manager,
40            buffer_pool,
41            wal,
42        };
43
44        let recovered = engine.recover()?;
45        if recovered > 0 {
46            info!(recovered, ?path, "storage engine opened with WAL recovery");
47        } else {
48            info!(?path, "storage engine opened");
49        }
50
51        Ok(engine)
52    }
53
54    /// Replay WAL entries to recover writes that were not checkpointed.
55    ///
56    /// For each `PageWrite` entry the serialized data is written back to its page.
57    /// After replay the WAL is truncated. Returns the number of entries replayed.
58    pub fn recover(&mut self) -> MenteResult<usize> {
59        let entries = self.wal.iterate()?;
60        let mut count = 0usize;
61
62        for entry in &entries {
63            match entry.entry_type {
64                WalEntryType::PageWrite => {
65                    let page_id = PageId(entry.page_id);
66
67                    // Ensure the page file is large enough for this page id.
68                    while self.page_manager.page_count() <= entry.page_id {
69                        self.page_manager.allocate_page()?;
70                    }
71
72                    let mut page = self.page_manager.read_page(page_id)?;
73                    let copy_len = entry.data.len().min(PAGE_DATA_SIZE);
74                    page.data[..copy_len].copy_from_slice(&entry.data[..copy_len]);
75                    if copy_len < PAGE_DATA_SIZE {
76                        page.data[copy_len..].fill(0);
77                    }
78                    page.header.page_id = entry.page_id;
79                    page.header.lsn = entry.lsn;
80                    page.header.page_type = PageType::Data as u8;
81                    page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
82                    page.header.checksum = page.compute_checksum();
83
84                    self.page_manager.write_page(page_id, &page)?;
85                    count += 1;
86                }
87                WalEntryType::Checkpoint | WalEntryType::Commit => {
88                    // No page data to replay for these entry types.
89                }
90            }
91        }
92
93        if count > 0 {
94            self.page_manager.sync()?;
95            // Truncate the entire WAL — all entries have been applied.
96            let next_lsn = self.wal.next_lsn();
97            self.wal.truncate(next_lsn)?;
98            info!(count, "WAL recovery replayed entries");
99        }
100
101        Ok(count)
102    }
103
104    /// Gracefully shut down: flush dirty pages, sync files.
105    pub fn close(&mut self) -> MenteResult<()> {
106        self.buffer_pool.flush_all(&mut self.page_manager)?;
107        self.page_manager.sync()?;
108        self.wal.sync()?;
109        info!("storage engine closed");
110        Ok(())
111    }
112
113    // ---- low-level page operations ----
114
115    /// Allocate a fresh page.
116    pub fn allocate_page(&mut self) -> MenteResult<PageId> {
117        self.page_manager.allocate_page()
118    }
119
120    /// Read a page through the buffer pool.
121    pub fn read_page(&mut self, page_id: PageId) -> MenteResult<Box<Page>> {
122        self.buffer_pool.fetch_page(page_id, &mut self.page_manager)
123    }
124
125    /// Write data into a page with WAL protection.
126    pub fn write_page(&mut self, page_id: PageId, data: &[u8]) -> MenteResult<()> {
127        // WAL-first: log before modifying the page.
128        let lsn = self.wal.append(WalEntryType::PageWrite, page_id.0, data)?;
129
130        // Load the page into the buffer pool (or get cached copy).
131        let mut page = self
132            .buffer_pool
133            .fetch_page(page_id, &mut self.page_manager)?;
134
135        let copy_len = data.len().min(PAGE_DATA_SIZE);
136        page.data[..copy_len].copy_from_slice(&data[..copy_len]);
137        // Zero out remaining space if data is shorter than existing content.
138        if copy_len < PAGE_DATA_SIZE {
139            page.data[copy_len..].fill(0);
140        }
141        page.header.lsn = lsn;
142        page.header.page_type = PageType::Data as u8;
143        page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
144        page.header.checksum = page.compute_checksum();
145
146        // Push modified page back into the buffer pool.
147        if self.buffer_pool.update_page(page_id, &page).is_err() {
148            // Not cached (shouldn't happen after fetch, but be safe).
149            self.page_manager.write_page(page_id, &page)?;
150        }
151        self.buffer_pool.unpin_page(page_id, true).ok();
152
153        Ok(())
154    }
155
156    // ---- high-level memory operations ----
157
158    /// Serialize and store a [`MemoryNode`] into a single page.
159    ///
160    /// Returns the [`PageId`] where the node was stored.
161    pub fn store_memory(&mut self, node: &MemoryNode) -> MenteResult<PageId> {
162        let serialized =
163            serde_json::to_vec(node).map_err(|e| MenteError::Serialization(e.to_string()))?;
164
165        // 4 bytes for the length prefix.
166        if serialized.len() + 4 > PAGE_DATA_SIZE {
167            return Err(MenteError::CapacityExceeded(format!(
168                "memory node serialized to {} bytes (max {})",
169                serialized.len(),
170                PAGE_DATA_SIZE - 4,
171            )));
172        }
173
174        let page_id = self.allocate_page()?;
175
176        // Layout: [length: u32 LE][JSON bytes]
177        let mut buf = Vec::with_capacity(4 + serialized.len());
178        buf.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
179        buf.extend_from_slice(&serialized);
180
181        self.write_page(page_id, &buf)?;
182
183        info!(
184            page_id = page_id.0,
185            bytes = serialized.len(),
186            "stored memory node"
187        );
188        Ok(page_id)
189    }
190
191    /// Load and deserialize a [`MemoryNode`] from the given page.
192    pub fn load_memory(&mut self, page_id: PageId) -> MenteResult<MemoryNode> {
193        let page = self.read_page(page_id)?;
194        // Unpin immediately — we copy the data we need.
195        self.buffer_pool.unpin_page(page_id, false).ok();
196
197        let len = u32::from_le_bytes(page.data[..4].try_into().unwrap()) as usize;
198        if len == 0 || len + 4 > PAGE_DATA_SIZE {
199            return Err(MenteError::Storage(format!(
200                "invalid memory node length prefix: {len}"
201            )));
202        }
203
204        serde_json::from_slice(&page.data[4..4 + len])
205            .map_err(|e| MenteError::Serialization(e.to_string()))
206    }
207
208    // ---- durability ----
209
210    /// Checkpoint: flush all dirty pages, sync to disk, and truncate the WAL.
211    pub fn checkpoint(&mut self) -> MenteResult<()> {
212        self.buffer_pool.flush_all(&mut self.page_manager)?;
213        self.page_manager.sync()?;
214
215        let lsn = self.wal.append(WalEntryType::Checkpoint, 0, &[])?;
216        self.wal.sync()?;
217        self.wal.truncate(lsn)?;
218
219        info!(lsn, "checkpoint complete");
220        Ok(())
221    }
222
223    /// Scan all pages and return (MemoryId, PageId) pairs for every valid memory node.
224    ///
225    /// Used to rebuild the page map on startup.
226    pub fn scan_all_memories(&mut self) -> Vec<(mentedb_core::types::MemoryId, PageId)> {
227        let count = self.page_manager.page_count();
228        let mut results = Vec::new();
229        // Page 0 is the header page, start from 1
230        for i in 1..count {
231            let page_id = PageId(i);
232            if let Ok(node) = self.load_memory(page_id) {
233                results.push((node.id, page_id));
234            }
235        }
236        results
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use mentedb_core::memory::MemoryType;
244    use mentedb_core::types::AgentId;
245
246    fn setup() -> (tempfile::TempDir, StorageEngine) {
247        let dir = tempfile::tempdir().unwrap();
248        let engine = StorageEngine::open(dir.path()).unwrap();
249        (dir, engine)
250    }
251
252    #[test]
253    fn test_allocate_write_read() {
254        let (_dir, mut engine) = setup();
255
256        let pid = engine.allocate_page().unwrap();
257        engine.write_page(pid, b"hello storage engine").unwrap();
258
259        let page = engine.read_page(pid).unwrap();
260        assert_eq!(&page.data[..20], b"hello storage engine");
261        engine.buffer_pool.unpin_page(pid, false).ok();
262    }
263
264    #[test]
265    fn test_store_and_load_memory() {
266        let (_dir, mut engine) = setup();
267
268        let node = MemoryNode::new(
269            AgentId::new(),
270            MemoryType::Episodic,
271            "The user prefers Rust over Go".to_string(),
272            vec![0.1, 0.2, 0.3, 0.4],
273        );
274
275        let page_id = engine.store_memory(&node).unwrap();
276        let loaded = engine.load_memory(page_id).unwrap();
277
278        assert_eq!(node.id, loaded.id);
279        assert_eq!(node.content, loaded.content);
280        assert_eq!(node.embedding, loaded.embedding);
281        assert_eq!(node.memory_type, loaded.memory_type);
282    }
283
284    #[test]
285    fn test_checkpoint() {
286        let (_dir, mut engine) = setup();
287
288        let node = MemoryNode::new(
289            AgentId::new(),
290            MemoryType::Semantic,
291            "checkpoint test".to_string(),
292            vec![1.0, 2.0],
293        );
294
295        let pid = engine.store_memory(&node).unwrap();
296        engine.checkpoint().unwrap();
297
298        // Data should still be readable after checkpoint.
299        let loaded = engine.load_memory(pid).unwrap();
300        assert_eq!(loaded.content, "checkpoint test");
301    }
302
303    #[test]
304    fn test_close_and_reopen() {
305        let dir = tempfile::tempdir().unwrap();
306        let pid;
307        {
308            let mut engine = StorageEngine::open(dir.path()).unwrap();
309            let node = MemoryNode::new(
310                AgentId::new(),
311                MemoryType::Procedural,
312                "persist across close".to_string(),
313                vec![0.5],
314            );
315            pid = engine.store_memory(&node).unwrap();
316            engine.close().unwrap();
317        }
318        {
319            let mut engine = StorageEngine::open(dir.path()).unwrap();
320            let loaded = engine.load_memory(pid).unwrap();
321            assert_eq!(loaded.content, "persist across close");
322        }
323    }
324
325    #[test]
326    fn test_crash_recovery() {
327        let dir = tempfile::tempdir().unwrap();
328        let mut ids = Vec::new();
329        let mut contents = Vec::new();
330        {
331            let mut engine = StorageEngine::open(dir.path()).unwrap();
332            for i in 0..3 {
333                let content = format!("crash-recovery-{i}");
334                let node = MemoryNode::new(
335                    AgentId::new(),
336                    MemoryType::Episodic,
337                    content.clone(),
338                    vec![i as f32],
339                );
340                let pid = engine.store_memory(&node).unwrap();
341                ids.push(pid);
342                contents.push(content);
343            }
344            // Simulate crash: sync the WAL but do NOT call close/checkpoint.
345            engine.wal.sync().unwrap();
346            // Drop without close — dirty pages may not be flushed.
347        }
348        {
349            // Reopen — WAL replay should recover the writes.
350            let mut engine = StorageEngine::open(dir.path()).unwrap();
351            for (pid, expected) in ids.iter().zip(contents.iter()) {
352                let loaded = engine.load_memory(*pid).unwrap();
353                assert_eq!(&loaded.content, expected);
354            }
355        }
356    }
357
358    #[test]
359    fn test_recovery_idempotent() {
360        let dir = tempfile::tempdir().unwrap();
361        let pid;
362        let content = "idempotent-check".to_string();
363        {
364            let mut engine = StorageEngine::open(dir.path()).unwrap();
365            let node = MemoryNode::new(
366                AgentId::new(),
367                MemoryType::Semantic,
368                content.clone(),
369                vec![1.0, 2.0],
370            );
371            pid = engine.store_memory(&node).unwrap();
372            // Proper shutdown — checkpoint flushes pages and truncates WAL.
373            engine.checkpoint().unwrap();
374            engine.close().unwrap();
375        }
376        {
377            // Reopen after clean shutdown — WAL should be empty, no duplicate data.
378            let mut engine = StorageEngine::open(dir.path()).unwrap();
379            let loaded = engine.load_memory(pid).unwrap();
380            assert_eq!(loaded.content, content);
381        }
382    }
383
384    #[test]
385    fn test_partial_write_recovery() {
386        let dir = tempfile::tempdir().unwrap();
387        let mut ids = Vec::new();
388        let mut contents = Vec::new();
389        {
390            let mut engine = StorageEngine::open(dir.path()).unwrap();
391            // Store 3 memories then checkpoint.
392            for i in 0..3 {
393                let content = format!("checkpointed-{i}");
394                let node = MemoryNode::new(
395                    AgentId::new(),
396                    MemoryType::Semantic,
397                    content.clone(),
398                    vec![i as f32],
399                );
400                let pid = engine.store_memory(&node).unwrap();
401                ids.push(pid);
402                contents.push(content);
403            }
404            engine.checkpoint().unwrap();
405
406            // Store 2 more without checkpoint (will only be in WAL).
407            for i in 3..5 {
408                let content = format!("unckeckpointed-{i}");
409                let node = MemoryNode::new(
410                    AgentId::new(),
411                    MemoryType::Episodic,
412                    content.clone(),
413                    vec![i as f32],
414                );
415                let pid = engine.store_memory(&node).unwrap();
416                ids.push(pid);
417                contents.push(content);
418            }
419            // Simulate crash — sync WAL but don't close.
420            engine.wal.sync().unwrap();
421        }
422        {
423            let mut engine = StorageEngine::open(dir.path()).unwrap();
424            // All 5 memories should be present: 3 from checkpoint, 2 from WAL replay.
425            for (pid, expected) in ids.iter().zip(contents.iter()) {
426                let loaded = engine.load_memory(*pid).unwrap();
427                assert_eq!(&loaded.content, expected);
428            }
429        }
430    }
431}