use std::sync::atomic::{AtomicUsize, AtomicBool, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::thread::ThreadId;
use std::collections::HashSet;
pub struct FrameBarrier {
thread_count: usize,
arrived: AtomicUsize,
generation: AtomicUsize,
all_arrived: AtomicBool,
lock: Mutex<()>,
cvar: Condvar,
registered_threads: Mutex<HashSet<ThreadId>>,
}
impl FrameBarrier {
pub fn new(thread_count: usize) -> Arc<Self> {
Arc::new(Self {
thread_count,
arrived: AtomicUsize::new(0),
generation: AtomicUsize::new(0),
all_arrived: AtomicBool::new(false),
lock: Mutex::new(()),
cvar: Condvar::new(),
registered_threads: Mutex::new(HashSet::new()),
})
}
pub fn register_thread(&self) {
let mut threads = self.registered_threads.lock().unwrap();
threads.insert(std::thread::current().id());
}
pub fn unregister_thread(&self) {
let mut threads = self.registered_threads.lock().unwrap();
threads.remove(&std::thread::current().id());
}
pub fn is_registered(&self, thread_id: ThreadId) -> bool {
let threads = self.registered_threads.lock().unwrap();
threads.contains(&thread_id)
}
pub fn thread_count(&self) -> usize {
self.thread_count
}
pub fn arrived_count(&self) -> usize {
self.arrived.load(Ordering::SeqCst)
}
pub fn generation(&self) -> usize {
self.generation.load(Ordering::SeqCst)
}
pub fn signal_frame_complete(&self) {
let prev = self.arrived.fetch_add(1, Ordering::SeqCst);
if prev + 1 == self.thread_count {
self.all_arrived.store(true, Ordering::SeqCst);
self.cvar.notify_all();
}
}
pub fn wait_all(&self) {
let mut guard = self.lock.lock().unwrap();
let current_gen = self.generation.load(Ordering::SeqCst);
while !self.all_arrived.load(Ordering::SeqCst)
&& self.generation.load(Ordering::SeqCst) == current_gen
{
guard = self.cvar.wait(guard).unwrap();
}
}
pub fn wait_timeout(&self, timeout: std::time::Duration) -> bool {
let mut guard = self.lock.lock().unwrap();
let current_gen = self.generation.load(Ordering::SeqCst);
let deadline = std::time::Instant::now() + timeout;
while !self.all_arrived.load(Ordering::SeqCst)
&& self.generation.load(Ordering::SeqCst) == current_gen
{
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
if remaining.is_zero() {
return false;
}
let result = self.cvar.wait_timeout(guard, remaining).unwrap();
guard = result.0;
}
true
}
pub fn reset(&self) {
self.arrived.store(0, Ordering::SeqCst);
self.all_arrived.store(false, Ordering::SeqCst);
self.generation.fetch_add(1, Ordering::SeqCst);
}
pub fn wait_and_reset(&self) {
self.wait_all();
self.reset();
}
pub fn is_complete(&self) -> bool {
self.all_arrived.load(Ordering::SeqCst)
}
}
pub struct FrameBarrierBuilder {
thread_count: usize,
thread_names: Vec<String>,
}
impl FrameBarrierBuilder {
pub fn new() -> Self {
Self {
thread_count: 0,
thread_names: Vec::new(),
}
}
pub fn with_thread(mut self, name: &str) -> Self {
self.thread_count += 1;
self.thread_names.push(name.to_string());
self
}
pub fn with_count(mut self, count: usize) -> Self {
self.thread_count = count;
self
}
pub fn build(self) -> Arc<FrameBarrier> {
FrameBarrier::new(self.thread_count)
}
}
impl Default for FrameBarrierBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default, Clone)]
pub struct BarrierStats {
pub total_waits: u64,
pub total_wait_time_us: u64,
pub max_wait_time_us: u64,
pub timeout_count: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_barrier_single_thread() {
let barrier = FrameBarrier::new(1);
barrier.signal_frame_complete();
assert!(barrier.is_complete());
barrier.wait_all();
barrier.reset();
assert!(!barrier.is_complete());
}
#[test]
fn test_barrier_multi_thread() {
let barrier = FrameBarrier::new(3);
let b1 = Arc::clone(&barrier);
let b2 = Arc::clone(&barrier);
let h1 = thread::spawn(move || {
b1.signal_frame_complete();
});
let h2 = thread::spawn(move || {
b2.signal_frame_complete();
});
barrier.signal_frame_complete();
barrier.wait_all();
h1.join().unwrap();
h2.join().unwrap();
assert!(barrier.is_complete());
}
#[test]
fn test_barrier_builder() {
let barrier = FrameBarrierBuilder::new()
.with_thread("main")
.with_thread("worker1")
.with_thread("worker2")
.build();
assert_eq!(barrier.thread_count(), 3);
}
}