use std::collections::HashMap;
pub const DEFAULT_WINDOW_BITS: u32 = 64;
pub const MAX_WINDOW_BITS: u32 = 1024;
pub const WINDOW_BITS: u64 = DEFAULT_WINDOW_BITS as u64;
#[derive(Debug, Clone)]
struct StreamWindow {
highest: u64,
bitmap: Vec<u64>,
initialized: bool,
}
impl StreamWindow {
fn new(bitmap_words: usize) -> Self {
Self {
highest: 0,
bitmap: vec![0u64; bitmap_words],
initialized: false,
}
}
}
#[derive(Debug, Clone)]
pub struct ReplayWindow {
streams: HashMap<(u64, u8, u8), StreamWindow>,
window_bits: u32,
bitmap_words: usize,
}
impl Default for ReplayWindow {
fn default() -> Self {
Self::new()
}
}
impl ReplayWindow {
pub fn new() -> Self {
Self::with_window_bits(DEFAULT_WINDOW_BITS)
}
pub fn with_window_bits(bits: u32) -> Self {
assert!(
(DEFAULT_WINDOW_BITS..=MAX_WINDOW_BITS).contains(&bits)
&& bits % DEFAULT_WINDOW_BITS == 0,
"replay window bits must be a multiple of 64 between 64 and 1024; got {bits}",
);
Self {
streams: HashMap::new(),
window_bits: bits,
bitmap_words: (bits / DEFAULT_WINDOW_BITS) as usize,
}
}
pub fn window_bits(&self) -> u32 {
self.window_bits
}
pub fn clear(&mut self) {
self.streams.clear();
}
pub fn accept(&mut self, source_id: u64, payload_type: u8, key_epoch: u8, seq: u64) -> bool {
let window_bits_u64 = self.window_bits as u64;
let bitmap_words = self.bitmap_words;
let win = self
.streams
.entry((source_id, payload_type, key_epoch))
.or_insert_with(|| StreamWindow::new(bitmap_words));
if !win.initialized {
win.highest = seq;
win.bitmap.fill(0);
win.bitmap[0] = 1; win.initialized = true;
return true;
}
if seq > win.highest {
let shift = seq - win.highest;
if shift >= window_bits_u64 {
win.bitmap.fill(0);
win.bitmap[0] = 1;
} else {
shift_bitmap_left(&mut win.bitmap, shift as u32);
win.bitmap[0] |= 1;
}
win.highest = seq;
true
} else {
let offset = win.highest - seq;
if offset >= window_bits_u64 {
false
} else {
let word_idx = (offset / DEFAULT_WINDOW_BITS as u64) as usize;
let bit_idx = (offset % DEFAULT_WINDOW_BITS as u64) as u32;
let mask = 1u64 << bit_idx;
if win.bitmap[word_idx] & mask != 0 {
false
} else {
win.bitmap[word_idx] |= mask;
true
}
}
}
}
pub fn drop_epoch(&mut self, key_epoch: u8) {
self.streams.retain(|(_, _, epoch), _| *epoch != key_epoch);
}
pub fn stream_count(&self) -> usize {
self.streams.len()
}
}
#[inline]
fn shift_bitmap_left(bitmap: &mut [u64], shift: u32) {
debug_assert!(
(shift as usize) < bitmap.len() * 64,
"shift {} out of range for {}-word bitmap",
shift,
bitmap.len()
);
if bitmap.is_empty() || shift == 0 {
return;
}
let word_shift = (shift / 64) as usize;
let bit_shift = shift % 64;
let len = bitmap.len();
if bit_shift == 0 {
for i in (0..len).rev() {
bitmap[i] = if i >= word_shift {
bitmap[i - word_shift]
} else {
0
};
}
return;
}
let inv_bit_shift = 64 - bit_shift;
for i in (0..len).rev() {
let hi_src = if i >= word_shift {
bitmap[i - word_shift] << bit_shift
} else {
0
};
let lo_src = if i > word_shift {
bitmap[i - word_shift - 1] >> inv_bit_shift
} else {
0
};
bitmap[i] = hi_src | lo_src;
}
}
#[cfg(test)]
mod tests {
use super::*;
const SRC: u64 = 0xDEAD_BEEF_CAFE_BABE;
const EPOCH: u8 = 0;
fn accept_default(w: &mut ReplayWindow, pld: u8, seq: u64) -> bool {
w.accept(SRC, pld, EPOCH, seq)
}
#[test]
fn first_sequence_accepted() {
let mut w = ReplayWindow::new();
assert!(accept_default(&mut w, 0x10, 0));
}
#[test]
fn sequential_sequences_accepted() {
let mut w = ReplayWindow::new();
for seq in 0..100 {
assert!(accept_default(&mut w, 0x10, seq), "seq {seq} should accept");
}
}
#[test]
fn duplicate_at_highest_rejected() {
let mut w = ReplayWindow::new();
assert!(accept_default(&mut w, 0x10, 5));
assert!(
!accept_default(&mut w, 0x10, 5),
"duplicate at highest should reject"
);
}
#[test]
fn duplicate_within_window_rejected() {
let mut w = ReplayWindow::new();
for seq in 0..=5 {
assert!(accept_default(&mut w, 0x10, seq));
}
assert!(!accept_default(&mut w, 0x10, 2));
assert!(accept_default(&mut w, 0x10, 6));
}
#[test]
fn out_of_order_within_window_accepted() {
let mut w = ReplayWindow::new();
assert!(accept_default(&mut w, 0x10, 10));
assert!(accept_default(&mut w, 0x10, 7));
assert!(!accept_default(&mut w, 0x10, 7));
assert!(accept_default(&mut w, 0x10, 8));
}
#[test]
fn too_old_sequence_rejected() {
let mut w = ReplayWindow::new();
assert!(accept_default(&mut w, 0x10, 100));
assert!(!accept_default(&mut w, 0x10, 35));
assert!(!accept_default(&mut w, 0x10, 36));
assert!(accept_default(&mut w, 0x10, 37));
}
#[test]
fn future_arrival_shifts_window_correctly() {
let mut w = ReplayWindow::new();
for seq in 0..=5 {
assert!(accept_default(&mut w, 0x10, seq));
}
assert!(accept_default(&mut w, 0x10, 1000));
for seq in 0..=5 {
assert!(
!accept_default(&mut w, 0x10, seq),
"old seq {seq} after jump should reject"
);
}
assert!(accept_default(&mut w, 0x10, 999));
assert!(accept_default(&mut w, 0x10, 950));
assert!(!accept_default(&mut w, 0x10, 936));
}
#[test]
fn independent_streams_dont_interfere() {
let mut w = ReplayWindow::new();
assert!(accept_default(&mut w, 0x10, 5));
assert!(accept_default(&mut w, 0x11, 5));
assert!(!accept_default(&mut w, 0x10, 5));
assert!(!accept_default(&mut w, 0x11, 5));
assert_eq!(w.stream_count(), 2);
}
#[test]
fn different_source_ids_dont_interfere() {
let mut w = ReplayWindow::new();
assert!(w.accept(0xAAAA_AAAA_AAAA_AAAA, 0x10, EPOCH, 100));
assert!(w.accept(0xBBBB_BBBB_BBBB_BBBB, 0x10, EPOCH, 100));
assert!(!w.accept(0xAAAA_AAAA_AAAA_AAAA, 0x10, EPOCH, 100));
assert_eq!(w.stream_count(), 2);
}
#[test]
fn window_edge_exactly_window_bits_below_rejected() {
let mut w = ReplayWindow::new();
assert!(accept_default(&mut w, 0x10, 100));
assert!(!accept_default(&mut w, 0x10, 36));
assert!(accept_default(&mut w, 0x10, 37));
}
#[test]
fn clear_resets_all_streams() {
let mut w = ReplayWindow::new();
assert!(accept_default(&mut w, 0x10, 5));
assert!(accept_default(&mut w, 0x11, 7));
assert_eq!(w.stream_count(), 2);
w.clear();
assert_eq!(w.stream_count(), 0);
assert!(accept_default(&mut w, 0x10, 5));
assert!(accept_default(&mut w, 0x11, 7));
}
#[test]
fn independent_epochs_dont_interfere_even_with_same_stream() {
let mut w = ReplayWindow::new();
for seq in 0..=1000 {
assert!(w.accept(SRC, 0x10, 0, seq));
}
assert!(
w.accept(SRC, 0x10, 1, 0),
"new-epoch seq=0 must be accepted despite old-epoch highest=1000"
);
assert!(w.accept(SRC, 0x10, 1, 1));
assert!(w.accept(SRC, 0x10, 1, 2));
}
#[test]
fn drop_epoch_removes_only_that_epoch() {
let mut w = ReplayWindow::new();
assert!(w.accept(SRC, 0x10, 0, 5));
assert!(w.accept(SRC, 0x10, 1, 5));
assert!(w.accept(SRC, 0x11, 0, 5));
assert_eq!(w.stream_count(), 3);
w.drop_epoch(0);
assert_eq!(w.stream_count(), 1);
assert!(w.accept(SRC, 0x10, 0, 5));
assert!(!w.accept(SRC, 0x10, 1, 5));
}
#[test]
fn drop_epoch_with_no_entries_is_noop() {
let mut w = ReplayWindow::new();
w.drop_epoch(42); assert_eq!(w.stream_count(), 0);
}
}