use std::fs;
use std::path::{Path, PathBuf};
use tracing::info;
use crate::error::{Result, WalError};
use crate::record::WalRecord;
use crate::segment::{
DEFAULT_SEGMENT_TARGET_SIZE, SegmentMeta, TruncateResult, discover_segments, segment_path,
truncate_segments,
};
use crate::writer::{WalWriter, WalWriterConfig};
#[derive(Debug, Clone)]
pub struct SegmentedWalConfig {
pub wal_dir: PathBuf,
pub segment_target_size: u64,
pub writer_config: WalWriterConfig,
}
impl SegmentedWalConfig {
pub fn new(wal_dir: PathBuf) -> Self {
Self {
wal_dir,
segment_target_size: DEFAULT_SEGMENT_TARGET_SIZE,
writer_config: WalWriterConfig::default(),
}
}
pub fn for_testing(wal_dir: PathBuf) -> Self {
Self {
wal_dir,
segment_target_size: DEFAULT_SEGMENT_TARGET_SIZE,
writer_config: WalWriterConfig {
use_direct_io: false,
..Default::default()
},
}
}
}
pub struct SegmentedWal {
wal_dir: PathBuf,
writer: WalWriter,
active_first_lsn: u64,
segment_target_size: u64,
writer_config: WalWriterConfig,
encryption_ring: Option<crate::crypto::KeyRing>,
}
impl SegmentedWal {
pub fn open(config: SegmentedWalConfig) -> Result<Self> {
fs::create_dir_all(&config.wal_dir).map_err(WalError::Io)?;
let segments = discover_segments(&config.wal_dir)?;
let (writer, active_first_lsn) = if segments.is_empty() {
let path = segment_path(&config.wal_dir, 1);
let writer = WalWriter::open(&path, config.writer_config.clone())?;
(writer, 1u64)
} else {
let last = &segments[segments.len() - 1];
let writer = WalWriter::open(&last.path, config.writer_config.clone())?;
(writer, last.first_lsn)
};
info!(
wal_dir = %config.wal_dir.display(),
segments = segments.len().max(1),
active_first_lsn,
next_lsn = writer.next_lsn(),
"segmented WAL opened"
);
Ok(Self {
wal_dir: config.wal_dir,
writer,
active_first_lsn,
segment_target_size: config.segment_target_size,
writer_config: config.writer_config,
encryption_ring: None,
})
}
pub fn set_encryption_ring(&mut self, ring: crate::crypto::KeyRing) {
self.writer.set_encryption_ring(ring.clone());
self.encryption_ring = Some(ring);
}
pub fn encryption_ring(&self) -> Option<&crate::crypto::KeyRing> {
self.encryption_ring.as_ref()
}
pub fn append(
&mut self,
record_type: u16,
tenant_id: u32,
vshard_id: u16,
payload: &[u8],
) -> Result<u64> {
if self.writer.file_offset() >= self.segment_target_size {
self.roll_segment()?;
}
self.writer
.append(record_type, tenant_id, vshard_id, payload)
}
pub fn sync(&mut self) -> Result<()> {
self.writer.sync()
}
pub fn next_lsn(&self) -> u64 {
self.writer.next_lsn()
}
pub fn active_segment_first_lsn(&self) -> u64 {
self.active_first_lsn
}
pub fn wal_dir(&self) -> &Path {
&self.wal_dir
}
pub fn truncate_before(&self, checkpoint_lsn: u64) -> Result<TruncateResult> {
truncate_segments(&self.wal_dir, checkpoint_lsn, self.active_first_lsn)
}
pub fn replay(&self) -> Result<Vec<WalRecord>> {
replay_all_segments(&self.wal_dir)
}
pub fn replay_from(&self, from_lsn: u64) -> Result<Vec<WalRecord>> {
let all = self.replay()?;
Ok(all
.into_iter()
.filter(|r| r.header.lsn >= from_lsn)
.collect())
}
pub fn replay_from_limit(
&self,
from_lsn: u64,
max_records: usize,
) -> Result<(Vec<WalRecord>, bool)> {
replay_from_limit_dir(&self.wal_dir, from_lsn, max_records)
}
pub fn list_segments(&self) -> Result<Vec<SegmentMeta>> {
discover_segments(&self.wal_dir)
}
pub fn total_size_bytes(&self) -> Result<u64> {
let segments = discover_segments(&self.wal_dir)?;
Ok(segments.iter().map(|s| s.file_size).sum())
}
fn roll_segment(&mut self) -> Result<()> {
self.writer.seal()?;
let new_first_lsn = self.writer.next_lsn();
let new_path = segment_path(&self.wal_dir, new_first_lsn);
let mut new_writer =
WalWriter::open_with_start_lsn(&new_path, self.writer_config.clone(), new_first_lsn)?;
if let Some(ref ring) = self.encryption_ring {
new_writer.set_encryption_ring(ring.clone());
}
self.writer = new_writer;
self.active_first_lsn = new_first_lsn;
let _ = crate::segment::fsync_directory(&self.wal_dir);
info!(
segment = %new_path.display(),
first_lsn = new_first_lsn,
"rolled to new WAL segment"
);
Ok(())
}
}
pub fn replay_all_segments(wal_dir: &Path) -> Result<Vec<WalRecord>> {
let segments = discover_segments(wal_dir)?;
let mut all_records = Vec::new();
for seg in &segments {
let reader = crate::reader::WalReader::open(&seg.path)?;
for record_result in reader.records() {
all_records.push(record_result?);
}
}
Ok(all_records)
}
pub fn replay_from_limit_dir(
wal_dir: &Path,
from_lsn: u64,
max_records: usize,
) -> Result<(Vec<WalRecord>, bool)> {
let segments = discover_segments(wal_dir)?;
let mut records = Vec::with_capacity(max_records.min(4096));
for seg in &segments {
let reader = crate::reader::WalReader::open(&seg.path)?;
for record_result in reader.records() {
let record = record_result?;
if record.header.lsn >= from_lsn {
records.push(record);
if records.len() >= max_records {
return Ok((records, true));
}
}
}
}
Ok((records, false))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::record::RecordType;
fn test_config(dir: &Path) -> SegmentedWalConfig {
SegmentedWalConfig::for_testing(dir.to_path_buf())
}
#[test]
fn create_and_append() {
let dir = tempfile::tempdir().unwrap();
let wal_dir = dir.path().join("wal");
let mut wal = SegmentedWal::open(test_config(&wal_dir)).unwrap();
let lsn1 = wal.append(RecordType::Put as u16, 1, 0, b"hello").unwrap();
let lsn2 = wal.append(RecordType::Put as u16, 1, 0, b"world").unwrap();
wal.sync().unwrap();
assert_eq!(lsn1, 1);
assert_eq!(lsn2, 2);
assert_eq!(wal.next_lsn(), 3);
}
#[test]
fn replay_after_close() {
let dir = tempfile::tempdir().unwrap();
let wal_dir = dir.path().join("wal");
{
let mut wal = SegmentedWal::open(test_config(&wal_dir)).unwrap();
wal.append(RecordType::Put as u16, 1, 0, b"first").unwrap();
wal.append(RecordType::Delete as u16, 2, 1, b"second")
.unwrap();
wal.append(RecordType::Put as u16, 1, 0, b"third").unwrap();
wal.sync().unwrap();
}
let wal = SegmentedWal::open(test_config(&wal_dir)).unwrap();
let records = wal.replay().unwrap();
assert_eq!(records.len(), 3);
assert_eq!(records[0].payload, b"first");
assert_eq!(records[1].payload, b"second");
assert_eq!(records[2].payload, b"third");
assert_eq!(wal.next_lsn(), 4);
}
#[test]
fn automatic_segment_rollover() {
let dir = tempfile::tempdir().unwrap();
let wal_dir = dir.path().join("wal");
let config = SegmentedWalConfig {
wal_dir: wal_dir.clone(),
segment_target_size: 100, writer_config: WalWriterConfig {
use_direct_io: false,
..Default::default()
},
};
let mut wal = SegmentedWal::open(config).unwrap();
for i in 0..20u32 {
let payload = format!("record-{i:04}");
wal.append(RecordType::Put as u16, 1, 0, payload.as_bytes())
.unwrap();
wal.sync().unwrap();
}
let segments = wal.list_segments().unwrap();
assert!(
segments.len() > 1,
"expected multiple segments, got {}",
segments.len()
);
let records = wal.replay().unwrap();
assert_eq!(records.len(), 20);
for (i, record) in records.iter().enumerate() {
assert_eq!(record.header.lsn, (i + 1) as u64);
let expected = format!("record-{i:04}");
assert_eq!(record.payload, expected.as_bytes());
}
}
#[test]
fn truncation_removes_old_segments() {
let dir = tempfile::tempdir().unwrap();
let wal_dir = dir.path().join("wal");
let config = SegmentedWalConfig {
wal_dir: wal_dir.clone(),
segment_target_size: 100,
writer_config: WalWriterConfig {
use_direct_io: false,
..Default::default()
},
};
let mut wal = SegmentedWal::open(config).unwrap();
for i in 0..20u32 {
let payload = format!("record-{i:04}");
wal.append(RecordType::Put as u16, 1, 0, payload.as_bytes())
.unwrap();
wal.sync().unwrap();
}
let segments_before = wal.list_segments().unwrap();
assert!(segments_before.len() > 1);
let result = wal.truncate_before(15).unwrap();
assert!(result.segments_deleted > 0);
assert!(result.bytes_reclaimed > 0);
let segments_after = wal.list_segments().unwrap();
assert!(segments_after.len() < segments_before.len());
let records = wal.replay().unwrap();
assert!(records.iter().any(|r| r.header.lsn >= 15));
}
#[test]
fn replay_from_checkpoint_lsn() {
let dir = tempfile::tempdir().unwrap();
let wal_dir = dir.path().join("wal");
let mut wal = SegmentedWal::open(test_config(&wal_dir)).unwrap();
for i in 0..10u32 {
wal.append(RecordType::Put as u16, 1, 0, format!("r{i}").as_bytes())
.unwrap();
}
wal.sync().unwrap();
let records = wal.replay_from(6).unwrap();
assert_eq!(records.len(), 5);
assert_eq!(records[0].header.lsn, 6);
assert_eq!(records[4].header.lsn, 10);
}
#[test]
fn total_size_bytes() {
let dir = tempfile::tempdir().unwrap();
let wal_dir = dir.path().join("wal");
let mut wal = SegmentedWal::open(test_config(&wal_dir)).unwrap();
wal.append(RecordType::Put as u16, 1, 0, b"data").unwrap();
wal.sync().unwrap();
let size = wal.total_size_bytes().unwrap();
assert!(size > 0);
}
#[test]
fn reopen_continues_lsn() {
let dir = tempfile::tempdir().unwrap();
let wal_dir = dir.path().join("wal");
{
let mut wal = SegmentedWal::open(test_config(&wal_dir)).unwrap();
wal.append(RecordType::Put as u16, 1, 0, b"a").unwrap();
wal.append(RecordType::Put as u16, 1, 0, b"b").unwrap();
wal.sync().unwrap();
}
{
let mut wal = SegmentedWal::open(test_config(&wal_dir)).unwrap();
assert_eq!(wal.next_lsn(), 3);
let lsn = wal.append(RecordType::Put as u16, 1, 0, b"c").unwrap();
assert_eq!(lsn, 3);
wal.sync().unwrap();
}
let wal = SegmentedWal::open(test_config(&wal_dir)).unwrap();
let records = wal.replay().unwrap();
assert_eq!(records.len(), 3);
}
#[test]
fn encryption_persists_across_segments() {
let dir = tempfile::tempdir().unwrap();
let wal_dir = dir.path().join("wal");
let key = crate::crypto::WalEncryptionKey::from_bytes(&[42u8; 32]);
let ring = crate::crypto::KeyRing::new(key);
let config = SegmentedWalConfig {
wal_dir: wal_dir.clone(),
segment_target_size: 100, writer_config: WalWriterConfig {
use_direct_io: false,
..Default::default()
},
};
let mut wal = SegmentedWal::open(config).unwrap();
wal.set_encryption_ring(ring);
for i in 0..10u32 {
wal.append(RecordType::Put as u16, 1, 0, format!("enc-{i}").as_bytes())
.unwrap();
wal.sync().unwrap();
}
assert!(wal.list_segments().unwrap().len() > 1);
let records = wal.replay().unwrap();
assert_eq!(records.len(), 10);
assert!(records.iter().all(|r| r.is_encrypted()));
}
#[test]
fn replay_from_limit_basic() {
let dir = tempfile::tempdir().unwrap();
let config = test_config(dir.path());
let mut wal = SegmentedWal::open(config).unwrap();
for i in 0..10u8 {
wal.append(RecordType::Put as u16, 1, 0, &[i]).unwrap();
}
wal.sync().unwrap();
let (records, has_more) = wal.replay_from_limit(1, 100).unwrap();
assert_eq!(records.len(), 10);
assert!(!has_more);
let (records, has_more) = wal.replay_from_limit(1, 3).unwrap();
assert_eq!(records.len(), 3);
assert!(has_more);
assert_eq!(records[0].header.lsn, 1);
assert_eq!(records[2].header.lsn, 3);
}
#[test]
fn replay_from_limit_with_lsn_filter() {
let dir = tempfile::tempdir().unwrap();
let config = test_config(dir.path());
let mut wal = SegmentedWal::open(config).unwrap();
for i in 0..10u8 {
wal.append(RecordType::Put as u16, 1, 0, &[i]).unwrap();
}
wal.sync().unwrap();
let (records, has_more) = wal.replay_from_limit(6, 100).unwrap();
assert_eq!(records.len(), 5);
assert!(!has_more);
assert_eq!(records[0].header.lsn, 6);
let (records, has_more) = wal.replay_from_limit(6, 2).unwrap();
assert_eq!(records.len(), 2);
assert!(has_more);
}
#[test]
fn replay_from_limit_empty() {
let dir = tempfile::tempdir().unwrap();
let config = test_config(dir.path());
let mut wal = SegmentedWal::open(config).unwrap();
wal.append(RecordType::Put as u16, 1, 0, b"a").unwrap();
wal.sync().unwrap();
let (records, has_more) = wal.replay_from_limit(999, 100).unwrap();
assert!(records.is_empty());
assert!(!has_more);
}
#[test]
fn replay_from_limit_across_segments() {
let dir = tempfile::tempdir().unwrap();
let config = test_config(dir.path());
let mut wal = SegmentedWal::open(config).unwrap();
for i in 0..10u8 {
wal.append(RecordType::Put as u16, 1, 0, &[i]).unwrap();
}
wal.sync().unwrap();
wal.roll_segment().unwrap();
for i in 10..20u8 {
wal.append(RecordType::Put as u16, 1, 0, &[i]).unwrap();
}
wal.sync().unwrap();
let seg_count = wal.list_segments().unwrap().len();
assert!(
seg_count >= 2,
"expected multiple segments, got {seg_count}"
);
let (records, has_more) = wal.replay_from_limit(1, 5).unwrap();
assert_eq!(records.len(), 5);
assert!(has_more);
let next_lsn = records.last().unwrap().header.lsn + 1;
let (records2, _) = wal.replay_from_limit(next_lsn, 200).unwrap();
assert_eq!(records2.len(), 15); }
}