use crate::current;
use crate::runtime::execution::ExecutionState;
use crate::runtime::task::clock::VectorClock;
use crate::runtime::task::TaskId;
use crate::runtime::thread;
use crate::sync::MutexGuard;
use assoc::AssocExt;
use std::cell::RefCell;
use std::collections::VecDeque;
use std::sync::{LockResult, PoisonError};
use std::time::Duration;
use tracing::trace;
#[derive(Debug)]
pub struct Condvar {
state: RefCell<CondvarState>,
}
#[derive(Debug)]
struct CondvarState {
waiters: Vec<(TaskId, CondvarWaitStatus)>,
next_epoch: usize,
}
#[derive(PartialEq, Eq, Debug)]
enum CondvarWaitStatus {
Waiting,
Signal(VecDeque<(usize, VectorClock)>),
Broadcast(VectorClock),
}
impl Condvar {
pub const fn new() -> Self {
let state = CondvarState {
waiters: Vec::new(),
next_epoch: 0,
};
Self {
state: RefCell::new(state),
}
}
pub fn wait<'a, T>(&self, guard: MutexGuard<'a, T>) -> LockResult<MutexGuard<'a, T>> {
let me = ExecutionState::me();
let mut state = self.state.borrow_mut();
trace!(waiters=?state.waiters, next_epoch=state.next_epoch, "waiting on condvar {:p}", self);
debug_assert!(<_ as AssocExt<_, _>>::get(&state.waiters, &me).is_none());
state.waiters.push((me, CondvarWaitStatus::Waiting));
ExecutionState::with(|s| s.current_mut().block(false));
drop(state);
let mutex = guard.unlock();
let mut state = self.state.borrow_mut();
trace!(waiters=?state.waiters, next_epoch=state.next_epoch, "woken from condvar {:p}", self);
let my_status = <_ as AssocExt<_, _>>::remove(&mut state.waiters, &me).expect("should be waiting");
match my_status {
CondvarWaitStatus::Broadcast(clock) => {
ExecutionState::with(|s| s.update_clock(&clock));
}
CondvarWaitStatus::Signal(mut epochs) => {
let (epoch, clock) = epochs.pop_front().expect("should be a pending signal");
for (tid, status) in state.waiters.iter_mut() {
if let CondvarWaitStatus::Signal(epochs) = status {
if let Some(i) = epochs.iter().position(|e| epoch == e.0) {
epochs.remove(i);
if epochs.is_empty() {
*status = CondvarWaitStatus::Waiting;
ExecutionState::with(|s| s.get_mut(*tid).block(false));
}
}
}
}
ExecutionState::with(|s| s.update_clock(&clock));
}
CondvarWaitStatus::Waiting => panic!("should not have been woken while in Waiting status"),
}
drop(state);
mutex.lock()
}
pub fn wait_while<'a, T, F>(&self, mut guard: MutexGuard<'a, T>, mut condition: F) -> LockResult<MutexGuard<'a, T>>
where
F: FnMut(&mut T) -> bool,
{
while condition(&mut *guard) {
guard = self.wait(guard)?;
}
Ok(guard)
}
pub fn wait_timeout<'a, T>(
&self,
guard: MutexGuard<'a, T>,
_dur: Duration,
) -> LockResult<(MutexGuard<'a, T>, WaitTimeoutResult)> {
self.wait(guard)
.map(|guard| (guard, WaitTimeoutResult(false)))
.map_err(|e| PoisonError::new((e.into_inner(), WaitTimeoutResult(false))))
}
pub fn wait_timeout_while<'a, T, F>(
&self,
guard: MutexGuard<'a, T>,
_dur: Duration,
condition: F,
) -> LockResult<(MutexGuard<'a, T>, WaitTimeoutResult)>
where
F: FnMut(&mut T) -> bool,
{
self.wait_while(guard, condition)
.map(|guard| (guard, WaitTimeoutResult(false)))
.map_err(|e| PoisonError::new((e.into_inner(), WaitTimeoutResult(false))))
}
pub fn notify_one(&self) {
let me = ExecutionState::me();
let mut state = self.state.borrow_mut();
trace!(waiters=?state.waiters, next_epoch=state.next_epoch, "notifying one on condvar {:p}", self);
let epoch = state.next_epoch;
for (tid, status) in state.waiters.iter_mut() {
assert_ne!(*tid, me);
let clock = current::clock();
match status {
CondvarWaitStatus::Waiting => {
let mut epochs = VecDeque::new();
epochs.push_back((epoch, clock));
*status = CondvarWaitStatus::Signal(epochs);
}
CondvarWaitStatus::Signal(epochs) => {
epochs.push_back((epoch, clock));
}
CondvarWaitStatus::Broadcast(_) => {
}
}
ExecutionState::with(|s| s.get_mut(*tid).unblock());
}
state.next_epoch += 1;
drop(state);
thread::switch();
}
pub fn notify_all(&self) {
let me = ExecutionState::me();
let mut state = self.state.borrow_mut();
trace!(waiters=?state.waiters, next_epoch=state.next_epoch, "notifying all on condvar {:p}", self);
for (tid, status) in state.waiters.iter_mut() {
assert_ne!(*tid, me);
*status = CondvarWaitStatus::Broadcast(current::clock());
ExecutionState::with(|s| s.get_mut(*tid).unblock());
}
drop(state);
thread::switch();
}
}
unsafe impl Send for Condvar {}
unsafe impl Sync for Condvar {}
impl Default for Condvar {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub struct WaitTimeoutResult(bool);
impl WaitTimeoutResult {
pub fn timed_out(&self) -> bool {
self.0
}
}