use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::fs::{self, File, OpenOptions};
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use tracing::{debug, info, span, warn, Level};
use crate::annotations::TripleAnnotation;
use crate::StarResult;
#[derive(Debug, Clone)]
pub struct WalConfig {
pub wal_dir: PathBuf,
pub segment_size_threshold: usize,
pub enable_fsync: bool,
pub write_buffer_size: usize,
pub max_segments: usize,
pub enable_compression: bool,
}
impl Default for WalConfig {
fn default() -> Self {
Self {
wal_dir: std::env::temp_dir().join("oxirs_wal"),
segment_size_threshold: 64 * 1024 * 1024, enable_fsync: true,
write_buffer_size: 8192,
max_segments: 10,
enable_compression: false, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum WalEntryType {
Write,
Delete,
Checkpoint,
BeginTxn,
CommitTxn,
AbortTxn,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WalEntry {
pub sequence: u64,
pub entry_type: WalEntryType,
pub key: u64,
pub annotation: Option<TripleAnnotation>,
pub timestamp: DateTime<Utc>,
pub transaction_id: Option<u64>,
pub checksum: u32,
}
impl WalEntry {
pub fn write(
sequence: u64,
key: u64,
annotation: TripleAnnotation,
transaction_id: Option<u64>,
) -> Self {
let mut entry = Self {
sequence,
entry_type: WalEntryType::Write,
key,
annotation: Some(annotation),
timestamp: Utc::now(),
transaction_id,
checksum: 0,
};
entry.checksum = entry.calculate_checksum();
entry
}
pub fn delete(sequence: u64, key: u64, transaction_id: Option<u64>) -> Self {
let mut entry = Self {
sequence,
entry_type: WalEntryType::Delete,
key,
annotation: None,
timestamp: Utc::now(),
transaction_id,
checksum: 0,
};
entry.checksum = entry.calculate_checksum();
entry
}
pub fn checkpoint(sequence: u64) -> Self {
let mut entry = Self {
sequence,
entry_type: WalEntryType::Checkpoint,
key: 0,
annotation: None,
timestamp: Utc::now(),
transaction_id: None,
checksum: 0,
};
entry.checksum = entry.calculate_checksum();
entry
}
fn calculate_checksum(&self) -> u32 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
self.sequence.hash(&mut hasher);
self.key.hash(&mut hasher);
format!("{:?}", self.entry_type).hash(&mut hasher);
hasher.finish() as u32
}
pub fn verify_checksum(&self) -> bool {
self.checksum == self.calculate_checksum()
}
}
struct WalSegment {
id: u64,
#[allow(dead_code)]
path: PathBuf,
writer: BufWriter<File>,
size_bytes: usize,
entry_count: usize,
#[allow(dead_code)]
created_at: DateTime<Utc>,
}
impl WalSegment {
fn create(id: u64, wal_dir: &Path, buffer_size: usize) -> StarResult<Self> {
let path = wal_dir.join(format!("wal_{:08}.log", id));
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&path)
.map_err(|e| crate::StarError::serialization_error(e.to_string()))?;
let writer = BufWriter::with_capacity(buffer_size, file);
Ok(Self {
id,
path,
writer,
size_bytes: 0,
entry_count: 0,
created_at: Utc::now(),
})
}
fn append(&mut self, entry: &WalEntry, enable_fsync: bool) -> StarResult<()> {
let entry_bytes = oxicode::serde::encode_to_vec(entry, oxicode::config::standard())
.map_err(|e| crate::StarError::serialization_error(e.to_string()))?;
self.writer
.write_all(&(entry_bytes.len() as u32).to_le_bytes())
.map_err(|e| crate::StarError::serialization_error(e.to_string()))?;
self.writer
.write_all(&entry_bytes)
.map_err(|e| crate::StarError::serialization_error(e.to_string()))?;
self.size_bytes += 4 + entry_bytes.len();
self.entry_count += 1;
if enable_fsync {
self.writer
.flush()
.map_err(|e| crate::StarError::serialization_error(e.to_string()))?;
self.writer
.get_ref()
.sync_all()
.map_err(|e| crate::StarError::serialization_error(e.to_string()))?;
}
Ok(())
}
fn close(mut self) -> StarResult<()> {
self.writer
.flush()
.map_err(|e| crate::StarError::serialization_error(e.to_string()))?;
Ok(())
}
fn read_entries(path: &Path) -> StarResult<Vec<WalEntry>> {
let file = File::open(path).map_err(|e| crate::StarError::parse_error(e.to_string()))?;
let mut reader = BufReader::new(file);
let mut entries = Vec::new();
loop {
let mut len_bytes = [0u8; 4];
match reader.read_exact(&mut len_bytes) {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(crate::StarError::parse_error(e.to_string())),
}
let len = u32::from_le_bytes(len_bytes) as usize;
let mut entry_bytes = vec![0u8; len];
reader
.read_exact(&mut entry_bytes)
.map_err(|e| crate::StarError::parse_error(e.to_string()))?;
let entry: WalEntry =
oxicode::serde::decode_from_slice(&entry_bytes, oxicode::config::standard())
.map_err(|e| crate::StarError::parse_error(e.to_string()))?
.0;
if !entry.verify_checksum() {
warn!("Checksum mismatch for entry {}, skipping", entry.sequence);
continue;
}
entries.push(entry);
}
Ok(entries)
}
}
pub struct WriteAheadLog {
config: WalConfig,
current_segment: Option<WalSegment>,
next_segment_id: u64,
next_sequence: u64,
stats: WalStatistics,
}
#[derive(Debug, Clone, Default)]
pub struct WalStatistics {
pub total_entries: usize,
pub bytes_written: usize,
pub rotations: usize,
pub checkpoints: usize,
pub recoveries: usize,
}
impl WriteAheadLog {
pub fn new(config: WalConfig) -> StarResult<Self> {
let span = span!(Level::INFO, "wal_new");
let _enter = span.enter();
fs::create_dir_all(&config.wal_dir)
.map_err(|e| crate::StarError::serialization_error(e.to_string()))?;
let existing_segments = Self::list_segments(&config.wal_dir)?;
let next_segment_id = existing_segments
.iter()
.map(|(id, _)| id + 1)
.max()
.unwrap_or(1);
let next_sequence = Self::find_max_sequence(&existing_segments, &config.wal_dir)? + 1;
info!(
"Initialized WAL at {:?}, next segment: {}, next sequence: {}",
config.wal_dir, next_segment_id, next_sequence
);
let mut wal = Self {
config,
current_segment: None,
next_segment_id,
next_sequence,
stats: WalStatistics::default(),
};
wal.rotate_segment()?;
Ok(wal)
}
pub fn append_write(
&mut self,
key: u64,
annotation: TripleAnnotation,
transaction_id: Option<u64>,
) -> StarResult<u64> {
let sequence = self.next_sequence;
self.next_sequence += 1;
let entry = WalEntry::write(sequence, key, annotation, transaction_id);
self.append_entry(&entry)?;
Ok(sequence)
}
pub fn append_delete(&mut self, key: u64, transaction_id: Option<u64>) -> StarResult<u64> {
let sequence = self.next_sequence;
self.next_sequence += 1;
let entry = WalEntry::delete(sequence, key, transaction_id);
self.append_entry(&entry)?;
Ok(sequence)
}
pub fn checkpoint(&mut self) -> StarResult<u64> {
let span = span!(Level::INFO, "wal_checkpoint");
let _enter = span.enter();
let sequence = self.next_sequence;
self.next_sequence += 1;
let entry = WalEntry::checkpoint(sequence);
self.append_entry(&entry)?;
self.stats.checkpoints += 1;
info!("Wrote checkpoint at sequence {}", sequence);
Ok(sequence)
}
fn append_entry(&mut self, entry: &WalEntry) -> StarResult<()> {
let segment = self
.current_segment
.as_mut()
.ok_or_else(|| crate::StarError::serialization_error("No active segment"))?;
segment.append(entry, self.config.enable_fsync)?;
self.stats.total_entries += 1;
self.stats.bytes_written += segment.size_bytes;
if segment.size_bytes >= self.config.segment_size_threshold {
debug!("Segment {} reached size threshold, rotating", segment.id);
self.rotate_segment()?;
}
Ok(())
}
fn rotate_segment(&mut self) -> StarResult<()> {
let span = span!(Level::DEBUG, "rotate_segment");
let _enter = span.enter();
if let Some(segment) = self.current_segment.take() {
segment.close()?;
}
let segment_id = self.next_segment_id;
self.next_segment_id += 1;
let new_segment = WalSegment::create(
segment_id,
&self.config.wal_dir,
self.config.write_buffer_size,
)?;
self.current_segment = Some(new_segment);
self.stats.rotations += 1;
self.cleanup_old_segments()?;
debug!("Rotated to segment {}", segment_id);
Ok(())
}
fn cleanup_old_segments(&self) -> StarResult<()> {
let segments = Self::list_segments(&self.config.wal_dir)?;
if segments.len() <= self.config.max_segments {
return Ok(());
}
let segments_to_delete = segments.len() - self.config.max_segments;
for (_, path) in segments.iter().take(segments_to_delete) {
if let Err(e) = fs::remove_file(path) {
warn!("Failed to delete old WAL segment {:?}: {}", path, e);
} else {
debug!("Deleted old WAL segment {:?}", path);
}
}
Ok(())
}
fn list_segments(wal_dir: &Path) -> StarResult<Vec<(u64, PathBuf)>> {
let mut segments = Vec::new();
let entries =
fs::read_dir(wal_dir).map_err(|e| crate::StarError::parse_error(e.to_string()))?;
for entry in entries {
let entry = entry.map_err(|e| crate::StarError::parse_error(e.to_string()))?;
let path = entry.path();
if let Some(filename) = path.file_name().and_then(|n| n.to_str()) {
if filename.starts_with("wal_") && filename.ends_with(".log") {
let id_str = &filename[4..filename.len() - 4];
if let Ok(id) = id_str.parse::<u64>() {
segments.push((id, path));
}
}
}
}
segments.sort_by_key(|(id, _)| *id);
Ok(segments)
}
fn find_max_sequence(segments: &[(u64, PathBuf)], _wal_dir: &Path) -> StarResult<u64> {
let mut max_seq = 0u64;
for (_, path) in segments {
let entries = WalSegment::read_entries(path)?;
if let Some(last_entry) = entries.last() {
max_seq = max_seq.max(last_entry.sequence);
}
}
Ok(max_seq)
}
pub fn recover(&self) -> StarResult<Vec<WalEntry>> {
let span = span!(Level::INFO, "wal_recover");
let _enter = span.enter();
let segments = Self::list_segments(&self.config.wal_dir)?;
let mut all_entries = Vec::new();
for (_, path) in segments {
let entries = WalSegment::read_entries(&path)?;
all_entries.extend(entries);
}
let checkpoint_pos = all_entries
.iter()
.rposition(|e| e.entry_type == WalEntryType::Checkpoint);
let recovery_entries = if let Some(pos) = checkpoint_pos {
all_entries.split_off(pos + 1)
} else {
all_entries
};
info!("Recovered {} entries from WAL", recovery_entries.len());
Ok(recovery_entries)
}
pub fn statistics(&self) -> &WalStatistics {
&self.stats
}
pub fn flush(&mut self) -> StarResult<()> {
if let Some(segment) = self.current_segment.as_mut() {
segment
.writer
.flush()
.map_err(|e| crate::StarError::serialization_error(e.to_string()))?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wal_creation() {
let temp_dir = std::env::temp_dir().join(format!("oxirs_wal_test_{}", std::process::id()));
let config = WalConfig {
wal_dir: temp_dir.clone(),
..Default::default()
};
let wal = WriteAheadLog::new(config);
assert!(wal.is_ok());
let _ = std::fs::remove_dir_all(&temp_dir);
}
#[test]
fn test_append_write() {
let temp_dir = std::env::temp_dir().join(format!("oxirs_wal_test_{}", std::process::id()));
let config = WalConfig {
wal_dir: temp_dir.clone(),
enable_fsync: false, ..Default::default()
};
let mut wal = WriteAheadLog::new(config).unwrap();
let annotation = TripleAnnotation::new().with_confidence(0.9);
let seq = wal.append_write(123, annotation, None);
assert!(seq.is_ok());
let stats = wal.statistics();
assert_eq!(stats.total_entries, 1);
let _ = std::fs::remove_dir_all(&temp_dir);
}
#[test]
fn test_append_delete() {
let temp_dir =
std::env::temp_dir().join(format!("oxirs_wal_test_delete_{}", std::process::id()));
let config = WalConfig {
wal_dir: temp_dir.clone(),
enable_fsync: false,
..Default::default()
};
let mut wal = WriteAheadLog::new(config).unwrap();
let seq = wal.append_delete(456, None);
assert!(seq.is_ok());
let stats = wal.statistics();
assert_eq!(stats.total_entries, 1);
let _ = std::fs::remove_dir_all(&temp_dir);
}
#[test]
fn test_checkpoint() {
let temp_dir =
std::env::temp_dir().join(format!("oxirs_wal_test_checkpoint_{}", std::process::id()));
let config = WalConfig {
wal_dir: temp_dir.clone(),
enable_fsync: false,
..Default::default()
};
let mut wal = WriteAheadLog::new(config).unwrap();
wal.checkpoint().unwrap();
let stats = wal.statistics();
assert_eq!(stats.checkpoints, 1);
let _ = std::fs::remove_dir_all(&temp_dir);
}
#[test]
fn test_recovery() {
let temp_dir =
std::env::temp_dir().join(format!("oxirs_wal_test_recovery_{}", std::process::id()));
let config = WalConfig {
wal_dir: temp_dir.clone(),
enable_fsync: false,
..Default::default()
};
let mut wal = WriteAheadLog::new(config.clone()).unwrap();
let ann1 = TripleAnnotation::new().with_confidence(0.8);
wal.append_write(1, ann1, None).unwrap();
let ann2 = TripleAnnotation::new().with_confidence(0.9);
wal.append_write(2, ann2, None).unwrap();
wal.checkpoint().unwrap();
let ann3 = TripleAnnotation::new().with_confidence(0.95);
wal.append_write(3, ann3, None).unwrap();
wal.flush().unwrap();
let entries = wal.recover().unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].key, 3);
let _ = std::fs::remove_dir_all(&temp_dir);
}
#[test]
fn test_segment_rotation() {
let temp_dir =
std::env::temp_dir().join(format!("oxirs_wal_test_rotation_{}", std::process::id()));
let config = WalConfig {
wal_dir: temp_dir.clone(),
enable_fsync: false,
segment_size_threshold: 100, ..Default::default()
};
let mut wal = WriteAheadLog::new(config).unwrap();
for i in 0..100 {
let annotation = TripleAnnotation::new().with_confidence(0.9);
wal.append_write(i, annotation, None).unwrap();
}
let stats = wal.statistics();
assert!(stats.rotations > 0);
let _ = std::fs::remove_dir_all(&temp_dir);
}
#[test]
fn test_checksum_verification() {
let annotation = TripleAnnotation::new().with_confidence(0.9);
let entry = WalEntry::write(1, 123, annotation, None);
assert!(entry.verify_checksum());
}
}