use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering as AtomicOrdering;
use std::sync::mpsc::{channel, Sender};
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
use std::time::{Duration, Instant};
struct Schedule<T> {
date: Instant,
data: T,
guard: Guard,
repeat: Option<Duration>,
}
impl<T> Ord for Schedule<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.date.cmp(&other.date).reverse()
}
}
impl<T> PartialOrd for Schedule<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T> Eq for Schedule<T> {}
impl<T> PartialEq for Schedule<T> {
fn eq(&self, other: &Self) -> bool {
self.date.eq(&other.date)
}
}
enum Op<T> {
Schedule(Schedule<T>),
Stop,
}
struct WaiterChannel<T> {
messages: Mutex<Vec<Op<T>>>,
condvar: Condvar,
}
impl<T> WaiterChannel<T> {
fn with_capacity(cap: usize) -> Self {
WaiterChannel {
messages: Mutex::new(Vec::with_capacity(cap)),
condvar: Condvar::new(),
}
}
}
trait Executor<T> {
fn execute(&mut self, data: T);
fn execute_clone(&mut self, data: T) -> T;
}
struct CallbackExecutor;
impl Executor<Box<dyn FnMut() + Send>> for CallbackExecutor {
fn execute(&mut self, mut data: Box<dyn FnMut() + Send>) {
data();
}
fn execute_clone(&mut self, mut data: Box<dyn FnMut() + Send>) -> Box<dyn FnMut() + Send> {
data();
data
}
}
struct DeliveryExecutor<T>
where
T: 'static + Send,
{
tx: Sender<T>,
}
impl<T> Executor<T> for DeliveryExecutor<T>
where
T: 'static + Send + Clone,
{
fn execute(&mut self, data: T) {
let _ = self.tx.send(data);
}
fn execute_clone(&mut self, data: T) -> T {
let _ = self.tx.send(data.clone());
data
}
}
struct Scheduler<T, E>
where
E: Executor<T>,
{
waiter: Arc<WaiterChannel<T>>,
heap: BinaryHeap<Schedule<T>>,
executor: E,
}
impl<T, E> Scheduler<T, E>
where
E: Executor<T>,
{
fn with_capacity(waiter: Arc<WaiterChannel<T>>, executor: E, capacity: usize) -> Self {
Scheduler {
waiter,
executor,
heap: BinaryHeap::with_capacity(capacity),
}
}
fn run(&mut self) {
enum Sleep {
NotAtAll,
UntilAwakened,
AtMost(Duration),
}
let waiter = &(*self.waiter);
loop {
let mut sleep = if let Some(sched) = self.heap.peek() {
let now = Instant::now();
if sched.date > now {
Sleep::AtMost(sched.date.duration_since(now))
} else {
let sched = self.heap.pop().unwrap();
if sched.guard.should_execute() {
if let Some(delta) = sched.repeat {
let data = self.executor.execute_clone(sched.data);
self.heap.push(Schedule {
date: sched.date + delta,
data,
guard: sched.guard,
repeat: Some(delta),
});
} else {
self.executor.execute(sched.data);
}
}
Sleep::NotAtAll
}
} else {
Sleep::UntilAwakened
};
let mut lock = waiter.messages.lock().unwrap();
for msg in lock.drain(..) {
match msg {
Op::Stop => {
return;
}
Op::Schedule(sched) => {
self.heap.push(sched);
sleep = Sleep::NotAtAll;
}
}
}
match sleep {
Sleep::UntilAwakened => {
let _ = waiter.condvar.wait(lock);
}
Sleep::AtMost(delay) => {
let sec = delay.as_secs();
let ns = delay.subsec_nanos();
let duration = Duration::new(sec, ns);
let _ = waiter.condvar.wait_timeout(lock, duration);
}
Sleep::NotAtAll => {}
}
}
}
}
pub struct TimerBase<T>
where
T: 'static + Send,
{
tx: Sender<Op<T>>,
}
impl<T> Drop for TimerBase<T>
where
T: 'static + Send,
{
fn drop(&mut self) {
self.tx.send(Op::Stop).unwrap();
}
}
impl<T> TimerBase<T>
where
T: 'static + Send,
{
fn new<E>(executor: E) -> Self
where
E: 'static + Executor<T> + Send,
{
Self::with_capacity(executor, 32)
}
fn with_capacity<E>(executor: E, capacity: usize) -> Self
where
E: 'static + Executor<T> + Send,
{
let waiter_send = Arc::new(WaiterChannel::with_capacity(capacity));
let waiter_recv = waiter_send.clone();
let (tx, rx) = channel();
thread::spawn(move || {
use Op::*;
let waiter = &(*waiter_send);
for msg in rx.iter() {
let mut vec = waiter.messages.lock().unwrap();
match msg {
Schedule(sched) => {
vec.push(Schedule(sched));
waiter.condvar.notify_one();
}
Stop => {
vec.clear();
vec.push(Stop);
waiter.condvar.notify_one();
return;
}
}
}
});
thread::Builder::new()
.name("Timer thread".to_owned())
.spawn(move || {
let mut scheduler = Scheduler::with_capacity(waiter_recv, executor, capacity);
scheduler.run()
})
.unwrap();
TimerBase { tx }
}
pub fn schedule_with_delay(&self, delay: Duration, data: T) -> Guard {
self.schedule(Instant::now() + delay, None, data)
}
pub fn schedule_repeating(&self, repeat: Duration, data: T) -> Guard {
self.schedule(Instant::now() + repeat, Some(repeat), data)
}
pub fn schedule(&self, date: Instant, repeat: Option<Duration>, data: T) -> Guard {
let guard = Guard::new();
self.tx
.send(Op::Schedule(Schedule {
date,
data,
guard: guard.clone(),
repeat,
}))
.unwrap();
guard
}
}
pub struct Timer {
base: TimerBase<Box<dyn FnMut() + Send>>,
}
impl Timer {
pub fn new() -> Self {
Timer {
base: TimerBase::new(CallbackExecutor),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Timer {
base: TimerBase::with_capacity(CallbackExecutor, capacity),
}
}
pub fn schedule_with_delay<F>(&self, delay: Duration, cb: F) -> Guard
where
F: 'static + FnMut() + Send,
{
self.base.schedule_with_delay(delay, Box::new(cb))
}
pub fn schedule_repeating<F>(&self, repeat: Duration, cb: F) -> Guard
where
F: 'static + FnMut() + Send,
{
self.base.schedule_repeating(repeat, Box::new(cb))
}
pub fn schedule<F>(&self, date: Instant, repeat: Option<Duration>, cb: F) -> Guard
where
F: 'static + FnMut() + Send,
{
self.base.schedule(date, repeat, Box::new(cb))
}
}
impl Default for Timer {
fn default() -> Self {
Self::new()
}
}
pub struct MessageTimer<T>
where
T: 'static + Send + Clone,
{
base: TimerBase<T>,
}
impl<T> MessageTimer<T>
where
T: 'static + Send + Clone,
{
pub fn new(tx: Sender<T>) -> Self {
MessageTimer {
base: TimerBase::new(DeliveryExecutor { tx }),
}
}
pub fn with_capacity(tx: Sender<T>, capacity: usize) -> Self {
MessageTimer {
base: TimerBase::with_capacity(DeliveryExecutor { tx }, capacity),
}
}
pub fn schedule_with_delay(&self, delay: Duration, msg: T) -> Guard {
self.base.schedule_with_delay(delay, msg)
}
pub fn schedule_repeating(&self, repeat: Duration, msg: T) -> Guard {
self.base.schedule_repeating(repeat, msg)
}
pub fn schedule<D>(&self, date: Instant, repeat: Option<Duration>, msg: T) -> Guard {
self.base.schedule(date, repeat, msg)
}
}
#[derive(Clone)]
pub struct Guard {
should_execute: Arc<AtomicBool>,
ignore_drop: bool,
}
impl Guard {
fn new() -> Self {
Guard {
should_execute: Arc::new(AtomicBool::new(true)),
ignore_drop: false,
}
}
fn should_execute(&self) -> bool {
self.should_execute.load(AtomicOrdering::Relaxed)
}
pub fn ignore(mut self) {
self.ignore_drop = true;
}
}
impl Drop for Guard {
fn drop(&mut self) {
if !self.ignore_drop {
self.should_execute.store(false, AtomicOrdering::Relaxed)
}
}
}
#[cfg(test)]
mod tests {
extern crate std;
use super::*;
use std::sync::mpsc::channel;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
#[test]
fn test_schedule_with_delay() {
let timer = Timer::new();
let (tx, rx) = channel();
let mut guards = vec![];
let mut delays = vec![1, 5, 3, 0];
let start = Instant::now();
for i in delays.clone() {
println!("Scheduling for execution in {} seconds", i);
let tx = tx.clone();
guards.push(timer.schedule_with_delay(Duration::from_secs(i), move || {
println!("Callback {}", i);
tx.send(i).unwrap();
}));
}
delays.sort();
for (i, msg) in (0..delays.len()).zip(rx.iter()) {
let elapsed = start.elapsed().as_secs();
println!("Received message {} after {} seconds", msg, elapsed);
assert_eq!(msg, delays[i]);
assert!(
delays[i] <= elapsed && elapsed <= delays[i] + 3,
"We have waited {} seconds, expecting [{}, {}]",
elapsed,
delays[i],
delays[i] + 3
);
}
let start = Instant::now();
for i in vec![10, 0] {
println!("Scheduling for execution in {} seconds", i);
let tx = tx.clone();
guards.push(timer.schedule_with_delay(Duration::from_secs(i), move || {
println!("Callback {}", i);
tx.send(i).unwrap();
}));
}
assert_eq!(rx.recv().unwrap(), 0);
assert!(start.elapsed() <= Duration::from_secs(1));
}
#[test]
fn test_message_timer() {
let (tx, rx) = channel();
let timer = MessageTimer::new(tx);
let start = Instant::now();
let mut delays = vec![400, 300, 100, 500, 200];
for delay in delays.clone() {
timer
.schedule_with_delay(Duration::from_millis(delay), delay)
.ignore();
}
delays.sort();
for delay in delays {
assert_eq!(rx.recv().unwrap(), delay);
}
assert!(start.elapsed() <= Duration::from_secs(1));
}
#[test]
fn test_guards() {
println!("Testing that callbacks aren't called if the guard is dropped");
let timer = Timer::new();
let called = Arc::new(Mutex::new(false));
for i in 0..10 {
let called = called.clone();
timer.schedule_with_delay(Duration::from_millis(i), move || {
*called.lock().unwrap() = true;
});
}
thread::sleep(Duration::from_secs(1));
assert_eq!(*called.lock().unwrap(), false);
}
#[test]
fn test_guard_ignore() {
let timer = Timer::new();
let called = Arc::new(Mutex::new(false));
{
let called = called.clone();
timer
.schedule_with_delay(Duration::from_millis(1), move || {
*called.lock().unwrap() = true;
})
.ignore();
}
thread::sleep(Duration::from_secs(1));
assert_eq!(*called.lock().unwrap(), true);
}
struct NoCloneMessage;
impl Clone for NoCloneMessage {
fn clone(&self) -> Self {
panic!("TestMessage should not be cloned");
}
}
#[test]
fn test_no_clone() {
let (tx, rx) = channel();
let timer = MessageTimer::new(tx);
timer
.schedule_with_delay(Duration::from_millis(0), NoCloneMessage)
.ignore();
timer
.schedule_with_delay(Duration::from_millis(0), NoCloneMessage)
.ignore();
for _ in 0..2 {
let _ = rx.recv();
}
}
#[test]
fn test_too_much_work() {
let timer = Timer::new();
let was_called = Arc::new(Mutex::new(false));
let was_called_2 = Arc::new(Mutex::new(false));
{
let was_called = was_called.clone();
timer
.schedule(Instant::now(), Some(Duration::from_millis(10)), move || {
thread::sleep(Duration::from_millis(30));
*was_called.lock().unwrap() = true;
})
.ignore();
let was_called_2 = was_called_2.clone();
timer
.schedule(Instant::now(), None, move || {
thread::sleep(Duration::from_millis(30));
*was_called_2.lock().unwrap() = true;
})
.ignore();
}
thread::sleep(Duration::from_millis(150));
assert!(
*was_called.lock().unwrap(),
"Periodic task should have been called"
);
assert!(
*was_called_2.lock().unwrap(),
"One-time task should have been called"
);
drop(timer);
thread::sleep(Duration::from_millis(150));
*was_called.lock().unwrap() = false;
thread::sleep(Duration::from_millis(200));
assert!(
!*was_called.lock().unwrap(),
"Task should have been stopped when the timer dropped"
);
}
}