use common::{DakeraError, Result};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::fs::{self, File, OpenOptions};
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone)]
pub struct WalConfig {
pub wal_dir: PathBuf,
pub max_segment_size: u64,
pub sync_mode: WalSyncMode,
pub checkpoint_threshold: u64,
}
impl Default for WalConfig {
fn default() -> Self {
Self {
wal_dir: PathBuf::from("./data/wal"),
max_segment_size: 64 * 1024 * 1024, sync_mode: WalSyncMode::EveryWrite,
checkpoint_threshold: 10000,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WalSyncMode {
EveryWrite,
Periodic(u32),
Manual,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WalEntry {
Upsert {
namespace: String,
vectors: Vec<SerializedVector>,
},
Delete { namespace: String, ids: Vec<String> },
CreateNamespace { namespace: String },
DeleteNamespace { namespace: String },
Checkpoint { lsn: u64 },
TxnBegin { txn_id: u64 },
TxnCommit { txn_id: u64 },
TxnRollback { txn_id: u64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedVector {
pub id: String,
pub values: Vec<f32>,
pub metadata: Option<String>,
}
#[derive(Debug)]
struct WalSegment {
size: u64,
start_lsn: u64,
end_lsn: u64,
}
pub struct WriteAheadLog {
config: WalConfig,
lsn: AtomicU64,
current_segment: RwLock<Option<WalSegment>>,
writer: RwLock<Option<BufWriter<File>>>,
entries_since_checkpoint: AtomicU64,
last_checkpoint_lsn: AtomicU64,
write_count: AtomicU64,
}
impl WriteAheadLog {
pub fn new(config: WalConfig) -> Result<Self> {
fs::create_dir_all(&config.wal_dir)
.map_err(|e| DakeraError::Storage(format!("Failed to create WAL dir: {}", e)))?;
let wal = Self {
config,
lsn: AtomicU64::new(0),
current_segment: RwLock::new(None),
writer: RwLock::new(None),
entries_since_checkpoint: AtomicU64::new(0),
last_checkpoint_lsn: AtomicU64::new(0),
write_count: AtomicU64::new(0),
};
wal.recover_lsn()?;
Ok(wal)
}
fn recover_lsn(&self) -> Result<()> {
let segments = self.list_segments()?;
if let Some(last_segment) = segments.last() {
let entries = self.read_segment(last_segment)?;
if let Some(last_entry) = entries.last() {
self.lsn.store(last_entry.0 + 1, Ordering::SeqCst);
}
}
Ok(())
}
fn list_segments(&self) -> Result<Vec<PathBuf>> {
let mut segments = Vec::new();
if let Ok(entries) = fs::read_dir(&self.config.wal_dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().map(|e| e == "wal").unwrap_or(false) {
segments.push(path);
}
}
}
segments.sort();
Ok(segments)
}
fn read_segment(&self, path: &Path) -> Result<Vec<(u64, WalEntry)>> {
let file = File::open(path)
.map_err(|e| DakeraError::Storage(format!("Failed to open WAL: {}", e)))?;
let reader = BufReader::new(file);
let mut entries = Vec::new();
for line_result in reader.lines() {
let line =
line_result.map_err(|e| DakeraError::Storage(format!("WAL read error: {}", e)))?;
if line.trim().is_empty() {
continue;
}
if let Some((lsn_str, entry_json)) = line.split_once('|') {
let lsn: u64 = lsn_str.parse().unwrap_or(0);
if let Ok(entry) = serde_json::from_str::<WalEntry>(entry_json) {
entries.push((lsn, entry));
}
}
}
Ok(entries)
}
pub fn append(&self, entry: WalEntry) -> Result<u64> {
let lsn = self.lsn.fetch_add(1, Ordering::SeqCst);
self.ensure_segment()?;
let entry_json = serde_json::to_string(&entry)
.map_err(|e| DakeraError::Storage(format!("WAL serialize error: {}", e)))?;
let line = format!("{}|{}\n", lsn, entry_json);
{
let mut writer_guard = self.writer.write();
let writer = writer_guard.as_mut().ok_or_else(|| {
DakeraError::Storage("WAL writer not available after ensure_segment".to_string())
})?;
writer
.write_all(line.as_bytes())
.map_err(|e| DakeraError::Storage(format!("WAL write error: {}", e)))?;
let write_count = self.write_count.fetch_add(1, Ordering::Relaxed) + 1;
match self.config.sync_mode {
WalSyncMode::EveryWrite => {
if let Err(e) = writer.flush() {
tracing::warn!(error = %e, "WAL flush failed during EveryWrite sync");
}
}
WalSyncMode::Periodic(n) if n > 0 && write_count.is_multiple_of(n as u64) => {
if let Err(e) = writer.flush() {
tracing::warn!(error = %e, write_count, "WAL flush failed during periodic sync");
}
}
_ => {}
}
}
{
let mut segment_guard = self.current_segment.write();
if let Some(ref mut segment) = *segment_guard {
segment.size += line.len() as u64;
segment.end_lsn = lsn;
}
}
let _entries = self
.entries_since_checkpoint
.fetch_add(1, Ordering::Relaxed);
Ok(lsn)
}
fn ensure_segment(&self) -> Result<()> {
let needs_new_segment = {
let segment_guard = self.current_segment.read();
match &*segment_guard {
None => true,
Some(seg) => seg.size >= self.config.max_segment_size,
}
};
if needs_new_segment {
self.rotate_segment()?;
}
Ok(())
}
fn rotate_segment(&self) -> Result<()> {
{
let mut writer_guard = self.writer.write();
if let Some(ref mut writer) = *writer_guard {
if let Err(e) = writer.flush() {
tracing::warn!(error = %e, "WAL flush failed during segment rotation");
}
}
*writer_guard = None;
}
let current_lsn = self.lsn.load(Ordering::SeqCst);
let segment_id = current_lsn;
let segment_path = self.config.wal_dir.join(format!("{:020}.wal", segment_id));
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&segment_path)
.map_err(|e| DakeraError::Storage(format!("Failed to create WAL segment: {}", e)))?;
let writer = BufWriter::new(file);
{
let mut segment_guard = self.current_segment.write();
*segment_guard = Some(WalSegment {
size: 0,
start_lsn: current_lsn,
end_lsn: current_lsn,
});
}
tracing::debug!(
segment_id = segment_id,
path = %segment_path.display(),
"Created new WAL segment"
);
{
let mut writer_guard = self.writer.write();
*writer_guard = Some(writer);
}
Ok(())
}
pub fn checkpoint(&self) -> Result<u64> {
let lsn = self.lsn.load(Ordering::SeqCst);
self.append(WalEntry::Checkpoint { lsn })?;
self.last_checkpoint_lsn.store(lsn, Ordering::SeqCst);
self.entries_since_checkpoint.store(0, Ordering::SeqCst);
{
let mut writer_guard = self.writer.write();
if let Some(ref mut writer) = *writer_guard {
if let Err(e) = writer.flush() {
tracing::warn!(error = %e, lsn, "WAL flush failed during checkpoint");
}
}
}
Ok(lsn)
}
pub fn recover(&self) -> Result<Vec<WalEntry>> {
let segments = self.list_segments()?;
let checkpoint_lsn = self.last_checkpoint_lsn.load(Ordering::SeqCst);
let mut entries = Vec::new();
for segment_path in segments {
let segment_entries = self.read_segment(&segment_path)?;
for (lsn, entry) in segment_entries {
if checkpoint_lsn > 0 && lsn <= checkpoint_lsn {
if let WalEntry::Checkpoint { lsn: cp_lsn } = entry {
self.last_checkpoint_lsn.store(cp_lsn, Ordering::SeqCst);
}
continue;
}
match entry {
WalEntry::TxnBegin { .. }
| WalEntry::TxnCommit { .. }
| WalEntry::TxnRollback { .. }
| WalEntry::Checkpoint { .. } => continue,
_ => entries.push(entry),
}
}
}
Ok(entries)
}
pub fn truncate(&self, up_to_lsn: u64) -> Result<u64> {
let segments = self.list_segments()?;
let mut removed_count = 0u64;
let active_start_lsn = {
let segment_guard = self.current_segment.read();
segment_guard.as_ref().map(|s| s.start_lsn)
};
for segment_path in segments {
let segment_entries = self.read_segment(&segment_path)?;
if let Some((first_lsn, _)) = segment_entries.first() {
if active_start_lsn == Some(*first_lsn) {
continue;
}
}
if let Some((last_lsn, _)) = segment_entries.last() {
if *last_lsn <= up_to_lsn {
fs::remove_file(&segment_path).ok();
removed_count += segment_entries.len() as u64;
}
}
}
Ok(removed_count)
}
pub fn current_lsn(&self) -> u64 {
self.lsn.load(Ordering::SeqCst)
}
pub fn stats(&self) -> WalStats {
let segment_count = self.list_segments().map(|s| s.len()).unwrap_or(0);
let (current_segment_size, current_segment_entries) = {
let segment_guard = self.current_segment.read();
match &*segment_guard {
Some(seg) => (seg.size, seg.end_lsn.saturating_sub(seg.start_lsn)),
None => (0, 0),
}
};
WalStats {
current_lsn: self.lsn.load(Ordering::SeqCst),
last_checkpoint_lsn: self.last_checkpoint_lsn.load(Ordering::SeqCst),
segment_count,
current_segment_size,
current_segment_entries,
entries_since_checkpoint: self.entries_since_checkpoint.load(Ordering::Relaxed),
}
}
pub fn flush(&self) -> Result<()> {
let mut writer_guard = self.writer.write();
if let Some(ref mut writer) = *writer_guard {
writer
.flush()
.map_err(|e| DakeraError::Storage(format!("WAL flush error: {}", e)))?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct WalStats {
pub current_lsn: u64,
pub last_checkpoint_lsn: u64,
pub segment_count: usize,
pub current_segment_size: u64,
pub current_segment_entries: u64,
pub entries_since_checkpoint: u64,
}
pub struct WalStorage<S> {
inner: S,
wal: WriteAheadLog,
}
impl<S> WalStorage<S> {
pub fn new(inner: S, wal_config: WalConfig) -> Result<Self> {
let wal = WriteAheadLog::new(wal_config)?;
Ok(Self { inner, wal })
}
pub fn wal(&self) -> &WriteAheadLog {
&self.wal
}
pub fn inner(&self) -> &S {
&self.inner
}
pub fn checkpoint(&self) -> Result<u64> {
self.wal.checkpoint()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn test_config(dir: &Path) -> WalConfig {
WalConfig {
wal_dir: dir.to_path_buf(),
max_segment_size: 1024,
sync_mode: WalSyncMode::EveryWrite,
checkpoint_threshold: 100,
}
}
#[test]
fn test_wal_basic_operations() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(temp_dir.path());
let wal = WriteAheadLog::new(config).unwrap();
let lsn1 = wal
.append(WalEntry::CreateNamespace {
namespace: "test".to_string(),
})
.unwrap();
let lsn2 = wal
.append(WalEntry::Upsert {
namespace: "test".to_string(),
vectors: vec![SerializedVector {
id: "v1".to_string(),
values: vec![1.0, 2.0, 3.0],
metadata: None,
}],
})
.unwrap();
assert_eq!(lsn1, 0);
assert_eq!(lsn2, 1);
assert_eq!(wal.current_lsn(), 2);
}
#[test]
fn test_wal_recovery() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(temp_dir.path());
{
let wal = WriteAheadLog::new(config.clone()).unwrap();
wal.append(WalEntry::CreateNamespace {
namespace: "test".to_string(),
})
.unwrap();
wal.append(WalEntry::Upsert {
namespace: "test".to_string(),
vectors: vec![SerializedVector {
id: "v1".to_string(),
values: vec![1.0, 2.0],
metadata: None,
}],
})
.unwrap();
wal.flush().unwrap();
}
{
let wal = WriteAheadLog::new(config).unwrap();
let entries = wal.recover().unwrap();
assert_eq!(entries.len(), 2);
assert!(matches!(entries[0], WalEntry::CreateNamespace { .. }));
assert!(matches!(entries[1], WalEntry::Upsert { .. }));
}
}
#[test]
fn test_wal_checkpoint() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(temp_dir.path());
let wal = WriteAheadLog::new(config).unwrap();
wal.append(WalEntry::CreateNamespace {
namespace: "test".to_string(),
})
.unwrap();
let _checkpoint_lsn = wal.checkpoint().unwrap();
wal.append(WalEntry::Upsert {
namespace: "test".to_string(),
vectors: vec![],
})
.unwrap();
let stats = wal.stats();
assert!(stats.last_checkpoint_lsn > 0);
assert_eq!(stats.entries_since_checkpoint, 1); }
#[test]
fn test_wal_stats() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(temp_dir.path());
let wal = WriteAheadLog::new(config).unwrap();
for i in 0..5 {
wal.append(WalEntry::Upsert {
namespace: "test".to_string(),
vectors: vec![SerializedVector {
id: format!("v{}", i),
values: vec![i as f32],
metadata: None,
}],
})
.unwrap();
}
let stats = wal.stats();
assert_eq!(stats.current_lsn, 5);
assert_eq!(stats.entries_since_checkpoint, 5);
}
#[test]
fn test_segment_rotation() {
let temp_dir = TempDir::new().unwrap();
let config = WalConfig {
wal_dir: temp_dir.path().to_path_buf(),
max_segment_size: 100, sync_mode: WalSyncMode::EveryWrite,
checkpoint_threshold: 1000,
};
let wal = WriteAheadLog::new(config).unwrap();
for i in 0..10 {
wal.append(WalEntry::Upsert {
namespace: "test".to_string(),
vectors: vec![SerializedVector {
id: format!("v{}", i),
values: vec![i as f32; 10],
metadata: Some("some metadata here".to_string()),
}],
})
.unwrap();
}
let stats = wal.stats();
assert!(stats.segment_count > 1);
}
}