use super::record::{MIN_RECORD_SIZE, WalRecord, WalRecordType};
use crate::error::{Result, XervError};
use crate::types::{NodeId, TraceId};
use byteorder::{LittleEndian, ReadBytesExt};
use fs2::FileExt;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct WalConfig {
pub directory: PathBuf,
pub max_file_size: u64,
pub sync_on_write: bool,
pub buffer_size: usize,
}
impl Default for WalConfig {
fn default() -> Self {
Self {
directory: PathBuf::from("/tmp/xerv/wal"),
max_file_size: 64 * 1024 * 1024, sync_on_write: true,
buffer_size: 64 * 1024, }
}
}
impl WalConfig {
pub fn in_memory() -> Self {
Self {
directory: std::env::temp_dir().join(format!("xerv_wal_{}", uuid::Uuid::new_v4())),
max_file_size: 64 * 1024 * 1024,
sync_on_write: false,
buffer_size: 64 * 1024,
}
}
pub fn with_directory(mut self, dir: impl Into<PathBuf>) -> Self {
self.directory = dir.into();
self
}
pub fn with_sync(mut self, sync: bool) -> Self {
self.sync_on_write = sync;
self
}
}
struct WalInner {
file: BufWriter<File>,
path: PathBuf,
file_size: u64,
config: WalConfig,
sequence: u64,
}
pub struct Wal {
inner: Arc<Mutex<WalInner>>,
}
impl Wal {
pub fn open(config: WalConfig) -> Result<Self> {
std::fs::create_dir_all(&config.directory).map_err(|e| XervError::WalWrite {
trace_id: TraceId::new(),
cause: format!("Failed to create WAL directory: {}", e),
})?;
let (path, sequence) = find_or_create_wal_file(&config.directory)?;
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&path)
.map_err(|e| XervError::WalWrite {
trace_id: TraceId::new(),
cause: format!("Failed to open WAL file: {}", e),
})?;
file.try_lock_exclusive().map_err(|e| XervError::WalWrite {
trace_id: TraceId::new(),
cause: format!("Failed to lock WAL file: {}", e),
})?;
let file_size = file.metadata().map(|m| m.len()).unwrap_or(0);
let inner = WalInner {
file: BufWriter::with_capacity(config.buffer_size, file),
path,
file_size,
config,
sequence,
};
Ok(Self {
inner: Arc::new(Mutex::new(inner)),
})
}
pub fn write(&self, record: &WalRecord) -> Result<()> {
let mut inner = self.inner.lock();
let bytes = record.to_bytes().map_err(|e| XervError::WalWrite {
trace_id: record.trace_id,
cause: e.to_string(),
})?;
if inner.file_size + bytes.len() as u64 > inner.config.max_file_size {
self.rotate_locked(&mut inner)?;
}
inner
.file
.write_all(&bytes)
.map_err(|e| XervError::WalWrite {
trace_id: record.trace_id,
cause: e.to_string(),
})?;
inner.file_size += bytes.len() as u64;
if inner.config.sync_on_write {
inner.file.flush().map_err(|e| XervError::WalWrite {
trace_id: record.trace_id,
cause: e.to_string(),
})?;
inner
.file
.get_ref()
.sync_data()
.map_err(|e| XervError::WalWrite {
trace_id: record.trace_id,
cause: e.to_string(),
})?;
}
Ok(())
}
pub fn flush(&self) -> Result<()> {
let mut inner = self.inner.lock();
inner.file.flush().map_err(|e| XervError::WalWrite {
trace_id: TraceId::new(),
cause: e.to_string(),
})?;
inner
.file
.get_ref()
.sync_data()
.map_err(|e| XervError::WalWrite {
trace_id: TraceId::new(),
cause: e.to_string(),
})
}
fn rotate_locked(&self, inner: &mut WalInner) -> Result<()> {
inner.file.flush().map_err(|e| XervError::WalWrite {
trace_id: TraceId::new(),
cause: e.to_string(),
})?;
inner.sequence += 1;
let new_path = inner
.config
.directory
.join(format!("wal_{:016x}.log", inner.sequence));
let new_file = OpenOptions::new()
.create(true)
.append(true)
.open(&new_path)
.map_err(|e| XervError::WalWrite {
trace_id: TraceId::new(),
cause: format!("Failed to create new WAL file: {}", e),
})?;
new_file
.try_lock_exclusive()
.map_err(|e| XervError::WalWrite {
trace_id: TraceId::new(),
cause: format!("Failed to lock new WAL file: {}", e),
})?;
let _ = inner.file.get_ref().unlock();
inner.file = BufWriter::with_capacity(inner.config.buffer_size, new_file);
inner.path = new_path;
inner.file_size = 0;
Ok(())
}
pub fn path(&self) -> PathBuf {
self.inner.lock().path.clone()
}
pub fn reader(&self) -> WalReader {
let inner = self.inner.lock();
WalReader {
directory: inner.config.directory.clone(),
}
}
}
impl Drop for Wal {
fn drop(&mut self) {
if let Some(inner) = Arc::get_mut(&mut self.inner) {
let inner = inner.get_mut();
let _ = inner.file.flush();
let _ = inner.file.get_ref().unlock();
}
}
}
fn find_or_create_wal_file(directory: &Path) -> Result<(PathBuf, u64)> {
let mut max_sequence = 0u64;
if let Ok(entries) = std::fs::read_dir(directory) {
for entry in entries.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
if name_str.starts_with("wal_") && name_str.ends_with(".log") {
if let Some(seq_str) = name_str
.strip_prefix("wal_")
.and_then(|s| s.strip_suffix(".log"))
{
if let Ok(seq) = u64::from_str_radix(seq_str, 16) {
max_sequence = max_sequence.max(seq);
}
}
}
}
}
let path = directory.join(format!("wal_{:016x}.log", max_sequence));
if path.exists() {
if let Ok(meta) = std::fs::metadata(&path) {
if meta.len() > 32 * 1024 * 1024 {
let new_seq = max_sequence + 1;
let new_path = directory.join(format!("wal_{:016x}.log", new_seq));
return Ok((new_path, new_seq));
}
}
}
Ok((path, max_sequence))
}
pub struct WalReader {
directory: PathBuf,
}
impl WalReader {
pub fn new(directory: impl Into<PathBuf>) -> Self {
Self {
directory: directory.into(),
}
}
pub fn read_all(&self) -> Result<Vec<WalRecord>> {
let mut records = Vec::new();
let mut files: Vec<PathBuf> = Vec::new();
if let Ok(entries) = std::fs::read_dir(&self.directory) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "log") {
files.push(path);
}
}
}
files.sort();
for path in files {
records.extend(self.read_file(&path)?);
}
Ok(records)
}
fn read_file(&self, path: &Path) -> Result<Vec<WalRecord>> {
let file = File::open(path).map_err(|e| XervError::WalRead {
cause: format!("Failed to open {}: {}", path.display(), e),
})?;
let mut reader = BufReader::new(file);
let mut records = Vec::new();
loop {
let length = match reader.read_u32::<LittleEndian>() {
Ok(len) => len as usize,
Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => {
return Err(XervError::WalRead {
cause: format!("Failed to read record length: {}", e),
});
}
};
if length < MIN_RECORD_SIZE {
return Err(XervError::WalCorruption {
position: reader.stream_position().unwrap_or(0),
cause: format!("Invalid record length: {}", length),
});
}
reader
.seek(SeekFrom::Current(-4))
.map_err(|e| XervError::WalRead {
cause: format!("Seek failed: {}", e),
})?;
let mut buf = vec![0u8; length];
reader
.read_exact(&mut buf)
.map_err(|e| XervError::WalRead {
cause: format!("Failed to read record: {}", e),
})?;
match WalRecord::from_bytes(&buf) {
Ok(record) => records.push(record),
Err(e) => {
tracing::warn!("Corrupted WAL record at {}: {}", path.display(), e);
}
}
}
Ok(records)
}
pub fn get_incomplete_traces(&self) -> Result<HashMap<TraceId, TraceRecoveryState>> {
let records = self.read_all()?;
let mut traces: HashMap<TraceId, TraceRecoveryState> = HashMap::new();
for record in records {
match record.record_type {
WalRecordType::TraceStart => {
traces.insert(
record.trace_id,
TraceRecoveryState {
trace_id: record.trace_id,
last_completed_node: None,
suspended_at: None,
started_nodes: Vec::new(),
completed_nodes: HashMap::new(),
},
);
}
WalRecordType::NodeStart => {
if let Some(state) = traces.get_mut(&record.trace_id) {
state.started_nodes.push(record.node_id);
}
}
WalRecordType::NodeDone => {
if let Some(state) = traces.get_mut(&record.trace_id) {
state.last_completed_node = Some(record.node_id);
state.started_nodes.retain(|&n| n != record.node_id);
state.completed_nodes.insert(
record.node_id,
NodeOutputLocation {
offset: record.output_offset,
size: record.output_size,
schema_hash: record.schema_hash,
},
);
}
}
WalRecordType::TraceComplete | WalRecordType::TraceFailed => {
traces.remove(&record.trace_id);
}
WalRecordType::TraceSuspended => {
if let Some(state) = traces.get_mut(&record.trace_id) {
state.suspended_at = Some(record.node_id);
}
}
WalRecordType::TraceResumed => {
if let Some(state) = traces.get_mut(&record.trace_id) {
state.suspended_at = None;
}
}
_ => {}
}
}
Ok(traces)
}
}
#[derive(Debug, Clone)]
pub struct TraceRecoveryState {
pub trace_id: TraceId,
pub last_completed_node: Option<NodeId>,
pub suspended_at: Option<NodeId>,
pub started_nodes: Vec<NodeId>,
pub completed_nodes: HashMap<NodeId, NodeOutputLocation>,
}
#[derive(Debug, Clone, Copy)]
pub struct NodeOutputLocation {
pub offset: crate::types::ArenaOffset,
pub size: u32,
pub schema_hash: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::ArenaOffset;
use tempfile::tempdir;
#[test]
fn wal_write_and_read() {
let dir = tempdir().unwrap();
let config = WalConfig::default()
.with_directory(dir.path())
.with_sync(false);
let wal = Wal::open(config).unwrap();
let trace_id = TraceId::new();
let node_id = NodeId::new(1);
wal.write(&WalRecord::trace_start(trace_id)).unwrap();
wal.write(&WalRecord::node_start(trace_id, node_id))
.unwrap();
wal.write(&WalRecord::node_done(
trace_id,
node_id,
ArenaOffset::new(0x100),
64,
0,
))
.unwrap();
wal.write(&WalRecord::trace_complete(trace_id)).unwrap();
wal.flush().unwrap();
let reader = wal.reader();
let records = reader.read_all().unwrap();
assert_eq!(records.len(), 4);
assert_eq!(records[0].record_type, WalRecordType::TraceStart);
assert_eq!(records[1].record_type, WalRecordType::NodeStart);
assert_eq!(records[2].record_type, WalRecordType::NodeDone);
assert_eq!(records[3].record_type, WalRecordType::TraceComplete);
}
#[test]
fn wal_incomplete_trace_detection() {
let dir = tempdir().unwrap();
let config = WalConfig::default()
.with_directory(dir.path())
.with_sync(false);
let wal = Wal::open(config).unwrap();
let trace1 = TraceId::new();
let trace2 = TraceId::new();
let node_id = NodeId::new(1);
wal.write(&WalRecord::trace_start(trace1)).unwrap();
wal.write(&WalRecord::node_done(
trace1,
node_id,
ArenaOffset::NULL,
0,
0,
))
.unwrap();
wal.write(&WalRecord::trace_complete(trace1)).unwrap();
wal.write(&WalRecord::trace_start(trace2)).unwrap();
wal.write(&WalRecord::node_start(trace2, node_id)).unwrap();
wal.flush().unwrap();
let reader = wal.reader();
let incomplete = reader.get_incomplete_traces().unwrap();
assert!(!incomplete.contains_key(&trace1));
assert!(incomplete.contains_key(&trace2));
let state = incomplete.get(&trace2).unwrap();
assert!(state.started_nodes.contains(&node_id));
}
}