use std::collections::BTreeMap;
use std::fs::{self, File, OpenOptions};
use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use parking_lot::{Mutex, RwLock};
pub const DEFAULT_SEGMENT_MAX_SIZE: u64 = 64 * 1024 * 1024;
pub const DEFAULT_ROTATION_INTERVAL: Duration = Duration::from_secs(300);
pub const DEFAULT_CHECKPOINT_INTERVAL: Duration = Duration::from_secs(60);
const SEGMENT_MAGIC: u32 = 0x574C5347;
const SEGMENT_VERSION: u16 = 1;
const SEGMENT_HEADER_SIZE: usize = 32;
const CHECKPOINT_MAGIC: u32 = 0x43484B50;
#[derive(Debug, Clone)]
pub struct SegmentConfig {
pub max_size: u64,
pub rotation_interval: Duration,
pub checkpoint_interval: Duration,
pub wal_dir: PathBuf,
pub sync_on_write: bool,
pub preallocate: bool,
}
impl Default for SegmentConfig {
fn default() -> Self {
Self {
max_size: DEFAULT_SEGMENT_MAX_SIZE,
rotation_interval: DEFAULT_ROTATION_INTERVAL,
checkpoint_interval: DEFAULT_CHECKPOINT_INTERVAL,
wal_dir: PathBuf::from("wal"),
sync_on_write: true,
preallocate: true,
}
}
}
impl SegmentConfig {
pub fn with_wal_dir<P: AsRef<Path>>(mut self, dir: P) -> Self {
self.wal_dir = dir.as_ref().to_path_buf();
self
}
pub fn with_max_size(mut self, size: u64) -> Self {
self.max_size = size;
self
}
}
#[derive(Debug, Clone)]
pub struct SegmentHeader {
pub magic: u32,
pub version: u16,
pub flags: u16,
pub sequence: u64,
pub first_lsn: u64,
pub created_at: u64,
pub reserved: [u8; 8],
}
impl SegmentHeader {
fn new(sequence: u64, first_lsn: u64) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
Self {
magic: SEGMENT_MAGIC,
version: SEGMENT_VERSION,
flags: 0,
sequence,
first_lsn,
created_at: now,
reserved: [0; 8],
}
}
fn encode(&self) -> [u8; SEGMENT_HEADER_SIZE] {
let mut buf = [0u8; SEGMENT_HEADER_SIZE];
buf[0..4].copy_from_slice(&self.magic.to_le_bytes());
buf[4..6].copy_from_slice(&self.version.to_le_bytes());
buf[6..8].copy_from_slice(&self.flags.to_le_bytes());
buf[8..16].copy_from_slice(&self.sequence.to_le_bytes());
buf[16..24].copy_from_slice(&self.first_lsn.to_le_bytes());
buf[24..32].copy_from_slice(&self.created_at.to_le_bytes());
buf
}
fn decode(buf: &[u8]) -> Option<Self> {
if buf.len() < SEGMENT_HEADER_SIZE {
return None;
}
let magic = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
if magic != SEGMENT_MAGIC {
return None;
}
let version = u16::from_le_bytes([buf[4], buf[5]]);
if version > SEGMENT_VERSION {
return None;
}
Some(Self {
magic,
version,
flags: u16::from_le_bytes([buf[6], buf[7]]),
sequence: u64::from_le_bytes([
buf[8], buf[9], buf[10], buf[11], buf[12], buf[13], buf[14], buf[15],
]),
first_lsn: u64::from_le_bytes([
buf[16], buf[17], buf[18], buf[19], buf[20], buf[21], buf[22], buf[23],
]),
created_at: u64::from_le_bytes([
buf[24], buf[25], buf[26], buf[27], buf[28], buf[29], buf[30], buf[31],
]),
reserved: [0; 8],
})
}
}
struct ActiveSegment {
file: BufWriter<File>,
path: PathBuf,
header: SegmentHeader,
offset: u64,
created_at: Instant,
}
#[derive(Debug, Clone)]
pub struct CheckpointRecord {
pub lsn: u64,
pub last_segment: u64,
pub timestamp: u64,
pub memtable_checksum: u64,
pub entry_count: u64,
}
impl CheckpointRecord {
fn encode(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(48);
buf.extend_from_slice(&CHECKPOINT_MAGIC.to_le_bytes());
buf.extend_from_slice(&self.lsn.to_le_bytes());
buf.extend_from_slice(&self.last_segment.to_le_bytes());
buf.extend_from_slice(&self.timestamp.to_le_bytes());
buf.extend_from_slice(&self.memtable_checksum.to_le_bytes());
buf.extend_from_slice(&self.entry_count.to_le_bytes());
let checksum = crc32fast::hash(&buf);
buf.extend_from_slice(&checksum.to_le_bytes());
buf
}
fn decode(buf: &[u8]) -> Option<Self> {
if buf.len() < 48 {
return None;
}
let magic = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
if magic != CHECKPOINT_MAGIC {
return None;
}
let stored_checksum = u32::from_le_bytes([buf[44], buf[45], buf[46], buf[47]]);
let computed_checksum = crc32fast::hash(&buf[0..44]);
if stored_checksum != computed_checksum {
return None;
}
Some(Self {
lsn: u64::from_le_bytes([
buf[4], buf[5], buf[6], buf[7], buf[8], buf[9], buf[10], buf[11],
]),
last_segment: u64::from_le_bytes([
buf[12], buf[13], buf[14], buf[15], buf[16], buf[17], buf[18], buf[19],
]),
timestamp: u64::from_le_bytes([
buf[20], buf[21], buf[22], buf[23], buf[24], buf[25], buf[26], buf[27],
]),
memtable_checksum: u64::from_le_bytes([
buf[28], buf[29], buf[30], buf[31], buf[32], buf[33], buf[34], buf[35],
]),
entry_count: u64::from_le_bytes([
buf[36], buf[37], buf[38], buf[39], buf[40], buf[41], buf[42], buf[43],
]),
})
}
}
pub struct WalSegmentManager {
config: SegmentConfig,
active: Mutex<Option<ActiveSegment>>,
current_lsn: AtomicU64,
segment_sequence: AtomicU64,
segments: RwLock<BTreeMap<u64, SegmentMetadata>>,
last_checkpoint: RwLock<Option<CheckpointRecord>>,
shutdown: AtomicBool,
}
#[derive(Debug, Clone)]
pub struct SegmentMetadata {
pub sequence: u64,
pub first_lsn: u64,
pub last_lsn: Option<u64>,
pub path: PathBuf,
pub size: u64,
pub is_active: bool,
}
impl WalSegmentManager {
pub fn new(config: SegmentConfig) -> std::io::Result<Self> {
fs::create_dir_all(&config.wal_dir)?;
let manager = Self {
config,
active: Mutex::new(None),
current_lsn: AtomicU64::new(0),
segment_sequence: AtomicU64::new(0),
segments: RwLock::new(BTreeMap::new()),
last_checkpoint: RwLock::new(None),
shutdown: AtomicBool::new(false),
};
manager.recover()?;
Ok(manager)
}
fn recover(&self) -> std::io::Result<()> {
let checkpoint_path = self.config.wal_dir.join("checkpoint");
if checkpoint_path.exists() {
let mut file = File::open(&checkpoint_path)?;
let mut buf = Vec::new();
file.read_to_end(&mut buf)?;
if let Some(record) = CheckpointRecord::decode(&buf) {
*self.last_checkpoint.write() = Some(record);
}
}
let entries = fs::read_dir(&self.config.wal_dir)?;
let mut max_sequence = 0u64;
let mut max_lsn = 0u64;
for entry in entries {
let entry = entry?;
let path = entry.path();
if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
if name.starts_with("segment_") && name.ends_with(".wal") {
let mut file = File::open(&path)?;
let mut header_buf = [0u8; SEGMENT_HEADER_SIZE];
if file.read_exact(&mut header_buf).is_ok() {
if let Some(header) = SegmentHeader::decode(&header_buf) {
max_sequence = max_sequence.max(header.sequence);
max_lsn = max_lsn.max(header.first_lsn);
let metadata = file.metadata()?;
self.segments.write().insert(
header.sequence,
SegmentMetadata {
sequence: header.sequence,
first_lsn: header.first_lsn,
last_lsn: None,
path: path.clone(),
size: metadata.len(),
is_active: false,
},
);
}
}
}
}
}
self.segment_sequence
.store(max_sequence + 1, Ordering::SeqCst);
self.current_lsn.store(max_lsn, Ordering::SeqCst);
Ok(())
}
pub fn append(&self, data: &[u8]) -> std::io::Result<u64> {
let mut active = self.active.lock();
if self.needs_rotation(&active) {
self.rotate_segment(&mut active)?;
}
if active.is_none() {
self.create_new_segment(&mut active)?;
}
let segment = active.as_mut().unwrap();
let lsn = self.current_lsn.fetch_add(1, Ordering::SeqCst);
let record_len = 4 + 8 + data.len() + 4;
let mut record = Vec::with_capacity(record_len);
record.extend_from_slice(&(data.len() as u32).to_le_bytes());
record.extend_from_slice(&lsn.to_le_bytes());
record.extend_from_slice(data);
let checksum = crc32fast::hash(&record);
record.extend_from_slice(&checksum.to_le_bytes());
segment.file.write_all(&record)?;
segment.offset += record_len as u64;
if self.config.sync_on_write {
segment.file.flush()?;
segment.file.get_ref().sync_all()?;
}
Ok(lsn)
}
fn needs_rotation(&self, active: &Option<ActiveSegment>) -> bool {
match active {
Some(segment) => {
segment.offset >= self.config.max_size
|| segment.created_at.elapsed() >= self.config.rotation_interval
}
None => false,
}
}
fn rotate_segment(&self, active: &mut Option<ActiveSegment>) -> std::io::Result<()> {
if let Some(mut segment) = active.take() {
segment.file.flush()?;
segment
.file
.into_inner()
.map_err(|e| e.into_error())?
.sync_all()?;
let current_lsn = self.current_lsn.load(Ordering::SeqCst);
if let Some(meta) = self.segments.write().get_mut(&segment.header.sequence) {
meta.is_active = false;
meta.last_lsn = Some(current_lsn);
meta.size = segment.offset;
}
}
Ok(())
}
fn create_new_segment(&self, active: &mut Option<ActiveSegment>) -> std::io::Result<()> {
let sequence = self.segment_sequence.fetch_add(1, Ordering::SeqCst);
let first_lsn = self.current_lsn.load(Ordering::SeqCst);
let path = self
.config
.wal_dir
.join(format!("segment_{:016x}.wal", sequence));
let file = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&path)?;
if self.config.preallocate {
file.set_len(self.config.max_size)?;
}
let mut writer = BufWriter::new(file);
let header = SegmentHeader::new(sequence, first_lsn);
writer.write_all(&header.encode())?;
let segment = ActiveSegment {
file: writer,
path: path.clone(),
header: header.clone(),
offset: SEGMENT_HEADER_SIZE as u64,
created_at: Instant::now(),
};
self.segments.write().insert(
sequence,
SegmentMetadata {
sequence,
first_lsn,
last_lsn: None,
path,
size: SEGMENT_HEADER_SIZE as u64,
is_active: true,
},
);
*active = Some(segment);
Ok(())
}
pub fn create_checkpoint(
&self,
memtable_checksum: u64,
entry_count: u64,
) -> std::io::Result<CheckpointRecord> {
let lsn = self.current_lsn.load(Ordering::SeqCst);
let segments = self.segments.read();
let last_segment = segments
.values()
.filter(|s| s.last_lsn.map(|l| l < lsn).unwrap_or(false))
.map(|s| s.sequence)
.max()
.unwrap_or(0);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let record = CheckpointRecord {
lsn,
last_segment,
timestamp: now,
memtable_checksum,
entry_count,
};
let checkpoint_path = self.config.wal_dir.join("checkpoint");
let temp_path = self.config.wal_dir.join("checkpoint.tmp");
let mut file = File::create(&temp_path)?;
file.write_all(&record.encode())?;
file.sync_all()?;
fs::rename(&temp_path, &checkpoint_path)?;
*self.last_checkpoint.write() = Some(record.clone());
Ok(record)
}
pub fn cleanup_old_segments(&self) -> std::io::Result<usize> {
let checkpoint = self.last_checkpoint.read().clone();
let last_safe_segment = match checkpoint {
Some(cp) => cp.last_segment,
None => return Ok(0),
};
let mut segments = self.segments.write();
let old_segments: Vec<u64> = segments
.keys()
.filter(|&&seq| seq <= last_safe_segment)
.copied()
.collect();
let mut cleaned = 0;
for sequence in old_segments {
if let Some(meta) = segments.remove(&sequence) {
if meta.path.exists() {
fs::remove_file(&meta.path)?;
cleaned += 1;
}
}
}
Ok(cleaned)
}
pub fn stats(&self) -> SegmentStats {
let segments = self.segments.read();
let total_size: u64 = segments.values().map(|s| s.size).sum();
let checkpoint = self.last_checkpoint.read().clone();
SegmentStats {
segment_count: segments.len(),
total_size,
current_lsn: self.current_lsn.load(Ordering::SeqCst),
current_sequence: self.segment_sequence.load(Ordering::SeqCst),
last_checkpoint_lsn: checkpoint.as_ref().map(|c| c.lsn),
}
}
pub fn recovery_iterator(&self, from_lsn: u64) -> RecoveryIterator<'_> {
RecoveryIterator::new(self, from_lsn)
}
pub fn flush(&self) -> std::io::Result<()> {
let mut active = self.active.lock();
if let Some(ref mut segment) = *active {
segment.file.flush()?;
}
Ok(())
}
pub fn shutdown(&self) -> std::io::Result<()> {
self.shutdown.store(true, Ordering::SeqCst);
let mut active = self.active.lock();
if let Some(mut segment) = active.take() {
segment.file.flush()?;
segment
.file
.into_inner()
.map_err(|e| e.into_error())?
.sync_all()?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SegmentStats {
pub segment_count: usize,
pub total_size: u64,
pub current_lsn: u64,
pub current_sequence: u64,
pub last_checkpoint_lsn: Option<u64>,
}
pub struct RecoveryIterator<'a> {
manager: &'a WalSegmentManager,
current_segment_idx: usize,
segment_sequences: Vec<u64>,
current_reader: Option<BufReader<File>>,
current_offset: u64,
from_lsn: u64,
}
impl<'a> RecoveryIterator<'a> {
fn new(manager: &'a WalSegmentManager, from_lsn: u64) -> Self {
let segments = manager.segments.read();
let mut sequences: Vec<u64> = segments
.values()
.filter(|s| {
s.first_lsn >= from_lsn || s.last_lsn.map(|l| l >= from_lsn).unwrap_or(true)
})
.map(|s| s.sequence)
.collect();
sequences.sort();
Self {
manager,
current_segment_idx: 0,
segment_sequences: sequences,
current_reader: None,
current_offset: SEGMENT_HEADER_SIZE as u64,
from_lsn,
}
}
pub fn next_entry(&mut self) -> std::io::Result<Option<WalEntry>> {
loop {
if self.current_reader.is_none() {
if self.current_segment_idx >= self.segment_sequences.len() {
return Ok(None);
}
let sequence = self.segment_sequences[self.current_segment_idx];
let segments = self.manager.segments.read();
if let Some(meta) = segments.get(&sequence) {
let file = File::open(&meta.path)?;
let mut reader = BufReader::new(file);
reader.seek(SeekFrom::Start(SEGMENT_HEADER_SIZE as u64))?;
self.current_reader = Some(reader);
self.current_offset = SEGMENT_HEADER_SIZE as u64;
} else {
self.current_segment_idx += 1;
continue;
}
}
let reader = self.current_reader.as_mut().unwrap();
let mut len_buf = [0u8; 4];
match reader.read_exact(&mut len_buf) {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
self.current_reader = None;
self.current_segment_idx += 1;
continue;
}
Err(e) => return Err(e),
}
let data_len = u32::from_le_bytes(len_buf) as usize;
if data_len == 0 || data_len > 100 * 1024 * 1024 {
self.current_reader = None;
self.current_segment_idx += 1;
continue;
}
let mut lsn_buf = [0u8; 8];
reader.read_exact(&mut lsn_buf)?;
let lsn = u64::from_le_bytes(lsn_buf);
let mut data = vec![0u8; data_len];
reader.read_exact(&mut data)?;
let mut checksum_buf = [0u8; 4];
reader.read_exact(&mut checksum_buf)?;
let stored_checksum = u32::from_le_bytes(checksum_buf);
let mut verify_buf = Vec::with_capacity(4 + 8 + data_len);
verify_buf.extend_from_slice(&len_buf);
verify_buf.extend_from_slice(&lsn_buf);
verify_buf.extend_from_slice(&data);
let computed_checksum = crc32fast::hash(&verify_buf);
if stored_checksum != computed_checksum {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"WAL entry checksum mismatch",
));
}
self.current_offset += (4 + 8 + data_len + 4) as u64;
if lsn < self.from_lsn {
continue;
}
return Ok(Some(WalEntry { lsn, data }));
}
}
}
#[derive(Debug, Clone)]
pub struct WalEntry {
pub lsn: u64,
pub data: Vec<u8>,
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_segment_manager_basic() {
let dir = tempdir().unwrap();
let config = SegmentConfig::default()
.with_wal_dir(dir.path())
.with_max_size(1024);
let manager = WalSegmentManager::new(config).unwrap();
for i in 0..100 {
let data = format!("entry_{}", i);
let lsn = manager.append(data.as_bytes()).unwrap();
assert_eq!(lsn, i as u64);
}
let stats = manager.stats();
assert!(stats.segment_count > 0);
assert_eq!(stats.current_lsn, 100);
manager.shutdown().unwrap();
}
#[test]
fn test_checkpoint_and_cleanup() {
let dir = tempdir().unwrap();
let config = SegmentConfig::default()
.with_wal_dir(dir.path())
.with_max_size(256);
let manager = WalSegmentManager::new(config).unwrap();
for i in 0..50 {
let data = format!("entry_{:04}", i);
manager.append(data.as_bytes()).unwrap();
}
manager.flush().unwrap();
let checkpoint = manager.create_checkpoint(12345, 50).unwrap();
assert!(checkpoint.lsn > 0);
let cleaned = manager.cleanup_old_segments().unwrap();
let _ = cleaned;
manager.shutdown().unwrap();
}
#[test]
fn test_recovery() {
let dir = tempdir().unwrap();
let config = SegmentConfig::default().with_wal_dir(dir.path());
{
let manager = WalSegmentManager::new(config.clone()).unwrap();
for i in 0..10 {
let data = format!("data_{}", i);
manager.append(data.as_bytes()).unwrap();
}
manager.shutdown().unwrap();
}
{
let manager = WalSegmentManager::new(config).unwrap();
let mut iter = manager.recovery_iterator(0);
let mut count = 0;
while let Some(entry) = iter.next_entry().unwrap() {
let data = String::from_utf8_lossy(&entry.data);
assert!(data.starts_with("data_"));
count += 1;
}
assert_eq!(count, 10);
}
}
#[test]
fn test_segment_header_encoding() {
let header = SegmentHeader::new(42, 12345);
let encoded = header.encode();
let decoded = SegmentHeader::decode(&encoded).unwrap();
assert_eq!(decoded.magic, SEGMENT_MAGIC);
assert_eq!(decoded.sequence, 42);
assert_eq!(decoded.first_lsn, 12345);
}
#[test]
fn test_segment_header_rejects_future_version() {
let mut encoded = SegmentHeader::new(1, 1).encode();
let future = SEGMENT_VERSION + 1;
encoded[4..6].copy_from_slice(&future.to_le_bytes());
assert!(SegmentHeader::decode(&encoded).is_none());
}
#[test]
fn test_checkpoint_record_encoding() {
let record = CheckpointRecord {
lsn: 1000,
last_segment: 5,
timestamp: 123456789,
memtable_checksum: 0xDEADBEEF,
entry_count: 500,
};
let encoded = record.encode();
let decoded = CheckpointRecord::decode(&encoded).unwrap();
assert_eq!(decoded.lsn, 1000);
assert_eq!(decoded.last_segment, 5);
assert_eq!(decoded.memtable_checksum, 0xDEADBEEF);
assert_eq!(decoded.entry_count, 500);
}
}