use std::sync::atomic::{AtomicUsize, Ordering};
use sys::sync as sys;
use poison::LockResult;
use {mutex, MutexGuard, PoisonError};
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub struct WaitTimeoutResult(bool);
impl WaitTimeoutResult {
pub fn timed_out(&self) -> bool {
self.0
}
}
pub struct Condvar { inner: Box<StaticCondvar> }
pub struct StaticCondvar {
inner: sys::Condvar,
mutex: AtomicUsize,
}
impl Condvar {
pub fn new() -> Condvar {
Condvar {
inner: Box::new(StaticCondvar {
inner: sys::Condvar::new(),
mutex: AtomicUsize::new(0),
})
}
}
pub fn wait<'a, T>(&self, guard: MutexGuard<'a, T>)
-> LockResult<MutexGuard<'a, T>> {
unsafe {
let me: &'static Condvar = &*(self as *const _);
me.inner.wait(guard)
}
}
pub fn notify_one(&self) { unsafe { self.inner.inner.notify_one() } }
pub fn notify_all(&self) { unsafe { self.inner.inner.notify_all() } }
}
impl Drop for Condvar {
fn drop(&mut self) {
unsafe { self.inner.inner.destroy() }
}
}
impl StaticCondvar {
pub fn wait<'a, T>(&'static self, guard: MutexGuard<'a, T>)
-> LockResult<MutexGuard<'a, T>> {
let poisoned = unsafe {
let lock = mutex::guard_lock(&guard);
self.verify(lock);
self.inner.wait(lock);
mutex::guard_poison(&guard).get()
};
if poisoned {
Err(PoisonError::new(guard))
} else {
Ok(guard)
}
}
fn verify(&self, mutex: &sys::Mutex) {
let addr = mutex as *const _ as usize;
match self.mutex.compare_and_swap(0, addr, Ordering::SeqCst) {
0 => {}
n if n == addr => {}
_ => panic!("attempted to use a condition variable with two \
mutexes"),
}
}
}
#[cfg(test)]
mod tests {
use std::sync::mpsc::channel;
use std::sync::Arc;
use std::thread;
use {Mutex, Condvar};
#[test]
fn smoke() {
let c = Condvar::new();
c.notify_one();
c.notify_all();
}
#[test]
fn notify_one() {
let c = Arc::new(Condvar::new());
let m = Arc::new(Mutex::new(()));
let g = m.lock().unwrap();
let m2 = m.clone();
let c2 = c.clone();
let _t = thread::spawn(move|| {
let _g = m2.lock().unwrap();
c2.notify_one();
});
let g = c.wait(g).unwrap();
drop(g);
}
#[test]
fn notify_all() {
const N: usize = 10;
let data = Arc::new((Mutex::new(0), Condvar::new()));
let (tx, rx) = channel();
for _ in 0..N {
let data = data.clone();
let tx = tx.clone();
thread::spawn(move|| {
let &(ref lock, ref cond) = &*data;
let mut cnt = lock.lock().unwrap();
*cnt += 1;
if *cnt == N {
tx.send(()).unwrap();
}
while *cnt != 0 {
cnt = cond.wait(cnt).unwrap();
}
tx.send(()).unwrap();
});
}
drop(tx);
let &(ref lock, ref cond) = &*data;
rx.recv().unwrap();
let mut cnt = lock.lock().unwrap();
*cnt = 0;
cond.notify_all();
drop(cnt);
for _ in 0..N {
rx.recv().unwrap();
}
}
#[test]
#[should_panic]
fn two_mutexes() {
let m1 = Arc::new(Mutex::new(()));
let m2 = Arc::new(Mutex::new(()));
let c = Arc::new(Condvar::new());
let mut g = m1.lock().unwrap();
let m1_2 = m1.clone();
let c2 = c.clone();
let _t = thread::spawn(move|| {
let _g = m1_2.lock().unwrap();
c2.notify_one();
});
g = c.wait(g).unwrap();
drop(g);
let _ = c.wait(m2.lock().unwrap()).unwrap();
}
}