use std::collections::{BTreeMap, HashMap};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::task::Waker;
use std::thread::JoinHandle;
use std::time::{Duration, Instant};
use crate::Error;
pub(crate) struct OutputNotifier {
tick: Mutex<TickState>,
tick_cv: Condvar,
timer: Mutex<TimerState>,
timer_cv: Condvar,
shutdown: AtomicBool,
}
struct TickState {
#[cfg_attr(
not(feature = "async"),
expect(
dead_code,
reason = "key allocator only used by async tick registration"
)
)]
next_key: u64,
wakers: Vec<(u64, Waker)>,
}
struct TimerState {
#[cfg_attr(
not(feature = "async"),
expect(
dead_code,
reason = "key allocator only used by async deadline registration"
)
)]
next_key: u64,
by_deadline: BTreeMap<(Instant, u64), Waker>,
by_key: HashMap<u64, Instant>,
}
impl OutputNotifier {
const TIMER_THREAD_NAME: &'static str = "tastty-driver-output-timer";
pub(crate) fn new() -> Result<(Arc<Self>, JoinHandle<()>), Error> {
let notifier = Arc::new(Self {
tick: Mutex::new(TickState {
next_key: 0,
wakers: Vec::new(),
}),
tick_cv: Condvar::new(),
timer: Mutex::new(TimerState {
next_key: 0,
by_deadline: BTreeMap::new(),
by_key: HashMap::new(),
}),
timer_cv: Condvar::new(),
shutdown: AtomicBool::new(false),
});
let handle = {
let notifier = Arc::clone(¬ifier);
#[cfg(test)]
if consume_thread_spawn_failure(Self::TIMER_THREAD_NAME) {
return Err(Error::ThreadSpawn {
name: Self::TIMER_THREAD_NAME,
source: std::io::Error::other("injected thread spawn failure"),
});
}
std::thread::Builder::new()
.name(Self::TIMER_THREAD_NAME.into())
.spawn(move || timer_loop(notifier))
.map_err(|source| Error::ThreadSpawn {
name: Self::TIMER_THREAD_NAME,
source,
})?
};
Ok((notifier, handle))
}
pub(crate) fn notify_tick(&self) {
let wakers = {
let mut state = self
.tick
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
std::mem::take(&mut state.wakers)
};
self.tick_cv.notify_all();
for (_, waker) in wakers {
waker.wake();
}
}
pub(crate) fn shutdown(&self) {
self.shutdown.store(true, Ordering::Release);
let tick_wakers = {
let mut state = self
.tick
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
std::mem::take(&mut state.wakers)
};
let deadline_wakers = {
let mut state = self
.timer
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
state.by_key.clear();
std::mem::take(&mut state.by_deadline)
.into_values()
.collect::<Vec<_>>()
};
self.tick_cv.notify_all();
self.timer_cv.notify_all();
for (_, waker) in tick_wakers {
waker.wake();
}
for waker in deadline_wakers {
waker.wake();
}
}
pub(crate) fn wait_tick_blocking(&self, max_wait: Duration) {
if self.shutdown.load(Ordering::Acquire) {
return;
}
let guard = self
.tick
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
drop(self.tick_cv.wait_timeout(guard, max_wait));
}
#[cfg(feature = "async")]
pub(crate) fn register_tick(&self, slot: &mut Option<u64>, waker: &Waker) -> bool {
let mut state = self
.tick
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if self.shutdown.load(Ordering::Acquire) {
return false;
}
if let Some(key) = *slot {
for (k, w) in state.wakers.iter_mut() {
if *k == key {
if !w.will_wake(waker) {
*w = waker.clone();
}
return true;
}
}
}
let key = state.next_key;
state.next_key += 1;
state.wakers.push((key, waker.clone()));
*slot = Some(key);
true
}
#[cfg(feature = "async")]
pub(crate) fn unregister_tick(&self, key: u64) {
let mut state = self
.tick
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
state.wakers.retain(|(k, _)| *k != key);
}
#[cfg(feature = "async")]
pub(crate) fn register_deadline(
&self,
slot: &mut Option<u64>,
deadline: Instant,
waker: &Waker,
) -> bool {
let mut state = self
.timer
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if self.shutdown.load(Ordering::Acquire) {
return false;
}
if let Some(key) = *slot
&& let Some(&old_deadline) = state.by_key.get(&key)
{
if old_deadline == deadline {
if let Some(w) = state.by_deadline.get_mut(&(old_deadline, key))
&& !w.will_wake(waker)
{
*w = waker.clone();
}
return true;
}
state.by_deadline.remove(&(old_deadline, key));
state.by_key.remove(&key);
state.by_deadline.insert((deadline, key), waker.clone());
state.by_key.insert(key, deadline);
drop(state);
self.timer_cv.notify_all();
return true;
}
let key = state.next_key;
state.next_key += 1;
state.by_deadline.insert((deadline, key), waker.clone());
state.by_key.insert(key, deadline);
*slot = Some(key);
drop(state);
self.timer_cv.notify_all();
true
}
#[cfg(feature = "async")]
pub(crate) fn unregister_deadline(&self, key: u64) {
let mut state = self
.timer
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if let Some(deadline) = state.by_key.remove(&key) {
state.by_deadline.remove(&(deadline, key));
}
}
}
#[cfg(test)]
thread_local! {
static THREAD_SPAWN_FAILURES: std::cell::RefCell<Vec<&'static str>> =
const { std::cell::RefCell::new(Vec::new()) };
}
#[cfg(test)]
fn fail_next_thread_spawn(name: &'static str) {
THREAD_SPAWN_FAILURES.with(|failures| failures.borrow_mut().push(name));
}
#[cfg(test)]
fn consume_thread_spawn_failure(name: &'static str) -> bool {
THREAD_SPAWN_FAILURES.with(|failures| {
let mut failures = failures.borrow_mut();
let Some(pos) = failures.iter().position(|candidate| *candidate == name) else {
return false;
};
failures.remove(pos);
true
})
}
fn timer_loop(notifier: Arc<OutputNotifier>) {
loop {
if notifier.shutdown.load(Ordering::Acquire) {
return;
}
let mut state = notifier
.timer
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let next = state.by_deadline.keys().next().copied();
match next {
None => {
state = notifier
.timer_cv
.wait(state)
.unwrap_or_else(|p| p.into_inner());
drop(state);
}
Some((deadline, _)) => {
let now = Instant::now();
if now < deadline {
let dur = deadline - now;
let (guard, _) = notifier
.timer_cv
.wait_timeout(state, dur)
.unwrap_or_else(|p| p.into_inner());
drop(guard);
} else {
let mut to_wake = Vec::new();
while let Some((&(d, k), _)) = state.by_deadline.iter().next() {
if d > now {
break;
}
let waker = state.by_deadline.remove(&(d, k)).expect("entry exists");
state.by_key.remove(&k);
to_wake.push(waker);
}
drop(state);
for waker in to_wake {
waker.wake();
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_returns_thread_spawn_errors() {
fail_next_thread_spawn(OutputNotifier::TIMER_THREAD_NAME);
let Err(err) = OutputNotifier::new() else {
panic!("output notifier unexpectedly spawned timer thread");
};
assert!(
matches!(
err,
Error::ThreadSpawn {
name: OutputNotifier::TIMER_THREAD_NAME,
..
}
),
"expected output timer thread spawn failure, got {err:?}",
);
}
}