use super::WakerEntry;
use std::cell::Cell;
use std::collections::{BTreeSet, HashMap};
use std::rc::Weak;
use std::task::Waker;
use std::time::Instant;
#[derive(Clone, Debug, Default)]
pub struct ScheduledWakerQueue {
wakers_by_time: BTreeSet<(Instant, WakerEntry)>,
waker_to_time: HashMap<WakerEntry, Instant>,
}
impl ScheduledWakerQueue {
#[inline(always)]
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[inline(always)]
#[must_use]
pub fn len(&self) -> usize {
self.wakers_by_time.len()
}
#[inline(always)]
#[must_use]
pub fn is_empty(&self) -> bool {
self.wakers_by_time.is_empty()
}
pub fn clear(&mut self) {
self.wakers_by_time.clear();
self.waker_to_time.clear();
#[cfg(debug_assertions)]
self.validate();
}
pub fn push(&mut self, wake_time: Instant, waker: Weak<Cell<Option<Waker>>>) -> bool {
let waker_entry = WakerEntry(waker);
if !waker_entry.is_alive() {
return false;
}
self.trim_to_next_wake_time();
if self.len() == self.waker_to_time.capacity() {
self.wakers_by_time.retain(|(_wake_time, waker_entry)| {
if waker_entry.is_alive() {
true
} else {
self.waker_to_time.remove(waker_entry);
false
}
});
self.waker_to_time
.shrink_to(std::cmp::max(8, self.len() * 2));
self.waker_to_time.reserve(self.len());
}
let pushed = match self.waker_to_time.get_mut(&waker_entry) {
None => {
self.waker_to_time.insert(waker_entry.clone(), wake_time);
self.wakers_by_time.insert((wake_time, waker_entry))
}
Some(wake_time_entry) => {
if *wake_time_entry <= wake_time {
false
} else {
let old_wake_time = std::mem::replace(wake_time_entry, wake_time);
let waker_entry = self
.wakers_by_time
.take(&(old_wake_time, waker_entry))
.unwrap()
.1;
self.wakers_by_time.insert((wake_time, waker_entry))
}
}
};
#[cfg(debug_assertions)]
self.validate();
pushed
}
pub fn next_wake_time(&self) -> Option<Instant> {
self.wakers_by_time
.iter()
.find(|&(_, entry)| entry.is_alive())
.map(|(wake_time, _)| *wake_time)
}
pub fn trim_to_next_wake_time(&mut self) -> Option<Instant> {
let mut next_wake_time = None;
while let Some((wake_time, waker_entry)) = self.wakers_by_time.first() {
if waker_entry.is_alive() {
next_wake_time = Some(*wake_time);
break;
}
self.waker_to_time.remove(waker_entry);
self.wakers_by_time.pop_first();
}
#[cfg(debug_assertions)]
self.validate();
next_wake_time
}
pub fn wake(&mut self, now: Instant) {
while let Some((wake_time, waker_entry)) = self.wakers_by_time.first() {
if *wake_time > now {
break;
}
self.waker_to_time.remove(waker_entry);
let waker_entry = self.wakers_by_time.pop_first().unwrap().1;
if let Some(waker) = waker_entry.0.upgrade().and_then(|cell| cell.take()) {
waker.wake();
}
}
#[cfg(debug_assertions)]
self.validate();
}
#[cfg(debug_assertions)]
fn validate(&self) {
assert_eq!(self.wakers_by_time.len(), self.waker_to_time.len());
for (wake_time, entry) in &self.wakers_by_time {
assert_eq!(*wake_time, self.waker_to_time[entry]);
}
for (entry, wake_time) in &self.waker_to_time {
assert!(self.wakers_by_time.contains(&(*wake_time, entry.clone())));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helper::WakeFlag;
use std::rc::Rc;
use std::sync::Arc;
use std::time::Duration;
fn dummy_waker() -> Rc<Cell<Option<Waker>>> {
Rc::new(Cell::new(Some(Waker::noop().clone())))
}
#[test]
fn queue_is_initially_empty() {
let queue = ScheduledWakerQueue::new();
assert!(queue.is_empty());
assert_eq!(queue.len(), 0);
}
#[test]
fn pushed_wakers_are_stored_in_queue() {
let mut queue = ScheduledWakerQueue::new();
let waker = dummy_waker();
let pushed = queue.push(Instant::now(), Rc::downgrade(&waker));
assert!(pushed);
assert!(!queue.is_empty());
assert_eq!(queue.len(), 1);
assert_eq!(Rc::strong_count(&waker), 1);
assert_eq!(Rc::weak_count(&waker), 2);
let another_waker = dummy_waker();
let pushed = queue.push(Instant::now(), Rc::downgrade(&another_waker));
assert!(pushed);
assert!(!queue.is_empty());
assert_eq!(queue.len(), 2);
assert_eq!(Rc::strong_count(&another_waker), 1);
assert_eq!(Rc::weak_count(&another_waker), 2);
}
#[test]
fn queue_is_empty_after_cleared() {
let mut queue = ScheduledWakerQueue::new();
let waker = dummy_waker();
queue.push(Instant::now(), Rc::downgrade(&waker));
queue.clear();
assert!(queue.is_empty());
assert_eq!(queue.len(), 0);
}
#[test]
fn pushing_existing_waker_with_earlier_wake_time_discards_existing_waker() {
let mut queue = ScheduledWakerQueue::new();
let now = Instant::now();
let waker = dummy_waker();
queue.push(now + Duration::from_secs(5), Rc::downgrade(&waker));
let pushed = queue.push(now + Duration::from_secs(3), Rc::downgrade(&waker));
assert!(pushed);
assert_eq!(queue.len(), 1);
assert_eq!(queue.next_wake_time(), Some(now + Duration::from_secs(3)));
}
#[test]
fn pushing_existing_waker_with_later_wake_time_discards_new_waker() {
let mut queue = ScheduledWakerQueue::new();
let now = Instant::now();
let waker = dummy_waker();
queue.push(now + Duration::from_secs(3), Rc::downgrade(&waker));
let pushed = queue.push(now + Duration::from_secs(5), Rc::downgrade(&waker));
assert!(!pushed);
assert_eq!(queue.len(), 1);
assert_eq!(queue.next_wake_time(), Some(now + Duration::from_secs(3)));
}
#[test]
fn pushing_dead_waker_is_noop() {
let mut queue = ScheduledWakerQueue::new();
let now = Instant::now();
let pushed = queue.push(now, Weak::new());
assert!(!pushed);
assert!(queue.is_empty());
let waker = dummy_waker();
waker.take();
let pushed = queue.push(now, Rc::downgrade(&waker));
assert!(!pushed);
assert!(queue.is_empty());
}
#[test]
fn next_wake_time_returns_none_if_empty() {
let queue = ScheduledWakerQueue::new();
assert_eq!(queue.next_wake_time(), None);
}
#[test]
fn next_wake_time_returns_earliest_pending_waker_time() {
let mut queue = ScheduledWakerQueue::new();
let now = Instant::now();
let waker_1 = dummy_waker();
let waker_2 = dummy_waker();
let waker_3 = dummy_waker();
assert_eq!(queue.next_wake_time(), None);
queue.push(now + Duration::from_secs(5), Rc::downgrade(&waker_1));
assert_eq!(queue.next_wake_time(), Some(now + Duration::from_secs(5)));
queue.push(now + Duration::from_secs(3), Rc::downgrade(&waker_2));
assert_eq!(queue.next_wake_time(), Some(now + Duration::from_secs(3)));
queue.push(now + Duration::from_secs(10), Rc::downgrade(&waker_3));
assert_eq!(queue.next_wake_time(), Some(now + Duration::from_secs(3)));
}
#[test]
fn next_wake_time_ignores_dead_wakers() {
let mut queue = ScheduledWakerQueue::new();
let now = Instant::now();
let waker_1 = dummy_waker();
let waker_2 = dummy_waker();
let waker_3 = dummy_waker();
queue.push(now, Rc::downgrade(&waker_1));
queue.push(now, Rc::downgrade(&waker_2));
queue.push(now + Duration::from_secs(5), Rc::downgrade(&waker_3));
drop(waker_1);
waker_2.take();
assert_eq!(queue.next_wake_time(), Some(now + Duration::from_secs(5)));
}
#[test]
fn trim_to_next_wake_time_removes_leading_dead_wakers() {
let mut queue = ScheduledWakerQueue::new();
let now = Instant::now();
let waker_1 = dummy_waker();
let waker_2 = dummy_waker();
let waker_3 = dummy_waker();
queue.push(now, Rc::downgrade(&waker_1));
queue.push(now + Duration::from_secs(3), Rc::downgrade(&waker_2));
queue.push(now + Duration::from_secs(5), Rc::downgrade(&waker_3));
drop(waker_1);
waker_2.take();
queue.trim_to_next_wake_time();
assert_eq!(queue.len(), 1);
assert_eq!(queue.next_wake_time(), Some(now + Duration::from_secs(5)));
}
#[test]
fn trim_to_next_wake_time_returns_next_wake_time() {
let mut queue = ScheduledWakerQueue::new();
let now = Instant::now();
let waker_1 = dummy_waker();
let waker_2 = dummy_waker();
let waker_3 = dummy_waker();
queue.push(now, Rc::downgrade(&waker_1));
queue.push(now + Duration::from_secs(3), Rc::downgrade(&waker_2));
queue.push(now + Duration::from_secs(5), Rc::downgrade(&waker_3));
drop(waker_1);
waker_2.take();
let next_wake_time = queue.trim_to_next_wake_time();
assert_eq!(next_wake_time, Some(now + Duration::from_secs(5)));
drop(waker_3);
let next_wake_time = queue.trim_to_next_wake_time();
assert_eq!(next_wake_time, None);
}
#[test]
fn wake_removes_all_wakers_up_to_given_time() {
let mut queue = ScheduledWakerQueue::new();
let now = Instant::now();
let waker_1 = dummy_waker();
let waker_2 = dummy_waker();
let waker_3 = dummy_waker();
queue.push(now + Duration::from_secs(3), Rc::downgrade(&waker_1));
queue.push(now + Duration::from_secs(5), Rc::downgrade(&waker_2));
queue.push(now + Duration::from_secs(6), Rc::downgrade(&waker_3));
queue.wake(now + Duration::from_secs(5));
assert_eq!(queue.len(), 1);
assert_eq!(Rc::weak_count(&waker_1), 0);
assert_eq!(Rc::weak_count(&waker_2), 0);
assert_eq!(Rc::weak_count(&waker_3), 2);
}
#[test]
fn wake_activates_all_wakers_up_to_given_time() {
let mut queue = ScheduledWakerQueue::new();
let now = Instant::now();
let wake_flag_1 = Arc::new(WakeFlag::new());
let wake_flag_2 = Arc::new(WakeFlag::new());
let wake_flag_3 = Arc::new(WakeFlag::new());
let waker_1 = Rc::new(Cell::new(Some(Waker::from(wake_flag_1.clone()))));
let waker_2 = Rc::new(Cell::new(Some(Waker::from(wake_flag_2.clone()))));
let waker_3 = Rc::new(Cell::new(Some(Waker::from(wake_flag_3.clone()))));
queue.push(now + Duration::from_secs(3), Rc::downgrade(&waker_1));
queue.push(now + Duration::from_secs(5), Rc::downgrade(&waker_2));
queue.push(now + Duration::from_secs(6), Rc::downgrade(&waker_3));
queue.wake(now + Duration::from_secs(5));
assert!(wake_flag_1.is_woken());
assert!(wake_flag_2.is_woken());
assert!(!wake_flag_3.is_woken());
}
#[test]
fn complex_pushes_and_wakes() {
let mut queue = ScheduledWakerQueue::new();
let now = Instant::now();
let wake_flag_1 = Arc::new(WakeFlag::new());
let wake_flag_2 = Arc::new(WakeFlag::new());
let wake_flag_3 = Arc::new(WakeFlag::new());
let waker_1 = Rc::new(Cell::new(Some(Waker::from(wake_flag_1.clone()))));
let waker_2 = Rc::new(Cell::new(Some(Waker::from(wake_flag_2.clone()))));
let waker_3 = Rc::new(Cell::new(Some(Waker::from(wake_flag_3.clone()))));
queue.push(now + Duration::from_secs(5), Rc::downgrade(&waker_1));
queue.push(now + Duration::from_secs(3), Rc::downgrade(&waker_2));
queue.push(now + Duration::from_secs(10), Rc::downgrade(&waker_3));
queue.wake(now + Duration::from_secs(5));
assert!(wake_flag_1.is_woken());
assert!(wake_flag_2.is_woken());
assert!(!wake_flag_3.is_woken());
assert_eq!(queue.next_wake_time(), Some(now + Duration::from_secs(10)));
queue.wake(now + Duration::from_secs(15));
assert!(wake_flag_3.is_woken());
assert_eq!(queue.next_wake_time(), None);
}
#[test]
fn push_trims_earliest_dead_entries() {
let mut queue = ScheduledWakerQueue::new();
let now = Instant::now();
let wake_flag_1 = Arc::new(WakeFlag::new());
let wake_flag_2 = Arc::new(WakeFlag::new());
let wake_flag_3 = Arc::new(WakeFlag::new());
let wake_flag_4 = Arc::new(WakeFlag::new());
let waker_1 = Rc::new(Cell::new(Some(Waker::from(wake_flag_1))));
let waker_2 = Rc::new(Cell::new(Some(Waker::from(wake_flag_2))));
let waker_3 = Rc::new(Cell::new(Some(Waker::from(wake_flag_3.clone()))));
let waker_4 = Rc::new(Cell::new(Some(Waker::from(wake_flag_4.clone()))));
queue.push(now + Duration::from_secs(3), Rc::downgrade(&waker_1));
queue.push(now + Duration::from_secs(5), Rc::downgrade(&waker_2));
queue.push(now + Duration::from_secs(7), Rc::downgrade(&waker_3));
waker_1.take().unwrap().wake();
waker_2.take().unwrap().wake();
queue.push(now + Duration::from_secs(1), Rc::downgrade(&waker_4));
assert_eq!(queue.len(), 2);
queue.wake(now + Duration::from_secs(10));
assert!(wake_flag_3.is_woken());
assert!(wake_flag_4.is_woken());
}
#[test]
fn push_cleans_up_all_dead_entries_if_full() {
let mut queue = ScheduledWakerQueue::new();
let now = Instant::now();
let waker_1 = Rc::new(Cell::new(Some(Waker::noop().clone())));
let waker_2 = Rc::new(Cell::new(Some(Waker::noop().clone())));
let waker_3 = Rc::new(Cell::new(Some(Waker::noop().clone())));
let waker_4 = Rc::new(Cell::new(Some(Waker::noop().clone())));
queue.waker_to_time.reserve(10);
queue.push(now + Duration::from_secs(3), Rc::downgrade(&waker_1));
while queue.len() + 1 < queue.waker_to_time.capacity() {
let waker = dummy_waker();
queue.push(
now + Duration::new(3, queue.len() as u32),
Rc::downgrade(&waker),
);
}
queue.push(now + Duration::from_secs(4), Rc::downgrade(&waker_2));
assert_eq!(queue.len(), queue.waker_to_time.capacity());
queue.push(now + Duration::from_secs(5), Rc::downgrade(&waker_3));
assert_eq!(queue.len(), 3);
waker_3.take().unwrap().wake();
queue.push(now + Duration::from_secs(6), Rc::downgrade(&waker_4));
assert_eq!(queue.len(), 4);
}
}