use std::collections::HashMap;
use std::sync::RwLock;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::error::{ClusterError, Result};
pub const REPLAY_WINDOW: u64 = 64;
#[derive(Default, Debug)]
pub struct PeerSeqSender {
counter: AtomicU64,
}
impl PeerSeqSender {
pub fn new() -> Self {
Self::default()
}
pub fn next(&self) -> u64 {
self.counter.fetch_add(1, Ordering::Relaxed) + 1
}
#[cfg(test)]
pub fn peek(&self) -> u64 {
self.counter.load(Ordering::Relaxed)
}
}
#[derive(Default, Debug)]
pub struct PeerSeqWindow {
windows: RwLock<HashMap<u64, WindowState>>,
}
#[derive(Default, Debug, Clone, Copy)]
struct WindowState {
high: u64,
mask: u64,
}
impl PeerSeqWindow {
pub fn new() -> Self {
Self::default()
}
pub fn accept(&self, peer_id: u64, seq: u64) -> Result<()> {
if seq == 0 {
return Err(ClusterError::Codec {
detail: format!("peer {peer_id} sent reserved sequence 0"),
});
}
let mut guard = self.windows.write().unwrap_or_else(|p| p.into_inner());
let state = guard.entry(peer_id).or_default();
if seq > state.high {
let delta = seq - state.high;
state.mask = if delta >= REPLAY_WINDOW {
1
} else {
(state.mask << delta) | 1
};
state.high = seq;
return Ok(());
}
let offset = state.high - seq;
if offset >= REPLAY_WINDOW {
return Err(ClusterError::Codec {
detail: format!(
"peer {peer_id} sent stale sequence {seq}, window high is {}",
state.high
),
});
}
let bit = 1u64 << offset;
if state.mask & bit != 0 {
return Err(ClusterError::Codec {
detail: format!(
"peer {peer_id} replayed sequence {seq} (window high {})",
state.high
),
});
}
state.mask |= bit;
Ok(())
}
#[cfg(test)]
pub fn highest(&self, peer_id: u64) -> u64 {
let guard = self.windows.read().unwrap_or_else(|p| p.into_inner());
guard.get(&peer_id).map(|w| w.high).unwrap_or(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn outbound_counter_starts_at_one() {
let s = PeerSeqSender::new();
assert_eq!(s.next(), 1);
assert_eq!(s.next(), 2);
assert_eq!(s.next(), 3);
}
#[test]
fn outbound_counter_is_single_across_all_targets() {
let s = PeerSeqSender::new();
assert_eq!(s.next(), 1);
assert_eq!(s.next(), 2);
assert_eq!(s.next(), 3);
assert_eq!(s.next(), 4);
}
#[test]
fn window_accepts_monotonic_sequence() {
let w = PeerSeqWindow::new();
for seq in 1..=10 {
w.accept(7, seq).unwrap();
}
assert_eq!(w.highest(7), 10);
}
#[test]
fn window_rejects_immediate_replay() {
let w = PeerSeqWindow::new();
w.accept(1, 1).unwrap();
let err = w.accept(1, 1).unwrap_err();
assert!(err.to_string().contains("replayed"));
}
#[test]
fn window_rejects_zero_sequence() {
let w = PeerSeqWindow::new();
let err = w.accept(1, 0).unwrap_err();
assert!(err.to_string().contains("reserved sequence 0"));
}
#[test]
fn window_accepts_in_order_gap_within_window() {
let w = PeerSeqWindow::new();
w.accept(1, 5).unwrap();
w.accept(1, 3).unwrap();
w.accept(1, 1).unwrap();
w.accept(1, 2).unwrap();
w.accept(1, 4).unwrap();
assert_eq!(w.highest(1), 5);
}
#[test]
fn window_rejects_replay_within_window() {
let w = PeerSeqWindow::new();
w.accept(1, 5).unwrap();
w.accept(1, 3).unwrap();
let err = w.accept(1, 3).unwrap_err();
assert!(err.to_string().contains("replayed"));
}
#[test]
fn window_rejects_stale_outside_window() {
let w = PeerSeqWindow::new();
w.accept(1, 100).unwrap();
let err = w.accept(1, 36).unwrap_err();
assert!(err.to_string().contains("stale sequence 36"));
w.accept(1, 37).unwrap();
}
#[test]
fn window_advances_beyond_window_clears_mask() {
let w = PeerSeqWindow::new();
w.accept(1, 1).unwrap();
w.accept(1, 2).unwrap();
w.accept(1, 100).unwrap();
let err = w.accept(1, 1).unwrap_err();
assert!(err.to_string().contains("stale sequence 1"));
}
#[test]
fn windows_are_independent_per_peer() {
let w = PeerSeqWindow::new();
w.accept(1, 10).unwrap();
w.accept(2, 10).unwrap();
w.accept(1, 9).unwrap();
w.accept(2, 9).unwrap();
assert_eq!(w.highest(1), 10);
assert_eq!(w.highest(2), 10);
}
}