use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::marker::PhantomData;
use std::time::{Duration, SystemTime};
fn hash<T: Hash>(t: &T) -> u64 {
let mut hasher = DefaultHasher::new();
t.hash(&mut hasher);
hasher.finish()
}
pub struct MessageFilter<Message> {
entries: Vec<TimestampedMessage>,
time_to_live: Duration,
phantom: PhantomData<Message>,
}
impl<Message: Hash> MessageFilter<Message> {
pub fn with_expiry_duration(time_to_live: Duration) -> MessageFilter<Message> {
MessageFilter {
entries: vec![],
time_to_live: time_to_live,
phantom: PhantomData,
}
}
pub fn insert(&mut self, message: &Message) -> usize {
self.remove_expired();
let hash_code = hash(message);
if let Some(index) = self.entries.iter().position(|t| t.hash_code == hash_code) {
let mut timestamped_message = self.entries.remove(index);
timestamped_message.update_expiry_point(self.time_to_live);
let count = timestamped_message.increment_count();
self.entries.push(timestamped_message);
count
} else {
self.entries.push(TimestampedMessage::new(hash_code, self.time_to_live));
1
}
}
#[cfg(test)]
pub fn count(&self, message: &Message) -> usize {
let hash_code = hash(message);
self.entries
.iter()
.find(|t| t.hash_code == hash_code)
.map_or(0, |t| t.count)
}
pub fn contains(&mut self, message: &Message) -> bool {
self.remove_expired();
let hash_code = hash(message);
self.entries.iter().any(|entry| entry.hash_code == hash_code)
}
pub fn remove(&mut self, message: &Message) {
let hash_code = hash(message);
if let Some(index) = self.entries.iter().position(|t| t.hash_code == hash_code) {
let _old_val = self.entries.remove(index);
}
}
#[cfg(feature = "use-mock-crust")]
pub fn clear(&mut self) {
self.entries.clear();
}
fn remove_expired(&mut self) {
let now = SystemTime::now();
if let Some(at) = self.entries.iter().position(|entry| entry.expiry_point > now) {
self.entries = self.entries.split_off(at)
} else {
self.entries.clear();
}
}
}
struct TimestampedMessage {
pub hash_code: u64,
pub expiry_point: SystemTime,
pub count: usize,
}
impl TimestampedMessage {
pub fn new(hash_code: u64, time_to_live: Duration) -> TimestampedMessage {
TimestampedMessage {
hash_code: hash_code,
expiry_point: SystemTime::now() + time_to_live,
count: 1,
}
}
pub fn update_expiry_point(&mut self, time_to_live: Duration) {
self.expiry_point = SystemTime::now() + time_to_live;
}
pub fn increment_count(&mut self) -> usize {
self.count += 1;
self.count
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::{self, Rng};
use std::thread;
use std::time::Duration;
#[test]
fn timeout() {
let time_to_live = Duration::from_millis(rand::thread_rng().gen_range(50, 150));
let mut msg_filter = MessageFilter::<usize>::with_expiry_duration(time_to_live);
assert_eq!(time_to_live, msg_filter.time_to_live);
for i in 0..10 {
assert_eq!(1, msg_filter.insert(&i));
}
for i in 0..10 {
assert!(msg_filter.contains(&i));
}
let sleep_duration = time_to_live + Duration::from_millis(10);
thread::sleep(sleep_duration);
assert_eq!(1, msg_filter.insert(&11));
assert!(msg_filter.contains(&11));
for i in 0..10 {
assert_eq!(1, msg_filter.insert(&i));
assert!(msg_filter.contains(&i));
}
}
#[test]
fn struct_value() {
#[derive(PartialEq, PartialOrd, Ord, Clone, Eq, Hash)]
struct Temp {
id: Vec<u8>,
}
impl Default for Temp {
fn default() -> Temp {
let mut rng = rand::thread_rng();
Temp { id: rand::sample(&mut rng, 0u8..255, 64) }
}
}
let time_to_live = Duration::from_millis(rand::thread_rng().gen_range(50, 150));
let mut msg_filter = MessageFilter::<Temp>::with_expiry_duration(time_to_live);
let values: Vec<Temp> = (0..10).map(|_| Temp::default()).collect();
for temp in &values {
assert_eq!(1, msg_filter.insert(temp));
assert!(msg_filter.contains(temp));
}
let sleep_duration = time_to_live + Duration::from_millis(10);
thread::sleep(sleep_duration);
let temp: Temp = Default::default();
assert_eq!(1, msg_filter.insert(&temp));
assert!(msg_filter.contains(&temp));
for temp in &values {
assert!(!msg_filter.contains(temp));
}
}
#[test]
fn add_duplicate() {
let size = 10;
let time_to_live = Duration::from_secs(99);
let mut msg_filter = MessageFilter::<usize>::with_expiry_duration(time_to_live);
for i in 0..size {
assert_eq!(1, msg_filter.insert(&i));
}
assert!((0..size).all(|index| msg_filter.contains(&index)));
assert_eq!(1, msg_filter.count(&0));
assert_eq!(2, msg_filter.insert(&0));
assert_eq!(2, msg_filter.count(&0));
}
#[test]
fn insert_resets_timeout() {
let time_to_live = Duration::from_millis(3000);
let sleep_duration = Duration::from_millis(1800); let mut msg_filter = MessageFilter::<usize>::with_expiry_duration(time_to_live);
assert_eq!(1, msg_filter.insert(&0));
thread::sleep(sleep_duration);
assert_eq!(2, msg_filter.insert(&0));
thread::sleep(sleep_duration);
assert!(msg_filter.contains(&0));
thread::sleep(sleep_duration);
assert!(!msg_filter.contains(&0));
}
}