use crate::runtime::execution::ExecutionState;
use crate::runtime::task::clock::VectorClock;
use crate::runtime::task::TaskId;
use crate::runtime::thread;
use std::cell::RefCell;
use std::collections::HashSet;
use std::rc::Rc;
use tracing::trace;
#[derive(Clone, Copy, Debug)]
pub struct BarrierWaitResult {
is_leader: bool,
}
impl BarrierWaitResult {
pub fn is_leader(&self) -> bool {
self.is_leader
}
}
#[derive(Debug)]
struct BarrierState {
bound: usize,
leader: Option<TaskId>,
waiters: HashSet<TaskId>,
clock: VectorClock,
}
#[derive(Debug)]
pub struct Barrier {
state: Rc<RefCell<BarrierState>>,
}
impl Barrier {
pub fn new(n: usize) -> Self {
let state = BarrierState {
bound: n,
leader: None,
waiters: HashSet::new(),
clock: VectorClock::new(),
};
Self {
state: Rc::new(RefCell::new(state)),
}
}
pub fn wait(&self) -> BarrierWaitResult {
let me = ExecutionState::me();
let mut state = self.state.borrow_mut();
trace!(leader=?state.leader, waiters=?state.waiters, "waiting on barrier {:p}", self);
if state.leader.is_none() {
state.leader = Some(me);
}
ExecutionState::with(|s| {
let clock = s.increment_clock();
state.clock.update(clock);
});
if state.waiters.len() + 1 < state.bound {
assert!(state.waiters.insert(me)); ExecutionState::with(|s| s.current_mut().block());
trace!(leader=?state.leader, waiters=?state.waiters, "blocked on barrier {:?}", self);
} else {
trace!(leader=?state.leader, waiters=?state.waiters, "releasing waiters on barrier {:?}", self);
let clock = state.clock.clone();
ExecutionState::with(|s| {
for tid in state.waiters.drain() {
debug_assert_ne!(tid, me);
let t = s.get_mut(tid);
debug_assert!(t.blocked());
t.clock.increment(tid);
t.clock.update(&clock);
t.unblock();
}
let t = s.current_mut();
t.clock.increment(me);
t.clock.update(&clock);
});
}
let result = BarrierWaitResult {
is_leader: state.leader.unwrap() == me,
};
drop(state);
thread::switch();
result
}
}
unsafe impl Send for Barrier {}
unsafe impl Sync for Barrier {}