Skip to main content

mentedb_storage/
engine.rs

1//! Storage Engine — facade that ties the page manager, WAL, and buffer pool together.
2
3use std::path::Path;
4
5use mentedb_core::MemoryNode;
6use mentedb_core::error::{MenteError, MenteResult};
7use tracing::info;
8
9use crate::buffer::BufferPool;
10use crate::page::{PAGE_DATA_SIZE, Page, PageId, PageManager, PageType};
11use crate::wal::{Wal, WalEntryType};
12
13/// Default number of page frames in the buffer pool.
14const DEFAULT_BUFFER_POOL_SIZE: usize = 1024;
15
16/// The unified storage engine for MenteDB.
17///
18/// Coordinates page allocation, caching, and write-ahead logging to provide
19/// crash-safe, page-oriented storage for memory nodes.
20pub struct StorageEngine {
21    page_manager: PageManager,
22    buffer_pool: BufferPool,
23    wal: Wal,
24}
25
26impl StorageEngine {
27    /// Open (or create) a storage engine rooted at `path`.
28    ///
29    /// `path` must be a directory; it will be created if it does not exist.
30    /// After opening, any uncommitted WAL entries are replayed for crash recovery.
31    pub fn open(path: &Path) -> MenteResult<Self> {
32        std::fs::create_dir_all(path)?;
33
34        let page_manager = PageManager::open(path)?;
35        let buffer_pool = BufferPool::new(DEFAULT_BUFFER_POOL_SIZE);
36        let wal = Wal::open(path)?;
37
38        let mut engine = Self {
39            page_manager,
40            buffer_pool,
41            wal,
42        };
43
44        let recovered = engine.recover()?;
45        if recovered > 0 {
46            info!(recovered, ?path, "storage engine opened with WAL recovery");
47        } else {
48            info!(?path, "storage engine opened");
49        }
50
51        Ok(engine)
52    }
53
54    /// Replay WAL entries to recover writes that were not checkpointed.
55    ///
56    /// For each `PageWrite` entry the serialized data is written back to its page.
57    /// After replay the WAL is truncated. Returns the number of entries replayed.
58    pub fn recover(&mut self) -> MenteResult<usize> {
59        let entries = self.wal.iterate()?;
60        let mut count = 0usize;
61
62        for entry in &entries {
63            match entry.entry_type {
64                WalEntryType::PageWrite => {
65                    let page_id = PageId(entry.page_id);
66
67                    // Ensure the page file is large enough for this page id.
68                    while self.page_manager.page_count() <= entry.page_id {
69                        self.page_manager.allocate_page()?;
70                    }
71
72                    let mut page = self.page_manager.read_page(page_id)?;
73                    let copy_len = entry.data.len().min(PAGE_DATA_SIZE);
74                    page.data[..copy_len].copy_from_slice(&entry.data[..copy_len]);
75                    if copy_len < PAGE_DATA_SIZE {
76                        page.data[copy_len..].fill(0);
77                    }
78                    page.header.page_id = entry.page_id;
79                    page.header.lsn = entry.lsn;
80                    page.header.page_type = PageType::Data as u8;
81                    page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
82                    page.header.checksum = page.compute_checksum();
83
84                    self.page_manager.write_page(page_id, &page)?;
85                    count += 1;
86                }
87                WalEntryType::Checkpoint | WalEntryType::Commit => {
88                    // No page data to replay for these entry types.
89                }
90            }
91        }
92
93        if count > 0 {
94            self.page_manager.sync()?;
95            // Truncate the entire WAL — all entries have been applied.
96            let next_lsn = self.wal.next_lsn();
97            self.wal.truncate(next_lsn)?;
98            info!(count, "WAL recovery replayed entries");
99        }
100
101        Ok(count)
102    }
103
104    /// Gracefully shut down: flush dirty pages, sync files.
105    pub fn close(&mut self) -> MenteResult<()> {
106        self.buffer_pool.flush_all(&mut self.page_manager)?;
107        self.page_manager.sync()?;
108        self.wal.sync()?;
109        info!("storage engine closed");
110        Ok(())
111    }
112
113    // ---- low-level page operations ----
114
115    /// Allocate a fresh page.
116    pub fn allocate_page(&mut self) -> MenteResult<PageId> {
117        self.page_manager.allocate_page()
118    }
119
120    /// Read a page through the buffer pool.
121    pub fn read_page(&mut self, page_id: PageId) -> MenteResult<Box<Page>> {
122        self.buffer_pool.fetch_page(page_id, &mut self.page_manager)
123    }
124
125    /// Write data into a page with WAL protection.
126    pub fn write_page(&mut self, page_id: PageId, data: &[u8]) -> MenteResult<()> {
127        // WAL-first: log before modifying the page.
128        let lsn = self.wal.append(WalEntryType::PageWrite, page_id.0, data)?;
129
130        // Load the page into the buffer pool (or get cached copy).
131        let mut page = self
132            .buffer_pool
133            .fetch_page(page_id, &mut self.page_manager)?;
134
135        let copy_len = data.len().min(PAGE_DATA_SIZE);
136        page.data[..copy_len].copy_from_slice(&data[..copy_len]);
137        // Zero out remaining space if data is shorter than existing content.
138        if copy_len < PAGE_DATA_SIZE {
139            page.data[copy_len..].fill(0);
140        }
141        page.header.lsn = lsn;
142        page.header.page_type = PageType::Data as u8;
143        page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
144        page.header.checksum = page.compute_checksum();
145
146        // Push modified page back into the buffer pool.
147        if self.buffer_pool.update_page(page_id, &page).is_err() {
148            // Not cached (shouldn't happen after fetch, but be safe).
149            self.page_manager.write_page(page_id, &page)?;
150        }
151        self.buffer_pool.unpin_page(page_id, true).ok();
152
153        Ok(())
154    }
155
156    // ---- high-level memory operations ----
157
158    /// Serialize and store a [`MemoryNode`] into a single page.
159    ///
160    /// Returns the [`PageId`] where the node was stored.
161    pub fn store_memory(&mut self, node: &MemoryNode) -> MenteResult<PageId> {
162        let serialized =
163            serde_json::to_vec(node).map_err(|e| MenteError::Serialization(e.to_string()))?;
164
165        // 4 bytes for the length prefix.
166        if serialized.len() + 4 > PAGE_DATA_SIZE {
167            return Err(MenteError::CapacityExceeded(format!(
168                "memory node serialized to {} bytes (max {})",
169                serialized.len(),
170                PAGE_DATA_SIZE - 4,
171            )));
172        }
173
174        let page_id = self.allocate_page()?;
175
176        // Layout: [length: u32 LE][JSON bytes]
177        let mut buf = Vec::with_capacity(4 + serialized.len());
178        buf.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
179        buf.extend_from_slice(&serialized);
180
181        self.write_page(page_id, &buf)?;
182
183        info!(
184            page_id = page_id.0,
185            bytes = serialized.len(),
186            "stored memory node"
187        );
188        Ok(page_id)
189    }
190
191    /// Load and deserialize a [`MemoryNode`] from the given page.
192    pub fn load_memory(&mut self, page_id: PageId) -> MenteResult<MemoryNode> {
193        let page = self.read_page(page_id)?;
194        // Unpin immediately — we copy the data we need.
195        self.buffer_pool.unpin_page(page_id, false).ok();
196
197        let len = u32::from_le_bytes(page.data[..4].try_into().unwrap()) as usize;
198        if len == 0 || len + 4 > PAGE_DATA_SIZE {
199            return Err(MenteError::Storage(format!(
200                "invalid memory node length prefix: {len}"
201            )));
202        }
203
204        serde_json::from_slice(&page.data[4..4 + len])
205            .map_err(|e| MenteError::Serialization(e.to_string()))
206    }
207
208    // ---- durability ----
209
210    /// Checkpoint: flush all dirty pages, sync to disk, and truncate the WAL.
211    pub fn checkpoint(&mut self) -> MenteResult<()> {
212        self.buffer_pool.flush_all(&mut self.page_manager)?;
213        self.page_manager.sync()?;
214
215        let lsn = self.wal.append(WalEntryType::Checkpoint, 0, &[])?;
216        self.wal.sync()?;
217        self.wal.truncate(lsn)?;
218
219        info!(lsn, "checkpoint complete");
220        Ok(())
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use mentedb_core::memory::MemoryType;
228    use uuid::Uuid;
229
230    fn setup() -> (tempfile::TempDir, StorageEngine) {
231        let dir = tempfile::tempdir().unwrap();
232        let engine = StorageEngine::open(dir.path()).unwrap();
233        (dir, engine)
234    }
235
236    #[test]
237    fn test_allocate_write_read() {
238        let (_dir, mut engine) = setup();
239
240        let pid = engine.allocate_page().unwrap();
241        engine.write_page(pid, b"hello storage engine").unwrap();
242
243        let page = engine.read_page(pid).unwrap();
244        assert_eq!(&page.data[..20], b"hello storage engine");
245        engine.buffer_pool.unpin_page(pid, false).ok();
246    }
247
248    #[test]
249    fn test_store_and_load_memory() {
250        let (_dir, mut engine) = setup();
251
252        let node = MemoryNode::new(
253            Uuid::new_v4(),
254            MemoryType::Episodic,
255            "The user prefers Rust over Go".to_string(),
256            vec![0.1, 0.2, 0.3, 0.4],
257        );
258
259        let page_id = engine.store_memory(&node).unwrap();
260        let loaded = engine.load_memory(page_id).unwrap();
261
262        assert_eq!(node.id, loaded.id);
263        assert_eq!(node.content, loaded.content);
264        assert_eq!(node.embedding, loaded.embedding);
265        assert_eq!(node.memory_type, loaded.memory_type);
266    }
267
268    #[test]
269    fn test_checkpoint() {
270        let (_dir, mut engine) = setup();
271
272        let node = MemoryNode::new(
273            Uuid::new_v4(),
274            MemoryType::Semantic,
275            "checkpoint test".to_string(),
276            vec![1.0, 2.0],
277        );
278
279        let pid = engine.store_memory(&node).unwrap();
280        engine.checkpoint().unwrap();
281
282        // Data should still be readable after checkpoint.
283        let loaded = engine.load_memory(pid).unwrap();
284        assert_eq!(loaded.content, "checkpoint test");
285    }
286
287    #[test]
288    fn test_close_and_reopen() {
289        let dir = tempfile::tempdir().unwrap();
290        let pid;
291        {
292            let mut engine = StorageEngine::open(dir.path()).unwrap();
293            let node = MemoryNode::new(
294                Uuid::new_v4(),
295                MemoryType::Procedural,
296                "persist across close".to_string(),
297                vec![0.5],
298            );
299            pid = engine.store_memory(&node).unwrap();
300            engine.close().unwrap();
301        }
302        {
303            let mut engine = StorageEngine::open(dir.path()).unwrap();
304            let loaded = engine.load_memory(pid).unwrap();
305            assert_eq!(loaded.content, "persist across close");
306        }
307    }
308
309    #[test]
310    fn test_crash_recovery() {
311        let dir = tempfile::tempdir().unwrap();
312        let mut ids = Vec::new();
313        let mut contents = Vec::new();
314        {
315            let mut engine = StorageEngine::open(dir.path()).unwrap();
316            for i in 0..3 {
317                let content = format!("crash-recovery-{i}");
318                let node = MemoryNode::new(
319                    Uuid::new_v4(),
320                    MemoryType::Episodic,
321                    content.clone(),
322                    vec![i as f32],
323                );
324                let pid = engine.store_memory(&node).unwrap();
325                ids.push(pid);
326                contents.push(content);
327            }
328            // Simulate crash: sync the WAL but do NOT call close/checkpoint.
329            engine.wal.sync().unwrap();
330            // Drop without close — dirty pages may not be flushed.
331        }
332        {
333            // Reopen — WAL replay should recover the writes.
334            let mut engine = StorageEngine::open(dir.path()).unwrap();
335            for (pid, expected) in ids.iter().zip(contents.iter()) {
336                let loaded = engine.load_memory(*pid).unwrap();
337                assert_eq!(&loaded.content, expected);
338            }
339        }
340    }
341
342    #[test]
343    fn test_recovery_idempotent() {
344        let dir = tempfile::tempdir().unwrap();
345        let pid;
346        let content = "idempotent-check".to_string();
347        {
348            let mut engine = StorageEngine::open(dir.path()).unwrap();
349            let node = MemoryNode::new(
350                Uuid::new_v4(),
351                MemoryType::Semantic,
352                content.clone(),
353                vec![1.0, 2.0],
354            );
355            pid = engine.store_memory(&node).unwrap();
356            // Proper shutdown — checkpoint flushes pages and truncates WAL.
357            engine.checkpoint().unwrap();
358            engine.close().unwrap();
359        }
360        {
361            // Reopen after clean shutdown — WAL should be empty, no duplicate data.
362            let mut engine = StorageEngine::open(dir.path()).unwrap();
363            let loaded = engine.load_memory(pid).unwrap();
364            assert_eq!(loaded.content, content);
365        }
366    }
367
368    #[test]
369    fn test_partial_write_recovery() {
370        let dir = tempfile::tempdir().unwrap();
371        let mut ids = Vec::new();
372        let mut contents = Vec::new();
373        {
374            let mut engine = StorageEngine::open(dir.path()).unwrap();
375            // Store 3 memories then checkpoint.
376            for i in 0..3 {
377                let content = format!("checkpointed-{i}");
378                let node = MemoryNode::new(
379                    Uuid::new_v4(),
380                    MemoryType::Semantic,
381                    content.clone(),
382                    vec![i as f32],
383                );
384                let pid = engine.store_memory(&node).unwrap();
385                ids.push(pid);
386                contents.push(content);
387            }
388            engine.checkpoint().unwrap();
389
390            // Store 2 more without checkpoint (will only be in WAL).
391            for i in 3..5 {
392                let content = format!("unckeckpointed-{i}");
393                let node = MemoryNode::new(
394                    Uuid::new_v4(),
395                    MemoryType::Episodic,
396                    content.clone(),
397                    vec![i as f32],
398                );
399                let pid = engine.store_memory(&node).unwrap();
400                ids.push(pid);
401                contents.push(content);
402            }
403            // Simulate crash — sync WAL but don't close.
404            engine.wal.sync().unwrap();
405        }
406        {
407            let mut engine = StorageEngine::open(dir.path()).unwrap();
408            // All 5 memories should be present: 3 from checkpoint, 2 from WAL replay.
409            for (pid, expected) in ids.iter().zip(contents.iter()) {
410                let loaded = engine.load_memory(*pid).unwrap();
411                assert_eq!(&loaded.content, expected);
412            }
413        }
414    }
415}