1use std::path::Path;
4
5use mentedb_core::MemoryNode;
6use mentedb_core::error::{MenteError, MenteResult};
7
8use tracing::info;
9
10use crate::buffer::BufferPool;
11use crate::page::{PAGE_DATA_SIZE, Page, PageId, PageManager, PageType};
12use crate::wal::{Wal, WalEntryType};
13const DEFAULT_BUFFER_POOL_SIZE: usize = 1024;
15
16pub struct StorageEngine {
21 page_manager: PageManager,
22 buffer_pool: BufferPool,
23 wal: Wal,
24}
25
26impl StorageEngine {
27 pub fn open(path: &Path) -> MenteResult<Self> {
32 std::fs::create_dir_all(path)?;
33
34 let page_manager = PageManager::open(path)?;
35 let buffer_pool = BufferPool::new(DEFAULT_BUFFER_POOL_SIZE);
36 let wal = Wal::open(path)?;
37
38 let mut engine = Self {
39 page_manager,
40 buffer_pool,
41 wal,
42 };
43
44 let recovered = engine.recover()?;
45 if recovered > 0 {
46 info!(recovered, ?path, "storage engine opened with WAL recovery");
47 } else {
48 info!(?path, "storage engine opened");
49 }
50
51 Ok(engine)
52 }
53
54 pub fn recover(&mut self) -> MenteResult<usize> {
59 let entries = self.wal.iterate()?;
60 let mut count = 0usize;
61
62 for entry in &entries {
63 match entry.entry_type {
64 WalEntryType::PageWrite => {
65 let page_id = PageId(entry.page_id);
66
67 while self.page_manager.page_count() <= entry.page_id {
69 self.page_manager.allocate_page()?;
70 }
71
72 let mut page = self.page_manager.read_page(page_id)?;
73 let copy_len = entry.data.len().min(PAGE_DATA_SIZE);
74 page.data[..copy_len].copy_from_slice(&entry.data[..copy_len]);
75 if copy_len < PAGE_DATA_SIZE {
76 page.data[copy_len..].fill(0);
77 }
78 page.header.page_id = entry.page_id;
79 page.header.lsn = entry.lsn;
80 page.header.page_type = PageType::Data as u8;
81 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
82 page.header.checksum = page.compute_checksum();
83
84 self.page_manager.write_page(page_id, &page)?;
85 count += 1;
86 }
87 WalEntryType::Checkpoint | WalEntryType::Commit => {
88 }
90 }
91 }
92
93 if count > 0 {
94 self.page_manager.sync()?;
95 let next_lsn = self.wal.next_lsn();
97 self.wal.truncate(next_lsn)?;
98 info!(count, "WAL recovery replayed entries");
99 }
100
101 Ok(count)
102 }
103
104 pub fn close(&mut self) -> MenteResult<()> {
106 self.buffer_pool.flush_all(&mut self.page_manager)?;
107 self.page_manager.sync()?;
108 self.wal.sync()?;
109 info!("storage engine closed");
110 Ok(())
111 }
112
113 pub fn allocate_page(&mut self) -> MenteResult<PageId> {
117 self.page_manager.allocate_page()
118 }
119
120 pub fn read_page(&mut self, page_id: PageId) -> MenteResult<Box<Page>> {
122 self.buffer_pool.fetch_page(page_id, &mut self.page_manager)
123 }
124
125 pub fn write_page(&mut self, page_id: PageId, data: &[u8]) -> MenteResult<()> {
127 let lsn = self.wal.append(WalEntryType::PageWrite, page_id.0, data)?;
129
130 let mut page = self
132 .buffer_pool
133 .fetch_page(page_id, &mut self.page_manager)?;
134
135 let copy_len = data.len().min(PAGE_DATA_SIZE);
136 page.data[..copy_len].copy_from_slice(&data[..copy_len]);
137 if copy_len < PAGE_DATA_SIZE {
139 page.data[copy_len..].fill(0);
140 }
141 page.header.lsn = lsn;
142 page.header.page_type = PageType::Data as u8;
143 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
144 page.header.checksum = page.compute_checksum();
145
146 if self.buffer_pool.update_page(page_id, &page).is_err() {
148 self.page_manager.write_page(page_id, &page)?;
150 }
151 self.buffer_pool.unpin_page(page_id, true).ok();
152
153 Ok(())
154 }
155
156 pub fn store_memory(&mut self, node: &MemoryNode) -> MenteResult<PageId> {
162 let serialized =
163 serde_json::to_vec(node).map_err(|e| MenteError::Serialization(e.to_string()))?;
164
165 if serialized.len() + 4 > PAGE_DATA_SIZE {
167 return Err(MenteError::CapacityExceeded(format!(
168 "memory node serialized to {} bytes (max {})",
169 serialized.len(),
170 PAGE_DATA_SIZE - 4,
171 )));
172 }
173
174 let page_id = self.allocate_page()?;
175
176 let mut buf = Vec::with_capacity(4 + serialized.len());
178 buf.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
179 buf.extend_from_slice(&serialized);
180
181 self.write_page(page_id, &buf)?;
182
183 info!(
184 page_id = page_id.0,
185 bytes = serialized.len(),
186 "stored memory node"
187 );
188 Ok(page_id)
189 }
190
191 pub fn load_memory(&mut self, page_id: PageId) -> MenteResult<MemoryNode> {
193 let page = self.read_page(page_id)?;
194 self.buffer_pool.unpin_page(page_id, false).ok();
196
197 let len = u32::from_le_bytes(page.data[..4].try_into().unwrap()) as usize;
198 if len == 0 || len + 4 > PAGE_DATA_SIZE {
199 return Err(MenteError::Storage(format!(
200 "invalid memory node length prefix: {len}"
201 )));
202 }
203
204 serde_json::from_slice(&page.data[4..4 + len])
205 .map_err(|e| MenteError::Serialization(e.to_string()))
206 }
207
208 pub fn checkpoint(&mut self) -> MenteResult<()> {
212 self.buffer_pool.flush_all(&mut self.page_manager)?;
213 self.page_manager.sync()?;
214
215 let lsn = self.wal.append(WalEntryType::Checkpoint, 0, &[])?;
216 self.wal.sync()?;
217 self.wal.truncate(lsn)?;
218
219 info!(lsn, "checkpoint complete");
220 Ok(())
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use mentedb_core::memory::MemoryType;
228 use mentedb_core::types::AgentId;
229
230 fn setup() -> (tempfile::TempDir, StorageEngine) {
231 let dir = tempfile::tempdir().unwrap();
232 let engine = StorageEngine::open(dir.path()).unwrap();
233 (dir, engine)
234 }
235
236 #[test]
237 fn test_allocate_write_read() {
238 let (_dir, mut engine) = setup();
239
240 let pid = engine.allocate_page().unwrap();
241 engine.write_page(pid, b"hello storage engine").unwrap();
242
243 let page = engine.read_page(pid).unwrap();
244 assert_eq!(&page.data[..20], b"hello storage engine");
245 engine.buffer_pool.unpin_page(pid, false).ok();
246 }
247
248 #[test]
249 fn test_store_and_load_memory() {
250 let (_dir, mut engine) = setup();
251
252 let node = MemoryNode::new(
253 AgentId::new(),
254 MemoryType::Episodic,
255 "The user prefers Rust over Go".to_string(),
256 vec![0.1, 0.2, 0.3, 0.4],
257 );
258
259 let page_id = engine.store_memory(&node).unwrap();
260 let loaded = engine.load_memory(page_id).unwrap();
261
262 assert_eq!(node.id, loaded.id);
263 assert_eq!(node.content, loaded.content);
264 assert_eq!(node.embedding, loaded.embedding);
265 assert_eq!(node.memory_type, loaded.memory_type);
266 }
267
268 #[test]
269 fn test_checkpoint() {
270 let (_dir, mut engine) = setup();
271
272 let node = MemoryNode::new(
273 AgentId::new(),
274 MemoryType::Semantic,
275 "checkpoint test".to_string(),
276 vec![1.0, 2.0],
277 );
278
279 let pid = engine.store_memory(&node).unwrap();
280 engine.checkpoint().unwrap();
281
282 let loaded = engine.load_memory(pid).unwrap();
284 assert_eq!(loaded.content, "checkpoint test");
285 }
286
287 #[test]
288 fn test_close_and_reopen() {
289 let dir = tempfile::tempdir().unwrap();
290 let pid;
291 {
292 let mut engine = StorageEngine::open(dir.path()).unwrap();
293 let node = MemoryNode::new(
294 AgentId::new(),
295 MemoryType::Procedural,
296 "persist across close".to_string(),
297 vec![0.5],
298 );
299 pid = engine.store_memory(&node).unwrap();
300 engine.close().unwrap();
301 }
302 {
303 let mut engine = StorageEngine::open(dir.path()).unwrap();
304 let loaded = engine.load_memory(pid).unwrap();
305 assert_eq!(loaded.content, "persist across close");
306 }
307 }
308
309 #[test]
310 fn test_crash_recovery() {
311 let dir = tempfile::tempdir().unwrap();
312 let mut ids = Vec::new();
313 let mut contents = Vec::new();
314 {
315 let mut engine = StorageEngine::open(dir.path()).unwrap();
316 for i in 0..3 {
317 let content = format!("crash-recovery-{i}");
318 let node = MemoryNode::new(
319 AgentId::new(),
320 MemoryType::Episodic,
321 content.clone(),
322 vec![i as f32],
323 );
324 let pid = engine.store_memory(&node).unwrap();
325 ids.push(pid);
326 contents.push(content);
327 }
328 engine.wal.sync().unwrap();
330 }
332 {
333 let mut engine = StorageEngine::open(dir.path()).unwrap();
335 for (pid, expected) in ids.iter().zip(contents.iter()) {
336 let loaded = engine.load_memory(*pid).unwrap();
337 assert_eq!(&loaded.content, expected);
338 }
339 }
340 }
341
342 #[test]
343 fn test_recovery_idempotent() {
344 let dir = tempfile::tempdir().unwrap();
345 let pid;
346 let content = "idempotent-check".to_string();
347 {
348 let mut engine = StorageEngine::open(dir.path()).unwrap();
349 let node = MemoryNode::new(
350 AgentId::new(),
351 MemoryType::Semantic,
352 content.clone(),
353 vec![1.0, 2.0],
354 );
355 pid = engine.store_memory(&node).unwrap();
356 engine.checkpoint().unwrap();
358 engine.close().unwrap();
359 }
360 {
361 let mut engine = StorageEngine::open(dir.path()).unwrap();
363 let loaded = engine.load_memory(pid).unwrap();
364 assert_eq!(loaded.content, content);
365 }
366 }
367
368 #[test]
369 fn test_partial_write_recovery() {
370 let dir = tempfile::tempdir().unwrap();
371 let mut ids = Vec::new();
372 let mut contents = Vec::new();
373 {
374 let mut engine = StorageEngine::open(dir.path()).unwrap();
375 for i in 0..3 {
377 let content = format!("checkpointed-{i}");
378 let node = MemoryNode::new(
379 AgentId::new(),
380 MemoryType::Semantic,
381 content.clone(),
382 vec![i as f32],
383 );
384 let pid = engine.store_memory(&node).unwrap();
385 ids.push(pid);
386 contents.push(content);
387 }
388 engine.checkpoint().unwrap();
389
390 for i in 3..5 {
392 let content = format!("unckeckpointed-{i}");
393 let node = MemoryNode::new(
394 AgentId::new(),
395 MemoryType::Episodic,
396 content.clone(),
397 vec![i as f32],
398 );
399 let pid = engine.store_memory(&node).unwrap();
400 ids.push(pid);
401 contents.push(content);
402 }
403 engine.wal.sync().unwrap();
405 }
406 {
407 let mut engine = StorageEngine::open(dir.path()).unwrap();
408 for (pid, expected) in ids.iter().zip(contents.iter()) {
410 let loaded = engine.load_memory(*pid).unwrap();
411 assert_eq!(&loaded.content, expected);
412 }
413 }
414 }
415}