1use anyhow::{anyhow, Result};
36use bincode::{Decode, Encode};
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use std::fs::{File, OpenOptions};
40use std::io::{BufReader, BufWriter, Read, Write};
41use std::path::PathBuf;
42use std::sync::{Arc, Mutex};
43use std::time::{SystemTime, UNIX_EPOCH};
44
45const WAL_MAGIC: &[u8; 4] = b"WALV"; const WAL_VERSION: u32 = 1;
50
51#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
53pub enum WalEntry {
54 Insert {
56 id: String,
57 vector: Vec<f32>,
58 metadata: Option<HashMap<String, String>>,
59 timestamp: u64,
60 },
61 Update {
63 id: String,
64 vector: Vec<f32>,
65 metadata: Option<HashMap<String, String>>,
66 timestamp: u64,
67 },
68 Delete { id: String, timestamp: u64 },
70 Batch {
72 entries: Vec<WalEntry>,
73 timestamp: u64,
74 },
75 Checkpoint {
77 sequence_number: u64,
78 timestamp: u64,
79 },
80 BeginTransaction { transaction_id: u64, timestamp: u64 },
82 CommitTransaction { transaction_id: u64, timestamp: u64 },
84 AbortTransaction { transaction_id: u64, timestamp: u64 },
86}
87
88impl WalEntry {
89 pub fn timestamp(&self) -> u64 {
91 match self {
92 WalEntry::Insert { timestamp, .. }
93 | WalEntry::Update { timestamp, .. }
94 | WalEntry::Delete { timestamp, .. }
95 | WalEntry::Batch { timestamp, .. }
96 | WalEntry::Checkpoint { timestamp, .. }
97 | WalEntry::BeginTransaction { timestamp, .. }
98 | WalEntry::CommitTransaction { timestamp, .. }
99 | WalEntry::AbortTransaction { timestamp, .. } => *timestamp,
100 }
101 }
102
103 pub fn is_checkpoint(&self) -> bool {
105 matches!(self, WalEntry::Checkpoint { .. })
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct WalConfig {
112 pub wal_directory: PathBuf,
114 pub max_file_size: u64,
116 pub sync_on_write: bool,
118 pub checkpoint_interval: u64,
120 pub checkpoint_retention: usize,
122 pub buffer_size: usize,
124}
125
126impl Default for WalConfig {
127 fn default() -> Self {
128 Self {
129 wal_directory: PathBuf::from("./wal"),
130 max_file_size: 100 * 1024 * 1024, sync_on_write: false, checkpoint_interval: 10000,
133 checkpoint_retention: 3,
134 buffer_size: 64 * 1024, }
136 }
137}
138
139pub struct WalManager {
141 config: WalConfig,
142 current_file: Arc<Mutex<Option<BufWriter<File>>>>,
143 current_file_path: Arc<Mutex<PathBuf>>,
144 sequence_number: Arc<Mutex<u64>>,
145 last_checkpoint: Arc<Mutex<u64>>,
146}
147
148impl WalManager {
149 pub fn new(config: WalConfig) -> Result<Self> {
151 std::fs::create_dir_all(&config.wal_directory)?;
153
154 let manager = Self {
155 config,
156 current_file: Arc::new(Mutex::new(None)),
157 current_file_path: Arc::new(Mutex::new(PathBuf::new())),
158 sequence_number: Arc::new(Mutex::new(0)),
159 last_checkpoint: Arc::new(Mutex::new(0)),
160 };
161
162 manager.rotate_wal_file()?;
164
165 Ok(manager)
166 }
167
168 pub fn append(&self, entry: WalEntry) -> Result<u64> {
170 let seq = {
171 let mut seq_guard = self.sequence_number.lock().unwrap();
172 let seq = *seq_guard;
173 *seq_guard += 1;
174 seq
175 };
176
177 let needs_checkpoint = {
179 let mut file_guard = self.current_file.lock().unwrap();
180
181 if let Some(ref mut writer) = *file_guard {
182 let entry_bytes = bincode::encode_to_vec(&entry, bincode::config::standard())
184 .map_err(|e| anyhow!("Failed to serialize WAL entry: {}", e))?;
185 let entry_len = entry_bytes.len() as u32;
186
187 writer.write_all(&seq.to_le_bytes())?;
189 writer.write_all(&entry_len.to_le_bytes())?;
190 writer.write_all(&entry_bytes)?;
191
192 if self.config.sync_on_write {
193 writer.flush()?;
194 writer.get_ref().sync_all()?;
195 }
196
197 let needs_rotation = if let Ok(metadata) = writer.get_ref().metadata() {
199 metadata.len() >= self.config.max_file_size
200 } else {
201 false
202 };
203
204 if needs_rotation {
205 drop(file_guard);
206 self.rotate_wal_file()?;
207 }
208
209 let last_checkpoint = *self.last_checkpoint.lock().unwrap();
211 seq - last_checkpoint >= self.config.checkpoint_interval
212 } else {
213 return Err(anyhow!("WAL file not open"));
214 }
215 };
216
217 if needs_checkpoint {
219 self.checkpoint(seq)?;
220 }
221
222 Ok(seq)
223 }
224
225 pub fn checkpoint(&self, sequence_number: u64) -> Result<()> {
227 tracing::info!("Creating WAL checkpoint at sequence {}", sequence_number);
228
229 let timestamp = SystemTime::now()
230 .duration_since(UNIX_EPOCH)
231 .unwrap()
232 .as_secs();
233
234 let checkpoint_entry = WalEntry::Checkpoint {
235 sequence_number,
236 timestamp,
237 };
238
239 let seq = {
241 let mut seq_guard = self.sequence_number.lock().unwrap();
242 let seq = *seq_guard;
243 *seq_guard += 1;
244 seq
245 };
246
247 {
248 let mut file_guard = self.current_file.lock().unwrap();
249 if let Some(ref mut writer) = *file_guard {
250 let entry_bytes =
251 bincode::encode_to_vec(&checkpoint_entry, bincode::config::standard())
252 .map_err(|e| anyhow!("Failed to serialize checkpoint entry: {}", e))?;
253 let entry_len = entry_bytes.len() as u32;
254
255 writer.write_all(&seq.to_le_bytes())?;
256 writer.write_all(&entry_len.to_le_bytes())?;
257 writer.write_all(&entry_bytes)?;
258
259 if self.config.sync_on_write {
260 writer.flush()?;
261 writer.get_ref().sync_all()?;
262 }
263 }
264 }
265
266 let mut last_checkpoint = self.last_checkpoint.lock().unwrap();
267 *last_checkpoint = sequence_number;
268
269 self.cleanup_old_files()?;
271
272 Ok(())
273 }
274
275 fn rotate_wal_file(&self) -> Result<()> {
277 let timestamp = SystemTime::now()
278 .duration_since(UNIX_EPOCH)
279 .unwrap()
280 .as_secs();
281
282 let filename = format!("wal-{:016x}.log", timestamp);
283 let filepath = self.config.wal_directory.join(&filename);
284
285 tracing::info!("Rotating WAL to new file: {:?}", filepath);
286
287 let file = OpenOptions::new()
288 .create(true)
289 .append(true)
290 .open(&filepath)?;
291
292 let mut writer = BufWriter::with_capacity(self.config.buffer_size, file);
293
294 writer.write_all(WAL_MAGIC)?;
296 writer.write_all(&WAL_VERSION.to_le_bytes())?;
297 writer.write_all(×tamp.to_le_bytes())?;
298
299 if self.config.sync_on_write {
300 writer.flush()?;
301 writer.get_ref().sync_all()?;
302 }
303
304 let mut file_guard = self.current_file.lock().unwrap();
305 let mut path_guard = self.current_file_path.lock().unwrap();
306
307 if let Some(mut old_writer) = file_guard.take() {
309 old_writer.flush()?;
310 }
311
312 *file_guard = Some(writer);
313 *path_guard = filepath;
314
315 Ok(())
316 }
317
318 fn cleanup_old_files(&self) -> Result<()> {
320 let mut wal_files: Vec<_> = std::fs::read_dir(&self.config.wal_directory)?
321 .filter_map(|entry| entry.ok())
322 .filter(|entry| {
323 entry
324 .file_name()
325 .to_str()
326 .map(|s| s.starts_with("wal-") && s.ends_with(".log"))
327 .unwrap_or(false)
328 })
329 .collect();
330
331 wal_files.sort_by_key(|entry| entry.file_name());
333
334 if wal_files.len() > self.config.checkpoint_retention {
336 let to_remove = wal_files.len() - self.config.checkpoint_retention;
337 for entry in wal_files.iter().take(to_remove) {
338 tracing::info!("Removing old WAL file: {:?}", entry.path());
339 std::fs::remove_file(entry.path())?;
340 }
341 }
342
343 Ok(())
344 }
345
346 pub fn recover(&self) -> Result<Vec<WalEntry>> {
348 tracing::info!("Starting WAL recovery");
349
350 let mut all_entries = Vec::new();
351 let mut last_checkpoint_seq = 0u64;
352
353 let mut wal_files: Vec<_> = std::fs::read_dir(&self.config.wal_directory)?
355 .filter_map(|entry| entry.ok())
356 .filter(|entry| {
357 entry
358 .file_name()
359 .to_str()
360 .map(|s| s.starts_with("wal-") && s.ends_with(".log"))
361 .unwrap_or(false)
362 })
363 .collect();
364
365 wal_files.sort_by_key(|entry| entry.file_name());
367
368 for entry in wal_files {
370 let path = entry.path();
371 tracing::debug!("Reading WAL file: {:?}", path);
372
373 let file = File::open(&path)?;
374 let mut reader = BufReader::new(file);
375
376 let mut magic = [0u8; 4];
378 reader.read_exact(&mut magic)?;
379 if &magic != WAL_MAGIC {
380 tracing::warn!("Invalid WAL file magic number: {:?}", path);
381 continue;
382 }
383
384 let mut version_bytes = [0u8; 4];
386 reader.read_exact(&mut version_bytes)?;
387 let version = u32::from_le_bytes(version_bytes);
388 if version != WAL_VERSION {
389 tracing::warn!("Unsupported WAL version {} in {:?}", version, path);
390 continue;
391 }
392
393 let mut timestamp_bytes = [0u8; 8];
395 reader.read_exact(&mut timestamp_bytes)?;
396
397 loop {
399 let mut seq_bytes = [0u8; 8];
401 match reader.read_exact(&mut seq_bytes) {
402 Ok(_) => {}
403 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
404 tracing::debug!("Reached end of WAL file (expected)");
405 break;
406 }
407 Err(e) => return Err(e.into()),
408 }
409 let seq = u64::from_le_bytes(seq_bytes);
410
411 let mut len_bytes = [0u8; 4];
413 match reader.read_exact(&mut len_bytes) {
414 Ok(_) => {}
415 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
416 tracing::warn!(
417 "Incomplete entry at sequence {}: missing length field. Skipping rest of file.",
418 seq
419 );
420 break;
421 }
422 Err(e) => return Err(e.into()),
423 }
424 let len = u32::from_le_bytes(len_bytes);
425
426 if len > 100_000_000 {
428 tracing::warn!(
430 "Entry at sequence {} has suspicious length {}. Possibly corrupted. Skipping.",
431 seq,
432 len
433 );
434 break;
435 }
436
437 let mut entry_bytes = vec![0u8; len as usize];
439 match reader.read_exact(&mut entry_bytes) {
440 Ok(_) => {}
441 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
442 tracing::warn!(
443 "Incomplete entry at sequence {}: expected {} bytes but reached EOF. Skipping rest of file.",
444 seq,
445 len
446 );
447 break;
448 }
449 Err(e) => return Err(e.into()),
450 }
451
452 let entry: WalEntry =
454 match bincode::decode_from_slice(&entry_bytes, bincode::config::standard()) {
455 Ok((e, _)) => e,
456 Err(e) => {
457 tracing::warn!(
458 "Failed to deserialize entry at sequence {}: {}. Skipping entry.",
459 seq,
460 e
461 );
462 continue; }
464 };
465
466 if let WalEntry::Checkpoint {
468 sequence_number, ..
469 } = &entry
470 {
471 last_checkpoint_seq = *sequence_number;
472 }
473
474 all_entries.push((seq, entry));
475 }
476 }
477
478 let recovered_entries: Vec<_> = all_entries
482 .iter()
483 .filter(|(seq, _)| {
484 if last_checkpoint_seq == 0 {
485 true } else {
487 *seq > last_checkpoint_seq }
489 })
490 .map(|(_, entry)| entry.clone())
491 .collect();
492
493 tracing::info!(
494 "Recovered {} entries from WAL (after checkpoint {})",
495 recovered_entries.len(),
496 last_checkpoint_seq
497 );
498
499 if let Some((max_seq, _)) = all_entries.iter().max_by_key(|(seq, _)| seq) {
501 let mut seq = self.sequence_number.lock().unwrap();
502 *seq = max_seq + 1;
503 }
504
505 Ok(recovered_entries)
506 }
507
508 pub fn flush(&self) -> Result<()> {
510 let mut file_guard = self.current_file.lock().unwrap();
511 if let Some(ref mut writer) = *file_guard {
512 writer.flush()?;
513 writer.get_ref().sync_all()?;
514 }
515 Ok(())
516 }
517
518 pub fn current_sequence(&self) -> u64 {
520 *self.sequence_number.lock().unwrap()
521 }
522
523 pub fn last_checkpoint_sequence(&self) -> u64 {
525 *self.last_checkpoint.lock().unwrap()
526 }
527}
528
529impl Drop for WalManager {
530 fn drop(&mut self) {
531 let _ = self.flush();
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539 use tempfile::TempDir;
540
541 #[test]
542 fn test_wal_creation() {
543 let temp_dir = TempDir::new().unwrap();
544 let config = WalConfig {
545 wal_directory: temp_dir.path().to_path_buf(),
546 ..Default::default()
547 };
548
549 let wal = WalManager::new(config).unwrap();
550 assert_eq!(wal.current_sequence(), 0);
551 }
552
553 #[test]
554 fn test_wal_append() {
555 let temp_dir = TempDir::new().unwrap();
556 let config = WalConfig {
557 wal_directory: temp_dir.path().to_path_buf(),
558 sync_on_write: true,
559 ..Default::default()
560 };
561
562 let wal = WalManager::new(config).unwrap();
563
564 let entry = WalEntry::Insert {
565 id: "vec1".to_string(),
566 vector: vec![1.0, 2.0, 3.0],
567 metadata: None,
568 timestamp: 12345,
569 };
570
571 let seq = wal.append(entry).unwrap();
572 assert_eq!(seq, 0);
573 }
574
575 #[test]
576 fn test_wal_recovery() {
577 let temp_dir = TempDir::new().unwrap();
578 let config = WalConfig {
579 wal_directory: temp_dir.path().to_path_buf(),
580 sync_on_write: true,
581 checkpoint_interval: 100,
582 ..Default::default()
583 };
584
585 {
587 let wal = WalManager::new(config.clone()).unwrap();
588
589 for i in 0..5 {
590 let entry = WalEntry::Insert {
591 id: format!("vec{}", i),
592 vector: vec![i as f32, (i * 2) as f32],
593 metadata: None,
594 timestamp: (i + 1) * 1000, };
596 wal.append(entry).unwrap();
597 }
598
599 wal.flush().unwrap();
600 drop(wal);
602 }
603
604 std::thread::sleep(std::time::Duration::from_millis(100));
606
607 {
609 let wal = WalManager::new(config).unwrap();
610 let recovered = wal.recover().unwrap();
611
612 assert_eq!(
614 recovered.len(),
615 5,
616 "Expected exactly 5 entries, got {}",
617 recovered.len()
618 );
619
620 let timestamps: Vec<u64> = recovered.iter().map(|e| e.timestamp()).collect();
622 assert_eq!(timestamps, vec![1000, 2000, 3000, 4000, 5000]);
623 }
624 }
625
626 #[test]
627 fn test_wal_checkpoint() {
628 let temp_dir = TempDir::new().unwrap();
629 let config = WalConfig {
630 wal_directory: temp_dir.path().to_path_buf(),
631 sync_on_write: true,
632 checkpoint_interval: 3,
633 ..Default::default()
634 };
635
636 let wal = WalManager::new(config).unwrap();
637
638 for i in 0..5 {
640 let entry = WalEntry::Insert {
641 id: format!("vec{}", i),
642 vector: vec![i as f32],
643 metadata: None,
644 timestamp: i,
645 };
646 wal.append(entry).unwrap();
647 }
648
649 assert!(wal.last_checkpoint_sequence() > 0);
650 }
651
652 #[test]
653 fn test_wal_batch_operation() {
654 let temp_dir = TempDir::new().unwrap();
655 let config = WalConfig {
656 wal_directory: temp_dir.path().to_path_buf(),
657 ..Default::default()
658 };
659
660 let wal = WalManager::new(config).unwrap();
661
662 let batch = WalEntry::Batch {
663 entries: vec![
664 WalEntry::Insert {
665 id: "vec1".to_string(),
666 vector: vec![1.0],
667 metadata: None,
668 timestamp: 1,
669 },
670 WalEntry::Update {
671 id: "vec2".to_string(),
672 vector: vec![2.0],
673 metadata: None,
674 timestamp: 2,
675 },
676 ],
677 timestamp: 3,
678 };
679
680 wal.append(batch).unwrap();
681 wal.flush().unwrap();
682 }
683
684 #[test]
685 fn test_wal_transaction() {
686 let temp_dir = TempDir::new().unwrap();
687 let config = WalConfig {
688 wal_directory: temp_dir.path().to_path_buf(),
689 ..Default::default()
690 };
691
692 let wal = WalManager::new(config).unwrap();
693
694 wal.append(WalEntry::BeginTransaction {
696 transaction_id: 1,
697 timestamp: 100,
698 })
699 .unwrap();
700
701 wal.append(WalEntry::Insert {
703 id: "vec1".to_string(),
704 vector: vec![1.0],
705 metadata: None,
706 timestamp: 101,
707 })
708 .unwrap();
709
710 wal.append(WalEntry::CommitTransaction {
712 transaction_id: 1,
713 timestamp: 102,
714 })
715 .unwrap();
716
717 wal.flush().unwrap();
718 }
719}