rag_plusplus_core/wal/
writer.rs1use crate::error::{Error, Result};
6use crate::types::{MemoryRecord, RecordId};
7use crate::wal::entry::WalEntry;
8use parking_lot::Mutex;
9use std::fs::{File, OpenOptions};
10use std::io::{BufWriter, Write};
11use std::path::{Path, PathBuf};
12use std::sync::atomic::{AtomicU64, Ordering};
13
14#[derive(Debug, Clone)]
16pub struct WalConfig {
17 pub directory: PathBuf,
19 pub max_file_size: u64,
21 pub sync_on_write: bool,
23 pub buffer_size: usize,
25}
26
27impl Default for WalConfig {
28 fn default() -> Self {
29 Self {
30 directory: PathBuf::from("./wal"),
31 max_file_size: 64 * 1024 * 1024, sync_on_write: true,
33 buffer_size: 64 * 1024, }
35 }
36}
37
38impl WalConfig {
39 #[must_use]
41 pub fn new(directory: impl Into<PathBuf>) -> Self {
42 Self {
43 directory: directory.into(),
44 ..Default::default()
45 }
46 }
47
48 #[must_use]
50 pub const fn with_max_file_size(mut self, size: u64) -> Self {
51 self.max_file_size = size;
52 self
53 }
54
55 #[must_use]
57 pub const fn with_sync_on_write(mut self, sync: bool) -> Self {
58 self.sync_on_write = sync;
59 self
60 }
61}
62
63pub struct WalWriter {
84 config: WalConfig,
85 sequence: AtomicU64,
87 file: Mutex<Option<BufWriter<File>>>,
89 file_size: AtomicU64,
91 file_number: AtomicU64,
93}
94
95impl std::fmt::Debug for WalWriter {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 f.debug_struct("WalWriter")
98 .field("config", &self.config)
99 .field("sequence", &self.sequence)
100 .field("file_size", &self.file_size)
101 .field("file_number", &self.file_number)
102 .finish()
103 }
104}
105
106impl WalWriter {
107 pub fn new(config: WalConfig) -> Result<Self> {
115 std::fs::create_dir_all(&config.directory).map_err(|e| Error::WalWrite {
116 reason: format!("Failed to create WAL directory: {e}"),
117 })?;
118
119 let writer = Self {
120 config,
121 sequence: AtomicU64::new(0),
122 file: Mutex::new(None),
123 file_size: AtomicU64::new(0),
124 file_number: AtomicU64::new(0),
125 };
126
127 writer.recover_sequence()?;
129
130 Ok(writer)
131 }
132
133 fn recover_sequence(&self) -> Result<()> {
135 let mut max_seq = 0u64;
136 let mut max_file = 0u64;
137
138 if let Ok(entries) = std::fs::read_dir(&self.config.directory) {
139 for entry in entries.flatten() {
140 let path = entry.path();
141 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
142 if let Some(num_str) = name.strip_prefix("wal_").and_then(|s| s.strip_suffix(".log")) {
143 if let Ok(num) = num_str.parse::<u64>() {
144 max_file = max_file.max(num);
145
146 if let Ok(last_seq) = self.get_last_sequence(&path) {
148 max_seq = max_seq.max(last_seq);
149 }
150 }
151 }
152 }
153 }
154 }
155
156 self.sequence.store(max_seq, Ordering::SeqCst);
157 self.file_number.store(max_file + 1, Ordering::SeqCst);
158
159 Ok(())
160 }
161
162 fn get_last_sequence(&self, path: &Path) -> Result<u64> {
164 let reader = crate::wal::reader::WalReader::open(path)?;
165 let entries: Vec<_> = reader.collect();
166
167 if let Some(Ok(last)) = entries.last() {
168 Ok(last.sequence)
169 } else {
170 Ok(0)
171 }
172 }
173
174 fn current_file_path(&self) -> PathBuf {
176 let num = self.file_number.load(Ordering::Relaxed);
177 self.config.directory.join(format!("wal_{num:08}.log"))
178 }
179
180 fn ensure_file(&self) -> Result<()> {
182 let mut file_guard = self.file.lock();
183
184 if file_guard.is_none() {
185 let path = self.current_file_path();
186 let file = OpenOptions::new()
187 .create(true)
188 .append(true)
189 .open(&path)
190 .map_err(|e| Error::WalWrite {
191 reason: format!("Failed to open WAL file: {e}"),
192 })?;
193
194 let size = file.metadata().map(|m| m.len()).unwrap_or(0);
195 self.file_size.store(size, Ordering::SeqCst);
196
197 *file_guard = Some(BufWriter::with_capacity(self.config.buffer_size, file));
198 }
199
200 Ok(())
201 }
202
203 fn maybe_rotate(&self) -> Result<()> {
205 let size = self.file_size.load(Ordering::Relaxed);
206
207 if size >= self.config.max_file_size {
208 let mut file_guard = self.file.lock();
209
210 if let Some(mut f) = file_guard.take() {
212 f.flush().map_err(|e| Error::WalWrite {
213 reason: format!("Failed to flush WAL: {e}"),
214 })?;
215 }
216
217 self.file_number.fetch_add(1, Ordering::SeqCst);
219 self.file_size.store(0, Ordering::SeqCst);
220 }
221
222 Ok(())
223 }
224
225 fn write_entry(&self, entry: &WalEntry) -> Result<()> {
227 self.maybe_rotate()?;
228 self.ensure_file()?;
229
230 let bytes = entry.to_bytes();
231 let entry_size = bytes.len() as u64;
232
233 let mut file_guard = self.file.lock();
234 let writer = file_guard.as_mut().ok_or_else(|| Error::WalWrite {
235 reason: "WAL file not open".into(),
236 })?;
237
238 writer
240 .write_all(&(bytes.len() as u32).to_le_bytes())
241 .map_err(|e| Error::WalWrite {
242 reason: format!("Failed to write length: {e}"),
243 })?;
244
245 writer.write_all(&bytes).map_err(|e| Error::WalWrite {
247 reason: format!("Failed to write entry: {e}"),
248 })?;
249
250 if self.config.sync_on_write {
252 writer.flush().map_err(|e| Error::WalWrite {
253 reason: format!("Failed to flush: {e}"),
254 })?;
255
256 writer.get_ref().sync_all().map_err(|e| Error::WalWrite {
258 reason: format!("Failed to sync: {e}"),
259 })?;
260 }
261
262 self.file_size
264 .fetch_add(4 + entry_size, Ordering::Relaxed);
265
266 Ok(())
267 }
268
269 fn next_sequence(&self) -> u64 {
271 self.sequence.fetch_add(1, Ordering::SeqCst) + 1
272 }
273
274 pub fn log_insert(&self, record: &MemoryRecord) -> Result<u64> {
280 let seq = self.next_sequence();
281 let entry = WalEntry::insert(seq, record);
282 self.write_entry(&entry)?;
283 Ok(seq)
284 }
285
286 pub fn log_update_stats(&self, record_id: &RecordId, outcome: f64) -> Result<u64> {
288 let seq = self.next_sequence();
289 let entry = WalEntry::update_stats(seq, record_id, outcome);
290 self.write_entry(&entry)?;
291 Ok(seq)
292 }
293
294 pub fn log_delete(&self, record_id: &RecordId) -> Result<u64> {
296 let seq = self.next_sequence();
297 let entry = WalEntry::delete(seq, record_id);
298 self.write_entry(&entry)?;
299 Ok(seq)
300 }
301
302 pub fn log_checkpoint(&self) -> Result<u64> {
304 let seq = self.next_sequence();
305 let entry = WalEntry::checkpoint(seq);
306 self.write_entry(&entry)?;
307 Ok(seq)
308 }
309
310 #[must_use]
312 pub fn sequence(&self) -> u64 {
313 self.sequence.load(Ordering::SeqCst)
314 }
315
316 pub fn flush(&self) -> Result<()> {
318 let mut file_guard = self.file.lock();
319 if let Some(writer) = file_guard.as_mut() {
320 writer.flush().map_err(|e| Error::WalWrite {
321 reason: format!("Failed to flush: {e}"),
322 })?;
323 writer.get_ref().sync_all().map_err(|e| Error::WalWrite {
324 reason: format!("Failed to sync: {e}"),
325 })?;
326 }
327 Ok(())
328 }
329
330 pub fn close(&self) -> Result<()> {
332 self.flush()?;
333 let mut file_guard = self.file.lock();
334 *file_guard = None;
335 Ok(())
336 }
337
338 #[must_use]
340 pub fn directory(&self) -> &Path {
341 &self.config.directory
342 }
343
344 pub fn list_files(&self) -> Result<Vec<PathBuf>> {
346 let mut files = Vec::new();
347
348 if let Ok(entries) = std::fs::read_dir(&self.config.directory) {
349 for entry in entries.flatten() {
350 let path = entry.path();
351 if path.extension().map_or(false, |e| e == "log") {
352 files.push(path);
353 }
354 }
355 }
356
357 files.sort();
358 Ok(files)
359 }
360
361 pub fn truncate_before(&self, checkpoint_seq: u64) -> Result<()> {
365 let files = self.list_files()?;
366
367 for file_path in files {
368 let reader = crate::wal::reader::WalReader::open(&file_path)?;
370 let entries: Vec<_> = reader.collect();
371
372 if let Some(Ok(last)) = entries.last() {
373 if last.sequence < checkpoint_seq {
374 std::fs::remove_file(&file_path).map_err(|e| Error::WalWrite {
376 reason: format!("Failed to remove old WAL file: {e}"),
377 })?;
378 }
379 }
380 }
381
382 Ok(())
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use crate::stats::OutcomeStats;
390 use crate::types::RecordStatus;
391 use tempfile::TempDir;
392
393 fn create_test_record(id: &str) -> MemoryRecord {
394 MemoryRecord {
395 id: id.into(),
396 embedding: vec![1.0, 2.0, 3.0],
397 context: format!("Context for {id}"),
398 outcome: 0.5,
399 metadata: Default::default(),
400 created_at: 1234567890,
401 status: RecordStatus::Active,
402 stats: OutcomeStats::new(1),
403 }
404 }
405
406 #[test]
407 fn test_wal_writer_creation() {
408 let temp_dir = TempDir::new().unwrap();
409 let config = WalConfig::new(temp_dir.path());
410 let writer = WalWriter::new(config).unwrap();
411
412 assert_eq!(writer.sequence(), 0);
413 }
414
415 #[test]
416 fn test_log_insert() {
417 let temp_dir = TempDir::new().unwrap();
418 let config = WalConfig::new(temp_dir.path());
419 let writer = WalWriter::new(config).unwrap();
420
421 let record = create_test_record("test-1");
422 let seq = writer.log_insert(&record).unwrap();
423
424 assert_eq!(seq, 1);
425 assert_eq!(writer.sequence(), 1);
426 }
427
428 #[test]
429 fn test_log_multiple_operations() {
430 let temp_dir = TempDir::new().unwrap();
431 let config = WalConfig::new(temp_dir.path());
432 let writer = WalWriter::new(config).unwrap();
433
434 writer.log_insert(&create_test_record("rec-1")).unwrap();
435 writer.log_insert(&create_test_record("rec-2")).unwrap();
436 writer.log_update_stats(&"rec-1".into(), 0.8).unwrap();
437 writer.log_delete(&"rec-2".into()).unwrap();
438
439 assert_eq!(writer.sequence(), 4);
440 }
441
442 #[test]
443 fn test_wal_file_creation() {
444 let temp_dir = TempDir::new().unwrap();
445 let config = WalConfig::new(temp_dir.path());
446 let writer = WalWriter::new(config).unwrap();
447
448 writer.log_insert(&create_test_record("test")).unwrap();
449 writer.flush().unwrap();
450
451 let files = writer.list_files().unwrap();
452 assert_eq!(files.len(), 1);
453 }
454
455 #[test]
456 fn test_sequence_recovery() {
457 let temp_dir = TempDir::new().unwrap();
458
459 {
461 let config = WalConfig::new(temp_dir.path());
462 let writer = WalWriter::new(config).unwrap();
463 writer.log_insert(&create_test_record("rec-1")).unwrap();
464 writer.log_insert(&create_test_record("rec-2")).unwrap();
465 writer.log_insert(&create_test_record("rec-3")).unwrap();
466 writer.flush().unwrap();
467 }
468
469 {
471 let config = WalConfig::new(temp_dir.path());
472 let writer = WalWriter::new(config).unwrap();
473 assert_eq!(writer.sequence(), 3);
474
475 let seq = writer.log_insert(&create_test_record("rec-4")).unwrap();
477 assert_eq!(seq, 4);
478 }
479 }
480
481 #[test]
482 fn test_file_rotation() {
483 let temp_dir = TempDir::new().unwrap();
484 let config = WalConfig::new(temp_dir.path())
485 .with_max_file_size(1024); let writer = WalWriter::new(config).unwrap();
488
489 for i in 0..50 {
491 writer
492 .log_insert(&create_test_record(&format!("rec-{i}")))
493 .unwrap();
494 }
495 writer.flush().unwrap();
496
497 let files = writer.list_files().unwrap();
498 assert!(files.len() > 1, "Expected multiple WAL files after rotation");
499 }
500
501 #[test]
502 fn test_checkpoint_and_truncate() {
503 let temp_dir = TempDir::new().unwrap();
504 let config = WalConfig::new(temp_dir.path())
505 .with_max_file_size(512); let writer = WalWriter::new(config).unwrap();
508
509 for i in 0..20 {
511 writer
512 .log_insert(&create_test_record(&format!("rec-{i}")))
513 .unwrap();
514 }
515
516 let checkpoint_seq = writer.log_checkpoint().unwrap();
518 writer.flush().unwrap();
519
520 let files_before = writer.list_files().unwrap().len();
522
523 writer.truncate_before(checkpoint_seq).unwrap();
525
526 let files_after = writer.list_files().unwrap().len();
528 assert!(files_after <= files_before);
529 }
530}