use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::Duration;
use dashmap::DashMap;
use freenet_stdlib::prelude::ContractInstanceId;
use tokio::time::Instant;
use crate::util::time_source::TimeSource;
pub(crate) const MIN_UPDATE_INTERVAL: Duration = Duration::from_millis(100);
pub(crate) const CLEANUP_AGE: Duration = Duration::from_secs(5 * 60);
pub(crate) const MAX_TRACKED_PAIRS: usize = 16_384;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum RateLimitDecision {
Allowed,
Rejected {
elapsed: Duration,
min_interval: Duration,
},
CapacityExceeded,
}
impl RateLimitDecision {
pub fn is_allowed(self) -> bool {
matches!(self, RateLimitDecision::Allowed)
}
}
pub(crate) struct UpdateRateLimiter {
last_accepted: DashMap<(SocketAddr, ContractInstanceId), Instant>,
size: AtomicUsize,
min_interval: Duration,
max_tracked_pairs: usize,
time_source: Arc<dyn TimeSource + Send + Sync>,
accepted_total: AtomicU64,
rejected_total: AtomicU64,
capacity_rejected_total: AtomicU64,
}
impl UpdateRateLimiter {
pub fn new(time_source: Arc<dyn TimeSource + Send + Sync>) -> Self {
Self::with_config(time_source, MIN_UPDATE_INTERVAL, MAX_TRACKED_PAIRS)
}
pub fn with_config(
time_source: Arc<dyn TimeSource + Send + Sync>,
min_interval: Duration,
max_tracked_pairs: usize,
) -> Self {
Self {
last_accepted: DashMap::new(),
size: AtomicUsize::new(0),
min_interval,
max_tracked_pairs,
time_source,
accepted_total: AtomicU64::new(0),
rejected_total: AtomicU64::new(0),
capacity_rejected_total: AtomicU64::new(0),
}
}
pub fn check_and_record(
&self,
sender: SocketAddr,
contract: ContractInstanceId,
) -> RateLimitDecision {
let now = self.time_source.now();
let key = (sender, contract);
use dashmap::mapref::entry::Entry;
match self.last_accepted.entry(key) {
Entry::Occupied(mut entry) => {
let last = *entry.get();
let elapsed = now.saturating_duration_since(last);
if elapsed < self.min_interval {
self.rejected_total.fetch_add(1, Ordering::Relaxed);
return RateLimitDecision::Rejected {
elapsed,
min_interval: self.min_interval,
};
}
*entry.get_mut() = now;
self.accepted_total.fetch_add(1, Ordering::Relaxed);
RateLimitDecision::Allowed
}
Entry::Vacant(entry) => {
let prev = self.size.fetch_add(1, Ordering::Relaxed);
if prev >= self.max_tracked_pairs {
self.size.fetch_sub(1, Ordering::Relaxed);
self.capacity_rejected_total.fetch_add(1, Ordering::Relaxed);
return RateLimitDecision::CapacityExceeded;
}
entry.insert(now);
self.accepted_total.fetch_add(1, Ordering::Relaxed);
RateLimitDecision::Allowed
}
}
}
pub fn cleanup(&self) {
let now = self.time_source.now();
let cutoff = match now.checked_sub(CLEANUP_AGE) {
Some(t) => t,
None => return, };
let mut removed = 0usize;
self.last_accepted.retain(|_, last| {
let keep = *last >= cutoff;
if !keep {
removed += 1;
}
keep
});
if removed > 0 {
self.size.fetch_sub(removed, Ordering::Relaxed);
}
}
pub fn accepted_total(&self) -> u64 {
self.accepted_total.load(Ordering::Relaxed)
}
pub fn rejected_total(&self) -> u64 {
self.rejected_total.load(Ordering::Relaxed)
}
pub fn capacity_rejected_total(&self) -> u64 {
self.capacity_rejected_total.load(Ordering::Relaxed)
}
#[cfg_attr(not(test), allow(dead_code))]
pub fn len(&self) -> usize {
self.last_accepted.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::util::time_source::SharedMockTimeSource;
fn mk_sender(byte: u8) -> SocketAddr {
SocketAddr::from(([10, 0, 0, byte], 30000 + byte as u16))
}
fn mk_contract(byte: u8) -> ContractInstanceId {
ContractInstanceId::new([byte; 32])
}
fn mk_limiter() -> (UpdateRateLimiter, SharedMockTimeSource) {
let ts = SharedMockTimeSource::new();
let limiter = UpdateRateLimiter::new(Arc::new(ts.clone()));
(limiter, ts)
}
trait Advance {
fn advance(&self, d: Duration);
}
impl Advance for SharedMockTimeSource {
fn advance(&self, d: Duration) {
self.advance_time(d);
}
}
#[test]
fn first_update_for_pair_is_allowed() {
let (l, _ts) = mk_limiter();
let d = l.check_and_record(mk_sender(1), mk_contract(1));
assert_eq!(d, RateLimitDecision::Allowed);
assert_eq!(l.accepted_total(), 1);
assert_eq!(l.rejected_total(), 0);
}
#[test]
fn second_update_within_min_interval_is_rejected() {
let (l, ts) = mk_limiter();
assert!(
l.check_and_record(mk_sender(1), mk_contract(1))
.is_allowed()
);
ts.advance(Duration::from_millis(10));
let d = l.check_and_record(mk_sender(1), mk_contract(1));
assert!(
matches!(d, RateLimitDecision::Rejected { .. }),
"second UPDATE 10ms after first must be rejected, got {d:?}"
);
assert_eq!(l.accepted_total(), 1);
assert_eq!(l.rejected_total(), 1);
}
#[test]
fn update_after_min_interval_is_allowed() {
let (l, ts) = mk_limiter();
assert!(
l.check_and_record(mk_sender(1), mk_contract(1))
.is_allowed()
);
ts.advance(Duration::from_millis(200));
let d = l.check_and_record(mk_sender(1), mk_contract(1));
assert_eq!(d, RateLimitDecision::Allowed);
assert_eq!(l.accepted_total(), 2);
assert_eq!(l.rejected_total(), 0);
}
#[test]
fn different_senders_same_contract_independent() {
let (l, ts) = mk_limiter();
assert!(
l.check_and_record(mk_sender(1), mk_contract(1))
.is_allowed()
);
ts.advance(Duration::from_millis(1));
assert!(
l.check_and_record(mk_sender(2), mk_contract(1))
.is_allowed()
);
let d = l.check_and_record(mk_sender(1), mk_contract(1));
assert!(matches!(d, RateLimitDecision::Rejected { .. }));
}
#[test]
fn same_sender_different_contracts_independent() {
let (l, _ts) = mk_limiter();
assert!(
l.check_and_record(mk_sender(1), mk_contract(1))
.is_allowed()
);
assert!(
l.check_and_record(mk_sender(1), mk_contract(2))
.is_allowed()
);
assert_eq!(l.accepted_total(), 2);
}
#[test]
fn rejected_attempts_do_not_extend_window() {
let (l, ts) = mk_limiter();
assert!(
l.check_and_record(mk_sender(1), mk_contract(1))
.is_allowed()
);
for _ in 0..9 {
ts.advance(Duration::from_millis(10));
assert!(
!l.check_and_record(mk_sender(1), mk_contract(1))
.is_allowed()
);
}
ts.advance(Duration::from_millis(5));
assert!(
!l.check_and_record(mk_sender(1), mk_contract(1))
.is_allowed()
);
ts.advance(Duration::from_millis(10));
assert!(
l.check_and_record(mk_sender(1), mk_contract(1))
.is_allowed(),
"after 105ms+ from original accept, next attempt MUST be allowed — \
rejected attempts must not have moved the window forward"
);
}
#[test]
fn cleanup_removes_stale_entries() {
let (l, ts) = mk_limiter();
l.check_and_record(mk_sender(1), mk_contract(1));
l.check_and_record(mk_sender(2), mk_contract(2));
assert_eq!(l.len(), 2);
ts.advance(CLEANUP_AGE + Duration::from_secs(1));
l.cleanup();
assert_eq!(l.len(), 0, "all stale entries must be cleared");
}
#[test]
fn cleanup_preserves_fresh_entries() {
let (l, ts) = mk_limiter();
l.check_and_record(mk_sender(1), mk_contract(1));
ts.advance(CLEANUP_AGE / 2);
l.cleanup();
assert_eq!(l.len(), 1, "fresh entry must be preserved");
}
#[test]
fn counters_track_accepts_and_rejects() {
let (l, ts) = mk_limiter();
for i in 0..5 {
ts.advance(MIN_UPDATE_INTERVAL + Duration::from_millis(1));
assert!(
l.check_and_record(mk_sender(1), mk_contract(1))
.is_allowed(),
"iter {i}"
);
}
for _ in 0..3 {
assert!(
!l.check_and_record(mk_sender(1), mk_contract(1))
.is_allowed()
);
}
assert_eq!(l.accepted_total(), 5);
assert_eq!(l.rejected_total(), 3);
}
#[test]
fn may21_flood_pattern_is_throttled() {
let (l, ts) = mk_limiter();
let sender = mk_sender(1);
let contract = mk_contract(1);
for _ in 0..1000 {
l.check_and_record(sender, contract);
ts.advance(Duration::from_millis(1));
}
let accepted = l.accepted_total();
let rejected = l.rejected_total();
assert!(
(9..=12).contains(&accepted),
"expected ~10 admits over 1s of flooding, got {accepted}"
);
assert_eq!(accepted + rejected, 1000);
assert!(
rejected as f64 / 1000.0 > 0.95,
"expected >95% rejection rate, got {}",
rejected as f64 / 1000.0
);
}
#[test]
fn capacity_exceeded_when_cap_reached() {
let ts = SharedMockTimeSource::new();
let limiter = UpdateRateLimiter::with_config(
Arc::new(ts.clone()),
MIN_UPDATE_INTERVAL,
8, );
for i in 0..8 {
let d = limiter.check_and_record(mk_sender(i + 1), mk_contract(i + 1));
assert_eq!(d, RateLimitDecision::Allowed, "pair {i} should be allowed");
}
assert_eq!(limiter.len(), 8);
let d = limiter.check_and_record(mk_sender(99), mk_contract(99));
assert_eq!(
d,
RateLimitDecision::CapacityExceeded,
"new pair past the cap must be CapacityExceeded"
);
assert_eq!(limiter.capacity_rejected_total(), 1);
ts.advance(MIN_UPDATE_INTERVAL + Duration::from_millis(1));
let d = limiter.check_and_record(mk_sender(1), mk_contract(1));
assert_eq!(
d,
RateLimitDecision::Allowed,
"existing pair must keep working at cap"
);
}
#[test]
fn concurrent_check_and_record_admits_one_per_window() {
use std::sync::{Arc as StdArc, Barrier};
use std::thread;
let ts = SharedMockTimeSource::new();
let limiter = StdArc::new(UpdateRateLimiter::new(Arc::new(ts.clone())));
let sender = mk_sender(1);
let contract = mk_contract(1);
const THREADS: usize = 16;
let barrier = StdArc::new(Barrier::new(THREADS));
let mut handles = Vec::with_capacity(THREADS);
for _ in 0..THREADS {
let l = limiter.clone();
let b = barrier.clone();
handles.push(thread::spawn(move || {
b.wait();
l.check_and_record(sender, contract)
}));
}
let mut allowed = 0;
let mut rejected = 0;
for h in handles {
match h.join().unwrap() {
RateLimitDecision::Allowed => allowed += 1,
RateLimitDecision::Rejected { .. } => rejected += 1,
RateLimitDecision::CapacityExceeded => panic!("unexpected cap"),
}
}
assert_eq!(
allowed, 1,
"exactly ONE concurrent caller must be admitted per window; \
got {allowed} admits, {rejected} rejects"
);
assert_eq!(rejected, THREADS - 1);
assert_eq!(limiter.accepted_total(), 1);
assert_eq!(limiter.rejected_total(), (THREADS - 1) as u64);
}
#[test]
fn update_dispatch_gates_all_four_wire_variants() {
const NODE_SRC: &str = include_str!("../node.rs");
let block_start = NODE_SRC
.find("NetMessageV1::Update(ref op) =>")
.expect("could not locate UPDATE dispatch block in node.rs");
let tail = &NODE_SRC[block_start + 1..];
let block_len = tail
.find("\n NetMessageV1::")
.or_else(|| tail.find("\n NetMessageV1::"))
.unwrap_or(tail.len());
let block = &NODE_SRC[block_start..block_start + 1 + block_len];
let rate_limit_pos = block
.find("update_rate_limiter")
.expect("update_rate_limiter not invoked in UPDATE dispatch block");
let first_spawn_pos = block
.find("start_relay_request_update(")
.expect("start_relay_request_update spawn not found in block");
assert!(
rate_limit_pos < first_spawn_pos,
"rate limit gate (offset {rate_limit_pos}) must appear BEFORE \
the first relay spawn (offset {first_spawn_pos}) so rejected \
messages don't pay the spawn cost"
);
for variant in [
"UpdateMsg::RequestUpdate {",
"UpdateMsg::BroadcastTo {",
"UpdateMsg::RequestUpdateStreaming {",
"UpdateMsg::BroadcastToStreaming {",
] {
assert!(
block.contains(variant),
"UPDATE dispatch block missing wire variant: `{variant}`. \
If a new UPDATE wire variant was added, gate it through \
the rate limiter and update this list. If a variant was \
removed, update this list."
);
}
for spawn in [
"start_relay_request_update(",
"start_relay_broadcast_to(",
"start_relay_request_update_streaming(",
"start_relay_broadcast_to_streaming(",
] {
let count = block.matches(spawn).count();
assert!(
count >= 1,
"UPDATE dispatch block does not invoke `{spawn}` — the \
corresponding wire variant is not actually gated."
);
}
}
#[test]
fn concurrent_distinct_keys_do_not_overshoot_cap() {
use std::sync::{Arc as StdArc, Barrier};
use std::thread;
const CAP: usize = 8;
const THREADS: usize = 64;
let ts = SharedMockTimeSource::new();
let limiter = StdArc::new(UpdateRateLimiter::with_config(
Arc::new(ts.clone()),
MIN_UPDATE_INTERVAL,
CAP,
));
let barrier = StdArc::new(Barrier::new(THREADS));
let mut handles = Vec::with_capacity(THREADS);
for i in 0..THREADS {
let l = limiter.clone();
let b = barrier.clone();
handles.push(thread::spawn(move || {
b.wait();
l.check_and_record(mk_sender((i + 1) as u8), mk_contract((i + 1) as u8))
}));
}
let mut allowed = 0;
let mut cap_rejected = 0;
let mut rate_rejected = 0;
for h in handles {
match h.join().unwrap() {
RateLimitDecision::Allowed => allowed += 1,
RateLimitDecision::CapacityExceeded => cap_rejected += 1,
RateLimitDecision::Rejected { .. } => rate_rejected += 1,
}
}
assert_eq!(
limiter.len(),
CAP,
"strict cap: map size must equal CAP after a 64-thread \
concurrent flood of distinct keys, got {}",
limiter.len()
);
assert_eq!(
allowed, CAP,
"exactly CAP admissions allowed under flood, got {allowed}"
);
assert_eq!(cap_rejected, THREADS - CAP);
assert_eq!(rate_rejected, 0);
assert_eq!(limiter.capacity_rejected_total(), (THREADS - CAP) as u64);
}
#[test]
fn cleanup_decrements_size_counter() {
let ts = SharedMockTimeSource::new();
let limiter = UpdateRateLimiter::with_config(
Arc::new(ts.clone()),
MIN_UPDATE_INTERVAL,
4, );
for i in 0..4 {
assert_eq!(
limiter.check_and_record(mk_sender(i + 1), mk_contract(i + 1)),
RateLimitDecision::Allowed
);
}
assert_eq!(
limiter.check_and_record(mk_sender(5), mk_contract(5)),
RateLimitDecision::CapacityExceeded
);
ts.advance(CLEANUP_AGE + Duration::from_secs(1));
limiter.cleanup();
assert_eq!(limiter.len(), 0);
for i in 10..14 {
assert_eq!(
limiter.check_and_record(mk_sender(i), mk_contract(i)),
RateLimitDecision::Allowed,
"after cleanup, new pair (sender={i}) should be admitted"
);
}
}
}