1use std::path::Path;
4
5use mentedb_core::MemoryNode;
6use mentedb_core::error::{MenteError, MenteResult};
7
8use tracing::info;
9
10use crate::buffer::BufferPool;
11use crate::page::{PAGE_DATA_SIZE, Page, PageId, PageManager, PageType};
12use crate::wal::{Wal, WalEntryType};
13const DEFAULT_BUFFER_POOL_SIZE: usize = 1024;
15
16pub struct StorageEngine {
21 page_manager: PageManager,
22 buffer_pool: BufferPool,
23 wal: Wal,
24}
25
26impl StorageEngine {
27 pub fn open(path: &Path) -> MenteResult<Self> {
32 std::fs::create_dir_all(path)?;
33
34 let page_manager = PageManager::open(path)?;
35 let buffer_pool = BufferPool::new(DEFAULT_BUFFER_POOL_SIZE);
36 let wal = Wal::open(path)?;
37
38 let mut engine = Self {
39 page_manager,
40 buffer_pool,
41 wal,
42 };
43
44 let recovered = engine.recover()?;
45 if recovered > 0 {
46 info!(recovered, ?path, "storage engine opened with WAL recovery");
47 } else {
48 info!(?path, "storage engine opened");
49 }
50
51 Ok(engine)
52 }
53
54 pub fn recover(&mut self) -> MenteResult<usize> {
59 let entries = self.wal.iterate()?;
60 let mut count = 0usize;
61
62 for entry in &entries {
63 match entry.entry_type {
64 WalEntryType::PageWrite => {
65 let page_id = PageId(entry.page_id);
66
67 while self.page_manager.page_count() <= entry.page_id {
69 self.page_manager.allocate_page()?;
70 }
71
72 let mut page = self.page_manager.read_page(page_id)?;
73 let copy_len = entry.data.len().min(PAGE_DATA_SIZE);
74 page.data[..copy_len].copy_from_slice(&entry.data[..copy_len]);
75 if copy_len < PAGE_DATA_SIZE {
76 page.data[copy_len..].fill(0);
77 }
78 page.header.page_id = entry.page_id;
79 page.header.lsn = entry.lsn;
80 page.header.page_type = PageType::Data as u8;
81 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
82 page.header.checksum = page.compute_checksum();
83
84 self.page_manager.write_page(page_id, &page)?;
85 count += 1;
86 }
87 WalEntryType::Checkpoint | WalEntryType::Commit => {
88 }
90 }
91 }
92
93 if count > 0 {
94 self.page_manager.sync()?;
95 let next_lsn = self.wal.next_lsn();
97 self.wal.truncate(next_lsn)?;
98 info!(count, "WAL recovery replayed entries");
99 }
100
101 Ok(count)
102 }
103
104 pub fn close(&mut self) -> MenteResult<()> {
106 self.buffer_pool.flush_all(&mut self.page_manager)?;
107 self.page_manager.sync()?;
108 self.wal.sync()?;
109 info!("storage engine closed");
110 Ok(())
111 }
112
113 pub fn allocate_page(&mut self) -> MenteResult<PageId> {
117 self.page_manager.allocate_page()
118 }
119
120 pub fn read_page(&mut self, page_id: PageId) -> MenteResult<Box<Page>> {
122 self.buffer_pool.fetch_page(page_id, &mut self.page_manager)
123 }
124
125 pub fn write_page(&mut self, page_id: PageId, data: &[u8]) -> MenteResult<()> {
127 let lsn = self.wal.append(WalEntryType::PageWrite, page_id.0, data)?;
129
130 let mut page = self
132 .buffer_pool
133 .fetch_page(page_id, &mut self.page_manager)?;
134
135 let copy_len = data.len().min(PAGE_DATA_SIZE);
136 page.data[..copy_len].copy_from_slice(&data[..copy_len]);
137 if copy_len < PAGE_DATA_SIZE {
139 page.data[copy_len..].fill(0);
140 }
141 page.header.lsn = lsn;
142 page.header.page_type = PageType::Data as u8;
143 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
144 page.header.checksum = page.compute_checksum();
145
146 if self.buffer_pool.update_page(page_id, &page).is_err() {
148 self.page_manager.write_page(page_id, &page)?;
150 }
151 self.buffer_pool.unpin_page(page_id, true).ok();
152
153 Ok(())
154 }
155
156 pub fn store_memory(&mut self, node: &MemoryNode) -> MenteResult<PageId> {
162 let serialized =
163 serde_json::to_vec(node).map_err(|e| MenteError::Serialization(e.to_string()))?;
164
165 if serialized.len() + 4 > PAGE_DATA_SIZE {
167 return Err(MenteError::CapacityExceeded(format!(
168 "memory node serialized to {} bytes (max {})",
169 serialized.len(),
170 PAGE_DATA_SIZE - 4,
171 )));
172 }
173
174 let page_id = self.allocate_page()?;
175
176 let mut buf = Vec::with_capacity(4 + serialized.len());
178 buf.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
179 buf.extend_from_slice(&serialized);
180
181 self.write_page(page_id, &buf)?;
182
183 info!(
184 page_id = page_id.0,
185 bytes = serialized.len(),
186 "stored memory node"
187 );
188 Ok(page_id)
189 }
190
191 pub fn load_memory(&mut self, page_id: PageId) -> MenteResult<MemoryNode> {
193 let page = self.read_page(page_id)?;
194 self.buffer_pool.unpin_page(page_id, false).ok();
196
197 let len = u32::from_le_bytes(page.data[..4].try_into().unwrap()) as usize;
198 if len == 0 || len + 4 > PAGE_DATA_SIZE {
199 return Err(MenteError::Storage(format!(
200 "invalid memory node length prefix: {len}"
201 )));
202 }
203
204 serde_json::from_slice(&page.data[4..4 + len])
205 .map_err(|e| MenteError::Serialization(e.to_string()))
206 }
207
208 pub fn checkpoint(&mut self) -> MenteResult<()> {
212 self.buffer_pool.flush_all(&mut self.page_manager)?;
213 self.page_manager.sync()?;
214
215 let lsn = self.wal.append(WalEntryType::Checkpoint, 0, &[])?;
216 self.wal.sync()?;
217 self.wal.truncate(lsn)?;
218
219 info!(lsn, "checkpoint complete");
220 Ok(())
221 }
222
223 pub fn scan_all_memories(&mut self) -> Vec<(mentedb_core::types::MemoryId, PageId)> {
227 let count = self.page_manager.page_count();
228 let mut results = Vec::new();
229 for i in 1..count {
231 let page_id = PageId(i);
232 if let Ok(node) = self.load_memory(page_id) {
233 results.push((node.id, page_id));
234 }
235 }
236 results
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use mentedb_core::memory::MemoryType;
244 use mentedb_core::types::AgentId;
245
246 fn setup() -> (tempfile::TempDir, StorageEngine) {
247 let dir = tempfile::tempdir().unwrap();
248 let engine = StorageEngine::open(dir.path()).unwrap();
249 (dir, engine)
250 }
251
252 #[test]
253 fn test_allocate_write_read() {
254 let (_dir, mut engine) = setup();
255
256 let pid = engine.allocate_page().unwrap();
257 engine.write_page(pid, b"hello storage engine").unwrap();
258
259 let page = engine.read_page(pid).unwrap();
260 assert_eq!(&page.data[..20], b"hello storage engine");
261 engine.buffer_pool.unpin_page(pid, false).ok();
262 }
263
264 #[test]
265 fn test_store_and_load_memory() {
266 let (_dir, mut engine) = setup();
267
268 let node = MemoryNode::new(
269 AgentId::new(),
270 MemoryType::Episodic,
271 "The user prefers Rust over Go".to_string(),
272 vec![0.1, 0.2, 0.3, 0.4],
273 );
274
275 let page_id = engine.store_memory(&node).unwrap();
276 let loaded = engine.load_memory(page_id).unwrap();
277
278 assert_eq!(node.id, loaded.id);
279 assert_eq!(node.content, loaded.content);
280 assert_eq!(node.embedding, loaded.embedding);
281 assert_eq!(node.memory_type, loaded.memory_type);
282 }
283
284 #[test]
285 fn test_checkpoint() {
286 let (_dir, mut engine) = setup();
287
288 let node = MemoryNode::new(
289 AgentId::new(),
290 MemoryType::Semantic,
291 "checkpoint test".to_string(),
292 vec![1.0, 2.0],
293 );
294
295 let pid = engine.store_memory(&node).unwrap();
296 engine.checkpoint().unwrap();
297
298 let loaded = engine.load_memory(pid).unwrap();
300 assert_eq!(loaded.content, "checkpoint test");
301 }
302
303 #[test]
304 fn test_close_and_reopen() {
305 let dir = tempfile::tempdir().unwrap();
306 let pid;
307 {
308 let mut engine = StorageEngine::open(dir.path()).unwrap();
309 let node = MemoryNode::new(
310 AgentId::new(),
311 MemoryType::Procedural,
312 "persist across close".to_string(),
313 vec![0.5],
314 );
315 pid = engine.store_memory(&node).unwrap();
316 engine.close().unwrap();
317 }
318 {
319 let mut engine = StorageEngine::open(dir.path()).unwrap();
320 let loaded = engine.load_memory(pid).unwrap();
321 assert_eq!(loaded.content, "persist across close");
322 }
323 }
324
325 #[test]
326 fn test_crash_recovery() {
327 let dir = tempfile::tempdir().unwrap();
328 let mut ids = Vec::new();
329 let mut contents = Vec::new();
330 {
331 let mut engine = StorageEngine::open(dir.path()).unwrap();
332 for i in 0..3 {
333 let content = format!("crash-recovery-{i}");
334 let node = MemoryNode::new(
335 AgentId::new(),
336 MemoryType::Episodic,
337 content.clone(),
338 vec![i as f32],
339 );
340 let pid = engine.store_memory(&node).unwrap();
341 ids.push(pid);
342 contents.push(content);
343 }
344 engine.wal.sync().unwrap();
346 }
348 {
349 let mut engine = StorageEngine::open(dir.path()).unwrap();
351 for (pid, expected) in ids.iter().zip(contents.iter()) {
352 let loaded = engine.load_memory(*pid).unwrap();
353 assert_eq!(&loaded.content, expected);
354 }
355 }
356 }
357
358 #[test]
359 fn test_recovery_idempotent() {
360 let dir = tempfile::tempdir().unwrap();
361 let pid;
362 let content = "idempotent-check".to_string();
363 {
364 let mut engine = StorageEngine::open(dir.path()).unwrap();
365 let node = MemoryNode::new(
366 AgentId::new(),
367 MemoryType::Semantic,
368 content.clone(),
369 vec![1.0, 2.0],
370 );
371 pid = engine.store_memory(&node).unwrap();
372 engine.checkpoint().unwrap();
374 engine.close().unwrap();
375 }
376 {
377 let mut engine = StorageEngine::open(dir.path()).unwrap();
379 let loaded = engine.load_memory(pid).unwrap();
380 assert_eq!(loaded.content, content);
381 }
382 }
383
384 #[test]
385 fn test_partial_write_recovery() {
386 let dir = tempfile::tempdir().unwrap();
387 let mut ids = Vec::new();
388 let mut contents = Vec::new();
389 {
390 let mut engine = StorageEngine::open(dir.path()).unwrap();
391 for i in 0..3 {
393 let content = format!("checkpointed-{i}");
394 let node = MemoryNode::new(
395 AgentId::new(),
396 MemoryType::Semantic,
397 content.clone(),
398 vec![i as f32],
399 );
400 let pid = engine.store_memory(&node).unwrap();
401 ids.push(pid);
402 contents.push(content);
403 }
404 engine.checkpoint().unwrap();
405
406 for i in 3..5 {
408 let content = format!("unckeckpointed-{i}");
409 let node = MemoryNode::new(
410 AgentId::new(),
411 MemoryType::Episodic,
412 content.clone(),
413 vec![i as f32],
414 );
415 let pid = engine.store_memory(&node).unwrap();
416 ids.push(pid);
417 contents.push(content);
418 }
419 engine.wal.sync().unwrap();
421 }
422 {
423 let mut engine = StorageEngine::open(dir.path()).unwrap();
424 for (pid, expected) in ids.iter().zip(contents.iter()) {
426 let loaded = engine.load_memory(*pid).unwrap();
427 assert_eq!(&loaded.content, expected);
428 }
429 }
430 }
431}