1use std::fs::File;
4use std::path::Path;
5
6use mentedb_core::MemoryNode;
7use mentedb_core::error::{MenteError, MenteResult};
8
9use fs2::FileExt;
10use parking_lot::Mutex;
11use tracing::info;
12
13use crate::buffer::BufferPool;
14use crate::page::{PAGE_DATA_SIZE, Page, PageId, PageManager, PageType};
15use crate::wal::{Wal, WalEntryType};
16const DEFAULT_BUFFER_POOL_SIZE: usize = 1024;
18
19pub struct StorageEngine {
27 page_manager: Mutex<PageManager>,
28 buffer_pool: BufferPool,
29 wal: Mutex<Wal>,
30 _lock_file: File,
32}
33
34impl StorageEngine {
35 pub fn open(path: &Path) -> MenteResult<Self> {
40 std::fs::create_dir_all(path)?;
41
42 let lock_path = path.join("mentedb.lock");
44 let lock_file = File::create(&lock_path)
45 .map_err(|e| MenteError::Storage(format!("failed to create lock file: {e}")))?;
46 lock_file.try_lock_exclusive().map_err(|_| {
47 MenteError::Storage(
48 "Database is locked by another process. Only one instance can access the database at a time.".to_string()
49 )
50 })?;
51
52 let page_manager = PageManager::open(path)?;
53 let buffer_pool = BufferPool::new(DEFAULT_BUFFER_POOL_SIZE);
54 let wal = Wal::open(path)?;
55
56 let engine = Self {
57 page_manager: Mutex::new(page_manager),
58 buffer_pool,
59 wal: Mutex::new(wal),
60 _lock_file: lock_file,
61 };
62
63 let recovered = engine.recover()?;
64 if recovered > 0 {
65 info!(recovered, ?path, "storage engine opened with WAL recovery");
66 } else {
67 info!(?path, "storage engine opened");
68 }
69
70 Ok(engine)
71 }
72
73 pub fn recover(&self) -> MenteResult<usize> {
78 let mut wal = self.wal.lock();
79 let entries = wal.iterate()?;
80 let mut count = 0usize;
81 let mut pm = self.page_manager.lock();
82
83 for entry in &entries {
84 match entry.entry_type {
85 WalEntryType::PageWrite => {
86 let page_id = PageId(entry.page_id);
87
88 while pm.page_count() <= entry.page_id {
89 pm.allocate_page()?;
90 }
91
92 let mut page = pm.read_page(page_id)?;
93 let copy_len = entry.data.len().min(PAGE_DATA_SIZE);
94 page.data[..copy_len].copy_from_slice(&entry.data[..copy_len]);
95 if copy_len < PAGE_DATA_SIZE {
96 page.data[copy_len..].fill(0);
97 }
98 page.header.page_id = entry.page_id;
99 page.header.lsn = entry.lsn;
100 page.header.page_type = PageType::Data as u8;
101 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
102 page.header.checksum = page.compute_checksum();
103
104 pm.write_page(page_id, &page)?;
105 count += 1;
106 }
107 WalEntryType::Checkpoint | WalEntryType::Commit => {}
108 }
109 }
110
111 if count > 0 {
112 pm.sync()?;
113 let next_lsn = wal.next_lsn();
114 wal.truncate(next_lsn)?;
115 info!(count, "WAL recovery replayed entries");
116 }
117
118 Ok(count)
119 }
120
121 pub fn close(&self) -> MenteResult<()> {
123 let mut pm = self.page_manager.lock();
124 self.buffer_pool.flush_all(&mut pm)?;
125 pm.sync()?;
126 self.wal.lock().sync()?;
127 info!("storage engine closed");
128 Ok(())
129 }
130
131 pub fn allocate_page(&self) -> MenteResult<PageId> {
135 self.page_manager.lock().allocate_page()
136 }
137
138 pub fn read_page(&self, page_id: PageId) -> MenteResult<Box<Page>> {
140 self.buffer_pool
141 .fetch_page(page_id, &mut self.page_manager.lock())
142 }
143
144 pub fn write_page(&self, page_id: PageId, data: &[u8]) -> MenteResult<()> {
146 let lsn = {
147 let mut wal = self.wal.lock();
148 let lsn = wal.append(WalEntryType::PageWrite, page_id.0, data)?;
149 wal.sync()?;
150 lsn
151 };
152
153 let mut pm = self.page_manager.lock();
154 let mut page = self.buffer_pool.fetch_page(page_id, &mut pm)?;
155 drop(pm); let copy_len = data.len().min(PAGE_DATA_SIZE);
158 page.data[..copy_len].copy_from_slice(&data[..copy_len]);
159 if copy_len < PAGE_DATA_SIZE {
160 page.data[copy_len..].fill(0);
161 }
162 page.header.lsn = lsn;
163 page.header.page_type = PageType::Data as u8;
164 page.header.free_space = (PAGE_DATA_SIZE - copy_len) as u16;
165 page.header.checksum = page.compute_checksum();
166
167 if self.buffer_pool.update_page(page_id, &page).is_err() {
168 self.page_manager.lock().write_page(page_id, &page)?;
169 }
170 self.buffer_pool.unpin_page(page_id, true).ok();
171
172 Ok(())
173 }
174
175 pub fn store_memory(&self, node: &MemoryNode) -> MenteResult<PageId> {
181 let serialized =
182 serde_json::to_vec(node).map_err(|e| MenteError::Serialization(e.to_string()))?;
183
184 if serialized.len() + 4 > PAGE_DATA_SIZE {
185 return Err(MenteError::CapacityExceeded(format!(
186 "memory node serialized to {} bytes (max {})",
187 serialized.len(),
188 PAGE_DATA_SIZE - 4,
189 )));
190 }
191
192 let page_id = self.allocate_page()?;
193
194 let mut buf = Vec::with_capacity(4 + serialized.len());
195 buf.extend_from_slice(&(serialized.len() as u32).to_le_bytes());
196 buf.extend_from_slice(&serialized);
197
198 self.write_page(page_id, &buf)?;
199
200 info!(
201 page_id = page_id.0,
202 bytes = serialized.len(),
203 "stored memory node"
204 );
205 Ok(page_id)
206 }
207
208 pub fn load_memory(&self, page_id: PageId) -> MenteResult<MemoryNode> {
210 let page = self.read_page(page_id)?;
211 self.buffer_pool.unpin_page(page_id, false).ok();
212
213 let len = u32::from_le_bytes(page.data[..4].try_into().unwrap()) as usize;
214 if len == 0 || len + 4 > PAGE_DATA_SIZE {
215 return Err(MenteError::Storage(format!(
216 "invalid memory node length prefix: {len}"
217 )));
218 }
219
220 serde_json::from_slice(&page.data[4..4 + len])
221 .map_err(|e| MenteError::Serialization(e.to_string()))
222 }
223
224 pub fn checkpoint(&self) -> MenteResult<()> {
228 let mut pm = self.page_manager.lock();
229 self.buffer_pool.flush_all(&mut pm)?;
230 pm.sync()?;
231
232 let mut wal = self.wal.lock();
233 let lsn = wal.append(WalEntryType::Checkpoint, 0, &[])?;
234 wal.sync()?;
235 wal.truncate(lsn)?;
236
237 info!(lsn, "checkpoint complete");
238 Ok(())
239 }
240
241 pub fn scan_all_memories(&self) -> Vec<(mentedb_core::types::MemoryId, PageId)> {
245 let count = self.page_manager.lock().page_count();
246 let mut results = Vec::new();
247 for i in 1..count {
248 let page_id = PageId(i);
249 if let Ok(node) = self.load_memory(page_id) {
250 results.push((node.id, page_id));
251 }
252 }
253 results
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use mentedb_core::memory::MemoryType;
261 use mentedb_core::types::AgentId;
262
263 fn setup() -> (tempfile::TempDir, StorageEngine) {
264 let dir = tempfile::tempdir().unwrap();
265 let engine = StorageEngine::open(dir.path()).unwrap();
266 (dir, engine)
267 }
268
269 #[test]
270 fn test_allocate_write_read() {
271 let (_dir, engine) = setup();
272
273 let pid = engine.allocate_page().unwrap();
274 engine.write_page(pid, b"hello storage engine").unwrap();
275
276 let page = engine.read_page(pid).unwrap();
277 assert_eq!(&page.data[..20], b"hello storage engine");
278 engine.buffer_pool.unpin_page(pid, false).ok();
279 }
280
281 #[test]
282 fn test_store_and_load_memory() {
283 let (_dir, engine) = setup();
284
285 let node = MemoryNode::new(
286 AgentId::new(),
287 MemoryType::Episodic,
288 "The user prefers Rust over Go".to_string(),
289 vec![0.1, 0.2, 0.3, 0.4],
290 );
291
292 let page_id = engine.store_memory(&node).unwrap();
293 let loaded = engine.load_memory(page_id).unwrap();
294
295 assert_eq!(node.id, loaded.id);
296 assert_eq!(node.content, loaded.content);
297 assert_eq!(node.embedding, loaded.embedding);
298 assert_eq!(node.memory_type, loaded.memory_type);
299 }
300
301 #[test]
302 fn test_checkpoint() {
303 let (_dir, engine) = setup();
304
305 let node = MemoryNode::new(
306 AgentId::new(),
307 MemoryType::Semantic,
308 "checkpoint test".to_string(),
309 vec![1.0, 2.0],
310 );
311
312 let pid = engine.store_memory(&node).unwrap();
313 engine.checkpoint().unwrap();
314
315 let loaded = engine.load_memory(pid).unwrap();
316 assert_eq!(loaded.content, "checkpoint test");
317 }
318
319 #[test]
320 fn test_close_and_reopen() {
321 let dir = tempfile::tempdir().unwrap();
322 let pid;
323 {
324 let engine = StorageEngine::open(dir.path()).unwrap();
325 let node = MemoryNode::new(
326 AgentId::new(),
327 MemoryType::Procedural,
328 "persist across close".to_string(),
329 vec![0.5],
330 );
331 pid = engine.store_memory(&node).unwrap();
332 engine.close().unwrap();
333 }
334 {
335 let engine = StorageEngine::open(dir.path()).unwrap();
336 let loaded = engine.load_memory(pid).unwrap();
337 assert_eq!(loaded.content, "persist across close");
338 }
339 }
340
341 #[test]
342 fn test_crash_recovery() {
343 let dir = tempfile::tempdir().unwrap();
344 let mut ids = Vec::new();
345 let mut contents = Vec::new();
346 {
347 let engine = StorageEngine::open(dir.path()).unwrap();
348 for i in 0..3 {
349 let content = format!("crash-recovery-{i}");
350 let node = MemoryNode::new(
351 AgentId::new(),
352 MemoryType::Episodic,
353 content.clone(),
354 vec![i as f32],
355 );
356 let pid = engine.store_memory(&node).unwrap();
357 ids.push(pid);
358 contents.push(content);
359 }
360 engine.wal.lock().sync().unwrap();
362 }
363 {
364 let engine = StorageEngine::open(dir.path()).unwrap();
365 for (pid, expected) in ids.iter().zip(contents.iter()) {
366 let loaded = engine.load_memory(*pid).unwrap();
367 assert_eq!(&loaded.content, expected);
368 }
369 }
370 }
371
372 #[test]
373 fn test_recovery_idempotent() {
374 let dir = tempfile::tempdir().unwrap();
375 let pid;
376 let content = "idempotent-check".to_string();
377 {
378 let engine = StorageEngine::open(dir.path()).unwrap();
379 let node = MemoryNode::new(
380 AgentId::new(),
381 MemoryType::Semantic,
382 content.clone(),
383 vec![1.0, 2.0],
384 );
385 pid = engine.store_memory(&node).unwrap();
386 engine.checkpoint().unwrap();
387 engine.close().unwrap();
388 }
389 {
390 let engine = StorageEngine::open(dir.path()).unwrap();
391 let loaded = engine.load_memory(pid).unwrap();
392 assert_eq!(loaded.content, content);
393 }
394 }
395
396 #[test]
397 fn test_partial_write_recovery() {
398 let dir = tempfile::tempdir().unwrap();
399 let mut ids = Vec::new();
400 let mut contents = Vec::new();
401 {
402 let engine = StorageEngine::open(dir.path()).unwrap();
403 for i in 0..3 {
404 let content = format!("checkpointed-{i}");
405 let node = MemoryNode::new(
406 AgentId::new(),
407 MemoryType::Semantic,
408 content.clone(),
409 vec![i as f32],
410 );
411 let pid = engine.store_memory(&node).unwrap();
412 ids.push(pid);
413 contents.push(content);
414 }
415 engine.checkpoint().unwrap();
416
417 for i in 3..5 {
418 let content = format!("unckeckpointed-{i}");
419 let node = MemoryNode::new(
420 AgentId::new(),
421 MemoryType::Episodic,
422 content.clone(),
423 vec![i as f32],
424 );
425 let pid = engine.store_memory(&node).unwrap();
426 ids.push(pid);
427 contents.push(content);
428 }
429 engine.wal.lock().sync().unwrap();
431 }
432 {
433 let engine = StorageEngine::open(dir.path()).unwrap();
434 for (pid, expected) in ids.iter().zip(contents.iter()) {
435 let loaded = engine.load_memory(*pid).unwrap();
436 assert_eq!(&loaded.content, expected);
437 }
438 }
439 }
440}