use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
use parking_lot::{self, UnparkResult};
use mutex::{MutexGuard, guard_lock};
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub struct WaitTimeoutResult(bool);
impl WaitTimeoutResult {
#[inline]
pub fn timed_out(&self) -> bool {
self.0
}
}
pub struct Condvar {
state: AtomicBool,
}
impl Condvar {
#[cfg(feature = "nightly")]
#[inline]
pub const fn new() -> Condvar {
Condvar { state: AtomicBool::new(false) }
}
#[cfg(not(feature = "nightly"))]
#[inline]
pub fn new() -> Condvar {
Condvar { state: AtomicBool::new(false) }
}
#[inline]
pub fn notify_one(&self) {
if !self.state.load(Ordering::Relaxed) {
return;
}
unsafe {
let addr = self as *const _ as usize;
let callback = &mut |result| {
if result != UnparkResult::UnparkedNotLast {
self.state.store(false, Ordering::Relaxed);
}
};
parking_lot::unpark_one(addr, callback);
}
}
#[inline]
pub fn notify_all(&self) {
if !self.state.load(Ordering::Relaxed) {
return;
}
self.state.store(false, Ordering::Relaxed);
unsafe {
let addr = self as *const _ as usize;
parking_lot::unpark_all(addr);
}
}
#[inline]
pub fn wait<T: ?Sized>(&self, guard: &mut MutexGuard<T>) {
unsafe {
let addr = self as *const _ as usize;
let validate = &mut || {
self.state.store(true, Ordering::Relaxed);
true
};
let before_sleep = &mut || {
guard_lock(guard).unlock();
};
parking_lot::park(addr, validate, before_sleep, None);
guard_lock(guard).lock();
}
}
#[inline]
pub fn wait_until<T: ?Sized>(&self,
guard: &mut MutexGuard<T>,
timeout: Instant)
-> WaitTimeoutResult {
unsafe {
let result;
if timeout <= Instant::now() {
guard_lock(guard).unlock();
result = false;
} else {
let addr = self as *const _ as usize;
let validate = &mut || {
self.state.store(true, Ordering::Relaxed);
true
};
let before_sleep = &mut || {
guard_lock(guard).unlock();
};
result = parking_lot::park(addr, validate, before_sleep, Some(timeout));
}
guard_lock(guard).lock();
WaitTimeoutResult(!result)
}
}
#[inline]
pub fn wait_for<T: ?Sized>(&self,
guard: &mut MutexGuard<T>,
timeout: Duration)
-> WaitTimeoutResult {
self.wait_until(guard, Instant::now() + timeout)
}
}
impl Default for Condvar {
#[inline]
fn default() -> Condvar {
Condvar::new()
}
}
#[cfg(test)]
mod tests {
use std::sync::mpsc::channel;
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
use {Condvar, Mutex};
#[test]
fn smoke() {
let c = Condvar::new();
c.notify_one();
c.notify_all();
}
#[test]
fn notify_one() {
lazy_static! {
static ref C: Condvar = Condvar::new();
static ref M: Mutex<()> = Mutex::new(());
}
let mut g = M.lock();
let _t = thread::spawn(move || {
let _g = M.lock();
C.notify_one();
});
C.wait(&mut g);
}
#[test]
fn notify_all() {
const N: usize = 10;
let data = Arc::new((Mutex::new(0), Condvar::new()));
let (tx, rx) = channel();
for _ in 0..N {
let data = data.clone();
let tx = tx.clone();
thread::spawn(move || {
let &(ref lock, ref cond) = &*data;
let mut cnt = lock.lock();
*cnt += 1;
if *cnt == N {
tx.send(()).unwrap();
}
while *cnt != 0 {
cond.wait(&mut cnt);
}
tx.send(()).unwrap();
});
}
drop(tx);
let &(ref lock, ref cond) = &*data;
rx.recv().unwrap();
let mut cnt = lock.lock();
*cnt = 0;
cond.notify_all();
drop(cnt);
for _ in 0..N {
rx.recv().unwrap();
}
}
#[test]
fn wait_for() {
lazy_static! {
static ref C: Condvar = Condvar::new();
static ref M: Mutex<()> = Mutex::new(());
}
let mut g = M.lock();
let no_timeout = C.wait_for(&mut g, Duration::from_millis(1));
assert!(no_timeout.timed_out());
let _t = thread::spawn(move || {
let _g = M.lock();
C.notify_one();
});
let timeout_res = C.wait_for(&mut g, Duration::from_millis(u32::max_value() as u64));
assert!(!timeout_res.timed_out());
drop(g);
}
#[test]
fn wait_until() {
lazy_static! {
static ref C: Condvar = Condvar::new();
static ref M: Mutex<()> = Mutex::new(());
}
let mut g = M.lock();
let no_timeout = C.wait_until(&mut g, Instant::now() + Duration::from_millis(1));
assert!(no_timeout.timed_out());
let _t = thread::spawn(move || {
let _g = M.lock();
C.notify_one();
});
let timeout_res = C.wait_until(&mut g,
Instant::now() +
Duration::from_millis(u32::max_value() as u64));
assert!(!timeout_res.timed_out());
drop(g);
}
}