use std::cell::UnsafeCell;
use std::collections::{HashMap, LinkedList};
use std::future::Future;
use std::pin::Pin;
use std::sync::Mutex;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::{Context, Poll, Waker};
pub use std::time::Duration;
pub use std::time::Instant;
const TICK_MS: u64 = 1;
const WHEEL0_SIZE: usize = 256;
const WHEEL0_SHIFT: usize = 8; const WHEEL0_MASK: usize = WHEEL0_SIZE - 1;
const WHEEL1_SIZE: usize = 64;
const WHEEL1_SHIFT: usize = 6; const WHEEL1_MASK: usize = WHEEL1_SIZE - 1;
const WHEEL2_SIZE: usize = 64;
const WHEEL2_SHIFT: usize = 6;
const WHEEL2_MASK: usize = WHEEL2_SIZE - 1;
const WHEEL3_SIZE: usize = 64;
#[allow(dead_code)]
const WHEEL3_SHIFT: usize = 6;
const WHEEL3_MASK: usize = WHEEL3_SIZE - 1;
#[allow(dead_code)]
const MAX_TIMEOUT_MS: u64 =
(WHEEL0_SIZE * WHEEL1_SIZE * WHEEL2_SIZE * WHEEL3_SIZE) as u64 * TICK_MS;
struct TimerEntry {
id: u64,
expiration_ms: u64,
waker: Option<Waker>,
#[allow(dead_code)]
canceled: Mutex<bool>,
}
unsafe impl Send for TimerEntry {}
unsafe impl Sync for TimerEntry {}
#[derive(Debug)]
struct TimerSlot {
timers: UnsafeCell<LinkedList<TimerEntry>>,
}
impl TimerSlot {
fn new() -> Self {
Self {
timers: UnsafeCell::new(LinkedList::new()),
}
}
unsafe fn push(&self, timer: TimerEntry) {
let list = &mut *self.timers.get();
list.push_back(timer);
}
unsafe fn take_all(&self) -> LinkedList<TimerEntry> {
let list = &mut *self.timers.get();
std::mem::take(list)
}
}
unsafe impl Send for TimerSlot {}
unsafe impl Sync for TimerSlot {}
pub struct TimerWheel {
current_ticks: AtomicU64,
wheel0: Box<[TimerSlot; WHEEL0_SIZE]>,
wheel1: Box<[TimerSlot; WHEEL1_SIZE]>,
wheel2: Box<[TimerSlot; WHEEL2_SIZE]>,
wheel3: Box<[TimerSlot; WHEEL3_SIZE]>,
next_id: AtomicU64,
timer_registry: Mutex<HashMap<u64, TimerLocation>>,
}
#[derive(Clone, Copy, Debug)]
struct TimerLocation {
#[allow(dead_code)]
wheel_level: u8,
#[allow(dead_code)]
slot_index: usize,
}
unsafe impl Send for TimerWheel {}
unsafe impl Sync for TimerWheel {}
impl TimerWheel {
pub fn new() -> Self {
Self {
current_ticks: AtomicU64::new(0),
wheel0: (0..WHEEL0_SIZE)
.map(|_| TimerSlot::new())
.collect::<Vec<_>>()
.into_boxed_slice()
.try_into()
.unwrap(),
wheel1: (0..WHEEL1_SIZE)
.map(|_| TimerSlot::new())
.collect::<Vec<_>>()
.into_boxed_slice()
.try_into()
.unwrap(),
wheel2: (0..WHEEL2_SIZE)
.map(|_| TimerSlot::new())
.collect::<Vec<_>>()
.into_boxed_slice()
.try_into()
.unwrap(),
wheel3: (0..WHEEL3_SIZE)
.map(|_| TimerSlot::new())
.collect::<Vec<_>>()
.into_boxed_slice()
.try_into()
.unwrap(),
next_id: AtomicU64::new(1),
timer_registry: Mutex::new(HashMap::new()),
}
}
pub fn cancel_timer(&self, id: u64) -> bool {
let mut registry = self.timer_registry.lock().unwrap();
if let Some(_location) = registry.remove(&id) {
true
} else {
false
}
}
#[inline]
pub fn current_ticks(&self) -> u64 {
self.current_ticks.load(Ordering::Acquire)
}
pub fn advance(&self, ticks: u64) -> usize {
let mut expired = 0;
let _start = self.current_ticks.load(Ordering::Acquire);
for _ in 0..ticks {
let tick = self.current_ticks.fetch_add(1, Ordering::AcqRel);
let pos0 = (tick & WHEEL0_MASK as u64) as usize;
unsafe {
let timers = self.wheel0[pos0].take_all();
for timer in timers {
let is_active = {
let mut registry = self.timer_registry.lock().unwrap();
registry.remove(&timer.id).is_some()
};
if is_active {
if let Some(waker) = timer.waker {
waker.wake();
}
expired += 1;
}
}
}
if tick & (WHEEL0_SIZE as u64 - 1) == 0 {
let pos1 = ((tick >> WHEEL0_SHIFT) & WHEEL1_MASK as u64) as usize;
unsafe {
let timers = self.wheel1[pos1].take_all();
for timer in timers {
let is_active = {
let registry = self.timer_registry.lock().unwrap();
registry.contains_key(&timer.id)
};
if is_active {
self.insert_timer_inner(timer);
}
}
}
}
if tick & ((WHEEL0_SIZE * WHEEL1_SIZE) as u64 - 1) == 0 {
let pos2 = ((tick >> (WHEEL0_SHIFT + WHEEL1_SHIFT)) & WHEEL2_MASK as u64) as usize;
unsafe {
let timers = self.wheel2[pos2].take_all();
for timer in timers {
let is_active = {
let registry = self.timer_registry.lock().unwrap();
registry.contains_key(&timer.id)
};
if is_active {
self.insert_timer_inner(timer);
}
}
}
}
if tick & ((WHEEL0_SIZE * WHEEL1_SIZE * WHEEL2_SIZE) as u64 - 1) == 0 {
let pos3 = ((tick >> (WHEEL0_SHIFT + WHEEL1_SHIFT + WHEEL2_SHIFT))
& WHEEL3_MASK as u64) as usize;
unsafe {
let timers = self.wheel3[pos3].take_all();
for timer in timers {
let is_active = {
let registry = self.timer_registry.lock().unwrap();
registry.contains_key(&timer.id)
};
if is_active {
self.insert_timer_inner(timer);
}
}
}
}
}
expired
}
fn insert_timer_inner(&self, timer: TimerEntry) {
let current = self.current_ticks.load(Ordering::Acquire);
let expiration = timer.expiration_ms / TICK_MS;
let id = timer.id;
if expiration <= current {
if let Some(waker) = timer.waker {
waker.wake();
}
return;
}
let ticks = expiration - current;
let (wheel_level, pos) = if ticks < WHEEL0_SIZE as u64 {
(0u8, ((current + ticks) & WHEEL0_MASK as u64) as usize)
} else if ticks < (WHEEL0_SIZE * WHEEL1_SIZE) as u64 {
(1u8, (((current + ticks) >> WHEEL0_SHIFT) & WHEEL1_MASK as u64) as usize)
} else if ticks < (WHEEL0_SIZE * WHEEL1_SIZE * WHEEL2_SIZE) as u64 {
(
2u8,
(((current + ticks) >> (WHEEL0_SHIFT + WHEEL1_SHIFT)) & WHEEL2_MASK as u64)
as usize,
)
} else {
(
3u8,
(((current + ticks) >> (WHEEL0_SHIFT + WHEEL1_SHIFT + WHEEL2_SHIFT))
& WHEEL3_MASK as u64) as usize,
)
};
{
let mut registry = self.timer_registry.lock().unwrap();
registry.insert(
id,
TimerLocation {
wheel_level,
slot_index: pos,
},
);
}
match wheel_level {
0 => unsafe { self.wheel0[pos].push(timer) },
1 => unsafe { self.wheel1[pos].push(timer) },
2 => unsafe { self.wheel2[pos].push(timer) },
_ => unsafe { self.wheel3[pos].push(timer) },
}
}
pub fn insert_timer(&self, duration: Duration) -> TimerHandle {
let duration_ms = duration.as_millis() as u64;
let current = self.current_ticks.load(Ordering::Acquire);
let expiration_ms = (current * TICK_MS) + duration_ms;
let id = self.next_id.fetch_add(1, Ordering::AcqRel);
let timer = TimerEntry {
id,
expiration_ms,
waker: None,
canceled: Mutex::new(false),
};
self.insert_timer_inner(timer);
TimerHandle::new(id)
}
pub fn insert_timer_with_waker(&self, duration: Duration, waker: Waker) -> TimerHandle {
let duration_ms = duration.as_millis() as u64;
let current = self.current_ticks.load(Ordering::Acquire);
let expiration_ms = (current * TICK_MS) + duration_ms;
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let timer = TimerEntry {
id,
expiration_ms,
waker: Some(waker),
canceled: Mutex::new(false),
};
self.insert_timer_inner(timer);
TimerHandle::new(id)
}
pub fn next_expiration(&self) -> Option<u64> {
None
}
}
impl Default for TimerWheel {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct TimerHandle {
#[allow(dead_code)]
id: u64,
}
unsafe impl Send for TimerHandle {}
impl TimerHandle {
pub fn cancel(&self) {
global_timer().cancel_timer(self.id);
}
fn new(id: u64) -> Self {
Self { id }
}
}
static GLOBAL_TIMER: OnceLock<TimerWheel> = OnceLock::new();
#[inline]
pub fn global_timer() -> &'static TimerWheel {
GLOBAL_TIMER.get_or_init(|| TimerWheel::new())
}
pub struct Sleep {
duration: Duration,
registered: bool,
start: Option<Instant>,
}
impl Sleep {
pub fn new(duration: Duration) -> Self {
Self {
duration,
registered: false,
start: None,
}
}
}
impl Future for Sleep {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.registered {
if let Some(start) = self.start
&& start.elapsed() >= self.duration
{
return Poll::Ready(());
}
Poll::Pending
} else {
self.registered = true;
self.start = Some(Instant::now());
global_timer().insert_timer_with_waker(self.duration, cx.waker().clone());
Poll::Pending
}
}
}
pub fn sleep(duration: Duration) -> Sleep {
Sleep::new(duration)
}
pub fn sleep_until(instant: Instant) -> SleepUntil {
let now = Instant::now();
let duration = if instant > now {
instant.duration_since(now)
} else {
Duration::ZERO
};
SleepUntil {
sleep: sleep(duration),
}
}
pub struct SleepUntil {
sleep: Sleep,
}
impl Future for SleepUntil {
type Output = ();
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Pin::new(&mut self.sleep).poll(_cx)
}
}
pub fn interval(duration: Duration) -> Interval {
Interval {
duration,
next: Instant::now(),
}
}
pub struct Interval {
duration: Duration,
next: Instant,
}
impl Interval {
pub async fn tick(&mut self) -> Instant {
let now = Instant::now();
if now >= self.next {
self.next = now + self.duration;
}
sleep_until(self.next).await;
self.next
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_timer_wheel_creation() {
let wheel = TimerWheel::new();
assert_eq!(wheel.current_ticks(), 0);
}
#[test]
fn test_timer_constants() {
assert_eq!(TICK_MS, 1);
assert_eq!(WHEEL0_SIZE, 256);
assert_eq!(WHEEL1_SIZE, 64);
assert_eq!(WHEEL2_SIZE, 64);
assert_eq!(WHEEL3_SIZE, 64);
}
#[test]
fn test_global_timer() {
let timer = global_timer();
assert_eq!(timer.current_ticks(), 0);
}
#[test]
fn test_max_timeout() {
assert!(MAX_TIMEOUT_MS > 60_000 * 60 * 18);
assert!(MAX_TIMEOUT_MS < 60_000 * 60 * 20);
}
}