use std::mem::MaybeUninit;
use super::Vec;
use slab::Slab;
#[non_exhaustive]
#[derive(Debug)]
pub enum StartTimerError<T> {
Expired(T),
}
const EMPTY_INDEX: u32 = u32::MAX;
#[derive(Debug)]
#[allow(missing_copy_implementations)]
pub struct TimerHandle(u32);
#[derive(Debug, Copy, Clone)]
enum TimerEntryData<T> {
Head,
Timer {
expire_time: u64,
context: T,
},
}
#[derive(Debug, Copy, Clone)]
struct TimerEntry<T> {
next: u32,
previous: u32,
data: TimerEntryData<T>,
}
#[derive(Debug)]
struct EntryList {
head: u32,
}
impl Default for EntryList {
fn default() -> Self {
Self { head: EMPTY_INDEX }
}
}
#[derive(Debug)]
struct Level<const NUM_SLOTS: usize> {
slots: [EntryList; NUM_SLOTS],
position: usize,
}
#[derive(Debug)]
pub struct TimerWheel<T, const NUM_LEVELS: usize, const NUM_SLOTS: usize> {
timers: Slab<TimerEntry<T>>,
levels: [Level<NUM_SLOTS>; NUM_LEVELS],
current_time: u64,
}
impl<T, const NUM_LEVELS: usize, const NUM_SLOTS: usize> Default
for TimerWheel<T, NUM_LEVELS, NUM_SLOTS>
{
fn default() -> Self {
Self::new()
}
}
impl<T, const NUM_LEVELS: usize, const NUM_SLOTS: usize> TimerWheel<T, NUM_LEVELS, NUM_SLOTS> {
pub fn new() -> Self {
let mut uninit_self = MaybeUninit::uninit();
Self::init(&mut uninit_self);
unsafe { uninit_self.assume_init() }
}
pub fn init(uninit_self: &mut MaybeUninit<Self>) -> &mut Self {
let init_self = unsafe {
let ptr = uninit_self.as_mut_ptr();
std::ptr::addr_of_mut!((*ptr).timers)
.write(Slab::with_capacity(NUM_LEVELS * NUM_SLOTS));
for level in 0..NUM_LEVELS {
let level_ptr =
(std::ptr::addr_of_mut!((*ptr).levels) as *mut Level<NUM_SLOTS>).add(level);
for slot in 0..NUM_SLOTS {
let slot_ptr =
(std::ptr::addr_of_mut!((*level_ptr).slots) as *mut EntryList).add(slot);
slot_ptr.write(Default::default());
}
std::ptr::addr_of_mut!((*level_ptr).position).write(0);
}
std::ptr::addr_of_mut!((*ptr).current_time).write(0);
uninit_self.assume_init_mut()
};
for level in 0..NUM_LEVELS {
for slot in 0..NUM_SLOTS {
init_self.levels[level].slots[slot].head = init_self.timers.insert(TimerEntry {
next: EMPTY_INDEX,
previous: EMPTY_INDEX,
data: TimerEntryData::Head,
}) as u32;
}
}
init_self
}
pub fn start_timer(&mut self, interval: u64, context: T) -> TimerHandle {
let expire_time = self.current_time.saturating_add(interval);
let handle = TimerHandle(self.timers.insert(TimerEntry {
next: EMPTY_INDEX,
previous: EMPTY_INDEX,
data: TimerEntryData::Timer {
expire_time,
context,
},
}) as u32);
self.start_timer_unchecked(handle.0);
handle
}
pub fn start_timer_absolute(
&mut self,
expire_time: u64,
context: T,
) -> Result<TimerHandle, StartTimerError<T>> {
if expire_time <= self.current_time {
return Err(StartTimerError::Expired(context));
}
let handle = TimerHandle(self.timers.insert(TimerEntry {
next: EMPTY_INDEX,
previous: EMPTY_INDEX,
data: TimerEntryData::Timer {
expire_time,
context,
},
}) as u32);
self.start_timer_unchecked(handle.0);
Ok(handle)
}
fn start_timer_unchecked(&mut self, index: u32) {
let timer = &self.timers[index as usize];
let delta = match &timer.data {
TimerEntryData::Timer { expire_time, .. } => *expire_time - self.current_time,
TimerEntryData::Head => unreachable!(),
};
let mut level = NUM_LEVELS - 1;
let mut slot = (self.levels[level].position + NUM_SLOTS - 1) % NUM_SLOTS;
for l in 0..NUM_LEVELS {
if delta < (NUM_SLOTS as u64).pow((l + 1) as u32) {
level = l;
slot = (self.levels[l].position
+ (delta / (NUM_SLOTS as u64).pow(l as u32)) as usize)
.saturating_sub(1)
% NUM_SLOTS;
break;
}
}
let head = &mut self.timers[self.levels[level].slots[slot].head as usize];
let old_head_next = std::mem::replace(&mut head.next, index);
self.timers[index as usize].next = old_head_next;
self.timers[index as usize].previous = self.levels[level].slots[slot].head;
if old_head_next != EMPTY_INDEX {
self.timers[old_head_next as usize].previous = index;
}
}
pub fn stop_timer(&mut self, handle: TimerHandle) -> Option<T> {
let timer = self.timers.get(handle.0 as usize)?;
if matches!(timer.data, TimerEntryData::Head) {
return None;
}
let timer_next = timer.next;
let timer_previous = timer.previous;
self.timers[timer_previous as usize].next = timer_next;
if timer_next != EMPTY_INDEX {
self.timers[timer_next as usize].previous = timer_previous;
}
let timer = self.timers.remove(handle.0 as usize);
match timer.data {
TimerEntryData::Timer { context, .. } => Some(context),
TimerEntryData::Head => unreachable!(),
}
}
fn process_level(&mut self, level: usize) -> Vec<T> {
let mut expired_contexts = Vec::new();
if level >= NUM_LEVELS {
return expired_contexts;
}
let slot = self.levels[level].position;
let head = &mut self.timers[self.levels[level].slots[slot].head as usize];
let mut timer_index = std::mem::replace(&mut head.next, EMPTY_INDEX);
while timer_index != EMPTY_INDEX {
let timer = &self.timers[timer_index as usize];
let next_timer_index = timer.next;
match &timer.data {
TimerEntryData::Timer { expire_time, .. } => {
if *expire_time <= self.current_time {
let timer = self.timers.remove(timer_index as usize);
match timer.data {
TimerEntryData::Timer { context, .. } => {
expired_contexts.push(context);
}
TimerEntryData::Head => unreachable!(),
}
} else {
self.start_timer_unchecked(timer_index);
}
}
TimerEntryData::Head => unreachable!(),
}
timer_index = next_timer_index;
}
self.levels[level].position = (self.levels[level].position + 1) % NUM_SLOTS;
if self.levels[level].position == 0 {
let cascaded_expired_contexts = self.process_level(level + 1);
expired_contexts.extend(cascaded_expired_contexts);
}
expired_contexts
}
fn tick(&mut self) -> Vec<T> {
self.current_time += 1;
self.process_level(0)
}
pub fn expire_timers(&mut self, ticks: u64) -> Vec<T> {
let mut expired_contexts = Vec::new();
for _ in 0..ticks {
let tick_expired_contexts = self.tick();
expired_contexts.extend(tick_expired_contexts);
}
expired_contexts
}
pub fn next_expiration(&self) -> Option<u64> {
for level in 0..NUM_LEVELS {
let start_slot = self.levels[level].position;
for i in 0..NUM_SLOTS {
let slot = (start_slot + i) % NUM_SLOTS;
let head = &self.timers[self.levels[level].slots[slot].head as usize];
if head.next != EMPTY_INDEX {
let mut timer_index = head.next;
let mut min_expire_time: Option<u64> = None;
while timer_index != EMPTY_INDEX {
let timer = &self.timers[timer_index as usize];
timer_index = timer.next;
match &timer.data {
TimerEntryData::Timer { expire_time, .. } => {
if let Some(prev_min_expire_time) = min_expire_time {
min_expire_time = Some(prev_min_expire_time.min(*expire_time));
} else {
min_expire_time = Some(*expire_time);
}
}
TimerEntryData::Head => unreachable!(),
}
}
return min_expire_time;
}
}
}
None
}
}
#[cfg(test)]
mod tests {
use crate::vppinfra::clib_mem_init;
use super::*;
#[test]
fn test_immediate_timer() {
clib_mem_init();
let mut wheel: TimerWheel<u8, 4, 256> = Default::default();
let e = wheel.start_timer_absolute(0, 1).expect_err("add timer");
assert!(
matches!(e, StartTimerError::Expired(1)),
"{:?} != AddTimerError::Expired",
e
);
}
#[test]
fn test_future_timer() {
clib_mem_init();
let mut wheel: TimerWheel<u8, 4, 256> = Default::default();
wheel.start_timer_absolute(5, 1).expect("add timer");
let contexts = wheel.expire_timers(4);
assert_eq!(contexts, []);
let contexts = wheel.expire_timers(1);
assert_eq!(contexts, [1]);
}
#[test]
fn test_multiple_timers() {
clib_mem_init();
let mut wheel: TimerWheel<u8, 4, 256> = Default::default();
wheel.start_timer_absolute(1, 1).expect("add timer");
wheel.start_timer(2, 2);
wheel.start_timer_absolute(3, 3).expect("add timer");
let contexts = wheel.expire_timers(1);
assert_eq!(contexts, [1]);
let contexts = wheel.expire_timers(2);
assert_eq!(contexts, [2, 3]);
}
#[test]
fn test_level_1() {
clib_mem_init();
let mut wheel: TimerWheel<u8, 4, 256> = Default::default();
wheel.start_timer(257, 1);
let contexts = wheel.expire_timers(256);
assert_eq!(contexts, []);
let contexts = wheel.expire_timers(1);
assert_eq!(contexts, [1]);
}
#[test]
fn test_next_expiration() {
clib_mem_init();
let mut wheel: TimerWheel<u8, 4, 256> = Default::default();
assert_eq!(wheel.next_expiration(), None);
wheel.start_timer(5, 1);
assert_eq!(wheel.next_expiration(), Some(5));
wheel.start_timer_absolute(3, 2).expect("add timer");
assert_eq!(wheel.next_expiration(), Some(3));
wheel.start_timer(10, 3);
assert_eq!(wheel.next_expiration(), Some(3));
let contexts = wheel.expire_timers(3); assert_eq!(contexts, [2]);
assert_eq!(wheel.next_expiration(), Some(5));
}
#[test]
fn test_stop_timers() {
clib_mem_init();
let mut wheel: TimerWheel<u8, 4, 256> = Default::default();
let timer1 = wheel.start_timer(5, 1);
let timer2 = wheel.start_timer_absolute(3, 2).expect("add timer");
let timer3 = wheel.start_timer_absolute(5, 3).expect("add timer");
assert_eq!(wheel.stop_timer(timer1), Some(1));
let contexts = wheel.expire_timers(3);
assert_eq!(contexts, [2]);
assert_eq!(wheel.stop_timer(timer2), None);
assert_eq!(wheel.stop_timer(timer3), Some(3));
let contexts = wheel.expire_timers(2);
assert_eq!(contexts, []);
}
#[test]
fn test_timer_far_in_future() {
clib_mem_init();
let mut wheel: TimerWheel<u8, 2, 4> = Default::default();
let timer1 = wheel.start_timer_absolute(17, 1).expect("add timer");
let contexts = wheel.expire_timers(16);
assert_eq!(contexts, []);
assert_eq!(wheel.stop_timer(timer1), Some(1));
wheel.start_timer(18, 2);
wheel.start_timer(17, 1);
assert_eq!(wheel.next_expiration(), Some(16 + 17));
let contexts = wheel.expire_timers(17);
assert_eq!(contexts, [1]);
}
}