use crate::timer::timer_future::TimerFuture;
use crate::timer::timer_state::TimerState;
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
use std::time::{Duration, Instant};
pub struct Timer {
state: Arc<Mutex<TimerState>>,
condvar: Arc<Condvar>,
}
impl Timer {
pub fn new() -> Timer {
Timer {
state: Arc::new(Mutex::new(TimerState::new())),
condvar: Arc::new(Condvar::new()),
}
}
pub fn wait(&mut self, duration: Duration) -> TimerFuture {
let state = &mut *self.state.lock().unwrap();
let expiration = Instant::now() + duration;
let launched = state.queue.len() > 0;
let (time_future, shortest) = state.add_to_queue(expiration);
if !launched {
self.launch();
} else if shortest {
self.condvar.notify_one();
}
time_future
}
fn launch(&self) {
let lock = self.state.clone();
let condvar = self.condvar.clone();
thread::spawn(move || {
loop {
let mut state = lock.lock().unwrap();
let expiration = state.current_expiration();
if expiration.is_none() {
break;
}
let expiration = expiration.unwrap();
let now = Instant::now();
let duration = if expiration > now {
expiration - now
} else {
Duration::ZERO
};
let result = match condvar.wait_timeout(state, duration) {
Ok(result) => result,
Err(err) => {
eprintln!("Err condvar.wait_timeout: {:?}", &err);
err.into_inner()
}
};
state = result.0;
if result.1.timed_out() {
let guard = state.queue.pop().unwrap();
let mut future_state = guard.lock().unwrap();
future_state.completed = true;
if let Some(waker) = future_state.waker.take() {
waker.wake()
}
} else {
}
}
});
}
}
impl Clone for Timer {
fn clone(&self) -> Self {
Timer {
state: self.state.clone(),
condvar: self.condvar.clone(),
}
}
}
impl Drop for Timer {
fn drop(&mut self) {
let mut state = self.state.lock().unwrap();
state.queue.clear();
self.condvar.notify_one();
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::executor::block_on;
use futures::join;
use std::time::Instant;
#[test]
fn test_timer() {
let mut timer = Timer::new();
let future = async {
let benchmark = Instant::now();
let future1 = timer.wait(Duration::from_millis(100));
thread::sleep(Duration::from_millis(30));
let future2 = timer.wait(Duration::from_millis(50));
let future3 = timer.wait(Duration::from_millis(100));
join!(future1, future2, future3);
assert!(benchmark.elapsed() <= Duration::from_millis(140));
let benchmark = Instant::now();
let future1 = timer.wait(Duration::from_millis(100));
thread::sleep(Duration::from_millis(30));
let future2 = timer.wait(Duration::from_millis(50));
let future3 = timer.wait(Duration::from_millis(100));
join!(future1, future2, future3);
let elapsed = benchmark.elapsed();
assert!(elapsed <= Duration::from_millis(140));
};
block_on(future);
}
#[test]
fn test_many_timers() {
let mut timer = Timer::new();
let futures = (0..100)
.map(|i| timer.wait(Duration::from_millis(i % 10 + 1)))
.collect::<Vec<_>>();
block_on(futures::future::join_all(futures));
assert!(timer.state.lock().unwrap().queue.is_empty());
}
#[test]
fn test_timer_new() {
let timer = Timer::new();
assert!(timer.state.lock().unwrap().queue.is_empty());
}
#[test]
fn test_timer_wait_single() {
let mut timer = Timer::new();
let start = Instant::now();
let future = timer.wait(Duration::from_millis(50));
block_on(future);
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(45));
assert!(elapsed <= Duration::from_millis(100));
}
#[test]
fn test_timer_wait_multiple_sequential() {
let mut timer = Timer::new();
let start = Instant::now();
block_on(timer.wait(Duration::from_millis(10)));
block_on(timer.wait(Duration::from_millis(10)));
block_on(timer.wait(Duration::from_millis(10)));
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(25));
assert!(elapsed <= Duration::from_millis(80));
}
#[test]
fn test_timer_wait_multiple_concurrent() {
let mut timer = Timer::new();
let start = Instant::now();
let future1 = timer.wait(Duration::from_millis(50));
let future2 = timer.wait(Duration::from_millis(30));
let future3 = timer.wait(Duration::from_millis(40));
block_on(async {
join!(future1, future2, future3);
});
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(45));
assert!(elapsed <= Duration::from_millis(100));
}
#[test]
fn test_timer_wait_zero_duration() {
let mut timer = Timer::new();
let start = Instant::now();
let future = timer.wait(Duration::ZERO);
block_on(future);
let elapsed = start.elapsed();
assert!(elapsed <= Duration::from_millis(10));
}
#[test]
fn test_timer_wait_very_short_duration() {
let mut timer = Timer::new();
let start = Instant::now();
let future = timer.wait(Duration::from_millis(1));
block_on(future);
let elapsed = start.elapsed();
assert!(elapsed <= Duration::from_millis(20));
}
#[test]
fn test_timer_queue_ordering() {
let mut timer = Timer::new();
let start = Instant::now();
let future1 = timer.wait(Duration::from_millis(100));
let future2 = timer.wait(Duration::from_millis(50));
let future3 = timer.wait(Duration::from_millis(75));
block_on(async {
join!(future1, future2, future3);
});
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(90));
assert!(elapsed <= Duration::from_millis(150));
}
#[test]
fn test_timer_drop_clears_queue() {
let mut timer = Timer::new();
let _future1 = timer.wait(Duration::from_millis(100));
let _future2 = timer.wait(Duration::from_millis(200));
assert!(!timer.state.lock().unwrap().queue.is_empty());
drop(timer);
}
#[test]
fn test_timer_concurrent_access() {
let timer = Timer::new();
let timer_arc = std::sync::Arc::new(std::sync::Mutex::new(timer));
let mut handles = vec![];
for i in 0..5 {
let timer_clone = timer_arc.clone();
let handle = std::thread::spawn(move || {
let mut timer = timer_clone.lock().unwrap();
let future = timer.wait(Duration::from_millis(10 + i * 10));
futures::executor::block_on(future);
i
});
handles.push(handle);
}
for handle in handles {
let id = handle.join().expect("Thread panicked");
println!("Timer thread {} completed", id);
}
}
#[test]
fn test_timer_shortest_detection() {
let mut timer = Timer::new();
let start = Instant::now();
let future1 = timer.wait(Duration::from_millis(100));
let future2 = timer.wait(Duration::from_millis(20));
block_on(async {
join!(future1, future2);
});
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(90));
assert!(elapsed <= Duration::from_millis(150));
}
}