use parking_lot::RwLock;
use std::collections::BTreeMap;
use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use thread_local::ThreadLocal;
use tokio::sync::Notify;
const RESOLUTION_MS: u64 = 10;
const RESOLUTION_DURATION: Duration = Duration::from_millis(RESOLUTION_MS);
#[inline]
fn round_to(raw: u128, resolution: u128) -> u128 {
raw - 1 + resolution - (raw - 1) % resolution
}
#[derive(PartialEq, PartialOrd, Eq, Ord, Clone, Copy, Debug)]
struct Time(u128);
impl From<u128> for Time {
fn from(raw_ms: u128) -> Self {
Time(round_to(raw_ms, RESOLUTION_MS as u128))
}
}
impl From<Duration> for Time {
fn from(d: Duration) -> Self {
Time(round_to(d.as_millis(), RESOLUTION_MS as u128))
}
}
impl Time {
pub fn not_after(&self, ts: u128) -> bool {
self.0 <= ts
}
}
pub struct TimerStub(Arc<Notify>, Arc<AtomicBool>);
impl TimerStub {
pub async fn poll(self) {
if self.1.load(Ordering::SeqCst) {
return;
}
self.0.notified().await;
}
}
struct Timer(Arc<Notify>, Arc<AtomicBool>);
impl Timer {
pub fn new() -> Self {
Timer(Arc::new(Notify::new()), Arc::new(AtomicBool::new(false)))
}
pub fn fire(&self) {
self.1.store(true, Ordering::SeqCst);
self.0.notify_waiters();
}
pub fn subscribe(&self) -> TimerStub {
TimerStub(self.0.clone(), self.1.clone())
}
}
pub struct TimerManager {
timers: ThreadLocal<RwLock<BTreeMap<Time, Timer>>>,
zero: Instant, clock_watchdog: AtomicI64,
paused: AtomicBool,
}
const DELAYS_SEC: i64 = 2;
impl Default for TimerManager {
fn default() -> Self {
TimerManager {
timers: ThreadLocal::new(),
zero: Instant::now(),
clock_watchdog: AtomicI64::new(-DELAYS_SEC),
paused: AtomicBool::new(false),
}
}
}
impl TimerManager {
pub fn new() -> Self {
Self::default()
}
pub(crate) fn clock_thread(&self) {
loop {
std::thread::sleep(RESOLUTION_DURATION);
let now = Instant::now() - self.zero;
self.clock_watchdog
.store(now.as_secs() as i64, Ordering::Relaxed);
if self.is_paused_for_fork() {
continue;
}
let now = now.as_millis();
for thread_timer in self.timers.iter() {
let mut timers = thread_timer.write();
loop {
let key_to_remove = timers.iter().next().and_then(|(k, _)| {
if k.not_after(now) {
Some(*k)
} else {
None
}
});
if let Some(k) = key_to_remove {
let timer = timers.remove(&k);
timer.unwrap().fire();
} else {
break;
}
}
}
}
}
pub(crate) fn should_i_start_clock(&self) -> bool {
let Err(prev) = self.is_clock_running() else {
return false;
};
let now = Instant::now().duration_since(self.zero).as_secs() as i64;
let res =
self.clock_watchdog
.compare_exchange(prev, now, Ordering::SeqCst, Ordering::SeqCst);
res.is_ok()
}
pub(crate) fn is_clock_running(&self) -> Result<(), i64> {
let now = Instant::now().duration_since(self.zero).as_secs() as i64;
let prev = self.clock_watchdog.load(Ordering::SeqCst);
if now < prev + DELAYS_SEC {
Ok(())
} else {
Err(prev)
}
}
pub fn register_timer(&self, duration: Duration) -> TimerStub {
if self.is_paused_for_fork() {
let timer = Timer::new();
timer.fire();
return timer.subscribe();
}
let now: Time = (Instant::now() + duration - self.zero).into();
{
let timers = self.timers.get_or(|| RwLock::new(BTreeMap::new())).read();
if let Some(t) = timers.get(&now) {
return t.subscribe();
}
}
let timer = Timer::new();
let mut timers = self.timers.get_or(|| RwLock::new(BTreeMap::new())).write();
let stub = timer.subscribe();
timers.insert(now, timer);
stub
}
fn is_paused_for_fork(&self) -> bool {
self.paused.load(Ordering::SeqCst)
}
pub fn pause_for_fork(&self) {
self.paused.store(true, Ordering::SeqCst);
std::thread::sleep(RESOLUTION_DURATION * 2);
}
pub fn unpause(&self) {
self.paused.store(false, Ordering::SeqCst)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_round() {
assert_eq!(round_to(30, 10), 30);
assert_eq!(round_to(31, 10), 40);
assert_eq!(round_to(29, 10), 30);
}
#[test]
fn test_time() {
let t: Time = 128.into(); assert_eq!(t, Duration::from_millis(130).into());
assert!(!t.not_after(128));
assert!(!t.not_after(129));
assert!(t.not_after(130));
assert!(t.not_after(131));
}
#[tokio::test]
async fn test_timer_manager() {
let tm_a = Arc::new(TimerManager::new());
let tm = tm_a.clone();
std::thread::spawn(move || tm_a.clock_thread());
let now = Instant::now();
let t1 = tm.register_timer(Duration::from_secs(1));
let t2 = tm.register_timer(Duration::from_secs(1));
t1.poll().await;
assert_eq!(now.elapsed().as_secs(), 1);
let now = Instant::now();
t2.poll().await;
assert_eq!(now.elapsed().as_secs(), 0);
}
#[test]
fn test_timer_manager_start_check() {
let tm = Arc::new(TimerManager::new());
assert!(tm.should_i_start_clock());
assert!(!tm.should_i_start_clock());
assert!(tm.is_clock_running().is_ok());
}
#[test]
fn test_timer_manager_watchdog() {
let tm = Arc::new(TimerManager::new());
assert!(tm.should_i_start_clock());
assert!(!tm.should_i_start_clock());
std::thread::sleep(Duration::from_secs(DELAYS_SEC as u64 + 1));
assert!(tm.is_clock_running().is_err());
assert!(tm.should_i_start_clock());
}
#[tokio::test]
async fn test_timer_manager_pause() {
let tm_a = Arc::new(TimerManager::new());
let tm = tm_a.clone();
std::thread::spawn(move || tm_a.clock_thread());
let now = Instant::now();
let t1 = tm.register_timer(Duration::from_secs(2));
tm.pause_for_fork();
let t2 = tm.register_timer(Duration::from_secs(2));
t2.poll().await;
assert_eq!(now.elapsed().as_secs(), 0);
std::thread::sleep(Duration::from_secs(1));
tm.unpause();
t1.poll().await;
assert_eq!(now.elapsed().as_secs(), 2);
}
}