use super::WakerEntry;
use std::cell::Cell;
use std::collections::HashSet;
use std::rc::Weak;
use std::task::Waker;
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct WakerSet {
wakers: HashSet<WakerEntry>,
}
impl WakerSet {
#[inline(always)]
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[inline(always)]
#[must_use]
pub fn capacity(&self) -> usize {
self.wakers.capacity()
}
#[inline(always)]
pub fn reserve(&mut self, additional: usize) {
self.wakers.reserve(additional)
}
#[inline(always)]
pub fn shrink_to_fit(&mut self) {
self.wakers.shrink_to_fit()
}
#[inline(always)]
pub fn shrink_to(&mut self, min_capacity: usize) {
self.wakers.shrink_to(min_capacity)
}
#[inline(always)]
#[must_use]
pub fn len(&self) -> usize {
self.wakers.len()
}
#[inline(always)]
#[must_use]
pub fn is_empty(&self) -> bool {
self.wakers.is_empty()
}
#[inline]
pub fn insert(&mut self, waker_cell: Weak<Cell<Option<Waker>>>) -> bool {
if self.len() == self.capacity() {
self.wakers.retain(WakerEntry::is_alive);
self.shrink_to(std::cmp::max(8, self.wakers.len() * 2));
self.reserve(self.wakers.len());
}
let entry = WakerEntry(waker_cell);
entry.is_alive() && self.wakers.insert(entry)
}
pub fn wake_all(&mut self) {
self.wakers
.drain()
.filter_map(|entry| entry.0.upgrade().and_then(|cell| cell.take()))
.for_each(Waker::wake);
}
#[inline(always)]
pub fn clear(&mut self) {
self.wakers.clear()
}
}
impl FromIterator<Weak<Cell<Option<Waker>>>> for WakerSet {
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = Weak<Cell<Option<Waker>>>>,
{
let wakers = HashSet::from_iter(iter.into_iter().map(WakerEntry));
Self { wakers }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helper::WakeFlag;
use std::rc::Rc;
use std::sync::Arc;
#[test]
fn waking_all_inserted_wakers() {
let mut set = WakerSet::new();
let wake_flag_1 = Arc::new(WakeFlag::new());
let wake_flag_2 = Arc::new(WakeFlag::new());
let waker_1 = Rc::new(Cell::new(Some(Waker::from(wake_flag_1.clone()))));
let waker_2 = Rc::new(Cell::new(Some(Waker::from(wake_flag_2.clone()))));
assert!(set.insert(Rc::downgrade(&waker_1)));
assert!(set.insert(Rc::downgrade(&waker_2)));
set.wake_all();
assert!(wake_flag_1.is_woken());
assert!(wake_flag_2.is_woken());
assert!(set.is_empty());
}
#[test]
fn duplicate_wakers_are_not_inserted() {
let mut set = WakerSet::new();
let waker_1 = Rc::new(Cell::new(Some(Waker::noop().clone())));
let waker_1_clone = Rc::clone(&waker_1);
let waker_2 = Rc::new(Cell::new(Some(Waker::noop().clone())));
assert!(set.insert(Rc::downgrade(&waker_1)));
assert!(set.insert(Rc::downgrade(&waker_2)));
assert_eq!(set.len(), 2);
assert!(!set.insert(Rc::downgrade(&waker_1)));
assert!(!set.insert(Rc::downgrade(&waker_1_clone)));
assert_eq!(set.len(), 2);
}
#[test]
fn clearing_waker_set() {
let mut set = WakerSet::new();
let waker_1 = Rc::new(Cell::new(Some(Waker::noop().clone())));
let waker_2 = Rc::new(Cell::new(Some(Waker::noop().clone())));
assert!(set.insert(Rc::downgrade(&waker_1)));
assert!(set.insert(Rc::downgrade(&waker_2)));
assert_eq!(set.len(), 2);
set.clear();
assert!(set.is_empty());
assert_eq!(Rc::weak_count(&waker_1), 0);
assert_eq!(Rc::weak_count(&waker_2), 0);
}
#[test]
fn dead_wakers_are_removed_before_insertion_if_full() {
let mut set = WakerSet::new();
let waker_1 = Rc::new(Cell::new(Some(Waker::noop().clone())));
let waker_2 = Rc::new(Cell::new(Some(Waker::noop().clone())));
let waker_3 = Rc::new(Cell::new(Some(Waker::noop().clone())));
let waker_4 = Rc::new(Cell::new(Some(Waker::noop().clone())));
let waker_5 = Rc::new(Cell::new(Some(Waker::noop().clone())));
set.reserve(10);
assert!(set.insert(Rc::downgrade(&waker_1)));
assert!(set.insert(Rc::downgrade(&waker_2)));
waker_2.take(); assert!(set.insert(Rc::downgrade(&waker_3)));
while set.len() < set.capacity() - 1 {
let waker = Rc::new(Cell::new(Some(Waker::noop().clone())));
assert!(set.insert(Rc::downgrade(&waker)));
}
assert!(set.insert(Rc::downgrade(&waker_4)));
assert_eq!(set.len(), set.capacity());
assert!(set.insert(Rc::downgrade(&waker_5)));
assert_eq!(Rc::weak_count(&waker_1), 1);
assert_eq!(Rc::weak_count(&waker_2), 0);
assert_eq!(Rc::weak_count(&waker_3), 1);
assert_eq!(Rc::weak_count(&waker_4), 1);
assert_eq!(Rc::weak_count(&waker_5), 1);
assert_eq!(set.len(), 4);
}
}