use crate::MutexGuard;
use crate::futex::{futex_wait, futex_wake};
use lock_api;
use lock_api::RawMutex as RawMutexTrait;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct WaitTimeoutResult(pub(crate) bool);
impl WaitTimeoutResult {
#[inline]
pub fn timed_out(self) -> bool {
self.0
}
}
pub struct Condvar {
seq: AtomicU32,
}
impl Condvar {
pub const fn new() -> Self {
Condvar { seq: AtomicU32::new(0) }
}
pub fn wait<T>(&self, guard: &mut MutexGuard<'_, T>) {
let seq = self.seq.load(Ordering::SeqCst);
let mutex = lock_api::MutexGuard::mutex(guard);
unsafe { mutex.force_unlock() };
futex_wait(&self.seq, seq, None);
unsafe { mutex.raw().lock() };
}
pub fn wait_for<T>(
&self,
guard: &mut MutexGuard<'_, T>,
timeout: Duration,
) -> WaitTimeoutResult {
let seq = self.seq.load(Ordering::SeqCst);
let deadline = Instant::now() + timeout;
let mutex = lock_api::MutexGuard::mutex(guard);
unsafe { mutex.force_unlock() };
let timed_out = loop {
let now = Instant::now();
if now >= deadline {
break true;
}
let remaining = deadline - now;
let woke = futex_wait(&self.seq, seq, Some(remaining));
if !woke {
break true;
}
if self.seq.load(Ordering::Relaxed) != seq {
break false;
}
if Instant::now() >= deadline {
break true;
}
};
unsafe { mutex.raw().lock() };
WaitTimeoutResult(timed_out)
}
#[inline]
pub fn notify_one(&self) {
self.seq.fetch_add(1, Ordering::SeqCst);
futex_wake(&self.seq, 1);
}
#[inline]
pub fn notify_all(&self) {
self.seq.fetch_add(1, Ordering::SeqCst);
futex_wake(&self.seq, i32::MAX as u32);
}
}
impl std::fmt::Debug for Condvar {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Condvar").finish_non_exhaustive()
}
}
impl Default for Condvar {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Mutex;
use std::sync::Arc;
use std::time::Duration;
#[test]
fn test_notify_one_wakes_waiter() {
let mutex = Arc::new(Mutex::new(false));
let condvar = Arc::new(Condvar::new());
let m2 = mutex.clone();
let cv2 = condvar.clone();
let handle = std::thread::spawn(move || {
let mut guard = m2.lock();
while !*guard {
cv2.wait(&mut guard);
}
true
});
std::thread::sleep(Duration::from_millis(20));
{
let mut guard = mutex.lock();
*guard = true;
condvar.notify_one();
}
assert!(handle.join().unwrap());
}
#[test]
fn test_notify_all_wakes_all_waiters() {
let mutex = Arc::new(Mutex::new(0usize));
let condvar = Arc::new(Condvar::new());
let mut handles = Vec::new();
for _ in 0..4 {
let m = mutex.clone();
let cv = condvar.clone();
handles.push(std::thread::spawn(move || {
let mut guard = m.lock();
while *guard == 0 {
cv.wait(&mut guard);
}
}));
}
std::thread::sleep(Duration::from_millis(30));
{
let mut guard = mutex.lock();
*guard = 1;
condvar.notify_all();
}
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_wait_for_times_out() {
let mutex = Arc::new(Mutex::new(()));
let condvar = Arc::new(Condvar::new());
let mut guard = mutex.lock();
let result = condvar.wait_for(&mut guard, Duration::from_millis(30));
assert!(result.timed_out(), "should have timed out");
}
#[test]
fn test_wait_for_notified_before_timeout() {
let mutex = Arc::new(Mutex::new(()));
let condvar = Arc::new(Condvar::new());
let cv2 = condvar.clone();
let handle = std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(10));
cv2.notify_one();
});
let mut guard = mutex.lock();
let result = condvar.wait_for(&mut guard, Duration::from_millis(500));
assert!(!result.timed_out(), "should have been notified");
handle.join().unwrap();
}
}