use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
pub(crate) const MAX_CACHED_OBSERVATIONS: usize = 16;
pub(crate) const OBSERVATION_RECENCY_WINDOW: Duration = Duration::from_secs(600);
#[derive(Debug, Clone, Copy)]
struct ObservedEntry {
count: u64,
last_seen: Instant,
}
type CacheKey = (SocketAddr, SocketAddr);
#[derive(Debug, Default)]
pub(crate) struct ObservedAddressCache {
entries: HashMap<CacheKey, ObservedEntry>,
}
impl ObservedAddressCache {
pub(crate) fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
pub(crate) fn record(&mut self, local_bind: SocketAddr, observed: SocketAddr) {
self.record_at(local_bind, observed, Instant::now());
}
pub(crate) fn record_at(&mut self, local_bind: SocketAddr, observed: SocketAddr, now: Instant) {
let key = (local_bind, observed);
if let Some(entry) = self.entries.get_mut(&key) {
entry.count = entry.count.saturating_add(1);
entry.last_seen = now;
return;
}
if self.entries.len() >= MAX_CACHED_OBSERVATIONS {
self.evict_oldest();
}
self.entries.insert(
key,
ObservedEntry {
count: 1,
last_seen: now,
},
);
}
pub(crate) fn most_frequent_recent_per_local_bind(&self) -> Vec<SocketAddr> {
self.most_frequent_recent_per_local_bind_at(Instant::now())
}
pub(crate) fn most_frequent_recent_per_local_bind_at(&self, now: Instant) -> Vec<SocketAddr> {
let mut binds: Vec<SocketAddr> = self.entries.keys().map(|(bind, _)| *bind).collect();
binds.sort();
binds.dedup();
let mut result = Vec::with_capacity(binds.len());
for bind in binds {
if let Some(addr) = self.best_observed_for_bind_at(bind, now) {
result.push(addr);
}
}
result
}
fn best_observed_for_bind_at(
&self,
local_bind: SocketAddr,
now: Instant,
) -> Option<SocketAddr> {
let recent = self
.entries
.iter()
.filter(|((bind, _), _)| *bind == local_bind)
.filter(|(_, e)| now.duration_since(e.last_seen) <= OBSERVATION_RECENCY_WINDOW)
.max_by_key(|(_, e)| (e.count, e.last_seen))
.map(|((_, observed), _)| *observed);
if recent.is_some() {
return recent;
}
self.entries
.iter()
.filter(|((bind, _), _)| *bind == local_bind)
.max_by_key(|(_, e)| (e.count, e.last_seen))
.map(|((_, observed), _)| *observed)
}
pub(crate) fn most_frequent_recent(&self) -> Option<SocketAddr> {
self.most_frequent_recent_at(Instant::now())
}
pub(crate) fn most_frequent_recent_at(&self, now: Instant) -> Option<SocketAddr> {
let recent = self
.entries
.iter()
.filter(|(_, e)| now.duration_since(e.last_seen) <= OBSERVATION_RECENCY_WINDOW)
.max_by_key(|(_, e)| (e.count, e.last_seen))
.map(|((_, observed), _)| *observed);
if recent.is_some() {
return recent;
}
self.entries
.iter()
.max_by_key(|(_, e)| (e.count, e.last_seen))
.map(|((_, observed), _)| *observed)
}
fn evict_oldest(&mut self) {
let oldest = self
.entries
.iter()
.min_by_key(|(_, e)| e.last_seen)
.map(|(key, _)| *key);
if let Some(key) = oldest {
self.entries.remove(&key);
}
}
#[cfg(test)]
pub(crate) fn len(&self) -> usize {
self.entries.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
const DEFAULT_LOCAL_BIND_PORT: u16 = 7000;
const ALT_LOCAL_BIND_PORT: u16 = 7001;
fn addr(last_octet: u8, port: u16) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 0, 2, last_octet)), port)
}
fn default_bind() -> SocketAddr {
SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
DEFAULT_LOCAL_BIND_PORT,
)
}
fn alt_bind() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), ALT_LOCAL_BIND_PORT)
}
#[test]
fn empty_cache_returns_none() {
let cache = ObservedAddressCache::new();
assert_eq!(cache.most_frequent_recent(), None);
assert!(cache.most_frequent_recent_per_local_bind().is_empty());
}
#[test]
fn single_observation_returns_that_address() {
let mut cache = ObservedAddressCache::new();
let a = addr(1, 9000);
cache.record(default_bind(), a);
assert_eq!(cache.most_frequent_recent(), Some(a));
assert_eq!(cache.most_frequent_recent_per_local_bind(), vec![a]);
assert_eq!(cache.len(), 1);
}
#[test]
fn repeated_observation_increments_count_without_growing() {
let mut cache = ObservedAddressCache::new();
let a = addr(1, 9000);
cache.record(default_bind(), a);
cache.record(default_bind(), a);
cache.record(default_bind(), a);
assert_eq!(cache.len(), 1);
assert_eq!(cache.most_frequent_recent(), Some(a));
}
#[test]
fn higher_count_wins_among_recent_entries() {
let mut cache = ObservedAddressCache::new();
let popular = addr(1, 9000);
let unpopular = addr(2, 9000);
for _ in 0..5 {
cache.record(default_bind(), popular);
}
cache.record(default_bind(), unpopular);
assert_eq!(cache.most_frequent_recent(), Some(popular));
}
#[test]
fn equal_counts_break_tie_by_recency() {
let mut cache = ObservedAddressCache::new();
let older = addr(1, 9000);
let newer = addr(2, 9000);
let base = Instant::now();
cache.record_at(default_bind(), older, base);
cache.record_at(default_bind(), newer, base + Duration::from_secs(1));
assert_eq!(
cache.most_frequent_recent_at(base + Duration::from_secs(2)),
Some(newer)
);
}
#[test]
fn stale_high_count_loses_to_recent_low_count() {
let mut cache = ObservedAddressCache::new();
let stale = addr(1, 9000);
let fresh = addr(2, 9000);
let base = Instant::now();
let stale_time = base;
for _ in 0..1000 {
cache.record_at(default_bind(), stale, stale_time);
}
let fresh_time = base + OBSERVATION_RECENCY_WINDOW + Duration::from_secs(60);
for _ in 0..3 {
cache.record_at(default_bind(), fresh, fresh_time);
}
let now = fresh_time + Duration::from_secs(1);
assert_eq!(cache.most_frequent_recent_at(now), Some(fresh));
}
#[test]
fn falls_back_to_global_highest_count_when_nothing_is_recent() {
let mut cache = ObservedAddressCache::new();
let popular = addr(1, 9000);
let unpopular = addr(2, 9000);
let base = Instant::now();
for _ in 0..5 {
cache.record_at(default_bind(), popular, base);
}
cache.record_at(default_bind(), unpopular, base);
let far_future = base + OBSERVATION_RECENCY_WINDOW * 10;
assert_eq!(cache.most_frequent_recent_at(far_future), Some(popular));
}
#[test]
fn eviction_removes_oldest_by_last_seen_when_full() {
let mut cache = ObservedAddressCache::new();
let base = Instant::now();
for i in 0..(MAX_CACHED_OBSERVATIONS as u8) {
cache.record_at(
default_bind(),
addr(i + 1, 9000),
base + Duration::from_secs(u64::from(i)),
);
}
assert_eq!(cache.len(), MAX_CACHED_OBSERVATIONS);
let oldest_key = (default_bind(), addr(1, 9000));
assert!(cache.entries.contains_key(&oldest_key));
let newcomer_key = (default_bind(), addr(99, 9000));
cache.record_at(
newcomer_key.0,
newcomer_key.1,
base + Duration::from_secs(MAX_CACHED_OBSERVATIONS as u64),
);
assert_eq!(cache.len(), MAX_CACHED_OBSERVATIONS);
assert!(
!cache.entries.contains_key(&oldest_key),
"oldest entry should have been evicted"
);
assert!(
cache.entries.contains_key(&newcomer_key),
"newcomer should be present"
);
}
#[test]
fn re_observing_an_existing_entry_does_not_trigger_eviction() {
let mut cache = ObservedAddressCache::new();
let base = Instant::now();
for i in 0..(MAX_CACHED_OBSERVATIONS as u8) {
cache.record_at(
default_bind(),
addr(i + 1, 9000),
base + Duration::from_secs(u64::from(i)),
);
}
assert_eq!(cache.len(), MAX_CACHED_OBSERVATIONS);
let oldest_key = (default_bind(), addr(1, 9000));
let refresh_time = base + Duration::from_secs(1000);
cache.record_at(oldest_key.0, oldest_key.1, refresh_time);
assert_eq!(cache.len(), MAX_CACHED_OBSERVATIONS);
let entry = cache.entries.get(&oldest_key).copied().unwrap();
assert_eq!(entry.count, 2);
assert_eq!(entry.last_seen, refresh_time);
}
#[test]
fn observations_for_different_local_binds_do_not_collide() {
let mut cache = ObservedAddressCache::new();
let observed = addr(1, 9000);
cache.record(default_bind(), observed);
cache.record(alt_bind(), observed);
cache.record(alt_bind(), observed);
assert_eq!(cache.len(), 2);
let default_entry = cache.entries.get(&(default_bind(), observed)).unwrap();
let alt_entry = cache.entries.get(&(alt_bind(), observed)).unwrap();
assert_eq!(default_entry.count, 1);
assert_eq!(alt_entry.count, 2);
}
#[test]
fn per_local_bind_returns_one_address_per_distinct_bind() {
let mut cache = ObservedAddressCache::new();
let observed_default = addr(1, 9000);
let observed_alt = addr(2, 9000);
cache.record(default_bind(), observed_default);
cache.record(alt_bind(), observed_alt);
let mut result = cache.most_frequent_recent_per_local_bind();
result.sort();
let mut expected = vec![observed_default, observed_alt];
expected.sort();
assert_eq!(result, expected);
}
#[test]
fn per_local_bind_picks_best_within_each_bind_independently() {
let mut cache = ObservedAddressCache::new();
let default_winner = addr(1, 9000);
let default_loser = addr(2, 9000);
let alt_winner = addr(3, 9000);
let alt_loser = addr(4, 9000);
for _ in 0..5 {
cache.record(default_bind(), default_winner);
}
cache.record(default_bind(), default_loser);
for _ in 0..3 {
cache.record(alt_bind(), alt_winner);
}
cache.record(alt_bind(), alt_loser);
let mut result = cache.most_frequent_recent_per_local_bind();
result.sort();
let mut expected = vec![default_winner, alt_winner];
expected.sort();
assert_eq!(result, expected);
}
#[test]
fn stale_observation_on_one_bind_does_not_affect_recency_on_another() {
let mut cache = ObservedAddressCache::new();
let stale_for_default = addr(1, 9000);
let fresh_for_alt = addr(2, 9000);
let base = Instant::now();
cache.record_at(default_bind(), stale_for_default, base);
let fresh_time = base + OBSERVATION_RECENCY_WINDOW + Duration::from_secs(60);
cache.record_at(alt_bind(), fresh_for_alt, fresh_time);
let now = fresh_time + Duration::from_secs(1);
let mut result = cache.most_frequent_recent_per_local_bind_at(now);
result.sort();
let mut expected = vec![stale_for_default, fresh_for_alt];
expected.sort();
assert_eq!(result, expected);
}
}