use crate::error::{Error, Result};
use crate::types::{MemoryRecord, RecordId};
use crate::wal::entry::WalEntry;
use parking_lot::Mutex;
use std::fs::{File, OpenOptions};
use std::io::{BufWriter, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone)]
pub struct WalConfig {
pub directory: PathBuf,
pub max_file_size: u64,
pub sync_on_write: bool,
pub buffer_size: usize,
}
impl Default for WalConfig {
fn default() -> Self {
Self {
directory: PathBuf::from("./wal"),
max_file_size: 64 * 1024 * 1024, sync_on_write: true,
buffer_size: 64 * 1024, }
}
}
impl WalConfig {
#[must_use]
pub fn new(directory: impl Into<PathBuf>) -> Self {
Self {
directory: directory.into(),
..Default::default()
}
}
#[must_use]
pub const fn with_max_file_size(mut self, size: u64) -> Self {
self.max_file_size = size;
self
}
#[must_use]
pub const fn with_sync_on_write(mut self, sync: bool) -> Self {
self.sync_on_write = sync;
self
}
}
pub struct WalWriter {
config: WalConfig,
sequence: AtomicU64,
file: Mutex<Option<BufWriter<File>>>,
file_size: AtomicU64,
file_number: AtomicU64,
}
impl std::fmt::Debug for WalWriter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WalWriter")
.field("config", &self.config)
.field("sequence", &self.sequence)
.field("file_size", &self.file_size)
.field("file_number", &self.file_number)
.finish()
}
}
impl WalWriter {
pub fn new(config: WalConfig) -> Result<Self> {
std::fs::create_dir_all(&config.directory).map_err(|e| Error::WalWrite {
reason: format!("Failed to create WAL directory: {e}"),
})?;
let writer = Self {
config,
sequence: AtomicU64::new(0),
file: Mutex::new(None),
file_size: AtomicU64::new(0),
file_number: AtomicU64::new(0),
};
writer.recover_sequence()?;
Ok(writer)
}
fn recover_sequence(&self) -> Result<()> {
let mut max_seq = 0u64;
let mut max_file = 0u64;
if let Ok(entries) = std::fs::read_dir(&self.config.directory) {
for entry in entries.flatten() {
let path = entry.path();
if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
if let Some(num_str) = name.strip_prefix("wal_").and_then(|s| s.strip_suffix(".log")) {
if let Ok(num) = num_str.parse::<u64>() {
max_file = max_file.max(num);
if let Ok(last_seq) = self.get_last_sequence(&path) {
max_seq = max_seq.max(last_seq);
}
}
}
}
}
}
self.sequence.store(max_seq, Ordering::SeqCst);
self.file_number.store(max_file + 1, Ordering::SeqCst);
Ok(())
}
fn get_last_sequence(&self, path: &Path) -> Result<u64> {
let reader = crate::wal::reader::WalReader::open(path)?;
let entries: Vec<_> = reader.collect();
if let Some(Ok(last)) = entries.last() {
Ok(last.sequence)
} else {
Ok(0)
}
}
fn current_file_path(&self) -> PathBuf {
let num = self.file_number.load(Ordering::Relaxed);
self.config.directory.join(format!("wal_{num:08}.log"))
}
fn ensure_file(&self) -> Result<()> {
let mut file_guard = self.file.lock();
if file_guard.is_none() {
let path = self.current_file_path();
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&path)
.map_err(|e| Error::WalWrite {
reason: format!("Failed to open WAL file: {e}"),
})?;
let size = file.metadata().map(|m| m.len()).unwrap_or(0);
self.file_size.store(size, Ordering::SeqCst);
*file_guard = Some(BufWriter::with_capacity(self.config.buffer_size, file));
}
Ok(())
}
fn maybe_rotate(&self) -> Result<()> {
let size = self.file_size.load(Ordering::Relaxed);
if size >= self.config.max_file_size {
let mut file_guard = self.file.lock();
if let Some(mut f) = file_guard.take() {
f.flush().map_err(|e| Error::WalWrite {
reason: format!("Failed to flush WAL: {e}"),
})?;
}
self.file_number.fetch_add(1, Ordering::SeqCst);
self.file_size.store(0, Ordering::SeqCst);
}
Ok(())
}
fn write_entry(&self, entry: &WalEntry) -> Result<()> {
self.maybe_rotate()?;
self.ensure_file()?;
let bytes = entry.to_bytes();
let entry_size = bytes.len() as u64;
let mut file_guard = self.file.lock();
let writer = file_guard.as_mut().ok_or_else(|| Error::WalWrite {
reason: "WAL file not open".into(),
})?;
writer
.write_all(&(bytes.len() as u32).to_le_bytes())
.map_err(|e| Error::WalWrite {
reason: format!("Failed to write length: {e}"),
})?;
writer.write_all(&bytes).map_err(|e| Error::WalWrite {
reason: format!("Failed to write entry: {e}"),
})?;
if self.config.sync_on_write {
writer.flush().map_err(|e| Error::WalWrite {
reason: format!("Failed to flush: {e}"),
})?;
writer.get_ref().sync_all().map_err(|e| Error::WalWrite {
reason: format!("Failed to sync: {e}"),
})?;
}
self.file_size
.fetch_add(4 + entry_size, Ordering::Relaxed);
Ok(())
}
fn next_sequence(&self) -> u64 {
self.sequence.fetch_add(1, Ordering::SeqCst) + 1
}
pub fn log_insert(&self, record: &MemoryRecord) -> Result<u64> {
let seq = self.next_sequence();
let entry = WalEntry::insert(seq, record);
self.write_entry(&entry)?;
Ok(seq)
}
pub fn log_update_stats(&self, record_id: &RecordId, outcome: f64) -> Result<u64> {
let seq = self.next_sequence();
let entry = WalEntry::update_stats(seq, record_id, outcome);
self.write_entry(&entry)?;
Ok(seq)
}
pub fn log_delete(&self, record_id: &RecordId) -> Result<u64> {
let seq = self.next_sequence();
let entry = WalEntry::delete(seq, record_id);
self.write_entry(&entry)?;
Ok(seq)
}
pub fn log_checkpoint(&self) -> Result<u64> {
let seq = self.next_sequence();
let entry = WalEntry::checkpoint(seq);
self.write_entry(&entry)?;
Ok(seq)
}
#[must_use]
pub fn sequence(&self) -> u64 {
self.sequence.load(Ordering::SeqCst)
}
pub fn flush(&self) -> Result<()> {
let mut file_guard = self.file.lock();
if let Some(writer) = file_guard.as_mut() {
writer.flush().map_err(|e| Error::WalWrite {
reason: format!("Failed to flush: {e}"),
})?;
writer.get_ref().sync_all().map_err(|e| Error::WalWrite {
reason: format!("Failed to sync: {e}"),
})?;
}
Ok(())
}
pub fn close(&self) -> Result<()> {
self.flush()?;
let mut file_guard = self.file.lock();
*file_guard = None;
Ok(())
}
#[must_use]
pub fn directory(&self) -> &Path {
&self.config.directory
}
pub fn list_files(&self) -> Result<Vec<PathBuf>> {
let mut files = Vec::new();
if let Ok(entries) = std::fs::read_dir(&self.config.directory) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().map_or(false, |e| e == "log") {
files.push(path);
}
}
}
files.sort();
Ok(files)
}
pub fn truncate_before(&self, checkpoint_seq: u64) -> Result<()> {
let files = self.list_files()?;
for file_path in files {
let reader = crate::wal::reader::WalReader::open(&file_path)?;
let entries: Vec<_> = reader.collect();
if let Some(Ok(last)) = entries.last() {
if last.sequence < checkpoint_seq {
std::fs::remove_file(&file_path).map_err(|e| Error::WalWrite {
reason: format!("Failed to remove old WAL file: {e}"),
})?;
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stats::OutcomeStats;
use crate::types::RecordStatus;
use tempfile::TempDir;
fn create_test_record(id: &str) -> MemoryRecord {
MemoryRecord {
id: id.into(),
embedding: vec![1.0, 2.0, 3.0],
context: format!("Context for {id}"),
outcome: 0.5,
metadata: Default::default(),
created_at: 1234567890,
status: RecordStatus::Active,
stats: OutcomeStats::new(1),
}
}
#[test]
fn test_wal_writer_creation() {
let temp_dir = TempDir::new().unwrap();
let config = WalConfig::new(temp_dir.path());
let writer = WalWriter::new(config).unwrap();
assert_eq!(writer.sequence(), 0);
}
#[test]
fn test_log_insert() {
let temp_dir = TempDir::new().unwrap();
let config = WalConfig::new(temp_dir.path());
let writer = WalWriter::new(config).unwrap();
let record = create_test_record("test-1");
let seq = writer.log_insert(&record).unwrap();
assert_eq!(seq, 1);
assert_eq!(writer.sequence(), 1);
}
#[test]
fn test_log_multiple_operations() {
let temp_dir = TempDir::new().unwrap();
let config = WalConfig::new(temp_dir.path());
let writer = WalWriter::new(config).unwrap();
writer.log_insert(&create_test_record("rec-1")).unwrap();
writer.log_insert(&create_test_record("rec-2")).unwrap();
writer.log_update_stats(&"rec-1".into(), 0.8).unwrap();
writer.log_delete(&"rec-2".into()).unwrap();
assert_eq!(writer.sequence(), 4);
}
#[test]
fn test_wal_file_creation() {
let temp_dir = TempDir::new().unwrap();
let config = WalConfig::new(temp_dir.path());
let writer = WalWriter::new(config).unwrap();
writer.log_insert(&create_test_record("test")).unwrap();
writer.flush().unwrap();
let files = writer.list_files().unwrap();
assert_eq!(files.len(), 1);
}
#[test]
fn test_sequence_recovery() {
let temp_dir = TempDir::new().unwrap();
{
let config = WalConfig::new(temp_dir.path());
let writer = WalWriter::new(config).unwrap();
writer.log_insert(&create_test_record("rec-1")).unwrap();
writer.log_insert(&create_test_record("rec-2")).unwrap();
writer.log_insert(&create_test_record("rec-3")).unwrap();
writer.flush().unwrap();
}
{
let config = WalConfig::new(temp_dir.path());
let writer = WalWriter::new(config).unwrap();
assert_eq!(writer.sequence(), 3);
let seq = writer.log_insert(&create_test_record("rec-4")).unwrap();
assert_eq!(seq, 4);
}
}
#[test]
fn test_file_rotation() {
let temp_dir = TempDir::new().unwrap();
let config = WalConfig::new(temp_dir.path())
.with_max_file_size(1024);
let writer = WalWriter::new(config).unwrap();
for i in 0..50 {
writer
.log_insert(&create_test_record(&format!("rec-{i}")))
.unwrap();
}
writer.flush().unwrap();
let files = writer.list_files().unwrap();
assert!(files.len() > 1, "Expected multiple WAL files after rotation");
}
#[test]
fn test_checkpoint_and_truncate() {
let temp_dir = TempDir::new().unwrap();
let config = WalConfig::new(temp_dir.path())
.with_max_file_size(512);
let writer = WalWriter::new(config).unwrap();
for i in 0..20 {
writer
.log_insert(&create_test_record(&format!("rec-{i}")))
.unwrap();
}
let checkpoint_seq = writer.log_checkpoint().unwrap();
writer.flush().unwrap();
let files_before = writer.list_files().unwrap().len();
writer.truncate_before(checkpoint_seq).unwrap();
let files_after = writer.list_files().unwrap().len();
assert!(files_after <= files_before);
}
}