use std::borrow::Cow;
use std::fs::{File, OpenOptions};
use std::io::{BufWriter, Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use super::LSN;
use super::ring_buffer::PendingEntry;
use crate::core::error::{Error, Result, StorageError};
use super::segment_reader::{WAL_HEADER_SIZE, WAL_MAGIC, WAL_VERSION, WAL_VERSION_ENCRYPTED};
#[derive(Debug, Clone)]
pub struct SegmentMetadata {
pub min_lsn: LSN,
pub max_lsn: LSN,
pub entry_count: u64,
}
impl SegmentMetadata {
pub fn new(min_lsn: LSN, max_lsn: LSN, entry_count: u64) -> Self {
Self {
min_lsn,
max_lsn,
entry_count,
}
}
pub fn to_bytes(&self) -> [u8; 24] {
let mut bytes = [0u8; 24];
bytes[0..8].copy_from_slice(&self.min_lsn.0.to_le_bytes()[..]);
bytes[8..16].copy_from_slice(&self.max_lsn.0.to_le_bytes()[..]);
bytes[16..24].copy_from_slice(&self.entry_count.to_le_bytes()[..]);
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 24 {
return None;
}
let min_lsn = LSN(u64::from_le_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]));
let max_lsn = LSN(u64::from_le_bytes([
bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
]));
let entry_count = u64::from_le_bytes([
bytes[16], bytes[17], bytes[18], bytes[19], bytes[20], bytes[21], bytes[22], bytes[23],
]);
Some(Self {
min_lsn,
max_lsn,
entry_count,
})
}
}
#[derive(Clone)]
pub struct FlushCoordinatorConfig {
pub wal_dir: PathBuf,
pub segment_size: usize,
pub segments_to_retain: usize,
pub flush_interval_ms: u64,
pub sync_on_flush: bool,
pub write_buffer_size: usize,
pub wal_cipher: Option<Arc<dyn crate::encryption::cipher::Cipher>>,
}
impl std::fmt::Debug for FlushCoordinatorConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FlushCoordinatorConfig")
.field("wal_dir", &self.wal_dir)
.field("segment_size", &self.segment_size)
.field("segments_to_retain", &self.segments_to_retain)
.field("flush_interval_ms", &self.flush_interval_ms)
.field("sync_on_flush", &self.sync_on_flush)
.field("write_buffer_size", &self.write_buffer_size)
.field(
"wal_cipher",
&self.wal_cipher.as_ref().map(|c| c.algorithm_name()),
)
.finish()
}
}
impl Default for FlushCoordinatorConfig {
fn default() -> Self {
Self {
wal_dir: PathBuf::from("data/wal"),
segment_size: 64 * 1024 * 1024, segments_to_retain: 10,
flush_interval_ms: 10, sync_on_flush: true,
write_buffer_size: 64 * 1024, wal_cipher: None,
}
}
}
impl FlushCoordinatorConfig {
pub fn new(wal_dir: impl Into<PathBuf>) -> Self {
Self {
wal_dir: wal_dir.into(),
..Default::default()
}
}
}
#[derive(Debug, Clone, Default)]
pub struct FlushStats {
pub entries_flushed: usize,
pub bytes_written: usize,
pub flush_duration: Duration,
pub segment_rotated: bool,
}
pub struct FlushCoordinator {
config: FlushCoordinatorConfig,
current_segment_id: AtomicU64,
current_segment_size: AtomicU64,
writer: Mutex<Option<BufWriter<File>>>,
sync_handle: Mutex<Option<File>>,
total_entries_flushed: AtomicU64,
total_bytes_written: AtomicU64,
total_flushes: AtomicU64,
current_segment_min_lsn: AtomicU64,
current_segment_max_lsn: AtomicU64,
current_segment_entry_count: AtomicU64,
}
impl FlushCoordinator {
pub fn new(config: FlushCoordinatorConfig) -> Result<Self> {
std::fs::create_dir_all(&config.wal_dir).map_err(|e| {
Error::Storage(StorageError::IoError(format!(
"Failed to create WAL directory: {}",
e
)))
})?;
let coordinator = Self {
config,
current_segment_id: AtomicU64::new(0),
current_segment_size: AtomicU64::new(0),
writer: Mutex::new(None),
sync_handle: Mutex::new(None),
total_entries_flushed: AtomicU64::new(0),
total_bytes_written: AtomicU64::new(0),
total_flushes: AtomicU64::new(0),
current_segment_min_lsn: AtomicU64::new(u64::MAX),
current_segment_max_lsn: AtomicU64::new(0),
current_segment_entry_count: AtomicU64::new(0),
};
coordinator.initialize_from_existing()?;
Ok(coordinator)
}
fn initialize_from_existing(&self) -> Result<()> {
let mut max_segment_id = 0u64;
if let Ok(entries) = std::fs::read_dir(&self.config.wal_dir) {
for entry in entries.flatten() {
let path = entry.path();
if let Some(id) = path
.extension()
.filter(|ext| *ext == "log")
.and_then(|_| path.file_stem())
.and_then(|s| s.to_string_lossy().parse::<u64>().ok())
{
max_segment_id = max_segment_id.max(id);
}
}
}
self.current_segment_id
.store(max_segment_id, Ordering::Relaxed);
Ok(())
}
fn segment_path(&self, segment_id: u64) -> PathBuf {
self.config.wal_dir.join(format!("{:06}.log", segment_id))
}
fn segment_meta_path(&self, segment_id: u64) -> PathBuf {
self.config
.wal_dir
.join(format!("{:06}.log.meta", segment_id))
}
fn write_segment_metadata(
&self,
segment_id: u64,
min_lsn: u64,
max_lsn: u64,
entry_count: u64,
) -> Result<()> {
if min_lsn <= max_lsn && entry_count > 0 {
let metadata = SegmentMetadata::new(LSN(min_lsn), LSN(max_lsn), entry_count);
let meta_path = self.segment_meta_path(segment_id);
let bytes = metadata.to_bytes();
std::fs::write(&meta_path, bytes).map_err(|e| {
Error::Storage(StorageError::IoError(format!(
"Failed to write segment metadata: {}",
e
)))
})?;
}
Ok(())
}
pub fn read_segment_metadata(&self, segment_id: u64) -> Option<SegmentMetadata> {
let meta_path = self.segment_meta_path(segment_id);
let mut bytes = [0u8; 24];
File::open(&meta_path).ok()?.read_exact(&mut bytes).ok()?;
SegmentMetadata::from_bytes(&bytes)
}
fn ensure_segment_open(&self, writer_guard: &mut Option<BufWriter<File>>) -> Result<()> {
if writer_guard.is_some() {
return Ok(());
}
let segment_id = self.current_segment_id.fetch_add(1, Ordering::Relaxed) + 1;
let path = self.segment_path(segment_id);
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&path)
.map_err(|e| {
Error::Storage(StorageError::IoError(format!(
"Failed to open WAL segment {}: {}",
path.display(),
e
)))
})?;
let sync_file = file.try_clone().map_err(|e| {
Error::Storage(StorageError::IoError(format!(
"Failed to clone WAL file handle: {}",
e
)))
})?;
let mut writer = BufWriter::with_capacity(self.config.write_buffer_size, file);
let current_len = std::fs::metadata(&path).map(|m| m.len()).unwrap_or(0);
if current_len == 0 {
writer.write_all(&WAL_MAGIC).map_err(|e| {
Error::Storage(StorageError::IoError(format!(
"Failed to write WAL header: {}",
e
)))
})?;
let version = if self.config.wal_cipher.is_some() {
WAL_VERSION_ENCRYPTED
} else {
WAL_VERSION
};
writer.write_all(&[version]).map_err(|e| {
Error::Storage(StorageError::IoError(format!(
"Failed to write WAL version: {}",
e
)))
})?;
self.current_segment_size
.store(WAL_HEADER_SIZE as u64, Ordering::Relaxed);
} else {
self.current_segment_size
.store(current_len, Ordering::Relaxed);
}
*writer_guard = Some(writer);
let mut sync_guard = self.sync_handle.lock().unwrap_or_else(|e| e.into_inner());
*sync_guard = Some(sync_file);
self.current_segment_id.store(segment_id, Ordering::Relaxed);
Ok(())
}
fn maybe_rotate_segment(&self, writer_guard: &mut Option<BufWriter<File>>) -> Result<bool> {
let current_size = self.current_segment_size.load(Ordering::Relaxed);
if current_size >= self.config.segment_size as u64 {
let closing_segment_id = self.current_segment_id.load(Ordering::Relaxed);
if let Some(writer) = writer_guard {
writer.flush().map_err(|e| {
Error::Storage(StorageError::IoError(format!(
"Failed to flush WAL segment: {}",
e
)))
})?;
}
let min_lsn = self.current_segment_min_lsn.load(Ordering::Relaxed);
let max_lsn = self.current_segment_max_lsn.load(Ordering::Relaxed);
let entry_count = self.current_segment_entry_count.load(Ordering::Relaxed);
*writer_guard = None;
self.current_segment_size.store(0, Ordering::Relaxed);
self.current_segment_min_lsn
.store(u64::MAX, Ordering::Relaxed);
self.current_segment_max_lsn.store(0, Ordering::Relaxed);
self.current_segment_entry_count.store(0, Ordering::Relaxed);
if self.config.sync_on_flush {
let sync_guard = self.sync_handle.lock().unwrap_or_else(|e| e.into_inner());
if let Some(ref sync_file) = *sync_guard {
sync_file.sync_data().map_err(|e| {
Error::Storage(StorageError::IoError(format!(
"Failed to sync WAL segment: {}",
e
)))
})?;
}
}
self.write_segment_metadata(closing_segment_id, min_lsn, max_lsn, entry_count)?;
{
let mut sync_guard = self.sync_handle.lock().unwrap_or_else(|e| e.into_inner());
*sync_guard = None;
}
self.cleanup_old_segments()?;
return Ok(true);
}
Ok(false)
}
fn cleanup_old_segments(&self) -> Result<()> {
let current_id = self.current_segment_id.load(Ordering::Relaxed);
let retain_from = current_id.saturating_sub(self.config.segments_to_retain as u64);
if let Ok(entries) = std::fs::read_dir(&self.config.wal_dir) {
for entry in entries.flatten() {
let path = entry.path();
let is_old_segment = path.extension().is_some_and(|ext| ext == "log")
&& path
.file_stem()
.and_then(|s| s.to_string_lossy().parse::<u64>().ok())
.is_some_and(|id| id < retain_from);
let is_old_meta = path
.file_name()
.and_then(|n| n.to_str())
.is_some_and(|name| {
name.ends_with(".log.meta")
&& name
.strip_suffix(".log.meta")
.and_then(|s| s.parse::<u64>().ok())
.is_some_and(|id| id < retain_from)
});
if is_old_segment || is_old_meta {
let _ = std::fs::remove_file(&path);
}
}
}
Ok(())
}
pub fn truncate_to_lsn(&self, truncate_lsn: LSN) -> Result<usize> {
let current_id = self.current_segment_id.load(Ordering::Relaxed);
let mut removed_count = 0;
if let Ok(entries) = std::fs::read_dir(&self.config.wal_dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "log")
&& let Some(segment_id) = path
.file_stem()
.and_then(|s| s.to_string_lossy().parse::<u64>().ok())
&& segment_id < current_id
{
let should_remove =
if let Some(metadata) = self.read_segment_metadata(segment_id) {
metadata.max_lsn.0 < truncate_lsn.0
} else {
false
};
if should_remove {
if std::fs::remove_file(&path).is_ok() {
removed_count += 1;
}
let meta_path = self.segment_meta_path(segment_id);
let _ = std::fs::remove_file(&meta_path);
}
}
}
}
Ok(removed_count)
}
pub fn list_segments_with_metadata(&self) -> Vec<(u64, SegmentMetadata)> {
let mut segments = Vec::with_capacity(16);
if let Ok(entries) = std::fs::read_dir(&self.config.wal_dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "log")
&& let Some(segment_id) = path
.file_stem()
.and_then(|s| s.to_string_lossy().parse::<u64>().ok())
&& let Some(metadata) = self.read_segment_metadata(segment_id)
{
segments.push((segment_id, metadata));
}
}
}
segments.sort_by_key(|(id, _)| *id);
segments
}
pub fn get_min_lsn(&self) -> Option<LSN> {
self.list_segments_with_metadata()
.into_iter()
.map(|(_, meta)| meta.min_lsn)
.min()
}
pub fn flush(&self, entries: Vec<PendingEntry>, sync: bool) -> Result<FlushStats> {
if entries.is_empty() {
return Ok(FlushStats::default());
}
let start = Instant::now();
let do_flush = || -> Result<FlushStats> {
let mut writer_guard = self.writer.lock().unwrap_or_else(|e| e.into_inner());
self.ensure_segment_open(&mut writer_guard)?;
let start_size = self.current_segment_size.load(Ordering::Relaxed);
let mut bytes_written = 0usize;
let mut batch_min_lsn = u64::MAX;
let mut batch_max_lsn = 0u64;
{
let writer = writer_guard.as_mut().ok_or_else(|| {
Error::Storage(StorageError::WalError {
reason: "WAL writer not initialized".to_string(),
})
})?;
for entry in &entries {
batch_min_lsn = batch_min_lsn.min(entry.lsn.0);
batch_max_lsn = batch_max_lsn.max(entry.lsn.0);
let write_data: Cow<'_, [u8]> = if let Some(ref cipher) = self.config.wal_cipher
{
let encrypted = crate::encryption::wal_encryption::encrypt_wal_payload(
&entry.data,
cipher,
)
.map_err(|e| Error::Storage(StorageError::Encryption(e.to_string())))?;
let len_bytes = (encrypted.len() as u32).to_le_bytes();
let mut framed = Vec::with_capacity(4 + encrypted.len());
framed.extend_from_slice(&len_bytes);
framed.extend_from_slice(&encrypted);
Cow::Owned(framed)
} else {
Cow::Borrowed(&entry.data)
};
writer.write_all(&write_data).map_err(|e| {
Error::Storage(StorageError::IoError(format!(
"Failed to write WAL entry: {}",
e
)))
})?;
bytes_written += write_data.len();
}
writer.flush().map_err(|e| {
Error::Storage(StorageError::IoError(format!(
"Failed to flush WAL buffer: {}",
e
)))
})?;
}
self.current_segment_min_lsn
.fetch_min(batch_min_lsn, Ordering::Relaxed);
self.current_segment_max_lsn
.fetch_max(batch_max_lsn, Ordering::Relaxed);
self.current_segment_entry_count
.fetch_add(entries.len() as u64, Ordering::Relaxed);
if sync && self.config.sync_on_flush {
let sync_guard = self.sync_handle.lock().unwrap_or_else(|e| e.into_inner());
if let Some(ref sync_file) = *sync_guard
&& let Err(e) = sync_file.sync_data()
{
if let Some(writer) = writer_guard.as_mut() {
let file = writer.get_mut();
if let Err(trunc_err) = file.set_len(start_size) {
eprintln!(
"CRITICAL: Failed to truncate WAL segment after sync failure. \
Data consistency may be compromised. Error: {}",
trunc_err
);
} else {
if let Err(seek_err) = file.seek(SeekFrom::Start(start_size)) {
eprintln!(
"CRITICAL: Failed to seek WAL segment after sync failure. \
Data consistency may be compromised. Error: {}",
seek_err
);
}
self.current_segment_size
.store(start_size, Ordering::Relaxed);
}
}
return Err(Error::Storage(StorageError::IoError(format!(
"Failed to sync WAL: {}",
e
))));
}
}
self.current_segment_size
.fetch_add(bytes_written as u64, Ordering::Relaxed);
let segment_rotated = self.maybe_rotate_segment(&mut writer_guard)?;
self.total_entries_flushed
.fetch_add(entries.len() as u64, Ordering::Relaxed);
self.total_bytes_written
.fetch_add(bytes_written as u64, Ordering::Relaxed);
self.total_flushes.fetch_add(1, Ordering::Relaxed);
Ok(FlushStats {
entries_flushed: entries.len(),
bytes_written,
flush_duration: start.elapsed(),
segment_rotated,
})
};
match do_flush() {
Ok(stats) => {
for entry in entries {
entry.notify_completion();
}
Ok(stats)
}
Err(e) => {
let error_msg = e.to_string();
for entry in entries {
entry.notify_error(&error_msg);
}
Err(e)
}
}
}
#[inline]
pub fn total_entries_flushed(&self) -> u64 {
self.total_entries_flushed.load(Ordering::Relaxed)
}
#[inline]
pub fn total_bytes_written(&self) -> u64 {
self.total_bytes_written.load(Ordering::Relaxed)
}
#[inline]
pub fn total_flushes(&self) -> u64 {
self.total_flushes.load(Ordering::Relaxed)
}
#[inline]
pub fn current_segment_id(&self) -> u64 {
self.current_segment_id.load(Ordering::Relaxed)
}
#[inline]
pub fn current_segment_size(&self) -> u64 {
self.current_segment_size.load(Ordering::Relaxed)
}
pub fn wal_dir(&self) -> &Path {
&self.config.wal_dir
}
}
pub struct FlushSignal {
requested: AtomicBool,
mutex: Mutex<()>,
condvar: Condvar,
}
impl FlushSignal {
pub fn new() -> Self {
Self {
requested: AtomicBool::new(false),
mutex: Mutex::new(()),
condvar: Condvar::new(),
}
}
pub fn request_flush(&self) {
let _guard = self.mutex.lock().unwrap_or_else(|e| e.into_inner());
self.requested.store(true, Ordering::Release);
self.condvar.notify_all();
}
pub fn take_request(&self) -> bool {
self.requested.swap(false, Ordering::AcqRel)
}
pub fn wait_for_request(&self, timeout: Duration) -> bool {
let guard = self.mutex.lock().unwrap_or_else(|e| e.into_inner());
if self.requested.load(Ordering::Acquire) {
return true;
}
let (_guard, _result) = self
.condvar
.wait_timeout(guard, timeout)
.unwrap_or_else(|e| e.into_inner());
self.requested.load(Ordering::Acquire)
}
}
impl Default for FlushSignal {
fn default() -> Self {
Self::new()
}
}
pub struct FlushThread {
handle: Option<JoinHandle<()>>,
shutdown: Arc<AtomicBool>,
flush_signal: Arc<FlushSignal>,
}
impl FlushThread {
pub fn start<F>(
coordinator: Arc<FlushCoordinator>,
drain_fn: F,
flush_interval: Duration,
) -> Self
where
F: Fn() -> Vec<PendingEntry> + Send + 'static,
{
let shutdown = Arc::new(AtomicBool::new(false));
let flush_signal = Arc::new(FlushSignal::new());
let shutdown_clone = Arc::clone(&shutdown);
let signal_clone = Arc::clone(&flush_signal);
let handle = thread::spawn(move || {
while !shutdown_clone.load(Ordering::Acquire) {
let _ = signal_clone.wait_for_request(flush_interval);
signal_clone.take_request();
let entries = drain_fn();
if let Err(e) = coordinator.flush(entries, true) {
eprintln!("WAL flush error: {:?}", e);
}
}
let entries = drain_fn();
let _ = coordinator.flush(entries, true);
});
Self {
handle: Some(handle),
shutdown,
flush_signal,
}
}
pub fn request_flush(&self) {
self.flush_signal.request_flush();
}
pub fn shutdown(&mut self) {
self.shutdown.store(true, Ordering::Release);
self.flush_signal.request_flush();
if let Some(handle) = self.handle.take() {
let _ = handle.join();
}
}
}
impl Drop for FlushThread {
fn drop(&mut self) {
self.shutdown();
}
}
#[cfg(test)]
mod tests {
use super::super::LSN;
use super::*;
use tempfile::tempdir;
fn create_test_entry(lsn: u64, data: &[u8]) -> PendingEntry {
PendingEntry::new_async(LSN(lsn), data.to_vec())
}
#[test]
fn test_flush_coordinator_creation() {
let dir = tempdir().unwrap();
let config = FlushCoordinatorConfig::new(dir.path());
let coordinator = FlushCoordinator::new(config).unwrap();
assert_eq!(coordinator.total_entries_flushed(), 0);
assert_eq!(coordinator.total_bytes_written(), 0);
assert_eq!(coordinator.total_flushes(), 0);
}
#[test]
fn test_flush_empty_entries() {
let dir = tempdir().unwrap();
let config = FlushCoordinatorConfig::new(dir.path());
let coordinator = FlushCoordinator::new(config).unwrap();
let stats = coordinator.flush(vec![], true).unwrap();
assert_eq!(stats.entries_flushed, 0);
assert_eq!(stats.bytes_written, 0);
assert!(!stats.segment_rotated);
}
#[test]
fn test_flush_single_entry() {
let dir = tempdir().unwrap();
let config = FlushCoordinatorConfig::new(dir.path());
let coordinator = FlushCoordinator::new(config).unwrap();
let entry = create_test_entry(1, b"test data");
let data_len = entry.data.len();
let stats = coordinator.flush(vec![entry], true).unwrap();
assert_eq!(stats.entries_flushed, 1);
assert_eq!(stats.bytes_written, data_len);
assert_eq!(coordinator.total_entries_flushed(), 1);
}
#[test]
fn test_flush_multiple_entries() {
let dir = tempdir().unwrap();
let config = FlushCoordinatorConfig::new(dir.path());
let coordinator = FlushCoordinator::new(config).unwrap();
let entries: Vec<_> = (1..=10)
.map(|i| create_test_entry(i, &[i as u8; 100]))
.collect();
let total_bytes: usize = entries.iter().map(|e| e.data.len()).sum();
let stats = coordinator.flush(entries, true).unwrap();
assert_eq!(stats.entries_flushed, 10);
assert_eq!(stats.bytes_written, total_bytes);
}
#[test]
fn test_segment_rotation() {
let dir = tempdir().unwrap();
let mut config = FlushCoordinatorConfig::new(dir.path());
config.segment_size = 100;
let coordinator = FlushCoordinator::new(config).unwrap();
let entry1 = create_test_entry(1, &[0u8; 95]);
let stats1 = coordinator.flush(vec![entry1], true).unwrap();
assert!(
stats1.segment_rotated,
"Should rotate at exactly segment_size (100). Size was {}",
coordinator.current_segment_size() );
assert_eq!(coordinator.current_segment_size(), 0);
let entry2 = create_test_entry(2, &[0u8; 94]);
let stats2 = coordinator.flush(vec![entry2], true).unwrap();
assert!(
!stats2.segment_rotated,
"Should NOT rotate at segment_size - 1 (99)"
);
assert_eq!(coordinator.current_segment_size(), 99);
}
#[test]
fn test_completion_notification() {
let dir = tempdir().unwrap();
let config = FlushCoordinatorConfig::new(dir.path());
let coordinator = FlushCoordinator::new(config).unwrap();
let (entry, handle) = PendingEntry::new_sync(LSN(1), vec![1, 2, 3]);
assert!(!handle.is_complete());
coordinator.flush(vec![entry], true).unwrap();
assert!(handle.is_complete());
assert!(handle.wait().is_ok());
}
#[test]
fn test_flush_signal() {
let signal = FlushSignal::new();
assert!(!signal.take_request());
signal.request_flush();
assert!(signal.take_request());
assert!(!signal.take_request()); }
#[test]
fn test_flush_signal_wait_timeout() {
let signal = FlushSignal::new();
let result = signal.wait_for_request(Duration::from_millis(10));
assert!(!result);
}
#[test]
fn test_flush_signal_wait_immediate() {
let signal = FlushSignal::new();
signal.request_flush();
let result = signal.wait_for_request(Duration::from_secs(10));
assert!(result);
}
#[test]
fn test_segment_file_creation() {
let dir = tempdir().unwrap();
let config = FlushCoordinatorConfig::new(dir.path());
let coordinator = FlushCoordinator::new(config).unwrap();
let entry = create_test_entry(1, b"test");
coordinator.flush(vec![entry], true).unwrap();
let segment_path = coordinator.segment_path(coordinator.current_segment_id());
assert!(segment_path.exists());
}
#[test]
fn test_wal_header() {
let dir = tempdir().unwrap();
let config = FlushCoordinatorConfig::new(dir.path());
let coordinator = FlushCoordinator::new(config).unwrap();
let entry = create_test_entry(1, b"test");
coordinator.flush(vec![entry], true).unwrap();
let segment_path = coordinator.segment_path(coordinator.current_segment_id());
let data = std::fs::read(&segment_path).unwrap();
assert!(data.len() >= WAL_HEADER_SIZE);
assert_eq!(&data[0..4], &WAL_MAGIC);
assert_eq!(data[4], WAL_VERSION);
}
#[test]
fn test_flush_thread_basic() {
let dir = tempdir().unwrap();
let config = FlushCoordinatorConfig::new(dir.path());
let coordinator = Arc::new(FlushCoordinator::new(config).unwrap());
let entries = Arc::new(Mutex::new(vec![
create_test_entry(1, b"one"),
create_test_entry(2, b"two"),
]));
let entries_clone = Arc::clone(&entries);
let mut thread = FlushThread::start(
Arc::clone(&coordinator),
move || {
let mut guard = entries_clone.lock().unwrap();
std::mem::take(&mut *guard)
},
Duration::from_millis(10),
);
thread.request_flush();
let start = std::time::Instant::now();
let timeout = Duration::from_secs(5);
while coordinator.total_entries_flushed() < 2 {
if start.elapsed() > timeout {
panic!(
"Timeout waiting for flush: only {} entries flushed",
coordinator.total_entries_flushed()
);
}
std::thread::sleep(Duration::from_millis(10));
}
assert!(coordinator.total_entries_flushed() >= 2);
thread.shutdown();
}
#[test]
fn test_cleanup_old_segments() {
let dir = tempdir().unwrap();
let mut config = FlushCoordinatorConfig::new(dir.path());
config.segment_size = 50; config.segments_to_retain = 2;
let coordinator = FlushCoordinator::new(config).unwrap();
for i in 1..=10 {
let entry = create_test_entry(i, &[i as u8; 100]);
coordinator.flush(vec![entry], true).unwrap();
}
let mut segments: Vec<u64> = std::fs::read_dir(dir.path())
.unwrap()
.filter_map(|e| e.ok())
.filter_map(|e| {
let path = e.path();
if path.extension().is_some_and(|ext| ext == "log") {
path.file_stem()
.and_then(|s| s.to_string_lossy().parse::<u64>().ok())
} else {
None
}
})
.collect();
segments.sort();
assert_eq!(
segments.len(),
3,
"Should retain exactly 3 segments (2 + current). Found: {:?}",
segments
);
assert_eq!(segments, vec![8, 9, 10]);
}
#[test]
fn test_segment_metadata_serialization() {
let metadata = SegmentMetadata::new(LSN(100), LSN(200), 50);
let bytes = metadata.to_bytes();
assert_eq!(bytes.len(), 24);
let restored = SegmentMetadata::from_bytes(&bytes).unwrap();
assert_eq!(restored.min_lsn, LSN(100));
assert_eq!(restored.max_lsn, LSN(200));
assert_eq!(restored.entry_count, 50);
}
#[test]
fn test_segment_metadata_from_bytes_too_short() {
let bytes = vec![0u8; 10]; assert!(SegmentMetadata::from_bytes(&bytes).is_none());
}
#[test]
fn test_flush_tracks_lsn_range() {
let dir = tempdir().unwrap();
let mut config = FlushCoordinatorConfig::new(dir.path());
config.segment_size = 50;
let coordinator = FlushCoordinator::new(config).unwrap();
let entries = vec![
create_test_entry(10, &[1u8; 20]),
create_test_entry(20, &[2u8; 20]),
create_test_entry(15, &[3u8; 20]),
];
coordinator.flush(entries, true).unwrap();
let entry = create_test_entry(100, &[4u8; 100]);
coordinator.flush(vec![entry], true).unwrap();
let segments = coordinator.list_segments_with_metadata();
assert!(!segments.is_empty());
if let Some((_, meta)) = segments.first() {
assert_eq!(meta.min_lsn, LSN(10));
assert_eq!(meta.max_lsn, LSN(20));
assert_eq!(meta.entry_count, 3);
}
}
#[test]
fn test_truncate_to_lsn_removes_old_segments() {
let dir = tempdir().unwrap();
let mut config = FlushCoordinatorConfig::new(dir.path());
config.segment_size = 200;
config.segments_to_retain = 100;
let coordinator = FlushCoordinator::new(config).unwrap();
let entries1: Vec<_> = (1..=10)
.map(|i| create_test_entry(i, &[i as u8; 20]))
.collect();
coordinator.flush(entries1, true).unwrap();
let entries2: Vec<_> = (11..=20)
.map(|i| create_test_entry(i, &[i as u8; 20]))
.collect();
coordinator.flush(entries2, true).unwrap();
let entries3: Vec<_> = (21..=30)
.map(|i| create_test_entry(i, &[i as u8; 20]))
.collect();
coordinator.flush(entries3, true).unwrap();
let removed = coordinator.truncate_to_lsn(LSN(15)).unwrap();
assert_eq!(removed, 1, "Should remove exactly 1 segment (LSN 1-10)");
let segment_ids: Vec<u64> = std::fs::read_dir(dir.path())
.unwrap()
.filter_map(|e| e.ok())
.filter_map(|e| {
let path = e.path();
if path.extension().is_some_and(|ext| ext == "log") {
path.file_stem()
.and_then(|s| s.to_string_lossy().parse::<u64>().ok())
} else {
None
}
})
.collect();
assert!(
!segment_ids.contains(&1),
"Segment 1 should be physically deleted"
);
assert!(segment_ids.contains(&2), "Segment 2 should remain");
assert!(segment_ids.contains(&3), "Segment 3 should remain");
}
#[test]
fn test_truncate_to_lsn_keeps_needed_segments() {
let dir = tempdir().unwrap();
let mut config = FlushCoordinatorConfig::new(dir.path());
config.segment_size = 30; config.segments_to_retain = 100;
let coordinator = FlushCoordinator::new(config).unwrap();
for i in 100..=110 {
coordinator
.flush(vec![create_test_entry(i, &[i as u8; 20])], true)
.unwrap();
}
coordinator
.flush(vec![create_test_entry(200, &[200u8; 100])], true)
.unwrap();
let removed = coordinator.truncate_to_lsn(LSN(50)).unwrap();
assert_eq!(removed, 0);
let segments = coordinator.list_segments_with_metadata();
assert!(!segments.is_empty());
}
#[test]
fn test_truncate_to_lsn_never_removes_active_segment() {
let dir = tempdir().unwrap();
let mut config = FlushCoordinatorConfig::new(dir.path());
config.segments_to_retain = 100;
let coordinator = FlushCoordinator::new(config).unwrap();
coordinator
.flush(vec![create_test_entry(1, b"test")], true)
.unwrap();
let current_id = coordinator.current_segment_id();
let removed = coordinator.truncate_to_lsn(LSN(1000)).unwrap();
assert_eq!(removed, 0);
let segment_path = coordinator.segment_path(current_id);
assert!(segment_path.exists());
}
#[test]
fn test_get_min_lsn() {
let dir = tempdir().unwrap();
let mut config = FlushCoordinatorConfig::new(dir.path());
config.segment_size = 30;
config.segments_to_retain = 100;
let coordinator = FlushCoordinator::new(config).unwrap();
assert!(coordinator.get_min_lsn().is_none());
for i in 50..=60 {
coordinator
.flush(vec![create_test_entry(i, &[i as u8; 20])], true)
.unwrap();
}
coordinator
.flush(vec![create_test_entry(100, &[100u8; 100])], true)
.unwrap();
let min_lsn = coordinator.get_min_lsn();
assert!(min_lsn.is_some());
assert_eq!(min_lsn.unwrap(), LSN(50));
}
#[test]
fn test_cleanup_removes_meta_files() {
let dir = tempdir().unwrap();
let mut config = FlushCoordinatorConfig::new(dir.path());
config.segment_size = 30;
config.segments_to_retain = 1;
let coordinator = FlushCoordinator::new(config).unwrap();
for i in 1..=50 {
coordinator
.flush(vec![create_test_entry(i, &[i as u8; 30])], true)
.unwrap();
}
let meta_count = std::fs::read_dir(dir.path())
.unwrap()
.filter_map(|e| e.ok())
.filter(|e| {
e.path()
.file_name()
.and_then(|n| n.to_str())
.is_some_and(|s| s.ends_with(".log.meta"))
})
.count();
assert!(meta_count <= 3);
}
#[test]
#[cfg(unix)]
fn test_phantom_commit_prevention() {
use tempfile::tempdir;
let dir = tempdir().unwrap();
let mut config = FlushCoordinatorConfig::new(dir.path());
config.sync_on_flush = true;
let coordinator = FlushCoordinator::new(config).unwrap();
let entry1 = create_test_entry(1, b"valid_data");
coordinator.flush(vec![entry1], true).unwrap();
let segment_path = coordinator.segment_path(coordinator.current_segment_id());
let valid_size = std::fs::metadata(&segment_path).unwrap().len();
{
use std::os::unix::io::{FromRawFd, IntoRawFd};
use std::os::unix::net::UnixStream;
let (s1, _s2) = UnixStream::pair().expect("Failed to create socket pair");
let fd = s1.into_raw_fd();
let bad_file = unsafe { File::from_raw_fd(fd) };
let mut guard = coordinator.sync_handle.lock().unwrap();
*guard = Some(bad_file);
}
let entry2 = create_test_entry(2, b"phantom_data");
let result = coordinator.flush(vec![entry2], true);
assert!(
result.is_err(),
"Flush should fail due to broken sync handle"
);
let new_size = std::fs::metadata(&segment_path).unwrap().len();
assert_eq!(
new_size, valid_size,
"File size increased despite sync failure! Phantom commit detected. \
Expected {} bytes (valid only), got {} bytes (valid + phantom).",
valid_size, new_size
);
}
#[test]
#[cfg(unix)]
fn test_sync_logic_correctness() {
use tempfile::tempdir;
let dir = tempdir().unwrap();
let mut config = FlushCoordinatorConfig::new(dir.path());
config.sync_on_flush = true;
let coordinator = FlushCoordinator::new(config).unwrap();
{
let mut writer_guard = coordinator.writer.lock().unwrap();
coordinator.ensure_segment_open(&mut writer_guard).unwrap();
}
{
let dev_null = File::open("/dev/null").expect("Failed to open /dev/null");
let mut guard = coordinator.sync_handle.lock().unwrap();
*guard = Some(dev_null);
}
let entry = create_test_entry(1, &[1, 2, 3]);
let result = coordinator.flush(vec![entry], false);
assert!(
result.is_ok(),
"flush(false) should NOT sync, so it should succeed even if sync handle is broken"
);
let entry2 = create_test_entry(2, &[4, 5, 6]);
let result = coordinator.flush(vec![entry2], true);
assert!(
result.is_err(),
"flush(true) SHOULD sync, so it should fail due to broken sync handle"
);
let err_msg = format!("{}", result.unwrap_err());
assert!(
err_msg.contains("Failed to sync WAL"),
"Error should be about syncing, got: {}",
err_msg
);
}
#[test]
fn test_truncate_safe_defaults_on_missing_metadata() {
let dir = tempdir().unwrap();
let mut config = FlushCoordinatorConfig::new(dir.path());
config.segment_size = 50; config.segments_to_retain = 100;
let coordinator = FlushCoordinator::new(config).unwrap();
let entries: Vec<_> = (1..=10)
.map(|i| create_test_entry(i, &[i as u8; 10]))
.collect();
coordinator.flush(entries, true).unwrap();
let entries2: Vec<_> = (11..=20)
.map(|i| create_test_entry(i, &[i as u8; 10]))
.collect();
coordinator.flush(entries2, true).unwrap();
let segments = std::fs::read_dir(dir.path())
.unwrap()
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().is_some_and(|ext| ext == "log"))
.collect::<Vec<_>>();
assert!(segments.len() >= 2, "Expected at least 2 segments");
let oldest_segment_path = segments.iter().min_by_key(|e| e.path()).unwrap().path();
let mut meta_path = oldest_segment_path.clone();
if let Some(name) = meta_path.file_name() {
let mut name = name.to_os_string();
name.push(".meta");
meta_path.set_file_name(name);
}
assert!(
meta_path.exists(),
"Metadata file should exist: {:?}",
meta_path
);
std::fs::remove_file(&meta_path).unwrap();
assert!(!meta_path.exists());
let removed = coordinator.truncate_to_lsn(LSN(100)).unwrap();
assert_eq!(
removed, 0,
"Should not remove segment if metadata is missing"
);
assert!(
oldest_segment_path.exists(),
"Segment log file should still exist"
);
}
#[test]
fn test_truncate_to_lsn_boundary_exact_match() {
let dir = tempdir().unwrap();
let mut config = FlushCoordinatorConfig::new(dir.path());
config.segment_size = 50; config.segments_to_retain = 100;
let coordinator = FlushCoordinator::new(config).unwrap();
let entries: Vec<_> = (10..=20)
.map(|i| create_test_entry(i, &[i as u8; 20]))
.collect();
coordinator.flush(entries, true).unwrap();
let entries2: Vec<_> = (21..=30)
.map(|i| create_test_entry(i, &[i as u8; 20]))
.collect();
coordinator.flush(entries2, true).unwrap();
let segments = coordinator.list_segments_with_metadata();
assert!(segments.len() >= 2);
let historical = segments
.iter()
.find(|(_, m)| m.max_lsn == LSN(20))
.expect("Historical segment should exist");
let removed = coordinator.truncate_to_lsn(LSN(20)).unwrap();
assert_eq!(
removed, 0,
"Should NOT remove segment ending at LSN 20 when truncating to 20"
);
let segments_after = coordinator.list_segments_with_metadata();
assert!(segments_after.iter().any(|(id, _)| *id == historical.0));
}
#[test]
fn test_get_min_lsn_with_multiple_segments() {
let dir = tempdir().unwrap();
let mut config = FlushCoordinatorConfig::new(dir.path());
config.segment_size = 30;
config.segments_to_retain = 100;
let coordinator = FlushCoordinator::new(config).unwrap();
let entries1: Vec<_> = (10..=20)
.map(|i| create_test_entry(i, &[i as u8; 20]))
.collect();
coordinator.flush(entries1, true).unwrap();
let entries2: Vec<_> = (30..=40)
.map(|i| create_test_entry(i, &[i as u8; 20]))
.collect();
coordinator.flush(entries2, true).unwrap();
let entries3: Vec<_> = (50..=60)
.map(|i| create_test_entry(i, &[i as u8; 20]))
.collect();
coordinator.flush(entries3, true).unwrap();
coordinator
.flush(vec![create_test_entry(100, &[100u8; 100])], true)
.unwrap();
let min_lsn = coordinator.get_min_lsn();
assert_eq!(min_lsn, Some(LSN(10)));
}
}