use crate::{
error::{Result, SframeError},
header::{Counter, SframeHeader},
};
use std::{cell::Cell, ops::Deref};
pub trait FrameValidation {
fn validate(&self, header: &SframeHeader) -> Result<()>;
}
pub type FrameValidationBox = Box<dyn FrameValidation>;
impl FrameValidation for FrameValidationBox {
fn validate(&self, header: &SframeHeader) -> Result<()> {
self.deref().validate(header)
}
}
pub struct ReplayAttackProtection {
tolerance: u64,
last_counter: Cell<Counter>,
}
impl ReplayAttackProtection {
pub fn with_tolerance(tolerance: u64) -> Self {
ReplayAttackProtection {
tolerance,
last_counter: Cell::new(0u64),
}
}
}
impl FrameValidation for ReplayAttackProtection {
fn validate(&self, header: &SframeHeader) -> Result<()> {
let last_counter = self.last_counter.get();
let current_counter = header.counter();
if current_counter > last_counter {
self.last_counter.set(current_counter);
Ok(())
} else {
let age = last_counter - current_counter;
if age <= self.tolerance {
self.last_counter.set(current_counter);
Ok(())
} else {
let dist_to_overflow = u64::MAX - last_counter;
let overflow_age = current_counter + dist_to_overflow;
if overflow_age <= self.tolerance {
self.last_counter.set(current_counter);
Ok(())
} else {
Err(SframeError::FrameValidationFailed(
"Replay check failed, frame counter too old".to_string(),
))
}
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn accept_newer_headers() {
let validator = ReplayAttackProtection::with_tolerance(128);
let first_header = SframeHeader::new(23456789u64, 2400);
let second_header = SframeHeader::new(23456789u64, 2480);
assert_eq!(validator.validate(&first_header), Ok(()));
assert_eq!(validator.validate(&second_header), Ok(()));
}
#[test]
fn accept_older_headers_in_tolerance() {
let validator = ReplayAttackProtection::with_tolerance(128);
let first_header = SframeHeader::new(23456789u64, 2480);
let late_header = SframeHeader::new(23456789u64, 2400);
assert_eq!(validator.validate(&first_header), Ok(()));
assert_eq!(validator.validate(&late_header), Ok(()));
}
#[test]
fn reject_too_old_headers() {
let validator = ReplayAttackProtection::with_tolerance(128);
let first_header = SframeHeader::new(23456789u64, 2480);
let too_late_header = SframeHeader::new(23456789u64, 1024);
assert_eq!(validator.validate(&first_header), Ok(()));
assert!(matches!(
validator.validate(&too_late_header),
Err(SframeError::FrameValidationFailed(_))
));
}
#[test]
fn handle_overflowing_counters() {
let validator = ReplayAttackProtection::with_tolerance(128);
let start_count = u64::MAX - 3;
let first_header = SframeHeader::new(23456789u64, start_count);
assert_eq!(validator.validate(&first_header), Ok(()));
for step in 0..10 {
let late_count = start_count.wrapping_add(step); let too_late_header = SframeHeader::new(23456789u64, late_count);
assert_eq!(validator.validate(&too_late_header), Ok(()));
}
}
}