use std::sync::{
atomic::{AtomicUsize, Ordering},
Condvar, Mutex,
};
use crate::{Error, Result};
pub struct FailFastBarrier {
count: usize,
arrived: AtomicUsize,
condvar: Condvar,
state: Mutex<Option<String>>,
}
impl FailFastBarrier {
pub fn new(count: usize) -> Result<Self> {
if count == 0 {
return Err(Error::LockError(
"FailFastBarrier count must be greater than 0".to_string(),
));
}
Ok(Self {
count,
arrived: AtomicUsize::new(0),
condvar: Condvar::new(),
state: Mutex::new(None),
})
}
pub fn wait(&self) -> Result<()> {
{
let guard = self
.state
.lock()
.map_err(|e| Error::LockError(format!("barrier lock poisoned: {e}")))?;
if let Some(msg) = guard.as_ref() {
return Err(Error::LockError(format!("Barrier was broken: {msg}")));
}
}
let arrived_count = self.arrived.fetch_add(1, Ordering::AcqRel) + 1;
if arrived_count == self.count {
let _guard = self
.state
.lock()
.map_err(|e| Error::LockError(format!("barrier lock poisoned: {e}")))?;
self.condvar.notify_all();
Ok(())
} else {
let guard = self
.state
.lock()
.map_err(|e| Error::LockError(format!("barrier lock poisoned: {e}")))?;
let guard = self
.condvar
.wait_while(guard, |state| {
state.is_none() && self.arrived.load(Ordering::Acquire) < self.count
})
.map_err(|e| Error::LockError(format!("barrier condvar poisoned: {e}")))?;
if let Some(msg) = guard.as_ref() {
Err(Error::LockError(format!("Barrier was broken: {msg}")))
} else {
Ok(())
}
}
}
pub fn break_barrier(&self, error_message: impl Into<String>) {
if let Ok(mut guard) = self.state.lock() {
if guard.is_none() {
*guard = Some(error_message.into());
}
}
self.condvar.notify_all();
}
pub fn is_broken(&self) -> bool {
self.state
.lock()
.map(|guard| guard.is_some())
.unwrap_or(true) }
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[test]
fn test_normal_barrier_operation() {
let barrier = Arc::new(FailFastBarrier::new(3).unwrap());
let mut handles = vec![];
for _ in 0..3 {
let barrier_clone = Arc::clone(&barrier);
let handle = thread::spawn(move || barrier_clone.wait());
handles.push(handle);
}
for handle in handles {
assert!(handle.join().unwrap().is_ok());
}
}
#[test]
fn test_barrier_break() {
let barrier = Arc::new(FailFastBarrier::new(3).unwrap());
let mut handles = vec![];
for _ in 0..2 {
let barrier_clone = Arc::clone(&barrier);
let handle = thread::spawn(move || barrier_clone.wait());
handles.push(handle);
}
thread::sleep(Duration::from_millis(10));
barrier.break_barrier("Test failure");
for handle in handles {
let result = handle.join().unwrap();
assert!(result.is_err());
let error_msg = format!("{}", result.unwrap_err());
assert!(error_msg.contains("Test failure"));
}
}
#[test]
fn test_is_broken() {
let barrier = FailFastBarrier::new(2).unwrap();
assert!(!barrier.is_broken());
barrier.break_barrier("Test break");
assert!(barrier.is_broken());
}
#[test]
fn test_zero_count_returns_error() {
let result = FailFastBarrier::new(0);
assert!(result.is_err());
assert!(result
.err()
.unwrap()
.to_string()
.contains("count must be greater than 0"));
}
#[test]
fn test_break_barrier_with_string() {
let barrier = FailFastBarrier::new(2).unwrap();
barrier.break_barrier("test error");
assert!(barrier.is_broken());
let barrier2 = FailFastBarrier::new(2).unwrap();
barrier2.break_barrier(String::from("owned error"));
assert!(barrier2.is_broken());
}
}