1use std::fs::File;
4use std::path::Path;
5
6use mentedb_core::MemoryNode;
7use mentedb_core::error::{MenteError, MenteResult};
8
9use fs2::FileExt;
10use parking_lot::Mutex;
11use tracing::info;
12
13use crate::buffer::BufferPool;
14use crate::page::{PAGE_DATA_SIZE, Page, PageId, PageManager, PageType};
15use crate::wal::{Wal, WalEntryType};
16const DEFAULT_BUFFER_POOL_SIZE: usize = 1024;
18
19pub struct StorageEngine {
27 page_manager: Mutex<PageManager>,
28 buffer_pool: BufferPool,
29 wal: Mutex<Wal>,
30 _lock_file: File,
32}
33
34impl StorageEngine {
35 pub fn open(path: &Path) -> MenteResult<Self> {
40 std::fs::create_dir_all(path)?;
41
42 let lock_path = path.join("mentedb.lock");
44 let lock_file = File::create(&lock_path)
45 .map_err(|e| MenteError::Storage(format!("failed to create lock file: {e}")))?;
46 lock_file.try_lock_exclusive().map_err(|_| {
47 MenteError::Storage(
48 "Database is locked by another process. Only one instance can access the database at a time.".to_string()
49 )
50 })?;
51
52 let page_manager = PageManager::open(path)?;
53 let buffer_pool = BufferPool::new(DEFAULT_BUFFER_POOL_SIZE);
54 let wal = Wal::open(path)?;
55
56 let engine = Self {
57 page_manager: Mutex::new(page_manager),
58 buffer_pool,
59 wal: Mutex::new(wal),
60 _lock_file: lock_file,
61 };
62
63 let recovered = engine.recover()?;
64 if recovered > 0 {
65 info!(recovered, ?path, "storage engine opened with WAL recovery");
66 } else {
67 info!(?path, "storage engine opened");
68 }
69
70 Ok(engine)
71 }
72
73 pub fn recover(&self) -> MenteResult<usize> {
78 let mut wal = self.wal.lock();
79 let entries = wal.iterate()?;
80 let mut count = 0usize;
81 let mut pm = self.page_manager.lock();
82
83 for entry in &entries {
84 match entry.entry_type {
85 WalEntryType::PageWrite => {
86 let page_id = PageId(entry.page_id);
87
88 while pm.page_count() <= entry.page_id {
89 pm.allocate_page()?;
90 }
91
92 let mut page = pm.read_page(page_id)?;
93 let copy_len = entry.data.len().min(PAGE_DATA_SIZE);
94 page.data[..copy_len].copy_from_slice(&entry.data[..copy_len]);
95 if copy_len < PAGE_DATA_SIZE {
96 page.data[copy_len..].fill(0);
97 }
98 page.header.page_id = entry.page_id;
99 page.header.lsn = entry.lsn;
100 page.header.page_type = PageType::Data as u8;
101 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
102 page.header.checksum = page.compute_checksum();
103
104 pm.write_page(page_id, &page)?;
105 count += 1;
106 }
107 WalEntryType::Checkpoint | WalEntryType::Commit => {}
108 }
109 }
110
111 if count > 0 {
112 pm.sync()?;
113 let next_lsn = wal.next_lsn();
114 wal.truncate(next_lsn)?;
115 info!(count, "WAL recovery replayed entries");
116 }
117
118 Ok(count)
119 }
120
121 pub fn close(&self) -> MenteResult<()> {
123 let mut pm = self.page_manager.lock();
124 self.buffer_pool.flush_all(&mut pm)?;
125 pm.sync()?;
126 self.wal.lock().sync()?;
127 info!("storage engine closed");
128 Ok(())
129 }
130
131 pub fn allocate_page(&self) -> MenteResult<PageId> {
135 self.page_manager.lock().allocate_page()
136 }
137
138 pub fn read_page(&self, page_id: PageId) -> MenteResult<Box<Page>> {
140 self.buffer_pool
141 .fetch_page(page_id, &mut self.page_manager.lock())
142 }
143
144 pub fn write_page(&self, page_id: PageId, data: &[u8]) -> MenteResult<()> {
146 let lsn = self
147 .wal
148 .lock()
149 .append(WalEntryType::PageWrite, page_id.0, data)?;
150
151 let mut pm = self.page_manager.lock();
152 let mut page = self.buffer_pool.fetch_page(page_id, &mut pm)?;
153 drop(pm); let copy_len = data.len().min(PAGE_DATA_SIZE);
156 page.data[..copy_len].copy_from_slice(&data[..copy_len]);
157 if copy_len < PAGE_DATA_SIZE {
158 page.data[copy_len..].fill(0);
159 }
160 page.header.lsn = lsn;
161 page.header.page_type = PageType::Data as u8;
162 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
163 page.header.checksum = page.compute_checksum();
164
165 if self.buffer_pool.update_page(page_id, &page).is_err() {
166 self.page_manager.lock().write_page(page_id, &page)?;
167 }
168 self.buffer_pool.unpin_page(page_id, true).ok();
169
170 Ok(())
171 }
172
173 pub fn store_memory(&self, node: &MemoryNode) -> MenteResult<PageId> {
179 let serialized =
180 serde_json::to_vec(node).map_err(|e| MenteError::Serialization(e.to_string()))?;
181
182 if serialized.len() + 4 > PAGE_DATA_SIZE {
183 return Err(MenteError::CapacityExceeded(format!(
184 "memory node serialized to {} bytes (max {})",
185 serialized.len(),
186 PAGE_DATA_SIZE - 4,
187 )));
188 }
189
190 let page_id = self.allocate_page()?;
191
192 let mut buf = Vec::with_capacity(4 + serialized.len());
193 buf.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
194 buf.extend_from_slice(&serialized);
195
196 self.write_page(page_id, &buf)?;
197
198 info!(
199 page_id = page_id.0,
200 bytes = serialized.len(),
201 "stored memory node"
202 );
203 Ok(page_id)
204 }
205
206 pub fn load_memory(&self, page_id: PageId) -> MenteResult<MemoryNode> {
208 let page = self.read_page(page_id)?;
209 self.buffer_pool.unpin_page(page_id, false).ok();
210
211 let len = u32::from_le_bytes(page.data[..4].try_into().unwrap()) as usize;
212 if len == 0 || len + 4 > PAGE_DATA_SIZE {
213 return Err(MenteError::Storage(format!(
214 "invalid memory node length prefix: {len}"
215 )));
216 }
217
218 serde_json::from_slice(&page.data[4..4 + len])
219 .map_err(|e| MenteError::Serialization(e.to_string()))
220 }
221
222 pub fn checkpoint(&self) -> MenteResult<()> {
226 let mut pm = self.page_manager.lock();
227 self.buffer_pool.flush_all(&mut pm)?;
228 pm.sync()?;
229
230 let mut wal = self.wal.lock();
231 let lsn = wal.append(WalEntryType::Checkpoint, 0, &[])?;
232 wal.sync()?;
233 wal.truncate(lsn)?;
234
235 info!(lsn, "checkpoint complete");
236 Ok(())
237 }
238
239 pub fn scan_all_memories(&self) -> Vec<(mentedb_core::types::MemoryId, PageId)> {
243 let count = self.page_manager.lock().page_count();
244 let mut results = Vec::new();
245 for i in 1..count {
246 let page_id = PageId(i);
247 if let Ok(node) = self.load_memory(page_id) {
248 results.push((node.id, page_id));
249 }
250 }
251 results
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use mentedb_core::memory::MemoryType;
259 use mentedb_core::types::AgentId;
260
261 fn setup() -> (tempfile::TempDir, StorageEngine) {
262 let dir = tempfile::tempdir().unwrap();
263 let engine = StorageEngine::open(dir.path()).unwrap();
264 (dir, engine)
265 }
266
267 #[test]
268 fn test_allocate_write_read() {
269 let (_dir, engine) = setup();
270
271 let pid = engine.allocate_page().unwrap();
272 engine.write_page(pid, b"hello storage engine").unwrap();
273
274 let page = engine.read_page(pid).unwrap();
275 assert_eq!(&page.data[..20], b"hello storage engine");
276 engine.buffer_pool.unpin_page(pid, false).ok();
277 }
278
279 #[test]
280 fn test_store_and_load_memory() {
281 let (_dir, engine) = setup();
282
283 let node = MemoryNode::new(
284 AgentId::new(),
285 MemoryType::Episodic,
286 "The user prefers Rust over Go".to_string(),
287 vec![0.1, 0.2, 0.3, 0.4],
288 );
289
290 let page_id = engine.store_memory(&node).unwrap();
291 let loaded = engine.load_memory(page_id).unwrap();
292
293 assert_eq!(node.id, loaded.id);
294 assert_eq!(node.content, loaded.content);
295 assert_eq!(node.embedding, loaded.embedding);
296 assert_eq!(node.memory_type, loaded.memory_type);
297 }
298
299 #[test]
300 fn test_checkpoint() {
301 let (_dir, engine) = setup();
302
303 let node = MemoryNode::new(
304 AgentId::new(),
305 MemoryType::Semantic,
306 "checkpoint test".to_string(),
307 vec![1.0, 2.0],
308 );
309
310 let pid = engine.store_memory(&node).unwrap();
311 engine.checkpoint().unwrap();
312
313 let loaded = engine.load_memory(pid).unwrap();
314 assert_eq!(loaded.content, "checkpoint test");
315 }
316
317 #[test]
318 fn test_close_and_reopen() {
319 let dir = tempfile::tempdir().unwrap();
320 let pid;
321 {
322 let engine = StorageEngine::open(dir.path()).unwrap();
323 let node = MemoryNode::new(
324 AgentId::new(),
325 MemoryType::Procedural,
326 "persist across close".to_string(),
327 vec![0.5],
328 );
329 pid = engine.store_memory(&node).unwrap();
330 engine.close().unwrap();
331 }
332 {
333 let engine = StorageEngine::open(dir.path()).unwrap();
334 let loaded = engine.load_memory(pid).unwrap();
335 assert_eq!(loaded.content, "persist across close");
336 }
337 }
338
339 #[test]
340 fn test_crash_recovery() {
341 let dir = tempfile::tempdir().unwrap();
342 let mut ids = Vec::new();
343 let mut contents = Vec::new();
344 {
345 let engine = StorageEngine::open(dir.path()).unwrap();
346 for i in 0..3 {
347 let content = format!("crash-recovery-{i}");
348 let node = MemoryNode::new(
349 AgentId::new(),
350 MemoryType::Episodic,
351 content.clone(),
352 vec![i as f32],
353 );
354 let pid = engine.store_memory(&node).unwrap();
355 ids.push(pid);
356 contents.push(content);
357 }
358 engine.wal.lock().sync().unwrap();
360 }
361 {
362 let engine = StorageEngine::open(dir.path()).unwrap();
363 for (pid, expected) in ids.iter().zip(contents.iter()) {
364 let loaded = engine.load_memory(*pid).unwrap();
365 assert_eq!(&loaded.content, expected);
366 }
367 }
368 }
369
370 #[test]
371 fn test_recovery_idempotent() {
372 let dir = tempfile::tempdir().unwrap();
373 let pid;
374 let content = "idempotent-check".to_string();
375 {
376 let engine = StorageEngine::open(dir.path()).unwrap();
377 let node = MemoryNode::new(
378 AgentId::new(),
379 MemoryType::Semantic,
380 content.clone(),
381 vec![1.0, 2.0],
382 );
383 pid = engine.store_memory(&node).unwrap();
384 engine.checkpoint().unwrap();
385 engine.close().unwrap();
386 }
387 {
388 let engine = StorageEngine::open(dir.path()).unwrap();
389 let loaded = engine.load_memory(pid).unwrap();
390 assert_eq!(loaded.content, content);
391 }
392 }
393
394 #[test]
395 fn test_partial_write_recovery() {
396 let dir = tempfile::tempdir().unwrap();
397 let mut ids = Vec::new();
398 let mut contents = Vec::new();
399 {
400 let engine = StorageEngine::open(dir.path()).unwrap();
401 for i in 0..3 {
402 let content = format!("checkpointed-{i}");
403 let node = MemoryNode::new(
404 AgentId::new(),
405 MemoryType::Semantic,
406 content.clone(),
407 vec![i as f32],
408 );
409 let pid = engine.store_memory(&node).unwrap();
410 ids.push(pid);
411 contents.push(content);
412 }
413 engine.checkpoint().unwrap();
414
415 for i in 3..5 {
416 let content = format!("unckeckpointed-{i}");
417 let node = MemoryNode::new(
418 AgentId::new(),
419 MemoryType::Episodic,
420 content.clone(),
421 vec![i as f32],
422 );
423 let pid = engine.store_memory(&node).unwrap();
424 ids.push(pid);
425 contents.push(content);
426 }
427 engine.wal.lock().sync().unwrap();
429 }
430 {
431 let engine = StorageEngine::open(dir.path()).unwrap();
432 for (pid, expected) in ids.iter().zip(contents.iter()) {
433 let loaded = engine.load_memory(*pid).unwrap();
434 assert_eq!(&loaded.content, expected);
435 }
436 }
437 }
438}