use crate::list;
use slab::Slab;
use std::cmp;
const WHEEL_BITS: usize = 6;
const WHEEL_NUM: usize = 4;
const WHEEL_LEN: usize = 1 << WHEEL_BITS;
const WHEEL_MAX: usize = WHEEL_LEN - 1;
const WHEEL_MASK: u64 = (WHEEL_LEN as u64) - 1;
const TIMEOUT_MAX: u64 = (1 << (WHEEL_BITS * WHEEL_NUM)) - 1;
fn clz64(x: u64) -> u8 {
if x == 0 {
return 64;
}
let mut t = 1 << 63;
let mut r = 0;
while x & t == 0 {
t >>= 1;
r += 1;
}
r
}
fn ctz64(x: u64) -> u8 {
if x == 0 {
return 64;
}
let mut t = 1;
let mut r = 0;
while x & t == 0 {
t <<= 1;
r += 1;
}
r
}
fn fls64(x: u64) -> u8 {
64 - clz64(x)
}
fn rotl64(x: u64, count: u8) -> u64 {
if count & 0b111111 == 0 {
return x;
}
(x << count) | (x >> (64 - count))
}
fn rotr64(x: u64, count: u8) -> u64 {
if count & 0b111111 == 0 {
return x;
}
(x >> count) | (x << (64 - count))
}
fn need_resched(curtime: u64, newtime: u64) -> [u64; WHEEL_NUM] {
let mut result = [0; WHEEL_NUM];
if newtime <= curtime {
return result;
}
let mut elapsed = newtime - curtime;
for wheel in 0..WHEEL_NUM {
let trunc_bits = (wheel * WHEEL_BITS) as u64;
let pending;
if (elapsed >> trunc_bits) > (WHEEL_MAX as u64) {
pending = !0;
} else {
let old_slot = (curtime >> trunc_bits) & WHEEL_MASK;
let new_slot = (newtime >> trunc_bits) & WHEEL_MASK;
let d = if new_slot >= old_slot {
new_slot - old_slot
} else {
(WHEEL_LEN as u64) - old_slot + new_slot
};
pending = if wheel > 0 {
rotl64((1 << d) - 1, old_slot as u8)
} else {
rotl64((1 << d) - 1, (old_slot + 1) as u8)
};
}
result[wheel] = pending;
let finished_bit = if wheel > 0 {
1 << (WHEEL_LEN - 1)
} else {
1
};
if pending & finished_bit == 0 {
break;
}
elapsed = cmp::max(elapsed, (WHEEL_LEN << (wheel * WHEEL_BITS)) as u64);
}
result
}
enum InList {
Wheel(usize, usize),
Expired,
}
struct Timer {
expires: u64,
list: Option<InList>,
user_data: usize,
}
pub struct TimerWheel {
nodes: Slab<list::Node<Timer>>,
wheel: [[list::List; WHEEL_LEN]; WHEEL_NUM],
expired: list::List,
pending: [u64; WHEEL_NUM],
curtime: u64,
}
impl TimerWheel {
pub fn new(capacity: usize) -> Self {
Self {
nodes: Slab::with_capacity(capacity),
wheel: [[list::List::default(); WHEEL_LEN]; WHEEL_NUM],
expired: list::List::default(),
pending: [0; WHEEL_NUM],
curtime: 0,
}
}
pub fn add(&mut self, expires: u64, user_data: usize) -> Result<usize, ()> {
if self.nodes.len() == self.nodes.capacity() {
return Err(());
}
let t = Timer {
expires,
list: None,
user_data,
};
let key = self.nodes.insert(list::Node::new(t));
self.sched(key);
Ok(key)
}
pub fn remove(&mut self, key: usize) {
let n = match self.nodes.get(key) {
Some(n) => n,
None => return,
};
match n.value.list {
Some(InList::Wheel(wheel, slot)) => {
let l = &mut self.wheel[wheel][slot];
l.remove(&mut self.nodes, key);
if l.is_empty() {
self.pending[wheel] &= !(1 << slot);
}
}
Some(InList::Expired) => {
self.expired.remove(&mut self.nodes, key);
}
None => {}
}
self.nodes.remove(key);
}
pub fn timeout(&mut self) -> Option<u64> {
if !self.expired.is_empty() {
return Some(0);
}
let mut timeout = None;
let mut relmask = 0;
for wheel in 0..WHEEL_NUM {
let trunc_bits = (wheel * WHEEL_BITS) as u64;
if self.pending[wheel] != 0 {
let slot = ((self.curtime >> trunc_bits) & WHEEL_MASK) as usize;
let pending = rotr64(self.pending[wheel], slot as u8);
let offset = if wheel > 0 { 1 } else { 0 };
let t = ((ctz64(pending) as u64) + offset) << trunc_bits;
let t = t - (relmask & self.curtime);
timeout = Some(match timeout {
Some(best) => cmp::min(best, t),
None => t,
});
}
relmask <<= WHEEL_BITS;
relmask |= WHEEL_MASK;
}
timeout
}
pub fn update(&mut self, curtime: u64) {
if curtime <= self.curtime {
return;
}
let need = need_resched(self.curtime, curtime);
let mut l = list::List::default();
for wheel in 0..WHEEL_NUM {
let pending = need[wheel];
while pending & self.pending[wheel] != 0 {
let slot = ctz64(pending & self.pending[wheel]) as usize;
l.concat(&mut self.nodes, &mut self.wheel[wheel][slot]);
self.pending[wheel] &= !(1 << slot);
}
}
self.curtime = curtime;
while let Some(key) = l.head {
l.remove(&mut self.nodes, key);
let n = &mut self.nodes[key];
n.value.list = None;
self.sched(key);
}
}
pub fn take_expired(&mut self) -> Option<(usize, usize)> {
match self.expired.pop_front(&mut self.nodes) {
Some(key) => {
let n = &self.nodes[key];
let user_data = n.value.user_data;
self.nodes.remove(key);
Some((key, user_data))
}
None => None,
}
}
fn sched(&mut self, key: usize) {
let n = &self.nodes[key];
let expires = n.value.expires;
if expires > self.curtime {
let t = cmp::min(expires - self.curtime, TIMEOUT_MAX);
assert!(t > 0);
let wheel = ((fls64(t) - 1) as usize) / WHEEL_BITS;
assert!(wheel < WHEEL_NUM);
let trunc_bits = (wheel * WHEEL_BITS) as u64;
let offset = if wheel > 0 { 1 } else { 0 };
let slot = (((expires >> trunc_bits) - offset) & WHEEL_MASK) as usize;
self.wheel[wheel][slot].push_back(&mut self.nodes, key);
self.pending[wheel] |= 1 << slot;
let n = &mut self.nodes[key];
n.value.list = Some(InList::Wheel(wheel, slot));
} else {
self.expired.push_back(&mut self.nodes, key);
let n = &mut self.nodes[key];
n.value.list = Some(InList::Expired);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ts(s: &str) -> u64 {
let mut result = 0;
for (i, part) in s.rsplit(":").enumerate() {
let x: u64 = part.parse().unwrap();
assert!(x <= (WHEEL_MAX as u64));
result |= x << (i * WHEEL_BITS);
}
result
}
fn r2b(s: &str) -> u64 {
let mut it = s.split("-");
let start = it.next().unwrap();
let end = it.next().unwrap();
assert_eq!(it.next(), None);
let mut pos: u64 = start.parse().unwrap();
let end: u64 = end.parse().unwrap();
let mut result = 0;
loop {
result |= 1 << pos;
if pos == end {
break;
}
pos = (pos + 1) & WHEEL_MASK;
}
result
}
fn rev(a: [u64; WHEEL_NUM]) -> [u64; WHEEL_NUM] {
[a[3], a[2], a[1], a[0]]
}
#[test]
fn test_clz() {
assert_eq!(clz64(0), 64);
assert_eq!(clz64(0b1), 63);
assert_eq!(clz64(0b10), 62);
assert_eq!(clz64(0x4000000000000000), 1);
assert_eq!(clz64(0x8000000000000000), 0);
}
#[test]
fn test_ctz() {
assert_eq!(ctz64(0), 64);
assert_eq!(ctz64(0b1), 0);
assert_eq!(ctz64(0b10), 1);
assert_eq!(ctz64(0x4000000000000000), 62);
assert_eq!(ctz64(0x8000000000000000), 63);
}
#[test]
fn test_fls() {
assert_eq!(fls64(0), 0);
assert_eq!(fls64(0b1), 1);
assert_eq!(fls64(0b10), 2);
assert_eq!(fls64(0x4000000000000000), 63);
assert_eq!(fls64(0x8000000000000000), 64);
}
#[test]
fn test_rotl() {
assert_eq!(rotl64(0x0ffff00000000000, 0), 0x0ffff00000000000);
assert_eq!(rotl64(0x0ffff00000000000, 64), 0x0ffff00000000000);
assert_eq!(rotl64(0x0ffff00000000000, 4), 0xffff000000000000);
assert_eq!(rotl64(0x0ffff00000000000, 8), 0xfff000000000000f);
assert_eq!(rotl64(0x0ffff00000000000, 16), 0xf000000000000fff);
}
#[test]
fn test_rotr() {
assert_eq!(rotr64(0x00000000000ffff0, 0), 0x00000000000ffff0);
assert_eq!(rotr64(0x00000000000ffff0, 64), 0x00000000000ffff0);
assert_eq!(rotr64(0x00000000000ffff0, 4), 0x000000000000ffff);
assert_eq!(rotr64(0x00000000000ffff0, 8), 0xf000000000000fff);
assert_eq!(rotr64(0x00000000000ffff0, 16), 0xfff000000000000f);
}
#[test]
fn test_sched() {
let mut w = TimerWheel::new(10);
w.update(7);
let t1 = w.add(0b0_000000, 1).unwrap();
let t2 = w.add(0b0_001000, 1).unwrap();
let t3 = w.add(0b0_111111, 1).unwrap();
let t4 = w.add(0b1_000000, 1).unwrap();
let t5 = w.add(0b1_001000, 1).unwrap();
let t6 = w.add(0b1_000000_000000_000000_000000, 1).unwrap();
assert_eq!(w.expired.head, Some(t1));
assert_eq!(w.wheel[0][8].head, Some(t2));
assert_eq!(w.wheel[0][63].head, Some(t3));
assert_eq!(w.wheel[0][0].head, Some(t4));
assert_eq!(w.wheel[1][0].head, Some(t5));
assert_eq!(w.wheel[3][63].head, Some(t6));
}
#[test]
fn test_need_resched() {
assert_eq!(need_resched(ts("00:00"), ts("00:00")), [0, 0, 0, 0]);
assert_eq!(
need_resched(ts("00:00"), ts("00:01")),
rev([0, 0, 0, r2b("01-01")])
);
assert_eq!(
need_resched(ts("00:01"), ts("00:02")),
rev([0, 0, 0, r2b("02-02")])
);
assert_eq!(
need_resched(ts("00:02"), ts("00:63")),
rev([0, 0, 0, r2b("03-63")])
);
assert_eq!(
need_resched(ts("00:63"), ts("01:00")),
rev([0, 0, r2b("00-00"), r2b("00-00")])
);
assert_eq!(
need_resched(ts("01:00"), ts("01:02")),
rev([0, 0, 0, r2b("01-02")])
);
assert_eq!(
need_resched(ts("01:02"), ts("05:01")),
rev([0, 0, r2b("01-04"), r2b("00-63")])
);
assert_eq!(
need_resched(ts("05:01"), ts("05:02")),
rev([0, 0, 0, r2b("02-02")])
);
assert_eq!(
need_resched(ts("05:02"), ts("06:01")),
rev([0, 0, r2b("05-05"), r2b("03-01")])
);
assert_eq!(
need_resched(ts("00:63:63"), ts("01:00:00")),
rev([0, r2b("00-00"), r2b("63-63"), r2b("00-00")])
);
}
#[test]
fn test_rotate() {
let count = (64 * 64) + 1;
let mut w = TimerWheel::new(count);
for i in 0..count {
w.add(i as u64, i).unwrap();
}
for i in 0..count {
let (_, v) = w.take_expired().unwrap();
assert_eq!(v, i);
assert_eq!(w.take_expired(), None);
w.update((i + 1) as u64);
}
assert_eq!(w.take_expired(), None);
}
#[test]
fn test_wheel() {
let mut w = TimerWheel::new(10);
assert_eq!(w.timeout(), None);
assert_eq!(w.take_expired(), None);
let t1 = w.add(4, 1).unwrap();
assert_eq!(w.timeout(), Some(4));
w.remove(t1);
assert_eq!(w.timeout(), None);
w.update(5);
assert_eq!(w.take_expired(), None);
let t2 = w.add(8, 2).unwrap();
assert_eq!(w.timeout(), Some(3));
w.update(7);
assert_eq!(w.timeout(), Some(1));
assert_eq!(w.take_expired(), None);
w.update(8);
assert_eq!(w.timeout(), Some(0));
assert_eq!(w.take_expired(), Some((t2, 2)));
assert_eq!(w.take_expired(), None);
for i in 0..2 {
let base = i * 20_000_000;
let t1 = w.add(base + 1, 1).unwrap();
let t2 = w.add(base + 10, 2).unwrap();
let t3 = w.add(base + 1_000, 3).unwrap();
let t4 = w.add(base + 100_000, 4).unwrap();
let t5 = w.add(base + 10_000_000, 5).unwrap();
w.update(base + 100);
assert_eq!(w.timeout(), Some(0));
assert_eq!(w.take_expired(), Some((t1, 1)));
assert_eq!(w.take_expired(), Some((t2, 2)));
assert_eq!(w.take_expired(), None);
assert!(w.timeout().unwrap() <= 900);
w.update(base + 2_000);
assert_eq!(w.timeout(), Some(0));
assert_eq!(w.take_expired(), Some((t3, 3)));
assert_eq!(w.take_expired(), None);
assert!(w.timeout().unwrap() <= 98_000);
w.update(base + 200_000);
assert_eq!(w.timeout(), Some(0));
assert_eq!(w.take_expired(), Some((t4, 4)));
assert_eq!(w.take_expired(), None);
assert!(w.timeout().unwrap() <= 9_800_000);
w.update(base + 12_000_000);
assert_eq!(w.timeout(), Some(0));
assert_eq!(w.take_expired(), Some((t5, 5)));
assert_eq!(w.take_expired(), None);
assert_eq!(w.timeout(), None);
}
}
}