use std::collections::{BTreeSet, VecDeque};
use std::sync::{Condvar, Mutex};
use std::time::Duration;
use crate::core::error::{Error, StorageError};
pub struct GroupCommitCoordinator {
state: Mutex<GroupCommitState>,
flush_complete: Condvar,
config: GroupCommitConfig,
}
#[derive(Debug, Clone)]
pub struct GroupCommitConfig {
pub max_delay_ms: u64,
pub max_batch_size: usize,
pub timeout_multiplier: u32,
pub timeout_base_ms: u64,
pub timeout_min_ms: u64,
pub timeout_max_ms: u64,
pub recent_errors_capacity: usize,
}
impl Default for GroupCommitConfig {
fn default() -> Self {
Self {
max_delay_ms: 10,
max_batch_size: 200,
timeout_multiplier: 50,
timeout_base_ms: 5000,
timeout_min_ms: 10000,
timeout_max_ms: 60000,
recent_errors_capacity: 1024,
}
}
}
struct GroupCommitState {
current_epoch: u64,
batch_count: usize,
flushed_epoch: u64,
recent_errors: VecDeque<(u64, String)>,
oldest_error_epoch: u64,
completed_epochs: BTreeSet<u64>,
}
impl GroupCommitCoordinator {
pub fn new(max_delay_ms: u64, max_batch_size: usize) -> Self {
Self::with_config(GroupCommitConfig {
max_delay_ms,
max_batch_size,
recent_errors_capacity: 1024,
..GroupCommitConfig::default()
})
}
pub fn with_config(config: GroupCommitConfig) -> Self {
Self {
state: Mutex::new(GroupCommitState {
current_epoch: 1, batch_count: 0,
flushed_epoch: 0,
recent_errors: VecDeque::new(),
oldest_error_epoch: 0,
completed_epochs: BTreeSet::new(),
}),
flush_complete: Condvar::new(),
config,
}
}
pub fn with_defaults() -> Self {
let config = GroupCommitConfig::default();
Self::new(config.max_delay_ms, config.max_batch_size)
}
pub fn register_transaction(&self) -> Result<(u64, bool), Error> {
let mut state = self.state.lock().map_err(|_| {
Error::Storage(StorageError::LockPoisoned {
resource: "group_commit_state".to_string(),
})
})?;
state.batch_count += 1;
let epoch = state.current_epoch;
let should_flush = state.batch_count >= self.config.max_batch_size;
Ok((epoch, should_flush))
}
pub fn wait_for_flush(&self, epoch: u64) -> Result<(), Error> {
let mut state = self.state.lock().map_err(|_| {
Error::Storage(StorageError::LockPoisoned {
resource: "group_commit_state".to_string(),
})
})?;
let base_timeout =
Duration::from_millis(self.config.max_delay_ms * self.config.timeout_multiplier as u64)
+ Duration::from_millis(self.config.timeout_base_ms);
let timeout = base_timeout
.max(Duration::from_millis(self.config.timeout_min_ms))
.min(Duration::from_millis(self.config.timeout_max_ms));
let deadline = std::time::Instant::now() + timeout;
while state.flushed_epoch < epoch {
let now = std::time::Instant::now();
let remaining = if now >= deadline {
Duration::from_secs(0)
} else {
deadline - now
};
if remaining.as_nanos() == 0 {
return Err(Error::Storage(StorageError::WalError {
reason: format!(
"Group commit timeout waiting for epoch {} (current flushed: {})",
epoch, state.flushed_epoch
),
}));
}
let (new_state, timeout_result) = self
.flush_complete
.wait_timeout(state, remaining)
.map_err(|_| {
Error::Storage(StorageError::LockPoisoned {
resource: "group_commit_state".to_string(),
})
})?;
state = new_state;
if (timeout_result.timed_out() || std::time::Instant::now() >= deadline)
&& state.flushed_epoch < epoch
{
return Err(Error::Storage(StorageError::WalError {
reason: format!(
"Group commit timeout waiting for epoch {} (current flushed: {})",
epoch, state.flushed_epoch
),
}));
}
}
for (failed_epoch, error_msg) in &state.recent_errors {
if *failed_epoch == epoch {
return Err(Error::Storage(StorageError::WalError {
reason: format!("Group commit flush failed: {}", error_msg),
}));
}
}
if epoch < state.oldest_error_epoch {
return Err(Error::Storage(StorageError::WalError {
reason: format!(
"Group commit status unknown: epoch {} evicted from error history (history starts at {})",
epoch, state.oldest_error_epoch
),
}));
}
Ok(())
}
pub fn start_flush(&self) -> Result<u64, Error> {
let mut state = self.state.lock().map_err(|_| {
Error::Storage(StorageError::LockPoisoned {
resource: "group_commit_state".to_string(),
})
})?;
let epoch_to_flush = state.current_epoch;
state.current_epoch += 1;
state.batch_count = 0;
Ok(epoch_to_flush)
}
pub fn finish_flush(&self, epoch: u64, result: Result<(), Error>) -> Result<(), Error> {
let mut state = self.state.lock().map_err(|_| {
Error::Storage(StorageError::LockPoisoned {
resource: "group_commit_state".to_string(),
})
})?;
if let Err(e) = result {
state.recent_errors.push_back((epoch, e.to_string()));
while state.recent_errors.len() > self.config.recent_errors_capacity {
if let Some((evicted_epoch, _)) = state.recent_errors.pop_front() {
state.oldest_error_epoch = evicted_epoch + 1;
}
}
}
if epoch > state.flushed_epoch {
state.completed_epochs.insert(epoch);
}
let mut next_epoch = state.flushed_epoch + 1;
while state.completed_epochs.contains(&next_epoch) {
state.completed_epochs.remove(&next_epoch);
state.flushed_epoch = next_epoch;
next_epoch += 1;
}
self.flush_complete.notify_all();
Ok(())
}
pub fn mark_flushed(&self, result: Result<(), Error>) -> Result<(), Error> {
let epoch = self.start_flush()?;
self.finish_flush(epoch, result)
}
pub fn current_batch_size(&self) -> Result<usize, Error> {
let state = self.state.lock().map_err(|_| {
Error::Storage(StorageError::LockPoisoned {
resource: "group_commit_state".to_string(),
})
})?;
Ok(state.batch_count)
}
pub fn current_epoch(&self) -> Result<u64, Error> {
let state = self.state.lock().map_err(|_| {
Error::Storage(StorageError::LockPoisoned {
resource: "group_commit_state".to_string(),
})
})?;
Ok(state.current_epoch)
}
pub fn flushed_epoch(&self) -> Result<u64, Error> {
let state = self.state.lock().map_err(|_| {
Error::Storage(StorageError::LockPoisoned {
resource: "group_commit_state".to_string(),
})
})?;
Ok(state.flushed_epoch)
}
pub fn should_flush(&self) -> Result<bool, Error> {
let state = self.state.lock().map_err(|_| {
Error::Storage(StorageError::LockPoisoned {
resource: "group_commit_state".to_string(),
})
})?;
Ok(state.batch_count >= self.config.max_batch_size)
}
pub fn max_delay(&self) -> Duration {
Duration::from_millis(self.config.max_delay_ms)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_new_coordinator() {
let coord = GroupCommitCoordinator::new(10, 200);
assert_eq!(coord.current_epoch().unwrap(), 1); assert_eq!(coord.flushed_epoch().unwrap(), 0);
assert_eq!(coord.current_batch_size().unwrap(), 0);
}
#[test]
fn test_register_transaction() {
let coord = GroupCommitCoordinator::new(10, 5);
for i in 0..4 {
let (epoch, should_flush) = coord.register_transaction().unwrap();
assert_eq!(epoch, 1); assert!(!should_flush, "should not flush at batch size {}", i + 1);
}
let (epoch, should_flush) = coord.register_transaction().unwrap();
assert_eq!(epoch, 1); assert!(should_flush, "should flush when batch is full");
}
#[test]
fn test_mark_flushed_advances_epoch() {
let coord = GroupCommitCoordinator::new(10, 100);
coord.register_transaction().unwrap();
coord.register_transaction().unwrap();
assert_eq!(coord.current_epoch().unwrap(), 1); assert_eq!(coord.current_batch_size().unwrap(), 2);
coord.mark_flushed(Ok(())).unwrap();
assert_eq!(coord.current_epoch().unwrap(), 2); assert_eq!(coord.flushed_epoch().unwrap(), 1); assert_eq!(coord.current_batch_size().unwrap(), 0);
}
#[test]
fn test_wait_for_flush_success() {
let coord = Arc::new(GroupCommitCoordinator::new(100, 100));
let coord_clone = Arc::clone(&coord);
let (epoch, _) = coord.register_transaction().unwrap();
let handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(10));
coord_clone.mark_flushed(Ok(())).unwrap();
});
let result = coord.wait_for_flush(epoch);
assert!(result.is_ok());
handle.join().unwrap();
}
#[test]
fn test_wait_for_flush_error_propagation() {
let coord = Arc::new(GroupCommitCoordinator::new(100, 100));
let coord_clone = Arc::clone(&coord);
let (epoch, _) = coord.register_transaction().unwrap();
let handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(10));
coord_clone
.mark_flushed(Err(Error::Storage(StorageError::WalError {
reason: "disk full".to_string(),
})))
.unwrap();
});
let result = coord.wait_for_flush(epoch);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("disk full"));
handle.join().unwrap();
}
#[test]
fn test_wait_for_flush_timeout() {
let config = GroupCommitConfig {
max_delay_ms: 10,
max_batch_size: 100,
timeout_multiplier: 2, timeout_base_ms: 10, timeout_min_ms: 20, timeout_max_ms: 100, recent_errors_capacity: 1024,
};
let coord = GroupCommitCoordinator::with_config(config);
let (epoch, _) = coord.register_transaction().unwrap();
let start = std::time::Instant::now();
let result = coord.wait_for_flush(epoch);
let elapsed = start.elapsed();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("timeout"));
assert!(elapsed < Duration::from_millis(500));
}
#[test]
fn test_multiple_waiters() {
let coord = Arc::new(GroupCommitCoordinator::new(100, 100));
let mut epochs = Vec::new();
for _ in 0..5 {
let (epoch, _) = coord.register_transaction().unwrap();
epochs.push(epoch);
}
assert!(epochs.iter().all(|&e| e == 1));
let mut handles = Vec::new();
for _ in 0..5 {
let coord_clone = Arc::clone(&coord);
handles.push(thread::spawn(move || coord_clone.wait_for_flush(1))); }
thread::sleep(Duration::from_millis(10));
coord.mark_flushed(Ok(())).unwrap();
for handle in handles {
let result = handle.join().unwrap();
assert!(result.is_ok());
}
}
#[test]
fn test_multiple_epochs() {
let coord = GroupCommitCoordinator::new(10, 100);
coord.register_transaction().unwrap();
coord.register_transaction().unwrap();
coord.mark_flushed(Ok(())).unwrap();
assert_eq!(coord.current_epoch().unwrap(), 2);
let (epoch, _) = coord.register_transaction().unwrap();
assert_eq!(epoch, 2);
coord.mark_flushed(Ok(())).unwrap();
assert_eq!(coord.current_epoch().unwrap(), 3); }
#[test]
fn test_should_flush() {
let coord = GroupCommitCoordinator::new(10, 3);
assert!(!coord.should_flush().unwrap());
coord.register_transaction().unwrap();
assert!(!coord.should_flush().unwrap());
coord.register_transaction().unwrap();
assert!(!coord.should_flush().unwrap());
coord.register_transaction().unwrap();
assert!(coord.should_flush().unwrap());
}
#[test]
fn test_max_delay() {
let coord = GroupCommitCoordinator::new(42, 100);
assert_eq!(coord.max_delay(), Duration::from_millis(42));
}
#[test]
fn test_with_defaults() {
let coord = GroupCommitCoordinator::with_defaults();
assert_eq!(coord.max_delay(), Duration::from_millis(10));
}
#[test]
fn test_custom_timeout_config() {
let config = GroupCommitConfig {
max_delay_ms: 1,
max_batch_size: 100,
timeout_multiplier: 2,
timeout_base_ms: 10,
timeout_min_ms: 20,
timeout_max_ms: 100,
recent_errors_capacity: 1024,
};
let coord = GroupCommitCoordinator::with_config(config);
let (epoch, _) = coord.register_transaction().unwrap();
let start = std::time::Instant::now();
let result = coord.wait_for_flush(epoch);
let elapsed = start.elapsed();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("timeout"));
assert!(elapsed >= Duration::from_millis(5));
assert!(elapsed < Duration::from_millis(500));
}
#[test]
fn test_error_history_eviction() {
let config = GroupCommitConfig {
max_delay_ms: 10,
max_batch_size: 100,
timeout_multiplier: 2,
timeout_base_ms: 10,
timeout_min_ms: 20,
timeout_max_ms: 1000,
recent_errors_capacity: 2, };
let coord = GroupCommitCoordinator::with_config(config);
let epoch1 = coord.start_flush().unwrap();
coord
.finish_flush(
epoch1,
Err(Error::Storage(StorageError::WalError {
reason: "Fail 1".to_string(),
})),
)
.unwrap();
let epoch2 = coord.start_flush().unwrap();
coord
.finish_flush(
epoch2,
Err(Error::Storage(StorageError::WalError {
reason: "Fail 2".to_string(),
})),
)
.unwrap();
let epoch3 = coord.start_flush().unwrap();
coord
.finish_flush(
epoch3,
Err(Error::Storage(StorageError::WalError {
reason: "Fail 3".to_string(),
})),
)
.unwrap();
let result1 = coord.wait_for_flush(epoch1);
assert!(result1.is_err());
assert!(
result1
.unwrap_err()
.to_string()
.contains("evicted from error history")
);
let result2 = coord.wait_for_flush(epoch2);
assert!(result2.is_err());
assert!(result2.unwrap_err().to_string().contains("Fail 2"));
let result3 = coord.wait_for_flush(epoch3);
assert!(result3.is_err());
assert!(result3.unwrap_err().to_string().contains("Fail 3"));
}
#[test]
fn test_flush_race_condition() {
let coord = GroupCommitCoordinator::new(100, 100);
let (epoch_a, _) = coord.register_transaction().unwrap();
assert_eq!(epoch_a, 1);
let flushing_epoch = coord.start_flush().unwrap();
assert_eq!(flushing_epoch, 1);
assert_eq!(coord.current_epoch().unwrap(), 2);
let (epoch_b, _) = coord.register_transaction().unwrap();
assert_eq!(epoch_b, 2);
coord.finish_flush(flushing_epoch, Ok(())).unwrap();
assert!(coord.wait_for_flush(epoch_a).is_ok());
assert_eq!(coord.flushed_epoch().unwrap(), 1);
let flushing_epoch_2 = coord.start_flush().unwrap();
assert_eq!(flushing_epoch_2, 2);
coord.finish_flush(flushing_epoch_2, Ok(())).unwrap();
assert!(coord.wait_for_flush(epoch_b).is_ok());
}
#[test]
fn test_wait_for_flush_deadline_enforcement() {
let config = GroupCommitConfig {
max_delay_ms: 10,
max_batch_size: 100,
timeout_multiplier: 1, timeout_base_ms: 10, timeout_min_ms: 50, timeout_max_ms: 200, recent_errors_capacity: 1024,
};
let coord = Arc::new(GroupCommitCoordinator::with_config(config));
let coord_clone = Arc::clone(&coord);
let (epoch, _) = coord.register_transaction().unwrap();
thread::spawn(move || {
let start = std::time::Instant::now();
while start.elapsed() < Duration::from_millis(500) {
thread::sleep(Duration::from_millis(10));
let _ = coord_clone.finish_flush(100, Ok(()));
}
});
let start = std::time::Instant::now();
let result = coord.wait_for_flush(epoch);
let elapsed = start.elapsed();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("timeout"));
assert!(
elapsed < Duration::from_millis(150),
"Wait took {:?}, expected < 150ms",
elapsed
);
}
}