use std::sync::{Arc, Condvar, Mutex};
const NOTIFY_TIMEOUT_DEFAULT_MILLIS: u32 = 2000;
struct NotifyState {
count: u16,
pending: u16,
}
pub(super) struct Notify {
state: Mutex<NotifyState>,
cov: Condvar,
timeout: u32,
}
impl NotifyState {
}
impl Notify {
pub(super) fn new() -> Arc<Self> {
Self {
state: Mutex::new(NotifyState {
count: 0,
pending: 0,
}),
cov: Condvar::new(),
timeout: NOTIFY_TIMEOUT_DEFAULT_MILLIS,
}
.into()
}
#[allow(dead_code)]
pub(super) fn with_timeout(timeout_millis: u32) -> Arc<Self> {
Self {
state: Mutex::new(NotifyState {
count: 0,
pending: 0,
}),
cov: Condvar::new(),
timeout: timeout_millis,
}
.into()
}
pub(super) fn notified(&self) -> Result<(), ()> {
let mut state = self.state.lock().unwrap();
if state.pending > 0 {
state.pending -= 1;
return Ok(());
}
let current = state.count;
let (mut state, timeout) = self
.cov
.wait_timeout_while(
state,
std::time::Duration::from_millis(self.timeout as u64),
|s| s.count == current,
)
.unwrap();
if timeout.timed_out() {
return Err(());
}
if state.pending > 0 {
state.pending -= 1;
}
Ok(())
}
pub(super) fn notify_one(&self) {
let mut state = self.state.lock().unwrap();
state.count += 1;
state.pending = state.pending.saturating_add(1);
self.cov.notify_one();
}
pub(super) fn notify_waiters(&self) {
let mut generation = self.state.lock().unwrap();
generation.count += 1;
self.cov.notify_all();
}
}
impl std::fmt::Debug for Notify {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Notify").finish()
}
}
mod tests {
#[test]
fn test_notify_one() {
use super::*;
use std::sync::{
Arc, Barrier,
atomic::{AtomicBool, Ordering},
};
use std::thread;
use std::time::Duration;
let notify = Notify::new();
let flag = Arc::new(AtomicBool::new(false));
let barrier = Arc::new(Barrier::new(2));
let notify_clone = notify.clone();
let flag_clone = flag.clone();
let barrier_clone = barrier.clone();
let handle = thread::spawn(move || {
barrier_clone.wait();
notify_clone.notified().expect("notified failed");
flag_clone.store(true, Ordering::SeqCst);
});
barrier.wait();
thread::sleep(Duration::from_millis(50));
notify.notify_one();
handle.join().expect("Thread panicked");
assert!(
flag.load(Ordering::SeqCst),
"notify_one did not wake the waiting thread"
);
}
#[test]
fn test_notify_all() {
use super::*;
use std::sync::{
Arc, Barrier,
atomic::{AtomicUsize, Ordering},
};
use std::thread;
use std::time::Duration;
const THREAD_COUNT: usize = 50;
let notify = Arc::new(Notify::new());
let barrier = Arc::new(Barrier::new(THREAD_COUNT + 1));
let counter = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::with_capacity(THREAD_COUNT);
for _ in 0..THREAD_COUNT {
let notify_clone = notify.clone();
let barrier_clone = barrier.clone();
let counter_clone = counter.clone();
handles.push(thread::spawn(move || {
barrier_clone.wait();
notify_clone.notified().expect("notified failed");
counter_clone.fetch_add(1, Ordering::SeqCst);
}));
}
barrier.wait();
thread::sleep(Duration::from_millis(50));
notify.notify_waiters();
for handle in handles {
handle.join().expect("Thread panicked");
}
assert_eq!(
counter.load(Ordering::SeqCst),
THREAD_COUNT,
"not all waiters were notified"
);
}
}