use crate::{MemError, MemResult};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::io::{BufRead, BufReader, BufWriter, Read, Seek, SeekFrom, Write};
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tokio::sync::{Mutex, RwLock};
use uuid::Uuid;
const SEGMENT_PREFIX: &str = "wal_";
const SEGMENT_EXTENSION: &str = ".log";
const CHECKPOINT_PREFIX: &str = "checkpoint_";
const CHECKPOINT_EXTENSION: &str = ".ckpt";
const WAL_MAGIC: [u8; 4] = [0x57, 0x41, 0x4C, 0x31];
const CHECKPOINT_MAGIC: [u8; 4] = [0x43, 0x4B, 0x50, 0x54];
const WAL_VERSION: u8 = 1;
const DEFAULT_CHECKPOINT_RETENTION: usize = 2;
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Default,
)]
pub struct LogSequenceNumber(pub u64);
impl LogSequenceNumber {
pub const fn new(value: u64) -> Self {
Self(value)
}
pub const fn value(&self) -> u64 {
self.0
}
pub fn next(&self) -> Self {
Self(self.0 + 1)
}
}
impl std::fmt::Display for LogSequenceNumber {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "LSN({})", self.0)
}
}
impl From<u64> for LogSequenceNumber {
fn from(value: u64) -> Self {
Self(value)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub struct CheckpointId(pub u64);
impl CheckpointId {
pub const fn new(value: u64) -> Self {
Self(value)
}
pub const fn value(&self) -> u64 {
self.0
}
}
impl std::fmt::Display for CheckpointId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CKPT({})", self.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum WalOperation {
Insert {
id: Uuid,
content: String,
embedding: Vec<f32>,
},
Update {
id: Uuid,
content: String,
embedding: Vec<f32>,
},
Delete {
id: Uuid,
},
BatchInsert {
items: Vec<BatchItem>,
},
BatchDelete {
ids: Vec<Uuid>,
},
Checkpoint {
lsn: u64,
checkpoint_id: u64,
},
TxnBegin {
txn_id: Uuid,
},
TxnCommit {
txn_id: Uuid,
},
TxnRollback {
txn_id: Uuid,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct BatchItem {
pub id: Uuid,
pub content: String,
pub embedding: Vec<f32>,
}
impl WalOperation {
pub fn op_type(&self) -> &'static str {
match self {
WalOperation::Insert { .. } => "INSERT",
WalOperation::Update { .. } => "UPDATE",
WalOperation::Delete { .. } => "DELETE",
WalOperation::BatchInsert { .. } => "BATCH_INSERT",
WalOperation::BatchDelete { .. } => "BATCH_DELETE",
WalOperation::Checkpoint { .. } => "CHECKPOINT",
WalOperation::TxnBegin { .. } => "TXN_BEGIN",
WalOperation::TxnCommit { .. } => "TXN_COMMIT",
WalOperation::TxnRollback { .. } => "TXN_ROLLBACK",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WalEntry {
pub lsn: u64,
pub timestamp: i64,
pub operation: WalOperation,
pub checksum: u32,
}
impl WalEntry {
pub fn new(lsn: u64, operation: WalOperation) -> Self {
let timestamp = chrono::Utc::now().timestamp_micros();
let mut entry = WalEntry {
lsn,
timestamp,
operation,
checksum: 0,
};
entry.checksum = compute_checksum(&entry);
entry
}
pub fn verify(&self) -> bool {
verify_checksum(self)
}
pub fn serialized_size(&self) -> MemResult<usize> {
let data = wal_serialize(self)?;
Ok(4 + 1 + 4 + data.len() + 4)
}
pub fn lsn(&self) -> LogSequenceNumber {
LogSequenceNumber::new(self.lsn)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub id: CheckpointId,
pub lsn: LogSequenceNumber,
pub created_at: u64,
pub document_ids: Vec<Uuid>,
pub chunk_ids: Vec<Uuid>,
pub metadata: HashMap<String, String>,
pub checksum: u32,
}
impl Checkpoint {
pub fn new(
id: u64,
lsn: LogSequenceNumber,
document_ids: Vec<Uuid>,
chunk_ids: Vec<Uuid>,
) -> Self {
let created_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let mut checkpoint = Self {
id: CheckpointId::new(id),
lsn,
created_at,
document_ids,
chunk_ids,
metadata: HashMap::new(),
checksum: 0,
};
checkpoint.checksum = checkpoint.compute_checksum();
checkpoint
}
fn compute_checksum(&self) -> u32 {
let data = serde_json::to_vec(&(&self.id, &self.lsn, &self.created_at, &self.document_ids))
.unwrap_or_default();
crc32_compute(&data)
}
pub fn validate(&self) -> bool {
let expected = self.compute_checksum();
self.checksum == expected
}
pub fn to_bytes(&self) -> MemResult<Vec<u8>> {
let json = serde_json::to_vec(self)
.map_err(|e| MemError::storage(format!("Failed to serialize checkpoint: {}", e)))?;
let mut buf = Vec::with_capacity(4 + 1 + json.len());
buf.extend_from_slice(&CHECKPOINT_MAGIC);
buf.push(WAL_VERSION);
buf.extend_from_slice(&json);
Ok(buf)
}
pub fn from_bytes(buf: &[u8]) -> MemResult<Self> {
if buf.len() < 5 {
return Err(MemError::storage("Buffer too small for checkpoint"));
}
if buf[0..4] != CHECKPOINT_MAGIC {
return Err(MemError::storage("Invalid checkpoint magic"));
}
if buf[4] > WAL_VERSION {
return Err(MemError::storage(format!(
"Unsupported checkpoint version: {}",
buf[4]
)));
}
let checkpoint: Checkpoint = serde_json::from_slice(&buf[5..])
.map_err(|e| MemError::storage(format!("Failed to deserialize checkpoint: {}", e)))?;
if !checkpoint.validate() {
return Err(MemError::storage("Checkpoint checksum validation failed"));
}
Ok(checkpoint)
}
}
#[derive(Debug, Clone, Default)]
pub struct RecoveryReport {
pub entries_recovered: u64,
pub entries_skipped: u64,
pub entries_replayed: u64,
pub last_valid_lsn: LogSequenceNumber,
pub checkpoint_used: Option<CheckpointId>,
pub duration_ms: u64,
pub errors: Vec<RecoveryError>,
pub success: bool,
}
#[derive(Debug, Clone)]
pub struct RecoveryError {
pub lsn: Option<LogSequenceNumber>,
pub segment: Option<PathBuf>,
pub message: String,
pub fatal: bool,
}
#[derive(Debug, Clone)]
pub enum SyncMode {
Immediate,
Batched(Duration),
Async,
}
impl Default for SyncMode {
fn default() -> Self {
SyncMode::Batched(Duration::from_millis(100))
}
}
#[derive(Debug, Clone)]
pub struct WalConfig {
pub dir: PathBuf,
pub segment_size_mb: usize,
pub sync_mode: SyncMode,
pub checkpoint_retention: usize,
pub preallocate_segments: bool,
}
impl Default for WalConfig {
fn default() -> Self {
Self {
dir: PathBuf::from("./wal"),
segment_size_mb: 64,
sync_mode: SyncMode::default(),
checkpoint_retention: DEFAULT_CHECKPOINT_RETENTION,
preallocate_segments: true,
}
}
}
impl WalConfig {
pub fn new(dir: PathBuf) -> Self {
Self {
dir,
..Default::default()
}
}
pub fn with_segment_size(mut self, size_mb: usize) -> Self {
self.segment_size_mb = size_mb;
self
}
pub fn with_sync_mode(mut self, mode: SyncMode) -> Self {
self.sync_mode = mode;
self
}
pub fn with_checkpoint_retention(mut self, count: usize) -> Self {
self.checkpoint_retention = count;
self
}
fn segment_size_bytes(&self) -> usize {
self.segment_size_mb * 1024 * 1024
}
}
#[async_trait]
pub trait WriteAheadLogTrait: Send + Sync {
async fn append(&self, op: WalOperation) -> MemResult<LogSequenceNumber>;
async fn append_batch(&self, ops: Vec<WalOperation>) -> MemResult<LogSequenceNumber>;
async fn sync(&self) -> MemResult<()>;
async fn checkpoint(&self) -> MemResult<CheckpointId>;
async fn recover(&self) -> MemResult<RecoveryReport>;
async fn truncate_before(&self, lsn: LogSequenceNumber) -> MemResult<()>;
async fn current_lsn(&self) -> LogSequenceNumber;
async fn synced_lsn(&self) -> LogSequenceNumber;
async fn read_from(
&self,
start_lsn: LogSequenceNumber,
limit: usize,
) -> MemResult<Vec<WalEntry>>;
async fn close(&self) -> MemResult<()>;
}
#[derive(Debug, Clone)]
struct SegmentMeta {
segment_id: u64,
path: PathBuf,
first_lsn: u64,
last_lsn: Option<u64>,
size_bytes: u64,
}
struct SegmentWriter {
writer: BufWriter<File>,
meta: SegmentMeta,
needs_sync: bool,
}
impl SegmentWriter {
fn new(meta: SegmentMeta) -> MemResult<Self> {
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&meta.path)
.map_err(|e| MemError::storage(format!("Failed to open WAL segment: {}", e)))?;
Ok(Self {
writer: BufWriter::with_capacity(64 * 1024, file), meta,
needs_sync: false,
})
}
fn write_entry(&mut self, entry: &WalEntry) -> MemResult<()> {
let data = wal_serialize(entry)?;
let len = data.len() as u32;
self.writer
.write_all(&WAL_MAGIC)
.map_err(|e| MemError::storage(format!("Failed to write WAL magic: {}", e)))?;
self.writer
.write_all(&[WAL_VERSION])
.map_err(|e| MemError::storage(format!("Failed to write WAL version: {}", e)))?;
self.writer
.write_all(&len.to_le_bytes())
.map_err(|e| MemError::storage(format!("Failed to write entry length: {}", e)))?;
self.writer
.write_all(&data)
.map_err(|e| MemError::storage(format!("Failed to write entry data: {}", e)))?;
self.writer
.write_all(&len.to_le_bytes())
.map_err(|e| MemError::storage(format!("Failed to write trailing length: {}", e)))?;
self.meta.size_bytes += 4 + 1 + 4 + data.len() as u64 + 4;
self.meta.last_lsn = Some(entry.lsn);
self.needs_sync = true;
Ok(())
}
fn flush(&mut self) -> MemResult<()> {
self.writer
.flush()
.map_err(|e| MemError::storage(format!("Failed to flush WAL buffer: {}", e)))
}
fn sync(&mut self) -> MemResult<()> {
self.flush()?;
self.writer
.get_ref()
.sync_all()
.map_err(|e| MemError::storage(format!("Failed to sync WAL segment: {}", e)))?;
self.needs_sync = false;
Ok(())
}
}
pub struct WriteAheadLog {
config: WalConfig,
current_segment: Arc<Mutex<SegmentWriter>>,
current_lsn: AtomicU64,
synced_lsn: AtomicU64,
last_checkpoint: AtomicU64,
checkpoint_counter: AtomicU64,
segments: Arc<RwLock<Vec<SegmentMeta>>>,
#[allow(dead_code)]
sync_handle: Option<tokio::task::JoinHandle<()>>,
shutdown: Arc<AtomicU64>,
closed: AtomicU64,
}
impl WriteAheadLog {
pub async fn new(config: WalConfig) -> MemResult<Self> {
std::fs::create_dir_all(&config.dir)
.map_err(|e| MemError::storage(format!("Failed to create WAL directory: {}", e)))?;
let (segments, current_lsn, last_checkpoint, checkpoint_id) =
Self::discover_segments(&config).await?;
let current_segment = if segments.is_empty() {
Self::create_new_segment(&config, 0, current_lsn)?
} else {
let last = segments.last().unwrap();
if last.size_bytes >= config.segment_size_bytes() as u64 {
Self::create_new_segment(&config, last.segment_id + 1, current_lsn)?
} else {
SegmentWriter::new(last.clone())?
}
};
let segments = Arc::new(RwLock::new(segments));
let current_segment = Arc::new(Mutex::new(current_segment));
let shutdown = Arc::new(AtomicU64::new(0));
let sync_handle = match &config.sync_mode {
SyncMode::Batched(interval) => {
let interval = *interval;
let segment_clone = current_segment.clone();
let shutdown_clone = shutdown.clone();
Some(tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
loop {
ticker.tick().await;
if shutdown_clone.load(Ordering::SeqCst) != 0 {
break;
}
let mut segment = segment_clone.lock().await;
if segment.needs_sync {
if let Err(e) = segment.sync() {
tracing::error!("Background WAL sync failed: {}", e);
}
}
}
}))
}
_ => None,
};
Ok(Self {
config,
current_segment,
current_lsn: AtomicU64::new(current_lsn),
synced_lsn: AtomicU64::new(current_lsn.saturating_sub(1)),
last_checkpoint: AtomicU64::new(last_checkpoint),
checkpoint_counter: AtomicU64::new(checkpoint_id),
segments,
sync_handle,
shutdown,
closed: AtomicU64::new(0),
})
}
async fn discover_segments(config: &WalConfig) -> MemResult<(Vec<SegmentMeta>, u64, u64, u64)> {
let mut segments = Vec::new();
let mut max_lsn: u64 = 0;
let mut last_checkpoint: u64 = 0;
let mut max_checkpoint_id: u64 = 0;
let entries = std::fs::read_dir(&config.dir)
.map_err(|e| MemError::storage(format!("Failed to read WAL directory: {}", e)))?;
let mut segment_files: Vec<(u64, PathBuf)> = Vec::new();
for entry in entries {
let entry =
entry.map_err(|e| MemError::storage(format!("Failed to read dir entry: {}", e)))?;
let path = entry.path();
if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
if name.starts_with(SEGMENT_PREFIX) && name.ends_with(SEGMENT_EXTENSION) {
let id_str = name
.trim_start_matches(SEGMENT_PREFIX)
.trim_end_matches(SEGMENT_EXTENSION);
if let Ok(id) = u64::from_str_radix(id_str, 16) {
segment_files.push((id, path));
}
}
}
}
segment_files.sort_by_key(|(id, _)| *id);
for (segment_id, path) in segment_files {
let file = File::open(&path)
.map_err(|e| MemError::storage(format!("Failed to open segment: {}", e)))?;
let size_bytes = file
.metadata()
.map_err(|e| MemError::storage(format!("Failed to get segment metadata: {}", e)))?
.len();
let mut reader = BufReader::new(file);
let mut first_lsn = None;
let mut last_lsn = None;
while let Ok(Some(entry)) = read_entry(&mut reader) {
if first_lsn.is_none() {
first_lsn = Some(entry.lsn);
}
last_lsn = Some(entry.lsn);
if entry.lsn > max_lsn {
max_lsn = entry.lsn;
}
if let WalOperation::Checkpoint { lsn, checkpoint_id } = &entry.operation {
if *lsn > last_checkpoint {
last_checkpoint = *lsn;
}
if *checkpoint_id > max_checkpoint_id {
max_checkpoint_id = *checkpoint_id;
}
}
}
segments.push(SegmentMeta {
segment_id,
path,
first_lsn: first_lsn.unwrap_or(0),
last_lsn,
size_bytes,
});
}
if let Ok(checkpoint) = Self::find_latest_checkpoint(&config.dir) {
if checkpoint.lsn.value() > last_checkpoint {
last_checkpoint = checkpoint.lsn.value();
}
if checkpoint.id.value() > max_checkpoint_id {
max_checkpoint_id = checkpoint.id.value();
}
}
Ok((
segments,
max_lsn + 1,
last_checkpoint,
max_checkpoint_id + 1,
))
}
fn find_latest_checkpoint(dir: &PathBuf) -> MemResult<Checkpoint> {
let mut checkpoints: Vec<(u64, PathBuf)> = Vec::new();
let entries = std::fs::read_dir(dir)
.map_err(|e| MemError::storage(format!("Failed to read directory: {}", e)))?;
for entry in entries {
let entry =
entry.map_err(|e| MemError::storage(format!("Failed to read entry: {}", e)))?;
let path = entry.path();
if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
if name.starts_with(CHECKPOINT_PREFIX) && name.ends_with(CHECKPOINT_EXTENSION) {
let id_str = name
.trim_start_matches(CHECKPOINT_PREFIX)
.trim_end_matches(CHECKPOINT_EXTENSION);
if let Ok(id) = u64::from_str_radix(id_str, 16) {
checkpoints.push((id, path));
}
}
}
}
if checkpoints.is_empty() {
return Err(MemError::not_found("No checkpoint files found"));
}
checkpoints.sort_by(|a, b| b.0.cmp(&a.0));
for (_, path) in checkpoints {
let data = std::fs::read(&path)
.map_err(|e| MemError::storage(format!("Failed to read checkpoint: {}", e)))?;
match Checkpoint::from_bytes(&data) {
Ok(ckpt) if ckpt.validate() => return Ok(ckpt),
_ => continue,
}
}
Err(MemError::not_found("No valid checkpoint files found"))
}
fn create_new_segment(
config: &WalConfig,
segment_id: u64,
first_lsn: u64,
) -> MemResult<SegmentWriter> {
let filename = format!("{}{:016x}{}", SEGMENT_PREFIX, segment_id, SEGMENT_EXTENSION);
let path = config.dir.join(filename);
let meta = SegmentMeta {
segment_id,
path,
first_lsn,
last_lsn: None,
size_bytes: 0,
};
SegmentWriter::new(meta)
}
pub async fn append(&self, op: WalOperation) -> MemResult<u64> {
if self.closed.load(Ordering::SeqCst) != 0 {
return Err(MemError::storage("WAL is closed"));
}
let lsn = self.current_lsn.fetch_add(1, Ordering::SeqCst);
let entry = WalEntry::new(lsn, op);
let mut segment = self.current_segment.lock().await;
if segment.meta.size_bytes >= self.config.segment_size_bytes() as u64 {
segment.sync()?;
if let Some(last) = segment.meta.last_lsn {
self.synced_lsn.store(last, Ordering::SeqCst);
}
let mut segments = self.segments.write().await;
segments.push(segment.meta.clone());
let new_segment_id = segment.meta.segment_id + 1;
*segment = Self::create_new_segment(&self.config, new_segment_id, lsn)?;
}
segment.write_entry(&entry)?;
match &self.config.sync_mode {
SyncMode::Immediate => {
segment.sync()?;
self.synced_lsn.store(lsn, Ordering::SeqCst);
}
SyncMode::Async | SyncMode::Batched(_) => {
segment.flush()?;
}
}
tracing::trace!(
lsn = lsn,
op = entry.operation.op_type(),
"WAL: Appended entry"
);
Ok(lsn)
}
pub async fn append_batch(&self, ops: Vec<WalOperation>) -> MemResult<u64> {
if ops.is_empty() {
return Ok(self.current_lsn.load(Ordering::SeqCst));
}
if self.closed.load(Ordering::SeqCst) != 0 {
return Err(MemError::storage("WAL is closed"));
}
let first_lsn = self
.current_lsn
.fetch_add(ops.len() as u64, Ordering::SeqCst);
let mut segment = self.current_segment.lock().await;
for (i, op) in ops.into_iter().enumerate() {
let lsn = first_lsn + i as u64;
let entry = WalEntry::new(lsn, op);
segment.write_entry(&entry)?;
}
match &self.config.sync_mode {
SyncMode::Immediate => {
segment.sync()?;
self.synced_lsn.store(
self.current_lsn.load(Ordering::SeqCst) - 1,
Ordering::SeqCst,
);
}
SyncMode::Async | SyncMode::Batched(_) => {
segment.flush()?;
}
}
Ok(first_lsn)
}
pub async fn sync(&self) -> MemResult<()> {
let mut segment = self.current_segment.lock().await;
segment.sync()?;
if let Some(last) = segment.meta.last_lsn {
self.synced_lsn.store(last, Ordering::SeqCst);
}
Ok(())
}
pub async fn checkpoint(&self) -> MemResult<CheckpointId> {
self.sync().await?;
let checkpoint_id = self.checkpoint_counter.fetch_add(1, Ordering::SeqCst);
let checkpoint_lsn = self.current_lsn.load(Ordering::SeqCst);
let lsn = self
.append(WalOperation::Checkpoint {
lsn: checkpoint_lsn,
checkpoint_id,
})
.await?;
self.sync().await?;
let checkpoint = Checkpoint::new(
checkpoint_id,
LogSequenceNumber::new(checkpoint_lsn),
vec![], vec![], );
let checkpoint_path = self.config.dir.join(format!(
"{}{:016x}{}",
CHECKPOINT_PREFIX, checkpoint_id, CHECKPOINT_EXTENSION
));
let data = checkpoint.to_bytes()?;
std::fs::write(&checkpoint_path, data)
.map_err(|e| MemError::storage(format!("Failed to write checkpoint file: {}", e)))?;
self.last_checkpoint.store(checkpoint_lsn, Ordering::SeqCst);
self.cleanup_old_checkpoints().await?;
tracing::info!(
checkpoint_id = checkpoint_id,
lsn = lsn,
"WAL: Created checkpoint"
);
Ok(CheckpointId::new(checkpoint_id))
}
async fn cleanup_old_checkpoints(&self) -> MemResult<()> {
let mut checkpoints: Vec<(u64, PathBuf)> = Vec::new();
let entries = std::fs::read_dir(&self.config.dir)
.map_err(|e| MemError::storage(format!("Failed to read directory: {}", e)))?;
for entry in entries {
let entry =
entry.map_err(|e| MemError::storage(format!("Failed to read entry: {}", e)))?;
let path = entry.path();
if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
if name.starts_with(CHECKPOINT_PREFIX) && name.ends_with(CHECKPOINT_EXTENSION) {
let id_str = name
.trim_start_matches(CHECKPOINT_PREFIX)
.trim_end_matches(CHECKPOINT_EXTENSION);
if let Ok(id) = u64::from_str_radix(id_str, 16) {
checkpoints.push((id, path));
}
}
}
}
checkpoints.sort_by(|a, b| b.0.cmp(&a.0));
for (_, path) in checkpoints.iter().skip(self.config.checkpoint_retention) {
if let Err(e) = std::fs::remove_file(path) {
tracing::warn!(path = ?path, error = %e, "Failed to remove old checkpoint");
}
}
Ok(())
}
pub async fn recover(&self) -> MemResult<RecoveryReport> {
let start_time = Instant::now();
let mut report = RecoveryReport::default();
let replay_from = match Self::find_latest_checkpoint(&self.config.dir) {
Ok(checkpoint) => {
report.checkpoint_used = Some(checkpoint.id);
checkpoint.lsn.value()
}
Err(_) => self.last_checkpoint.load(Ordering::SeqCst),
};
let segments = self.segments.read().await;
for seg_meta in segments.iter() {
if let Some(last_lsn) = seg_meta.last_lsn {
if last_lsn < replay_from {
continue;
}
}
match File::open(&seg_meta.path) {
Ok(file) => {
let mut reader = BufReader::new(file);
while let Ok(Some(entry)) = read_entry(&mut reader) {
if entry.lsn > replay_from
&& !matches!(entry.operation, WalOperation::Checkpoint { .. })
{
if entry.verify() {
report.entries_recovered += 1;
report.entries_replayed += 1;
report.last_valid_lsn = entry.lsn();
} else {
report.entries_skipped += 1;
report.errors.push(RecoveryError {
lsn: Some(entry.lsn()),
segment: Some(seg_meta.path.clone()),
message: "Invalid checksum".to_string(),
fatal: false,
});
}
}
}
}
Err(e) => {
report.errors.push(RecoveryError {
lsn: None,
segment: Some(seg_meta.path.clone()),
message: format!("Failed to open segment: {}", e),
fatal: false,
});
}
}
}
{
let segment = self.current_segment.lock().await;
if segment.meta.path.exists() {
if let Ok(file) = File::open(&segment.meta.path) {
let mut reader = BufReader::new(file);
while let Ok(Some(entry)) = read_entry(&mut reader) {
if entry.lsn > replay_from
&& !matches!(entry.operation, WalOperation::Checkpoint { .. })
{
if entry.verify() {
report.entries_recovered += 1;
report.entries_replayed += 1;
report.last_valid_lsn = entry.lsn();
} else {
report.entries_skipped += 1;
}
}
}
}
}
}
report.duration_ms = start_time.elapsed().as_millis() as u64;
report.success = report.errors.iter().all(|e| !e.fatal);
tracing::info!(
count = report.entries_recovered,
skipped = report.entries_skipped,
duration_ms = report.duration_ms,
"WAL: Recovery completed"
);
Ok(report)
}
pub async fn truncate_before(&self, lsn: LogSequenceNumber) -> MemResult<()> {
let mut segments = self.segments.write().await;
let mut to_remove = Vec::new();
let mut i = 0;
while i < segments.len() {
let seg = &segments[i];
if let Some(last_lsn) = seg.last_lsn {
if last_lsn < lsn.value() {
to_remove.push(i);
}
}
i += 1;
}
for &idx in to_remove.iter().rev() {
let seg = segments.remove(idx);
if let Err(e) = std::fs::remove_file(&seg.path) {
tracing::warn!(
path = ?seg.path,
error = %e,
"WAL: Failed to remove old segment"
);
} else {
tracing::info!(
segment_id = seg.segment_id,
last_lsn = ?seg.last_lsn,
"WAL: Removed old segment"
);
}
}
Ok(())
}
pub fn get_current_lsn(&self) -> u64 {
self.current_lsn.load(Ordering::SeqCst)
}
pub fn last_checkpoint_lsn(&self) -> u64 {
self.last_checkpoint.load(Ordering::SeqCst)
}
pub async fn read_from(
&self,
start_lsn: LogSequenceNumber,
limit: usize,
) -> MemResult<Vec<WalEntry>> {
let mut result = Vec::with_capacity(limit);
let segments = self.segments.read().await;
for seg_meta in segments.iter() {
if result.len() >= limit {
break;
}
if let Some(last_lsn) = seg_meta.last_lsn {
if last_lsn < start_lsn.value() {
continue;
}
}
if let Ok(file) = File::open(&seg_meta.path) {
let mut reader = BufReader::new(file);
while let Ok(Some(entry)) = read_entry(&mut reader) {
if entry.lsn >= start_lsn.value() && entry.verify() {
result.push(entry);
if result.len() >= limit {
break;
}
}
}
}
}
{
let segment = self.current_segment.lock().await;
if result.len() < limit && segment.meta.path.exists() {
if let Ok(file) = File::open(&segment.meta.path) {
let mut reader = BufReader::new(file);
while let Ok(Some(entry)) = read_entry(&mut reader) {
if entry.lsn >= start_lsn.value() && entry.verify() {
result.push(entry);
if result.len() >= limit {
break;
}
}
}
}
}
}
result.sort_by_key(|e| e.lsn);
Ok(result)
}
pub async fn close(&self) -> MemResult<()> {
self.closed.store(1, Ordering::SeqCst);
self.shutdown.store(1, Ordering::SeqCst);
self.sync().await?;
Ok(())
}
pub async fn stats(&self) -> WalStats {
let segments = self.segments.read().await;
let current = self.current_segment.lock().await;
let total_segments = segments.len() + 1; let total_size: u64 =
segments.iter().map(|s| s.size_bytes).sum::<u64>() + current.meta.size_bytes;
WalStats {
current_lsn: self.current_lsn.load(Ordering::SeqCst),
synced_lsn: self.synced_lsn.load(Ordering::SeqCst),
last_checkpoint_lsn: self.last_checkpoint.load(Ordering::SeqCst),
total_segments,
total_size_bytes: total_size,
}
}
}
impl Drop for WriteAheadLog {
fn drop(&mut self) {
self.shutdown.store(1, Ordering::SeqCst);
}
}
#[async_trait]
impl WriteAheadLogTrait for WriteAheadLog {
async fn append(&self, op: WalOperation) -> MemResult<LogSequenceNumber> {
let lsn = WriteAheadLog::append(self, op).await?;
Ok(LogSequenceNumber::new(lsn))
}
async fn append_batch(&self, ops: Vec<WalOperation>) -> MemResult<LogSequenceNumber> {
let lsn = WriteAheadLog::append_batch(self, ops).await?;
Ok(LogSequenceNumber::new(lsn))
}
async fn sync(&self) -> MemResult<()> {
WriteAheadLog::sync(self).await
}
async fn checkpoint(&self) -> MemResult<CheckpointId> {
WriteAheadLog::checkpoint(self).await
}
async fn recover(&self) -> MemResult<RecoveryReport> {
WriteAheadLog::recover(self).await
}
async fn truncate_before(&self, lsn: LogSequenceNumber) -> MemResult<()> {
WriteAheadLog::truncate_before(self, lsn).await
}
async fn current_lsn(&self) -> LogSequenceNumber {
LogSequenceNumber::new(self.get_current_lsn())
}
async fn synced_lsn(&self) -> LogSequenceNumber {
LogSequenceNumber::new(self.synced_lsn.load(Ordering::SeqCst))
}
async fn read_from(
&self,
start_lsn: LogSequenceNumber,
limit: usize,
) -> MemResult<Vec<WalEntry>> {
WriteAheadLog::read_from(self, start_lsn, limit).await
}
async fn close(&self) -> MemResult<()> {
WriteAheadLog::close(self).await
}
}
pub struct InMemoryWal {
entries: RwLock<Vec<WalEntry>>,
current_lsn: AtomicU64,
checkpoints: RwLock<Vec<Checkpoint>>,
}
impl Default for InMemoryWal {
fn default() -> Self {
Self::new()
}
}
impl InMemoryWal {
pub fn new() -> Self {
Self {
entries: RwLock::new(Vec::new()),
current_lsn: AtomicU64::new(1),
checkpoints: RwLock::new(Vec::new()),
}
}
}
#[async_trait]
impl WriteAheadLogTrait for InMemoryWal {
async fn append(&self, op: WalOperation) -> MemResult<LogSequenceNumber> {
let lsn = self.current_lsn.fetch_add(1, Ordering::SeqCst);
let entry = WalEntry::new(lsn, op);
let mut entries = self.entries.write().await;
entries.push(entry);
Ok(LogSequenceNumber::new(lsn))
}
async fn append_batch(&self, ops: Vec<WalOperation>) -> MemResult<LogSequenceNumber> {
if ops.is_empty() {
return Ok(LogSequenceNumber::new(
self.current_lsn.load(Ordering::SeqCst),
));
}
let first_lsn = self
.current_lsn
.fetch_add(ops.len() as u64, Ordering::SeqCst);
let mut entries = self.entries.write().await;
for (i, op) in ops.into_iter().enumerate() {
let entry = WalEntry::new(first_lsn + i as u64, op);
entries.push(entry);
}
Ok(LogSequenceNumber::new(first_lsn))
}
async fn sync(&self) -> MemResult<()> {
Ok(())
}
async fn checkpoint(&self) -> MemResult<CheckpointId> {
let checkpoints = self.checkpoints.read().await;
let id = checkpoints.len() as u64 + 1;
drop(checkpoints);
let lsn = LogSequenceNumber::new(self.current_lsn.load(Ordering::SeqCst));
let checkpoint = Checkpoint::new(id, lsn, vec![], vec![]);
let mut checkpoints = self.checkpoints.write().await;
checkpoints.push(checkpoint);
Ok(CheckpointId::new(id))
}
async fn recover(&self) -> MemResult<RecoveryReport> {
let entries = self.entries.read().await;
let checkpoints = self.checkpoints.read().await;
Ok(RecoveryReport {
entries_recovered: entries.len() as u64,
entries_skipped: 0,
entries_replayed: entries.len() as u64,
last_valid_lsn: entries.last().map(|e| e.lsn()).unwrap_or_default(),
checkpoint_used: checkpoints.last().map(|c| c.id),
duration_ms: 0,
errors: vec![],
success: true,
})
}
async fn truncate_before(&self, lsn: LogSequenceNumber) -> MemResult<()> {
let mut entries = self.entries.write().await;
entries.retain(|e| e.lsn >= lsn.value());
Ok(())
}
async fn current_lsn(&self) -> LogSequenceNumber {
LogSequenceNumber::new(self.current_lsn.load(Ordering::SeqCst))
}
async fn synced_lsn(&self) -> LogSequenceNumber {
self.current_lsn().await
}
async fn read_from(
&self,
start_lsn: LogSequenceNumber,
limit: usize,
) -> MemResult<Vec<WalEntry>> {
let entries = self.entries.read().await;
Ok(entries
.iter()
.filter(|e| e.lsn >= start_lsn.value())
.take(limit)
.cloned()
.collect())
}
async fn close(&self) -> MemResult<()> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct WalStats {
pub current_lsn: u64,
pub synced_lsn: u64,
pub last_checkpoint_lsn: u64,
pub total_segments: usize,
pub total_size_bytes: u64,
}
pub fn compute_checksum(entry: &WalEntry) -> u32 {
let entry_for_hash = WalEntry {
lsn: entry.lsn,
timestamp: entry.timestamp,
operation: entry.operation.clone(),
checksum: 0,
};
match wal_serialize(&entry_for_hash) {
Ok(data) => crc32_compute(&data),
Err(_) => 0,
}
}
pub fn verify_checksum(entry: &WalEntry) -> bool {
let computed = compute_checksum(entry);
computed == entry.checksum
}
fn crc32_compute(data: &[u8]) -> u32 {
const CRC32C_TABLE: [u32; 256] = generate_crc32c_table();
let mut crc: u32 = 0xFFFFFFFF;
for byte in data {
let index = ((crc ^ (*byte as u32)) & 0xFF) as usize;
crc = (crc >> 8) ^ CRC32C_TABLE[index];
}
!crc
}
const fn generate_crc32c_table() -> [u32; 256] {
const POLYNOMIAL: u32 = 0x82F63B78; let mut table = [0u32; 256];
let mut i = 0;
while i < 256 {
let mut crc = i as u32;
let mut j = 0;
while j < 8 {
if crc & 1 != 0 {
crc = (crc >> 1) ^ POLYNOMIAL;
} else {
crc >>= 1;
}
j += 1;
}
table[i] = crc;
i += 1;
}
table
}
fn wal_serialize<T: Serialize>(value: &T) -> MemResult<Vec<u8>> {
serde_json::to_vec(value)
.map_err(|e| MemError::storage(format!("Failed to serialize WAL entry: {}", e)))
}
fn wal_deserialize<T: for<'de> Deserialize<'de>>(data: &[u8]) -> MemResult<T> {
serde_json::from_slice(data)
.map_err(|e| MemError::storage(format!("Failed to deserialize WAL entry: {}", e)))
}
fn read_entry<R: Read + BufRead>(reader: &mut R) -> MemResult<Option<WalEntry>> {
let mut magic = [0u8; 4];
match reader.read_exact(&mut magic) {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => {
return Err(MemError::storage(format!(
"Failed to read WAL magic: {}",
e
)))
}
}
if magic != WAL_MAGIC {
return Err(MemError::storage("Invalid WAL magic bytes"));
}
let mut version = [0u8; 1];
reader
.read_exact(&mut version)
.map_err(|e| MemError::storage(format!("Failed to read WAL version: {}", e)))?;
if version[0] != WAL_VERSION {
return Err(MemError::storage(format!(
"Unsupported WAL version: {}",
version[0]
)));
}
let mut len_bytes = [0u8; 4];
reader
.read_exact(&mut len_bytes)
.map_err(|e| MemError::storage(format!("Failed to read entry length: {}", e)))?;
let len = u32::from_le_bytes(len_bytes) as usize;
if len > 100 * 1024 * 1024 {
return Err(MemError::storage(format!(
"WAL entry too large: {} bytes",
len
)));
}
let mut data = vec![0u8; len];
reader
.read_exact(&mut data)
.map_err(|e| MemError::storage(format!("Failed to read entry data: {}", e)))?;
let mut trailing_len = [0u8; 4];
reader
.read_exact(&mut trailing_len)
.map_err(|e| MemError::storage(format!("Failed to read trailing length: {}", e)))?;
if u32::from_le_bytes(trailing_len) != len as u32 {
return Err(MemError::storage("WAL entry length mismatch"));
}
let entry: WalEntry = wal_deserialize(&data)?;
Ok(Some(entry))
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_lsn_ordering() {
let lsn1 = LogSequenceNumber::new(1);
let lsn2 = LogSequenceNumber::new(2);
let lsn3 = LogSequenceNumber::new(2);
assert!(lsn1 < lsn2);
assert_eq!(lsn2, lsn3);
assert_eq!(lsn1.next(), lsn2);
}
#[test]
fn test_checksum_computation() {
let entry = WalEntry::new(
1,
WalOperation::Insert {
id: Uuid::new_v4(),
content: "test content".to_string(),
embedding: vec![0.1, 0.2, 0.3],
},
);
assert!(entry.verify());
let mut modified = entry.clone();
modified.lsn = 999;
assert!(!modified.verify());
}
#[test]
fn test_crc32_computation() {
let data = b"123456789";
let crc = crc32_compute(data);
assert_eq!(crc, 0xE3069283);
}
#[test]
fn test_checkpoint_serialization() {
let checkpoint = Checkpoint::new(
1,
LogSequenceNumber::new(100),
vec![Uuid::new_v4(), Uuid::new_v4()],
vec![Uuid::new_v4()],
);
let bytes = checkpoint.to_bytes().unwrap();
let recovered = Checkpoint::from_bytes(&bytes).unwrap();
assert_eq!(checkpoint.id, recovered.id);
assert_eq!(checkpoint.lsn, recovered.lsn);
assert_eq!(checkpoint.document_ids.len(), recovered.document_ids.len());
}
#[tokio::test]
async fn test_in_memory_wal() {
let wal = InMemoryWal::new();
let lsn1 = wal
.append(WalOperation::Insert {
id: Uuid::new_v4(),
content: "doc1".to_string(),
embedding: vec![0.1, 0.2, 0.3],
})
.await
.unwrap();
let lsn2 = wal
.append(WalOperation::Insert {
id: Uuid::new_v4(),
content: "doc2".to_string(),
embedding: vec![0.4, 0.5, 0.6],
})
.await
.unwrap();
assert_eq!(lsn1.value(), 1);
assert_eq!(lsn2.value(), 2);
let entries = wal.read_from(LogSequenceNumber::new(1), 10).await.unwrap();
assert_eq!(entries.len(), 2);
let ckpt = wal.checkpoint().await.unwrap();
assert_eq!(ckpt.value(), 1);
let report = wal.recover().await.unwrap();
assert!(report.success);
assert_eq!(report.entries_recovered, 2);
}
#[tokio::test]
async fn test_wal_basic_operations() {
let temp_dir = TempDir::new().unwrap();
let config =
WalConfig::new(temp_dir.path().to_path_buf()).with_sync_mode(SyncMode::Immediate);
let wal = WriteAheadLog::new(config).await.unwrap();
let id1 = Uuid::new_v4();
let lsn1 = wal
.append(WalOperation::Insert {
id: id1,
content: "First entry".to_string(),
embedding: vec![0.1, 0.2],
})
.await
.unwrap();
let id2 = Uuid::new_v4();
let lsn2 = wal
.append(WalOperation::Insert {
id: id2,
content: "Second entry".to_string(),
embedding: vec![0.3, 0.4],
})
.await
.unwrap();
assert!(lsn2 > lsn1);
wal.checkpoint().await.unwrap();
let stats = wal.stats().await;
assert!(stats.current_lsn >= 3); assert_eq!(stats.total_segments, 1);
}
#[tokio::test]
async fn test_wal_recovery() {
let temp_dir = TempDir::new().unwrap();
let config =
WalConfig::new(temp_dir.path().to_path_buf()).with_sync_mode(SyncMode::Immediate);
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
{
let wal = WriteAheadLog::new(config.clone()).await.unwrap();
wal.append(WalOperation::Insert {
id: id1,
content: "Before checkpoint".to_string(),
embedding: vec![0.1],
})
.await
.unwrap();
wal.checkpoint().await.unwrap();
wal.append(WalOperation::Insert {
id: id2,
content: "After checkpoint".to_string(),
embedding: vec![0.2],
})
.await
.unwrap();
wal.close().await.unwrap();
}
let wal2 = WriteAheadLog::new(config).await.unwrap();
let report = wal2.recover().await.unwrap();
assert!(report.success);
assert!(report.entries_recovered >= 1);
}
#[tokio::test]
async fn test_wal_batch_operations() {
let temp_dir = TempDir::new().unwrap();
let config =
WalConfig::new(temp_dir.path().to_path_buf()).with_sync_mode(SyncMode::Immediate);
let wal = WriteAheadLog::new(config).await.unwrap();
let ops = vec![
WalOperation::Insert {
id: Uuid::new_v4(),
content: "Batch 1".to_string(),
embedding: vec![0.1],
},
WalOperation::Insert {
id: Uuid::new_v4(),
content: "Batch 2".to_string(),
embedding: vec![0.2],
},
WalOperation::Insert {
id: Uuid::new_v4(),
content: "Batch 3".to_string(),
embedding: vec![0.3],
},
];
let first_lsn = wal.append_batch(ops).await.unwrap();
let entries = wal
.read_from(LogSequenceNumber::new(first_lsn), 10)
.await
.unwrap();
assert_eq!(entries.len(), 3);
assert_eq!(entries[0].lsn, first_lsn);
assert_eq!(entries[1].lsn, first_lsn + 1);
assert_eq!(entries[2].lsn, first_lsn + 2);
}
#[tokio::test]
async fn test_wal_segment_rotation() {
let temp_dir = TempDir::new().unwrap();
let config = WalConfig::new(temp_dir.path().to_path_buf())
.with_segment_size(1) .with_sync_mode(SyncMode::Immediate);
let wal = WriteAheadLog::new(config).await.unwrap();
let large_embedding: Vec<f32> = (0..10000).map(|i| i as f32).collect();
for i in 0..20 {
wal.append(WalOperation::Insert {
id: Uuid::new_v4(),
content: format!("Entry {} with large embedding", i),
embedding: large_embedding.clone(),
})
.await
.unwrap();
}
let stats = wal.stats().await;
assert!(stats.total_segments >= 1);
}
#[tokio::test]
async fn test_wal_truncate() {
let temp_dir = TempDir::new().unwrap();
let config = WalConfig::new(temp_dir.path().to_path_buf())
.with_segment_size(1) .with_sync_mode(SyncMode::Immediate);
let wal = WriteAheadLog::new(config).await.unwrap();
for _ in 0..5 {
wal.append(WalOperation::Insert {
id: Uuid::new_v4(),
content: "test".to_string(),
embedding: vec![0.1; 1000],
})
.await
.unwrap();
}
let checkpoint_id = wal.checkpoint().await.unwrap();
wal.truncate_before(LogSequenceNumber::new(wal.last_checkpoint_lsn()))
.await
.unwrap();
assert!(checkpoint_id.value() > 0);
}
#[test]
fn test_wal_operation_types() {
let insert = WalOperation::Insert {
id: Uuid::new_v4(),
content: "test".to_string(),
embedding: vec![],
};
assert_eq!(insert.op_type(), "INSERT");
let update = WalOperation::Update {
id: Uuid::new_v4(),
content: "test".to_string(),
embedding: vec![],
};
assert_eq!(update.op_type(), "UPDATE");
let delete = WalOperation::Delete { id: Uuid::new_v4() };
assert_eq!(delete.op_type(), "DELETE");
let checkpoint = WalOperation::Checkpoint {
lsn: 0,
checkpoint_id: 1,
};
assert_eq!(checkpoint.op_type(), "CHECKPOINT");
let batch = WalOperation::BatchInsert { items: vec![] };
assert_eq!(batch.op_type(), "BATCH_INSERT");
let txn_begin = WalOperation::TxnBegin {
txn_id: Uuid::new_v4(),
};
assert_eq!(txn_begin.op_type(), "TXN_BEGIN");
}
#[test]
fn test_wal_config_builder() {
let config = WalConfig::new(PathBuf::from("/tmp/wal"))
.with_segment_size(128)
.with_sync_mode(SyncMode::Immediate)
.with_checkpoint_retention(5);
assert_eq!(config.segment_size_mb, 128);
assert_eq!(config.checkpoint_retention, 5);
assert!(matches!(config.sync_mode, SyncMode::Immediate));
}
#[test]
fn test_entry_serialized_size() {
let entry = WalEntry::new(
1,
WalOperation::Insert {
id: Uuid::new_v4(),
content: "Hello, World!".to_string(),
embedding: vec![0.1, 0.2, 0.3, 0.4, 0.5],
},
);
let size = entry.serialized_size().unwrap();
assert!(size > 0);
assert!(size > 50); }
#[tokio::test]
async fn test_wal_trait_interface() {
async fn use_wal(wal: &dyn WriteAheadLogTrait) -> MemResult<()> {
let lsn = wal
.append(WalOperation::Insert {
id: Uuid::new_v4(),
content: "test".to_string(),
embedding: vec![0.1],
})
.await?;
assert!(lsn.value() > 0);
Ok(())
}
let mem_wal = InMemoryWal::new();
use_wal(&mem_wal).await.unwrap();
let temp_dir = TempDir::new().unwrap();
let config =
WalConfig::new(temp_dir.path().to_path_buf()).with_sync_mode(SyncMode::Immediate);
let file_wal = WriteAheadLog::new(config).await.unwrap();
use_wal(&file_wal).await.unwrap();
}
}