use core::fmt;
use hybrid_array::{Array, ArraySize};
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReplayInfo {
EqualOrLess,
Greater { diff: u64, window: u64 },
}
impl fmt::Display for ReplayInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::EqualOrLess => write!(f, "received SN equal to or behind current"),
Self::Greater { diff, window } => {
write!(f, "received SN {diff} beyond window of {window}")
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SequenceNumber<N: ArraySize> {
current: Array<u8, N>,
window: u16,
}
impl<N: ArraySize> SequenceNumber<N> {
#[inline]
pub fn new(window: u16) -> Self { Self { current: Array::default(), window } }
#[inline]
pub fn disabled() -> Self { Self { current: Array::default(), window: 0 } }
#[inline]
#[must_use]
pub fn get_current(&self) -> Array<u8, N> { self.current.clone() }
#[inline]
#[must_use]
pub fn next(&self) -> Array<u8, N> {
let mut next = self.current.clone();
for byte in next.iter_mut().rev() {
let (new, carry) = byte.overflowing_add(1);
*byte = new;
if !carry {
break;
}
}
next
}
#[inline]
pub fn set_current(&mut self, val: Array<u8, N>) { self.current = val; }
#[inline]
#[must_use]
pub fn window(&self) -> u16 { self.window }
#[inline]
pub fn set_window(&mut self, window: u16) { self.window = window; }
#[inline]
pub fn advance(&mut self) {
for byte in self.current.iter_mut().rev() {
let (new, carry) = byte.overflowing_add(1);
*byte = new;
if !carry {
break;
}
}
}
#[inline]
pub fn is_replay(&self, received_sn: &Array<u8, N>) -> Result<(), ReplayInfo> {
let dist = Self::forward_dist(received_sn, &self.current);
if dist == 0 {
return Err(ReplayInfo::EqualOrLess);
}
if dist <= self.window as u64 {
return Ok(());
}
Err(ReplayInfo::Greater { diff: dist - self.window as u64, window: self.window as u64 })
}
fn forward_dist(received: &Array<u8, N>, current: &Array<u8, N>) -> u64 {
let mut dist = 0u64;
let mut borrow = 0u8;
let start = N::USIZE.saturating_sub(8);
for i in (start..N::USIZE).rev() {
let (r1, b1) = received[i].overflowing_sub(current[i]);
let (r2, b2) = r1.overflowing_sub(borrow);
borrow = (b1 as u8) | (b2 as u8);
dist |= (r2 as u64) << (8 * (N::USIZE - 1 - i));
}
if N::USIZE > 8 {
for i in 0..start {
if received[i] != current[i] {
return u64::MAX;
}
}
}
dist
}
}
#[cfg(test)]
mod tests {
use super::*;
use hybrid_array::Array;
use typenum::{U4, U12};
#[test]
fn advance_basic() {
let mut sn = SequenceNumber::<U4>::new(16);
sn.advance();
assert_eq!(sn.get_current(), Array::from([0x00, 0x00, 0x00, 0x01]));
}
#[test]
fn advance_carry() {
let mut sn = SequenceNumber::<U4>::new(16);
sn.set_current(Array::from([0x00, 0x00, 0x00, 0xFF]));
sn.advance();
assert_eq!(sn.get_current(), Array::from([0x00, 0x00, 0x01, 0x00]));
}
#[test]
fn advance_overflow() {
let mut sn = SequenceNumber::<U4>::new(16);
sn.set_current(Array::from([0xFF, 0xFF, 0xFF, 0xFF]));
sn.advance();
assert_eq!(sn.get_current(), Array::from([0x00, 0x00, 0x00, 0x00]));
}
#[test]
fn next_does_not_mutate() {
let mut sn = SequenceNumber::<U4>::new(16);
sn.advance(); let before = sn.get_current();
let _ = sn.next();
assert_eq!(sn.get_current(), before);
}
#[test]
fn is_replay_equal_rejected() {
let mut sn = SequenceNumber::<U4>::new(4);
sn.advance(); let current = sn.get_current();
assert_eq!(sn.is_replay(¤t), Err(ReplayInfo::EqualOrLess));
}
#[test]
fn is_replay_within_window() {
let mut sn = SequenceNumber::<U4>::new(4);
sn.advance(); let received = Array::from([0x00, 0x00, 0x00, 0x04]); assert!(sn.is_replay(&received).is_ok());
}
#[test]
fn is_replay_beyond_window() {
let mut sn = SequenceNumber::<U4>::new(4);
sn.advance(); let received = Array::from([0x00, 0x00, 0x00, 0x0B]); assert!(matches!(sn.is_replay(&received), Err(ReplayInfo::Greater { .. })));
}
#[test]
fn is_replay_behind_current() {
let mut sn = SequenceNumber::<U4>::new(4);
sn.set_current(Array::from([0x00, 0x00, 0x00, 0x05]));
let received = Array::from([0x00, 0x00, 0x00, 0x02]); assert!(sn.is_replay(&received).is_err());
}
#[test]
fn disabled_counter_rejects_equal() {
let sn = SequenceNumber::<U4>::disabled();
let received = Array::from([0x00, 0x00, 0x00, 0x00]); assert_eq!(sn.is_replay(&received), Err(ReplayInfo::EqualOrLess));
}
#[test]
fn forward_dist_large_array_high_bytes_differ() {
let sn = SequenceNumber::<U12>::new(4);
let mut received = Array::<u8, U12>::default();
received[0] = 0x01; assert!(sn.is_replay(&received).is_err());
}
}