#[cfg(test)]
#[macro_use]
extern crate assert_matches;
use std::collections::{BinaryHeap, HashMap};
use std::cmp::{Ordering, Ord, PartialOrd, PartialEq};
use std::time::{Instant, Duration};
use std::hash::Hash;
use std::fmt::{self, Debug};
use std::convert::From;
#[derive(Debug)]
pub enum Error {
AlreadyExists
}
#[derive(Debug, Clone)]
pub enum TimerType {
Oneshot,
Recurring
}
pub struct Expired<'a, T> where T: 'a {
now: Instant,
heap: &'a mut TimerHeap<T>
}
impl<'a, T> Iterator for Expired<'a, T> where T: Eq + Clone + Hash {
type Item = T;
fn next(&mut self) -> Option<T> {
while let Some(mut popped) = self.heap.timers.pop() {
if popped.expires_at <= self.now {
if self.heap.active.get(&popped.key) != Some(&popped.counter) {
continue;
}
if popped.recurring {
let key = popped.key.clone();
popped.expires_at += popped.duration;
self.heap.timers.push(popped);
return Some(key);
} else {
let _ = self.heap.active.remove(&popped.key);
return Some(popped.key);
}
} else {
self.heap.timers.push(popped);
return None;
}
}
None
}
}
pub struct TimerHeap<T> {
timers: BinaryHeap<TimerEntry<T>>,
active: HashMap<T, u64>,
counter: u64
}
impl<T:Debug + Eq + Clone + Hash + Ord> Debug for TimerHeap<T> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_map()
.entries(self.timers
.iter()
.filter(|e| self.is_active(e))
.map(|e| (&e.key, DebugEntry::from(e))))
.finish()
}
}
impl<T: Eq + Clone + Hash> TimerHeap<T> {
pub fn new() -> TimerHeap<T> {
TimerHeap {
timers: BinaryHeap::new(),
active: HashMap::new(),
counter: 0
}
}
pub fn len(&self) -> usize {
self.timers.len()
}
pub fn insert(&mut self, key: T, duration: Duration, ty: TimerType) -> Result<(), Error> {
self._insert(key, duration, ty, Instant::now())
}
fn _insert(&mut self, key: T, duration: Duration, ty: TimerType, now: Instant) -> Result<(), Error> {
if self.active.contains_key(&key) {
return Err(Error::AlreadyExists);
}
let entry = TimerEntry::new(key.clone(), duration, ty, now, self.counter);
self.timers.push(entry);
self.active.insert(key, self.counter);
self.counter += 1;
Ok(())
}
pub fn upsert(&mut self, key: T, duration: Duration, ty: TimerType) -> bool {
let entry = TimerEntry::new(key.clone(), duration, ty, Instant::now(), self.counter);
self.timers.push(entry);
let existed = self.active.insert(key, self.counter).is_some();
self.counter += 1;
existed
}
pub fn remove(&mut self, key: T) -> bool {
self.active.remove(&key).is_some()
}
pub fn time_remaining(&self) -> Option<Duration> {
self._time_remaining(Instant::now())
}
fn _time_remaining(&self, now: Instant) -> Option<Duration> {
self.timers
.iter()
.find(|e| self.is_active(e))
.map(|e| {
if now > e.expires_at {
return Duration::new(0, 0);
}
e.expires_at - now
})
}
pub fn earliest_timeout(&self, user_timeout: Duration) -> Duration {
if let Some(remaining) = self.time_remaining() {
if user_timeout < remaining {
user_timeout
} else {
remaining
}
} else {
user_timeout
}
}
pub fn expired(&mut self) -> Expired<T> {
self._expired(Instant::now())
}
fn _expired(&mut self, now: Instant) -> Expired<T> {
Expired {
now: now,
heap: self
}
}
fn is_active(&self, entry: &TimerEntry<T>) -> bool {
self.active.get(&entry.key) == Some(&entry.counter)
}
}
#[derive(Eq, Debug)]
struct TimerEntry<T> {
key: T,
recurring: bool,
expires_at: Instant,
duration: Duration,
counter: u64
}
struct DebugEntry {
recurring: bool,
expires_at: Instant,
duration: Duration,
}
impl Debug for DebugEntry {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("Timer")
.field("recurring", &self.recurring)
.field("expires_at", &self.expires_at)
.field("duration", &self.duration)
.finish()
}
}
impl<'a, T> From<&'a TimerEntry<T>> for DebugEntry {
fn from(e: &TimerEntry<T>) -> Self {
DebugEntry {
recurring: e.recurring,
expires_at: e.expires_at,
duration: e.duration
}
}
}
impl<T> TimerEntry<T> {
pub fn new(key: T,
duration: Duration,
ty: TimerType,
now: Instant,
counter: u64) -> TimerEntry<T> {
let recurring = match ty {
TimerType::Oneshot => false,
TimerType::Recurring => true
};
TimerEntry {
key: key,
recurring: recurring,
expires_at: now + duration,
duration: duration,
counter: counter
}
}
}
impl<T: Eq> Ord for TimerEntry<T> {
fn cmp(&self, other: &TimerEntry<T>) -> Ordering {
if self.expires_at > other.expires_at {
return Ordering::Less;
}
if self.expires_at < other.expires_at {
return Ordering::Greater;
}
Ordering::Equal
}
}
impl<T: Eq> PartialOrd for TimerEntry<T> {
fn partial_cmp(&self, other: &TimerEntry<T>) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T: Eq> PartialEq for TimerEntry<T> {
fn eq(&self, other: &TimerEntry<T>) -> bool {
self.expires_at == other.expires_at
}
}
#[cfg(test)]
mod tests {
use super::{TimerHeap, TimerType, Error};
use std::time::{Instant, Duration};
#[test]
fn time_remaining() {
let mut heap = TimerHeap::new();
let now = Instant::now();
let duration = Duration::from_millis(500);
heap._insert(1u64, duration, TimerType::Oneshot, now)
.unwrap();
println!("Active Oneshot Timer: {:?}", heap);
assert_eq!(heap._time_remaining(now), Some(Duration::from_millis(500)));
assert_eq!(
heap._time_remaining(now + duration),
Some(Duration::new(0, 0))
);
println!("Expired Oneshot Timer: {:?}", heap);
assert_eq!(
heap._time_remaining(now + duration + Duration::from_millis(100)),
Some(Duration::new(0, 0))
);
assert_eq!(heap.remove(2), false);
assert!(heap.remove(1));
println!("Empty heap: {:?}", heap);
assert_eq!(heap._time_remaining(now), None);
}
#[test]
fn expired_non_recurring() {
let mut heap = TimerHeap::new();
let now = Instant::now();
let duration = Duration::from_millis(500);
heap._insert(1u64, duration, TimerType::Oneshot, now).unwrap();
assert_eq!(heap._expired(now).count(), 0);
let count = heap._expired(now + duration).count();
assert_eq!(heap.active.len(), 0);
assert_eq!(count, 1);
assert_eq!(heap.len(), 0);
assert_eq!(heap._expired(now + duration).next(), None);
}
#[test]
fn expired_recurring() {
let mut heap = TimerHeap::new();
let now = Instant::now();
let duration = Duration::from_millis(500);
heap._insert(1u64, duration, TimerType::Recurring, now).unwrap();
assert_eq!(heap._expired(now).count(), 0);
let count = heap._expired(now + duration).count();
assert_eq!(count, 1);
assert_eq!(heap.len(), 1);
assert_eq!(heap._expired(now + duration + Duration::from_millis(1)).count(), 0);
let count = heap._expired(now + duration + duration).count();
assert_eq!(count, 1);
assert_eq!(heap.len(), 1);
assert_eq!(heap._expired(now + duration + duration).count(), 0);
}
#[test]
fn insert_twice_fails() {
let mut heap = TimerHeap::new();
let duration = Duration::from_millis(500);
heap.insert(1u64, duration, TimerType::Recurring).unwrap();
assert_matches!(heap.insert(1u64, duration, TimerType::Recurring), Err(Error::AlreadyExists));
}
#[test]
fn remove_causes_no_expiration() {
let mut heap = TimerHeap::new();
let now = Instant::now();
let duration = Duration::from_millis(500);
heap._insert(1u64, duration, TimerType::Recurring, now).unwrap();
assert_eq!(heap.remove(1u64), true);
assert_eq!(heap._expired(now + duration).count(), 0);
assert_eq!(heap.len(), 0);
}
#[test]
fn remove_then_reinsert_only_causes_one_expiration() {
let mut heap = TimerHeap::new();
let now = Instant::now();
let duration = Duration::from_millis(500);
heap._insert(1u64, duration, TimerType::Oneshot, now).unwrap();
assert_eq!(heap.remove(1u64), true);
heap._insert(1u64, duration, TimerType::Oneshot, now + duration).unwrap();
assert_eq!(heap._expired(now + duration + duration).count(), 1);
assert_eq!(heap.active.len(), 0);
assert_eq!(heap.len(), 0);
}
#[test]
fn upsert() {
let mut heap = TimerHeap::new();
let duration = Duration::from_millis(500);
heap.insert(1u64, duration, TimerType::Oneshot).unwrap();
assert_eq!(heap.upsert(1u64, duration, TimerType::Oneshot), true);
assert_eq!(heap.remove(1u64), true);
assert_eq!(heap.upsert(1u64, duration, TimerType::Oneshot), false);
assert_eq!(heap.upsert(1u64, duration, TimerType::Oneshot), true);
}
}