#![no_std]
use core::fmt;
use slab::Slab;
#[cfg(feature = "serde")]
use serde::ser::SerializeSeq;
#[derive(Debug, Clone)]
pub struct TimerQueue<T> {
timers: Slab<TimerState<T>>,
levels: [Level; LEVELS],
next_tick: u64,
}
impl<T> TimerQueue<T> {
pub const fn new() -> Self {
Self {
timers: Slab::new(),
levels: [Level::new(); LEVELS],
next_tick: 0,
}
}
pub fn with_capacity(n: usize) -> Self {
Self {
timers: Slab::with_capacity(n),
levels: [Level::new(); LEVELS],
next_tick: 0,
}
}
pub fn poll(&mut self, now: u64) -> Option<T> {
debug_assert!(now >= self.next_tick, "time advances monotonically");
loop {
self.advance_towards(now);
if let Some(value) = self.scan_bottom(now) {
return Some(value);
}
if self.next_tick >= now {
return None;
}
}
}
fn scan_bottom(&mut self, now: u64) -> Option<T> {
let index = self.levels[0].first_index()?;
if slot_start(self.next_tick, 0, index) > now {
return None;
}
let timer = self.levels[0].slots[index];
let state = self.timers.remove(timer.0);
debug_assert_eq!(state.prev, None, "head of list has no predecessor");
debug_assert!(state.expiry <= now);
if let Some(next) = state.next {
debug_assert_eq!(
self.timers[next.0].prev,
Some(timer),
"successor links to head"
);
self.timers[next.0].prev = None;
}
self.levels[0].set(index, state.next);
self.next_tick = state.expiry;
self.maybe_shrink();
Some(state.value)
}
fn advance_towards(&mut self, now: u64) {
for level in 0..LEVELS {
if let Some(slot) = self.levels[level].first_index() {
if slot_start(self.next_tick, level, slot) > now {
break;
}
self.advance_to(level, slot);
return;
}
}
self.next_tick = now;
}
fn advance_to(&mut self, level: usize, slot: usize) {
debug_assert!(
self.levels[..level].iter().all(|level| level.is_empty()),
"lower levels are empty"
);
debug_assert!(
self.levels[level].first_index().map_or(true, |x| x >= slot),
"lower slots in this level are empty"
);
self.next_tick = slot_start(self.next_tick, level, slot);
if level == 0 {
return;
}
while let Some(timer) = self.levels[level].take(slot) {
let next = self.timers[timer.0].next;
self.levels[level].set(slot, next);
if let Some(next) = next {
self.timers[next.0].prev = None;
}
self.list_unlink(timer);
self.schedule(timer);
}
}
fn schedule(&mut self, timer: Timer) {
debug_assert_eq!(
self.timers[timer.0].next, None,
"timer isn't already scheduled"
);
debug_assert_eq!(
self.timers[timer.0].prev, None,
"timer isn't already scheduled"
);
let (level, slot) = timer_index(self.next_tick, self.timers[timer.0].expiry);
let head = self.levels[level].get(slot);
self.timers[timer.0].next = head;
if let Some(head) = head {
self.timers[head.0].prev = Some(timer);
}
self.levels[level].set(slot, Some(timer));
}
pub fn next_timeout(&self) -> Option<u64> {
for level in 0..LEVELS {
let start = ((self.next_tick >> (level * LOG_2_SLOTS)) & (SLOTS - 1) as u64) as usize;
for slot in start..SLOTS {
if self.levels[level].get(slot).is_some() {
return Some(slot_start(self.next_tick, level, slot));
}
}
}
None
}
pub fn insert(&mut self, timeout: u64, value: T) -> Timer {
let timer = Timer(self.timers.insert(TimerState {
expiry: timeout.max(self.next_tick),
prev: None,
next: None,
value,
}));
self.schedule(timer);
timer
}
pub fn reset(&mut self, timer: Timer, timeout: u64) {
self.unlink(timer);
self.timers[timer.0].expiry = timeout.max(self.next_tick);
self.schedule(timer);
}
pub fn remove(&mut self, timer: Timer) -> T {
self.unlink(timer);
let state = self.timers.remove(timer.0);
self.maybe_shrink();
state.value
}
fn maybe_shrink(&mut self) {
if self.timers.capacity() / 16 > self.timers.len() {
self.timers.shrink_to_fit();
}
}
pub fn iter(&self) -> impl ExactSizeIterator<Item = (u64, &T)> {
self.timers.iter().map(|(_, x)| (x.expiry, &x.value))
}
pub fn iter_mut(&mut self) -> impl ExactSizeIterator<Item = (u64, &mut T)> {
self.timers
.iter_mut()
.map(|(_, x)| (x.expiry, &mut x.value))
}
pub fn get(&self, timer: Timer) -> &T {
&self.timers[timer.0].value
}
pub fn get_mut(&mut self, timer: Timer) -> &mut T {
&mut self.timers[timer.0].value
}
pub fn len(&self) -> usize {
self.timers.len()
}
pub fn is_empty(&self) -> bool {
self.timers.is_empty()
}
fn unlink(&mut self, timer: Timer) {
let (level, slot) = timer_index(self.next_tick, self.timers[timer.0].expiry);
let slot_head = self.levels[level].get(slot).unwrap();
if slot_head == timer {
self.levels[level].set(slot, self.timers[slot_head.0].next);
debug_assert_eq!(
self.timers[timer.0].prev, None,
"head of list has no predecessor"
);
}
self.list_unlink(timer);
}
fn list_unlink(&mut self, timer: Timer) {
let prev = self.timers[timer.0].prev.take();
let next = self.timers[timer.0].next.take();
if let Some(prev) = prev {
self.timers[prev.0].next = next;
}
if let Some(next) = next {
self.timers[next.0].prev = prev;
}
}
}
fn slot_start(base: u64, level: usize, slot: usize) -> u64 {
let shift = (level * LOG_2_SLOTS) as u64;
(base & ((!0 << shift) << LOG_2_SLOTS as u64)) | ((slot as u64) << shift)
}
fn timer_index(base: u64, expiry: u64) -> (usize, usize) {
let differing_bits = base ^ expiry;
let level = (63 - (differing_bits | 1).leading_zeros()) as usize / LOG_2_SLOTS;
debug_assert!(level < LEVELS, "every possible expiry is in range");
let slot_base = (base >> (level * LOG_2_SLOTS)) & (!0 << LOG_2_SLOTS);
let slot = (expiry >> (level * LOG_2_SLOTS)) - slot_base;
debug_assert!(slot < SLOTS as u64);
(level, slot as usize)
}
impl<T> Default for TimerQueue<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
struct TimerState<T> {
expiry: u64,
value: T,
prev: Option<Timer>,
next: Option<Timer>,
}
#[derive(Copy, Clone)]
struct Level {
slots: [Timer; SLOTS],
occupied: u64,
}
impl Level {
const fn new() -> Self {
Self {
slots: [Timer(usize::MAX); SLOTS],
occupied: 0,
}
}
fn first_index(&self) -> Option<usize> {
let x = self.occupied.trailing_zeros() as usize;
if x == self.slots.len() {
return None;
}
Some(x)
}
fn get(&self, slot: usize) -> Option<Timer> {
if self.occupied & (1 << slot) == 0 {
return None;
}
Some(self.slots[slot])
}
fn take(&mut self, slot: usize) -> Option<Timer> {
let x = self.get(slot)?;
self.set(slot, None);
Some(x)
}
fn set(&mut self, slot: usize, timer: Option<Timer>) {
match timer {
None => {
self.slots[slot] = Timer(usize::MAX);
self.occupied &= !(1 << slot);
}
Some(x) => {
self.slots[slot] = x;
self.occupied |= 1 << slot;
}
}
}
fn is_empty(&self) -> bool {
self.occupied == 0
}
}
impl fmt::Debug for Level {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut m = f.debug_map();
let numbered_nonempty_slots = self
.slots
.iter()
.enumerate()
.filter(|(i, _)| self.occupied & (1 << i) != 0);
for (i, Timer(t)) in numbered_nonempty_slots {
m.entry(&i, &t);
}
m.finish()
}
}
const LOG_2_SLOTS: usize = 6;
const LEVELS: usize = 1 + 64 / LOG_2_SLOTS;
const SLOTS: usize = 1 << LOG_2_SLOTS;
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Timer(usize);
#[cfg(feature = "serde")]
impl<T: serde::Serialize> serde::Serialize for TimerQueue<T> {
fn serialize<S>(
&self,
serializer: S,
) -> Result<<S as serde::Serializer>::Ok, <S as serde::Serializer>::Error>
where
S: serde::Serializer,
{
let mut seq = serializer.serialize_seq(Some(self.len()))?;
for v in self.iter() {
let t: (u64, &T) = v;
seq.serialize_element(&t)?;
}
seq.end()
}
}
#[cfg(feature = "serde")]
impl<'de, T> serde::Deserialize<'de> for TimerQueue<T>
where
T: serde::Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use core::fmt::Formatter;
use core::marker::PhantomData;
struct TimerQueueVisitor<T>(PhantomData<T>);
impl<'de, T> serde::de::Visitor<'de> for TimerQueueVisitor<T>
where
T: serde::Deserialize<'de>,
{
type Value = TimerQueue<T>;
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
write!(
formatter,
"a sequence of (u64, {}) tuples",
core::any::type_name::<T>()
)
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut timer_queue = if let Some(size) = seq.size_hint() {
TimerQueue::<T>::with_capacity(size)
} else {
TimerQueue::<T>::new()
};
while let Some((time, value)) = seq.next_element::<(u64, T)>()? {
timer_queue.insert(time, value);
}
Ok(timer_queue)
}
}
deserializer.deserialize_seq(TimerQueueVisitor(PhantomData))
}
}
#[cfg(test)]
mod tests {
extern crate alloc;
extern crate std;
use std::{collections::HashMap, vec::Vec};
use super::*;
use proptest::prelude::*;
#[test]
fn max_timeout() {
let mut queue = TimerQueue::new();
queue.insert(u64::MAX, ());
assert!(queue.poll(u64::MAX - 1).is_none());
assert!(queue.poll(u64::MAX).is_some());
}
#[test]
fn slot_starts() {
for i in 0..SLOTS {
assert_eq!(slot_start(0, 0, i), i as u64);
assert_eq!(slot_start(SLOTS as u64, 0, i), SLOTS as u64 + i as u64);
assert_eq!(slot_start(SLOTS as u64 + 1, 0, i), SLOTS as u64 + i as u64);
for j in 1..LEVELS {
assert_eq!(
slot_start(0, j, i),
(SLOTS as u64).pow(j as u32).wrapping_mul(i as u64)
);
}
}
}
#[test]
fn indexes() {
assert_eq!(timer_index(0, 0), (0, 0));
assert_eq!(timer_index(0, SLOTS as u64 - 1), (0, SLOTS - 1));
assert_eq!(
timer_index(SLOTS as u64 - 1, SLOTS as u64 - 1),
(0, SLOTS - 1)
);
assert_eq!(timer_index(0, SLOTS as u64), (1, 1));
for i in 0..LEVELS {
assert_eq!(timer_index(0, (SLOTS as u64).pow(i as u32)), (i, 1));
if i < LEVELS - 1 {
assert_eq!(
timer_index(0, (SLOTS as u64).pow(i as u32 + 1) - 1),
(i, SLOTS - 1)
);
assert_eq!(
timer_index(SLOTS as u64 - 1, (SLOTS as u64).pow(i as u32 + 1) - 1),
(i, SLOTS - 1)
);
}
}
}
#[test]
fn next_timeout() {
let mut queue = TimerQueue::new();
assert_eq!(queue.next_timeout(), None);
let k = queue.insert(0, ());
assert_eq!(queue.next_timeout(), Some(0));
queue.remove(k);
assert_eq!(queue.next_timeout(), None);
queue.insert(1234, ());
assert!(queue.next_timeout().unwrap() > 12);
queue.insert(12, ());
assert_eq!(queue.next_timeout(), Some(12));
}
#[test]
fn poll_boundary() {
let mut queue = TimerQueue::new();
queue.insert(SLOTS as u64 - 1, 'a');
queue.insert(SLOTS as u64, 'b');
assert_eq!(queue.poll(SLOTS as u64 - 2), None);
assert_eq!(queue.poll(SLOTS as u64 - 1), Some('a'));
assert_eq!(queue.poll(SLOTS as u64 - 1), None);
assert_eq!(queue.poll(SLOTS as u64), Some('b'));
}
#[test]
fn reset_list_middle() {
let mut queue = TimerQueue::new();
let slot = SLOTS as u64 / 2;
let a = queue.insert(slot, ());
let b = queue.insert(slot, ());
let c = queue.insert(slot, ());
queue.reset(b, slot + 1);
assert_eq!(queue.levels[0].get(slot as usize + 1), Some(b));
assert_eq!(queue.timers[b.0].prev, None);
assert_eq!(queue.timers[b.0].next, None);
assert_eq!(queue.levels[0].get(slot as usize), Some(c));
assert_eq!(queue.timers[c.0].prev, None);
assert_eq!(queue.timers[c.0].next, Some(a));
assert_eq!(queue.timers[a.0].prev, Some(c));
assert_eq!(queue.timers[a.0].next, None);
}
proptest! {
#[test]
fn poll(ts in times()) {
let mut queue = TimerQueue::new();
let mut time_values = HashMap::<u64, Vec<usize>>::new();
for (i, t) in ts.into_iter().enumerate() {
queue.insert(t, i);
time_values.entry(t).or_default().push(i);
}
let mut time_values = time_values.into_iter().collect::<Vec<(u64, Vec<usize>)>>();
time_values.sort_unstable_by_key(|&(t, _)| t);
for &(t, ref is) in &time_values {
assert!(queue.next_timeout().unwrap() <= t);
if t > 0 {
assert_eq!(queue.poll(t-1), None);
}
let mut values = Vec::new();
while let Some(i) = queue.poll(t) {
values.push(i);
}
assert_eq!(values.len(), is.len());
for i in is {
assert!(values.contains(i));
}
}
}
#[test]
fn reset(ts_a in times(), ts_b in times()) {
let mut queue = TimerQueue::new();
let timers = ts_a.map(|t| queue.insert(t, ()));
for (timer, t) in timers.into_iter().zip(ts_b) {
queue.reset(timer, t);
}
let mut n = 0;
while let Some(()) = queue.poll(u64::MAX) {
n += 1;
}
assert_eq!(n, timers.len());
}
#[test]
fn index_start_consistency(a in time(), b in time()) {
let base = a.min(b);
let t = a.max(b);
let (level, slot) = timer_index(base, t);
let start = slot_start(base, level, slot);
assert!(start <= t);
if let Some(end) = start.checked_add((SLOTS as u64).pow(level as u32)) {
assert!(end > t);
} else {
assert!(start >= slot_start(0, LEVELS - 1, 15));
if level == LEVELS - 1 {
assert_eq!(slot, 15);
} else {
assert_eq!(slot, SLOTS - 1);
}
}
}
}
#[test]
#[cfg(feature = "serde")]
fn serialization() {
const VALUES: [(u64, usize); 17] = [
(23, 5132),
(87, 6),
(45, 7839),
(122, 345),
(67, 12333),
(34, 8),
(90, 234),
(151, 82290),
(56, 32),
(78, 567),
(19, 345),
(22, 78),
(33, 890),
(44, 123),
(51235, 6),
(66, 89),
(727, 890),
];
let mut queue = TimerQueue::<usize>::new();
for (t, v) in VALUES {
queue.insert(t, v);
}
let serialized: Vec<u8> = bincode::serialize(&queue).expect("Serialization failed");
let mut deserialized: TimerQueue<usize> =
bincode::deserialize(&serialized).expect("Deserialization failed");
loop {
let r1 = queue.poll(u64::MAX);
let r2 = deserialized.poll(u64::MAX);
assert!(r1 == r2);
if r1.is_none() {
break;
}
}
}
fn time() -> impl Strategy<Value = u64> {
((0..LEVELS as u32), (0..SLOTS as u64)).prop_perturb(|(level, mut slot), mut rng| {
if level == LEVELS as u32 - 1 {
slot %= 16;
}
let slot_size = (SLOTS as u64).pow(level);
let slot_start = slot * slot_size;
let slot_end = (slot + 1).saturating_mul(slot_size);
rng.gen_range(slot_start..slot_end)
})
}
#[rustfmt::skip]
fn times() -> impl Strategy<Value = [u64; 16]> {
[time(), time(), time(), time(), time(), time(), time(), time(),
time(), time(), time(), time(), time(), time(), time(), time()]
}
}