1use std::fs::File;
4use std::path::Path;
5
6use mentedb_core::MemoryNode;
7use mentedb_core::error::{MenteError, MenteResult};
8
9use fs2::FileExt;
10use tracing::info;
11
12use crate::buffer::BufferPool;
13use crate::page::{PAGE_DATA_SIZE, Page, PageId, PageManager, PageType};
14use crate::wal::{Wal, WalEntryType};
15const DEFAULT_BUFFER_POOL_SIZE: usize = 1024;
17
18pub struct StorageEngine {
23 page_manager: PageManager,
24 buffer_pool: BufferPool,
25 wal: Wal,
26 _lock_file: File,
28}
29
30impl StorageEngine {
31 pub fn open(path: &Path) -> MenteResult<Self> {
36 std::fs::create_dir_all(path)?;
37
38 let lock_path = path.join("mentedb.lock");
40 let lock_file = File::create(&lock_path)
41 .map_err(|e| MenteError::Storage(format!("failed to create lock file: {e}")))?;
42 lock_file.try_lock_exclusive().map_err(|_| {
43 MenteError::Storage(
44 "Database is locked by another process. Only one instance can access the database at a time.".to_string()
45 )
46 })?;
47
48 let page_manager = PageManager::open(path)?;
49 let buffer_pool = BufferPool::new(DEFAULT_BUFFER_POOL_SIZE);
50 let wal = Wal::open(path)?;
51
52 let mut engine = Self {
53 page_manager,
54 buffer_pool,
55 wal,
56 _lock_file: lock_file,
57 };
58
59 let recovered = engine.recover()?;
60 if recovered > 0 {
61 info!(recovered, ?path, "storage engine opened with WAL recovery");
62 } else {
63 info!(?path, "storage engine opened");
64 }
65
66 Ok(engine)
67 }
68
69 pub fn recover(&mut self) -> MenteResult<usize> {
74 let entries = self.wal.iterate()?;
75 let mut count = 0usize;
76
77 for entry in &entries {
78 match entry.entry_type {
79 WalEntryType::PageWrite => {
80 let page_id = PageId(entry.page_id);
81
82 while self.page_manager.page_count() <= entry.page_id {
84 self.page_manager.allocate_page()?;
85 }
86
87 let mut page = self.page_manager.read_page(page_id)?;
88 let copy_len = entry.data.len().min(PAGE_DATA_SIZE);
89 page.data[..copy_len].copy_from_slice(&entry.data[..copy_len]);
90 if copy_len < PAGE_DATA_SIZE {
91 page.data[copy_len..].fill(0);
92 }
93 page.header.page_id = entry.page_id;
94 page.header.lsn = entry.lsn;
95 page.header.page_type = PageType::Data as u8;
96 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
97 page.header.checksum = page.compute_checksum();
98
99 self.page_manager.write_page(page_id, &page)?;
100 count += 1;
101 }
102 WalEntryType::Checkpoint | WalEntryType::Commit => {
103 }
105 }
106 }
107
108 if count > 0 {
109 self.page_manager.sync()?;
110 let next_lsn = self.wal.next_lsn();
112 self.wal.truncate(next_lsn)?;
113 info!(count, "WAL recovery replayed entries");
114 }
115
116 Ok(count)
117 }
118
119 pub fn close(&mut self) -> MenteResult<()> {
121 self.buffer_pool.flush_all(&mut self.page_manager)?;
122 self.page_manager.sync()?;
123 self.wal.sync()?;
124 info!("storage engine closed");
125 Ok(())
126 }
127
128 pub fn allocate_page(&mut self) -> MenteResult<PageId> {
132 self.page_manager.allocate_page()
133 }
134
135 pub fn read_page(&mut self, page_id: PageId) -> MenteResult<Box<Page>> {
137 self.buffer_pool.fetch_page(page_id, &mut self.page_manager)
138 }
139
140 pub fn write_page(&mut self, page_id: PageId, data: &[u8]) -> MenteResult<()> {
142 let lsn = self.wal.append(WalEntryType::PageWrite, page_id.0, data)?;
144
145 let mut page = self
147 .buffer_pool
148 .fetch_page(page_id, &mut self.page_manager)?;
149
150 let copy_len = data.len().min(PAGE_DATA_SIZE);
151 page.data[..copy_len].copy_from_slice(&data[..copy_len]);
152 if copy_len < PAGE_DATA_SIZE {
154 page.data[copy_len..].fill(0);
155 }
156 page.header.lsn = lsn;
157 page.header.page_type = PageType::Data as u8;
158 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
159 page.header.checksum = page.compute_checksum();
160
161 if self.buffer_pool.update_page(page_id, &page).is_err() {
163 self.page_manager.write_page(page_id, &page)?;
165 }
166 self.buffer_pool.unpin_page(page_id, true).ok();
167
168 Ok(())
169 }
170
171 pub fn store_memory(&mut self, node: &MemoryNode) -> MenteResult<PageId> {
177 let serialized =
178 serde_json::to_vec(node).map_err(|e| MenteError::Serialization(e.to_string()))?;
179
180 if serialized.len() + 4 > PAGE_DATA_SIZE {
182 return Err(MenteError::CapacityExceeded(format!(
183 "memory node serialized to {} bytes (max {})",
184 serialized.len(),
185 PAGE_DATA_SIZE - 4,
186 )));
187 }
188
189 let page_id = self.allocate_page()?;
190
191 let mut buf = Vec::with_capacity(4 + serialized.len());
193 buf.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
194 buf.extend_from_slice(&serialized);
195
196 self.write_page(page_id, &buf)?;
197
198 info!(
199 page_id = page_id.0,
200 bytes = serialized.len(),
201 "stored memory node"
202 );
203 Ok(page_id)
204 }
205
206 pub fn load_memory(&mut self, page_id: PageId) -> MenteResult<MemoryNode> {
208 let page = self.read_page(page_id)?;
209 self.buffer_pool.unpin_page(page_id, false).ok();
211
212 let len = u32::from_le_bytes(page.data[..4].try_into().unwrap()) as usize;
213 if len == 0 || len + 4 > PAGE_DATA_SIZE {
214 return Err(MenteError::Storage(format!(
215 "invalid memory node length prefix: {len}"
216 )));
217 }
218
219 serde_json::from_slice(&page.data[4..4 + len])
220 .map_err(|e| MenteError::Serialization(e.to_string()))
221 }
222
223 pub fn checkpoint(&mut self) -> MenteResult<()> {
227 self.buffer_pool.flush_all(&mut self.page_manager)?;
228 self.page_manager.sync()?;
229
230 let lsn = self.wal.append(WalEntryType::Checkpoint, 0, &[])?;
231 self.wal.sync()?;
232 self.wal.truncate(lsn)?;
233
234 info!(lsn, "checkpoint complete");
235 Ok(())
236 }
237
238 pub fn scan_all_memories(&mut self) -> Vec<(mentedb_core::types::MemoryId, PageId)> {
242 let count = self.page_manager.page_count();
243 let mut results = Vec::new();
244 for i in 1..count {
246 let page_id = PageId(i);
247 if let Ok(node) = self.load_memory(page_id) {
248 results.push((node.id, page_id));
249 }
250 }
251 results
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use mentedb_core::memory::MemoryType;
259 use mentedb_core::types::AgentId;
260
261 fn setup() -> (tempfile::TempDir, StorageEngine) {
262 let dir = tempfile::tempdir().unwrap();
263 let engine = StorageEngine::open(dir.path()).unwrap();
264 (dir, engine)
265 }
266
267 #[test]
268 fn test_allocate_write_read() {
269 let (_dir, mut engine) = setup();
270
271 let pid = engine.allocate_page().unwrap();
272 engine.write_page(pid, b"hello storage engine").unwrap();
273
274 let page = engine.read_page(pid).unwrap();
275 assert_eq!(&page.data[..20], b"hello storage engine");
276 engine.buffer_pool.unpin_page(pid, false).ok();
277 }
278
279 #[test]
280 fn test_store_and_load_memory() {
281 let (_dir, mut engine) = setup();
282
283 let node = MemoryNode::new(
284 AgentId::new(),
285 MemoryType::Episodic,
286 "The user prefers Rust over Go".to_string(),
287 vec![0.1, 0.2, 0.3, 0.4],
288 );
289
290 let page_id = engine.store_memory(&node).unwrap();
291 let loaded = engine.load_memory(page_id).unwrap();
292
293 assert_eq!(node.id, loaded.id);
294 assert_eq!(node.content, loaded.content);
295 assert_eq!(node.embedding, loaded.embedding);
296 assert_eq!(node.memory_type, loaded.memory_type);
297 }
298
299 #[test]
300 fn test_checkpoint() {
301 let (_dir, mut engine) = setup();
302
303 let node = MemoryNode::new(
304 AgentId::new(),
305 MemoryType::Semantic,
306 "checkpoint test".to_string(),
307 vec![1.0, 2.0],
308 );
309
310 let pid = engine.store_memory(&node).unwrap();
311 engine.checkpoint().unwrap();
312
313 let loaded = engine.load_memory(pid).unwrap();
315 assert_eq!(loaded.content, "checkpoint test");
316 }
317
318 #[test]
319 fn test_close_and_reopen() {
320 let dir = tempfile::tempdir().unwrap();
321 let pid;
322 {
323 let mut engine = StorageEngine::open(dir.path()).unwrap();
324 let node = MemoryNode::new(
325 AgentId::new(),
326 MemoryType::Procedural,
327 "persist across close".to_string(),
328 vec![0.5],
329 );
330 pid = engine.store_memory(&node).unwrap();
331 engine.close().unwrap();
332 }
333 {
334 let mut engine = StorageEngine::open(dir.path()).unwrap();
335 let loaded = engine.load_memory(pid).unwrap();
336 assert_eq!(loaded.content, "persist across close");
337 }
338 }
339
340 #[test]
341 fn test_crash_recovery() {
342 let dir = tempfile::tempdir().unwrap();
343 let mut ids = Vec::new();
344 let mut contents = Vec::new();
345 {
346 let mut engine = StorageEngine::open(dir.path()).unwrap();
347 for i in 0..3 {
348 let content = format!("crash-recovery-{i}");
349 let node = MemoryNode::new(
350 AgentId::new(),
351 MemoryType::Episodic,
352 content.clone(),
353 vec![i as f32],
354 );
355 let pid = engine.store_memory(&node).unwrap();
356 ids.push(pid);
357 contents.push(content);
358 }
359 engine.wal.sync().unwrap();
361 }
363 {
364 let mut engine = StorageEngine::open(dir.path()).unwrap();
366 for (pid, expected) in ids.iter().zip(contents.iter()) {
367 let loaded = engine.load_memory(*pid).unwrap();
368 assert_eq!(&loaded.content, expected);
369 }
370 }
371 }
372
373 #[test]
374 fn test_recovery_idempotent() {
375 let dir = tempfile::tempdir().unwrap();
376 let pid;
377 let content = "idempotent-check".to_string();
378 {
379 let mut engine = StorageEngine::open(dir.path()).unwrap();
380 let node = MemoryNode::new(
381 AgentId::new(),
382 MemoryType::Semantic,
383 content.clone(),
384 vec![1.0, 2.0],
385 );
386 pid = engine.store_memory(&node).unwrap();
387 engine.checkpoint().unwrap();
389 engine.close().unwrap();
390 }
391 {
392 let mut engine = StorageEngine::open(dir.path()).unwrap();
394 let loaded = engine.load_memory(pid).unwrap();
395 assert_eq!(loaded.content, content);
396 }
397 }
398
399 #[test]
400 fn test_partial_write_recovery() {
401 let dir = tempfile::tempdir().unwrap();
402 let mut ids = Vec::new();
403 let mut contents = Vec::new();
404 {
405 let mut engine = StorageEngine::open(dir.path()).unwrap();
406 for i in 0..3 {
408 let content = format!("checkpointed-{i}");
409 let node = MemoryNode::new(
410 AgentId::new(),
411 MemoryType::Semantic,
412 content.clone(),
413 vec![i as f32],
414 );
415 let pid = engine.store_memory(&node).unwrap();
416 ids.push(pid);
417 contents.push(content);
418 }
419 engine.checkpoint().unwrap();
420
421 for i in 3..5 {
423 let content = format!("unckeckpointed-{i}");
424 let node = MemoryNode::new(
425 AgentId::new(),
426 MemoryType::Episodic,
427 content.clone(),
428 vec![i as f32],
429 );
430 let pid = engine.store_memory(&node).unwrap();
431 ids.push(pid);
432 contents.push(content);
433 }
434 engine.wal.sync().unwrap();
436 }
437 {
438 let mut engine = StorageEngine::open(dir.path()).unwrap();
439 for (pid, expected) in ids.iter().zip(contents.iter()) {
441 let loaded = engine.load_memory(*pid).unwrap();
442 assert_eq!(&loaded.content, expected);
443 }
444 }
445 }
446}