pub type ClaimantFilter = ::sodiumoxide::crypto::sign::Signature;
pub type RoutingMessageFilter = ::sodiumoxide::crypto::hash::sha256::Digest;
#[allow(unused)]
pub struct Filter {
claimant_filter: ::message_filter::MessageFilter<ClaimantFilter>,
message_filter: ::message_filter::MessageFilter<RoutingMessageFilter>,
threshold: SimpleThresholdCalculator,
}
impl Filter {
pub fn with_expiry_duration(duration: ::time::Duration) -> Filter {
Filter {
claimant_filter: ::message_filter::MessageFilter::with_expiry_duration(duration),
message_filter: ::message_filter::MessageFilter::with_expiry_duration(duration),
threshold: SimpleThresholdCalculator::new(10000u16, ::types::QUORUM_SIZE / 2 + 1 ),
}
}
pub fn check(&mut self, signed_message: &::messages::SignedMessage) -> bool {
if self.claimant_filter.check(signed_message.signature()) { return false; };
self.claimant_filter.add(signed_message.signature().clone());
let digest = match ::utils::encode(signed_message.get_routing_message()) {
Ok(bytes) => ::sodiumoxide::crypto::hash::sha256::hash(&bytes[..]),
Err(_) => return false,
};
let blocked = self.message_filter.check(&digest);
!blocked
}
pub fn block(&mut self, routing_message: &::messages::RoutingMessage) {
let digest = match ::utils::encode(routing_message) {
Ok(bytes) => ::sodiumoxide::crypto::hash::sha256::hash(&bytes[..]),
Err(_) => return,
};
self.message_filter.add(digest);
}
}
#[allow(unused)]
pub struct SimpleThresholdCalculator {
total_messages: u32,
total_blockedmessages: u32,
total_uniquemessages: u32,
blocked_percentage: RunningAverage,
multiplicity: RunningAverage,
cap: u32,
current_threshold: usize,
}
#[allow(unused)]
impl SimpleThresholdCalculator {
pub fn new(cap: u16, start_threshold: usize) -> SimpleThresholdCalculator {
SimpleThresholdCalculator {
total_messages: 0u32,
total_blockedmessages: 0u32,
total_uniquemessages: 0u32,
blocked_percentage: RunningAverage::new(10000u32),
multiplicity: RunningAverage::new(10000u32),
cap: cap as u32,
current_threshold: start_threshold,
}
}
pub fn hit_message(&mut self, blocked: bool) {
if blocked { self.total_blockedmessages += 1u32; };
self.total_messages += 1u32;
if self.total_messages >= self.cap {
self.calculate_average();
}
}
pub fn hit_uniquemessage(&mut self) {
self.total_uniquemessages += 1u32;
let message_multiplicity = self.total_messages as f64 / self.total_uniquemessages as f64;
}
fn calculate_average(&mut self) {
let average_blocked: f64 = self.total_blockedmessages as f64
/ self.total_messages as f64;
let running_average = self.blocked_percentage.add_value(average_blocked);
if self.total_uniquemessages > 0u32 {
let message_multiplicity = self.total_messages as f64
/ self.total_uniquemessages as f64;
let running_multiplicity = self.multiplicity.add_value(message_multiplicity);
};
self.total_messages = 0u32;
self.total_blockedmessages = 0u32;
self.total_uniquemessages = 0u32;
}
}
#[allow(unused)]
pub struct RunningAverage {
average: f64,
block_average: f64,
counter: u32,
block_counter: u32,
block_size: u32,
}
#[allow(unused)]
impl RunningAverage {
pub fn new(block_size: u32) -> RunningAverage {
RunningAverage {
average: 0f64,
block_average: 0f64,
counter: 0u32,
block_counter: 0u32,
block_size: block_size,
}
}
pub fn add_value(&mut self, value: f64) -> f64 {
if self.counter == self.block_size {
let next_block: f64 = self.block_counter as f64 + 1f64;
let block_weight: f64 = (self.block_counter as f64) / next_block;
let new_block_average: f64 = self.average / next_block
+ block_weight * self.block_average;
self.block_average = new_block_average.clone();
self.block_counter += 1u32;
self.counter = 0u32;
self.average = 0f64;
}
let next: f64 = self.counter as f64 + 1f64;
let weight: f64 = (self.counter as f64) / next;
let new_average: f64 = value / next + weight * self.average;
self.average = new_average.clone();
self.counter += 1u32;
if self.block_counter > 0 {
self.block_average.clone()
} else {
new_average.clone()
}
}
pub fn get_average(&self) -> f64 {
if self.block_counter > 0 {
self.block_average.clone()
} else {
self.average.clone()
}
}
}
#[cfg(test)]
mod test {
#[test]
fn filter_check_before_duration_end() {
let duration = ::time::Duration::seconds(3);
let mut filter = super::Filter::with_expiry_duration(duration);
let claimant = ::types::Address::Node(::test_utils::Random::generate_random());
let keys = ::sodiumoxide::crypto::sign::gen_keypair();
let routing_message =
::test_utils::messages_util::arbitrary_routing_message(&keys.0, &keys.1);
let signed_message =
::messages::SignedMessage::new(claimant.clone(), routing_message.clone(), &keys.1);
let signed_message = signed_message.unwrap();
assert!(filter.check(&signed_message));
assert!(!filter.check(&signed_message));
}
#[test]
fn filter_check_after_duration_end() {
let duration = ::time::Duration::milliseconds(1);
let mut filter = super::Filter::with_expiry_duration(duration);
let claimant = ::types::Address::Node(::test_utils::Random::generate_random());
let keys = ::sodiumoxide::crypto::sign::gen_keypair();
let routing_message =
::test_utils::messages_util::arbitrary_routing_message(&keys.0, &keys.1);
let signed_message =
::messages::SignedMessage::new(claimant.clone(), routing_message.clone(), &keys.1);
let signed_message = signed_message.unwrap();
assert!(filter.check(&signed_message));
::std::thread::sleep_ms(2);
assert!(filter.check(&signed_message));
}
#[test]
fn filter_block() {
let duration = ::time::Duration::seconds(3);
let mut filter = super::Filter::with_expiry_duration(duration);
let claimant = ::types::Address::Node(::test_utils::Random::generate_random());
let keys = ::sodiumoxide::crypto::sign::gen_keypair();
let routing_message =
::test_utils::messages_util::arbitrary_routing_message(&keys.0, &keys.1);
let signed_message =
::messages::SignedMessage::new(claimant.clone(), routing_message.clone(), &keys.1);
let signed_message = signed_message.unwrap();
filter.block(signed_message.get_routing_message());
assert!(!filter.check(&signed_message));
}
#[test]
fn running_average_exact() {
use ::rand::Rng;
let mut rng = ::rand::thread_rng();
let mut running_average = super::RunningAverage::new(1000u32);
let average = |set: &Vec<f64>| {
let sum = set.iter().fold(0f64, |acc, &item| acc + &item);
sum / (set.len() as f64) };
let mut set: Vec<f64> = Vec::new();
for _ in 0..5000u32 {
let new_value = rng.gen::<u8>() as f64;
set.push(new_value.clone());
let result = running_average.add_value(new_value);
let average = average(&set);
if average > 0.0000001f64 {
let error = (1f64 - (result / average)).abs();
assert!(error < 0.05f64);
}
}
}
#[test]
fn running_average_long() {
use ::rand::Rng;
let mut rng = ::rand::thread_rng();
let mut running_average = super::RunningAverage::new(1000u32);
for _ in 0..100000u32 {
let new_value = rng.gen::<u8>() as f64;
let _ = running_average.add_value(new_value);
}
let new_value = rng.gen::<u8>() as f64;
let result = running_average.add_value(new_value);
let error = (1f64 - (result / 127.5f64)).abs();
assert!(error < 0.01f64);
}
}