use bytes::{Buf, BufMut, Bytes, BytesMut};
use dashmap::DashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
pub const ROUTING_HEADER_SIZE: usize = 18;
pub const ROUTING_MAGIC: u16 = 0x5452;
pub const _MAX_TTL: u8 = 16;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(transparent)]
pub struct RouteFlags(u8);
impl RouteFlags {
pub const NONE: Self = Self(0x00);
pub const CONTROL: Self = Self(0x01);
pub const REQUIRES_ACK: Self = Self(0x02);
pub const PRIORITY: Self = Self(0x04);
pub const END_OF_STREAM: Self = Self(0x08);
pub fn from_u8(v: u8) -> Self {
if v & 0xF0 != 0 {
tracing::warn!(
wire_byte = format_args!("0x{:02x}", v),
high_nibble = format_args!("0x{:02x}", v & 0xF0),
"route flags: high-nibble bits set on inbound wire byte and \
silently stripped — peer may be running a newer schema. \
Widen RouteFlags::from_u8's mask in lock-step before any \
production peer relies on a high-nibble bit."
);
}
Self(v & 0x0F)
}
pub fn as_u8(self) -> u8 {
self.0
}
pub fn contains(self, other: Self) -> bool {
(self.0 & other.0) == other.0
}
pub fn is_control(self) -> bool {
self.contains(Self::CONTROL)
}
pub fn is_priority(self) -> bool {
self.contains(Self::PRIORITY)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub struct RoutingHeader {
pub dest_id: u64,
pub src_id: u32,
pub ttl: u8,
pub hop_count: u8,
pub flags: RouteFlags,
pub _reserved: u8,
}
impl RoutingHeader {
pub fn new(dest_id: u64, src_id: u32, ttl: u8) -> Self {
Self {
dest_id,
src_id,
ttl,
hop_count: 0,
flags: RouteFlags::NONE,
_reserved: 0,
}
}
pub fn control(dest_id: u64, src_id: u32, ttl: u8) -> Self {
Self {
dest_id,
src_id,
ttl,
hop_count: 0,
flags: RouteFlags::CONTROL,
_reserved: 0,
}
}
pub fn priority(dest_id: u64, src_id: u32, ttl: u8) -> Self {
Self {
dest_id,
src_id,
ttl,
hop_count: 0,
flags: RouteFlags::PRIORITY,
_reserved: 0,
}
}
pub fn to_bytes(&self) -> [u8; ROUTING_HEADER_SIZE] {
let mut buf = [0u8; ROUTING_HEADER_SIZE];
buf[0..2].copy_from_slice(&ROUTING_MAGIC.to_le_bytes());
buf[2] = self.ttl;
buf[3] = self.hop_count;
buf[4] = self.flags.as_u8();
buf[5] = self._reserved;
buf[6..10].copy_from_slice(&self.src_id.to_le_bytes());
buf[10..18].copy_from_slice(&self.dest_id.to_le_bytes());
buf
}
pub fn from_bytes(buf: &[u8]) -> Option<Self> {
if buf.len() < ROUTING_HEADER_SIZE {
return None;
}
let magic = u16::from_le_bytes([buf[0], buf[1]]);
if magic != ROUTING_MAGIC {
return None;
}
Some(Self {
ttl: buf[2],
hop_count: buf[3],
flags: RouteFlags::from_u8(buf[4]),
_reserved: buf[5],
src_id: u32::from_le_bytes(buf[6..10].try_into().ok()?),
dest_id: u64::from_le_bytes(buf[10..18].try_into().ok()?),
})
}
pub fn write_to(&self, buf: &mut BytesMut) {
buf.put_u16_le(ROUTING_MAGIC);
buf.put_u8(self.ttl);
buf.put_u8(self.hop_count);
buf.put_u8(self.flags.as_u8());
buf.put_u8(self._reserved);
buf.put_u32_le(self.src_id);
buf.put_u64_le(self.dest_id);
}
pub fn write_at(&self, dst: &mut [u8]) {
assert!(
dst.len() >= ROUTING_HEADER_SIZE,
"write_at: dst is {} bytes, need {}",
dst.len(),
ROUTING_HEADER_SIZE,
);
dst[0..2].copy_from_slice(&ROUTING_MAGIC.to_le_bytes());
dst[2] = self.ttl;
dst[3] = self.hop_count;
dst[4] = self.flags.as_u8();
dst[5] = self._reserved;
dst[6..10].copy_from_slice(&self.src_id.to_le_bytes());
dst[10..18].copy_from_slice(&self.dest_id.to_le_bytes());
}
pub fn read_from(buf: &mut Bytes) -> Option<Self> {
if buf.remaining() < ROUTING_HEADER_SIZE {
return None;
}
let magic = u16::from_le_bytes([buf[0], buf[1]]);
if magic != ROUTING_MAGIC {
return None;
}
let _ = buf.get_u16_le();
let ttl = buf.get_u8();
let hop_count = buf.get_u8();
let flags = RouteFlags::from_u8(buf.get_u8());
let _reserved = buf.get_u8();
let src_id = buf.get_u32_le();
let dest_id = buf.get_u64_le();
Some(Self {
dest_id,
src_id,
ttl,
hop_count,
flags,
_reserved,
})
}
#[inline]
pub fn is_expired(&self) -> bool {
self.ttl == 0
}
#[inline]
pub fn forward(&mut self) -> bool {
if self.ttl == 0 {
return false;
}
self.ttl -= 1;
if self.hop_count == u8::MAX {
tracing::warn!(
"RoutingHeader::forward: hop_count saturated at {}; \
indirect-route metrics on this packet are inaccurate",
u8::MAX
);
} else {
self.hop_count = self.hop_count.saturating_add(1);
}
true
}
}
#[derive(Debug)]
pub struct SchedulerStreamStats {
pub packets_in: AtomicU64,
pub packets_out: AtomicU64,
pub packets_dropped: AtomicU64,
pub bytes_in: AtomicU64,
pub bytes_out: AtomicU64,
last_activity: AtomicU64,
}
impl SchedulerStreamStats {
pub fn new() -> Self {
Self {
packets_in: AtomicU64::new(0),
packets_out: AtomicU64::new(0),
packets_dropped: AtomicU64::new(0),
bytes_in: AtomicU64::new(0),
bytes_out: AtomicU64::new(0),
last_activity: AtomicU64::new(Self::now_nanos()),
}
}
#[inline]
pub fn record_in(&self, bytes: u64) {
self.packets_in.fetch_add(1, Ordering::Relaxed);
self.bytes_in.fetch_add(bytes, Ordering::Relaxed);
self.last_activity
.store(Self::now_nanos(), Ordering::Relaxed);
}
#[inline]
pub fn record_out(&self, bytes: u64) {
self.packets_out.fetch_add(1, Ordering::Relaxed);
self.bytes_out.fetch_add(bytes, Ordering::Relaxed);
}
#[inline]
pub fn record_drop(&self) {
self.packets_dropped.fetch_add(1, Ordering::Relaxed);
}
#[inline]
pub fn get_packets_in(&self) -> u64 {
self.packets_in.load(Ordering::Relaxed)
}
#[inline]
pub fn get_packets_out(&self) -> u64 {
self.packets_out.load(Ordering::Relaxed)
}
#[inline]
pub fn get_drops(&self) -> u64 {
self.packets_dropped.load(Ordering::Relaxed)
}
pub fn is_idle(&self, idle_nanos: u64) -> bool {
let last = self.last_activity.load(Ordering::Relaxed);
Self::now_nanos().saturating_sub(last) > idle_nanos
}
#[inline]
fn now_nanos() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64
}
}
impl Default for SchedulerStreamStats {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct RouteEntry {
pub next_hop: SocketAddr,
pub metric: u16,
pub active: bool,
pub updated_at: Instant,
}
impl RouteEntry {
pub fn new(next_hop: SocketAddr) -> Self {
Self {
next_hop,
metric: 1,
active: true,
updated_at: Instant::now(),
}
}
pub fn with_metric(next_hop: SocketAddr, metric: u16) -> Self {
Self {
next_hop,
metric,
active: true,
updated_at: Instant::now(),
}
}
}
pub const MAX_STREAM_STATS: usize = 65_536;
pub struct RoutingTable {
routes: DashMap<u64, RouteEntry>,
stream_stats: DashMap<u64, SchedulerStreamStats>,
local_id: u64,
max_route_age_nanos: AtomicU64,
}
impl RoutingTable {
pub fn new(local_id: u64) -> Self {
Self {
routes: DashMap::new(),
stream_stats: DashMap::new(),
local_id,
max_route_age_nanos: AtomicU64::new(u64::MAX),
}
}
#[inline]
pub fn local_id(&self) -> u64 {
self.local_id
}
pub fn add_route(&self, dest_id: u64, next_hop: SocketAddr) {
self.routes.insert(dest_id, RouteEntry::new(next_hop));
}
pub fn add_route_with_metric(&self, dest_id: u64, next_hop: SocketAddr, metric: u16) {
use dashmap::mapref::entry::Entry;
match self.routes.entry(dest_id) {
Entry::Vacant(v) => {
v.insert(RouteEntry::with_metric(next_hop, metric));
}
Entry::Occupied(mut o) => {
if metric < o.get().metric {
o.insert(RouteEntry::with_metric(next_hop, metric));
} else {
o.get_mut().updated_at = Instant::now();
}
}
}
}
pub fn remove_route(&self, dest_id: u64) -> Option<RouteEntry> {
self.routes.remove(&dest_id).map(|(_, v)| v)
}
pub fn remove_route_if_next_hop_is(&self, dest_id: u64, expected_next_hop: SocketAddr) -> bool {
self.routes
.remove_if(&dest_id, |_, entry| entry.next_hop == expected_next_hop)
.is_some()
}
pub fn lookup(&self, dest_id: u64) -> Option<SocketAddr> {
let max_age = self.max_route_age();
self.routes
.get(&dest_id)
.filter(|r| r.active && r.updated_at.elapsed() <= max_age)
.map(|r| r.next_hop)
}
pub fn lookup_alternate(
&self,
dest_id: u64,
exclude_next_hop: SocketAddr,
) -> Option<SocketAddr> {
let max_age = self.max_route_age();
self.routes
.get(&dest_id)
.filter(|r| {
r.active && r.updated_at.elapsed() <= max_age && r.next_hop != exclude_next_hop
})
.map(|r| r.next_hop)
}
pub fn sweep_stale(&self, max_age: std::time::Duration) -> usize {
let mut removed = 0;
self.routes.retain(|_, entry| {
let keep = entry.updated_at.elapsed() <= max_age;
if !keep {
removed += 1;
}
keep
});
removed
}
pub fn set_max_route_age(&self, age: std::time::Duration) {
self.max_route_age_nanos.store(
age.as_nanos().min(u64::MAX as u128) as u64,
Ordering::Relaxed,
);
}
fn max_route_age(&self) -> std::time::Duration {
let nanos = self.max_route_age_nanos.load(Ordering::Relaxed);
std::time::Duration::from_nanos(nanos)
}
#[inline]
pub fn is_local(&self, dest_id: u64) -> bool {
dest_id == self.local_id
}
pub fn get_stream_stats(
&self,
stream_id: u64,
) -> dashmap::mapref::one::Ref<'_, u64, SchedulerStreamStats> {
self.stream_stats.entry(stream_id).or_default().downgrade()
}
#[inline]
fn may_admit_stream(&self, stream_id: u64) -> bool {
if self.stream_stats.contains_key(&stream_id) {
return true;
}
self.stream_stats.len() < MAX_STREAM_STATS
}
pub fn record_in(&self, stream_id: u64, bytes: u64) {
if !self.may_admit_stream(stream_id) {
return;
}
self.stream_stats
.entry(stream_id)
.or_default()
.record_in(bytes);
}
pub fn record_out(&self, stream_id: u64, bytes: u64) {
if !self.may_admit_stream(stream_id) {
return;
}
self.stream_stats
.entry(stream_id)
.or_default()
.record_out(bytes);
}
pub fn record_drop(&self, stream_id: u64) {
if !self.may_admit_stream(stream_id) {
return;
}
self.stream_stats
.entry(stream_id)
.or_default()
.record_drop();
}
pub fn route_count(&self) -> usize {
self.routes.len()
}
pub fn stream_count(&self) -> usize {
self.stream_stats.len()
}
pub fn deactivate_route(&self, dest_id: u64) {
if let Some(mut entry) = self.routes.get_mut(&dest_id) {
entry.active = false;
}
}
pub fn activate_route(&self, dest_id: u64) {
if let Some(mut entry) = self.routes.get_mut(&dest_id) {
entry.active = true;
entry.updated_at = Instant::now();
}
}
pub fn all_routes(&self) -> Vec<(u64, RouteEntry)> {
self.routes
.iter()
.map(|r| (*r.key(), r.value().clone()))
.collect()
}
pub fn cleanup_idle_streams(&self, idle_nanos: u64) -> usize {
let mut removed = 0;
self.stream_stats.retain(|_, stats| {
if stats.is_idle(idle_nanos) {
removed += 1;
false
} else {
true
}
});
removed
}
pub fn aggregate_stats(&self) -> AggregateStats {
let mut total_in = 0u64;
let mut total_out = 0u64;
let mut total_drops = 0u64;
for entry in self.stream_stats.iter() {
total_in += entry.get_packets_in();
total_out += entry.get_packets_out();
total_drops += entry.get_drops();
}
AggregateStats {
routes: self.routes.len(),
streams: self.stream_stats.len(),
packets_in: total_in,
packets_out: total_out,
packets_dropped: total_drops,
}
}
}
impl std::fmt::Debug for RoutingTable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RoutingTable")
.field("local_id", &format!("{:016x}", self.local_id))
.field("routes", &self.routes.len())
.field("streams", &self.stream_stats.len())
.finish()
}
}
#[derive(Debug, Clone, Default)]
pub struct AggregateStats {
pub routes: usize,
pub streams: usize,
pub packets_in: u64,
pub packets_out: u64,
pub packets_dropped: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_routing_header_roundtrip() {
let header = RoutingHeader::new(0x123456789ABCDEF0, 0xDEADBEEF, 8);
let bytes = header.to_bytes();
let parsed = RoutingHeader::from_bytes(&bytes).unwrap();
assert_eq!(header, parsed);
}
#[test]
fn write_at_matches_write_to_byte_for_byte() {
let header = RoutingHeader::new(0xABCD_EF01_2345_6789, 0xDEAD_BEEF, 7);
let mut via_write_to = BytesMut::with_capacity(ROUTING_HEADER_SIZE);
header.write_to(&mut via_write_to);
let mut via_write_at = [0xCC; ROUTING_HEADER_SIZE];
header.write_at(&mut via_write_at);
assert_eq!(
&via_write_to[..],
&via_write_at[..],
"write_at must produce the same wire bytes as write_to; \
a divergence would silently malform every forwarded packet",
);
}
#[test]
#[should_panic(expected = "write_at")]
fn write_at_panics_on_short_slice() {
let header = RoutingHeader::new(1, 2, 1);
let mut short = [0u8; ROUTING_HEADER_SIZE - 1];
header.write_at(&mut short);
}
#[test]
fn test_routing_header_magic_at_offset_zero() {
let header = RoutingHeader::new(0x4E45_4E45_4E45_4E45, 0x4E45_4E45, 8);
let bytes = header.to_bytes();
assert_eq!(
u16::from_le_bytes([bytes[0], bytes[1]]),
ROUTING_MAGIC,
"magic must live at bytes 0-1 independent of dest_id's own byte pattern",
);
}
#[test]
fn test_routing_header_rejects_wrong_magic() {
let mut bytes = RoutingHeader::new(0x1234, 0x5678, 4).to_bytes();
bytes[0..2].copy_from_slice(&0x4E45_u16.to_le_bytes());
assert!(RoutingHeader::from_bytes(&bytes).is_none());
bytes[0..2].copy_from_slice(&0xFFFF_u16.to_le_bytes());
assert!(RoutingHeader::from_bytes(&bytes).is_none());
}
#[test]
fn test_regression_routing_discriminator_survives_magic_collision_node_id() {
let ambiguous_dest: u64 = 0xDEAD_BEEF_FFFF_4E45;
let header = RoutingHeader::new(ambiguous_dest, 0x1111_2222, 8);
let bytes = header.to_bytes();
assert_eq!(
u16::from_le_bytes([bytes[0], bytes[1]]),
ROUTING_MAGIC,
"magic at offset 0 must be independent of dest_id",
);
let parsed = RoutingHeader::from_bytes(&bytes).unwrap();
assert_eq!(parsed.dest_id, ambiguous_dest);
assert_eq!(parsed.src_id, 0x1111_2222);
assert_eq!(parsed.ttl, 8);
}
#[test]
fn test_routing_header_forward() {
let mut header = RoutingHeader::new(0x1234, 0x5678, 3);
assert_eq!(header.ttl, 3);
assert_eq!(header.hop_count, 0);
assert!(header.forward());
assert_eq!(header.ttl, 2);
assert_eq!(header.hop_count, 1);
assert!(header.forward());
assert!(header.forward());
assert_eq!(header.ttl, 0);
assert_eq!(header.hop_count, 3);
assert!(!header.forward());
}
#[test]
fn test_routing_header_flags() {
let control = RoutingHeader::control(0x1234, 0x5678, 2);
assert!(control.flags.is_control());
let priority = RoutingHeader::priority(0x1234, 0x5678, 2);
assert!(priority.flags.is_priority());
}
#[test]
fn test_route_flags_combined() {
let combined = RouteFlags::CONTROL.as_u8() | RouteFlags::REQUIRES_ACK.as_u8();
let parsed = RouteFlags::from_u8(combined);
assert!(
parsed.is_control(),
"Control bit must survive combined parse"
);
assert!(
parsed.contains(RouteFlags::REQUIRES_ACK),
"RequiresAck bit must survive combined parse"
);
let all = RouteFlags::CONTROL.as_u8()
| RouteFlags::REQUIRES_ACK.as_u8()
| RouteFlags::PRIORITY.as_u8()
| RouteFlags::END_OF_STREAM.as_u8();
let parsed_all = RouteFlags::from_u8(all);
assert!(parsed_all.is_control());
assert!(parsed_all.is_priority());
assert!(parsed_all.contains(RouteFlags::REQUIRES_ACK));
assert!(parsed_all.contains(RouteFlags::END_OF_STREAM));
}
#[test]
fn test_route_flags_roundtrip() {
let mut header = RoutingHeader::new(0x1234, 0x5678, 4);
header.flags =
RouteFlags::from_u8(RouteFlags::PRIORITY.as_u8() | RouteFlags::REQUIRES_ACK.as_u8());
let bytes = header.to_bytes();
let parsed = RoutingHeader::from_bytes(&bytes).unwrap();
assert!(parsed.flags.is_priority());
assert!(parsed.flags.contains(RouteFlags::REQUIRES_ACK));
}
#[test]
fn test_routing_table_basic() {
let table = RoutingTable::new(0x1234);
let addr1: SocketAddr = "127.0.0.1:9000".parse().unwrap();
let addr2: SocketAddr = "127.0.0.1:9001".parse().unwrap();
table.add_route(0x5678, addr1);
table.add_route(0x9ABC, addr2);
assert_eq!(table.lookup(0x5678), Some(addr1));
assert_eq!(table.lookup(0x9ABC), Some(addr2));
assert_eq!(table.lookup(0xFFFF), None);
assert!(table.is_local(0x1234));
assert!(!table.is_local(0x5678));
}
#[test]
fn test_routing_table_deactivate() {
let table = RoutingTable::new(0x1234);
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
table.add_route(0x5678, addr);
assert_eq!(table.lookup(0x5678), Some(addr));
table.deactivate_route(0x5678);
assert_eq!(table.lookup(0x5678), None);
table.activate_route(0x5678);
assert_eq!(table.lookup(0x5678), Some(addr));
}
#[test]
fn test_stream_stats() {
let stats = SchedulerStreamStats::new();
stats.record_in(100);
stats.record_in(200);
stats.record_out(100);
stats.record_drop();
assert_eq!(stats.get_packets_in(), 2);
assert_eq!(stats.get_packets_out(), 1);
assert_eq!(stats.get_drops(), 1);
}
#[test]
fn test_routing_table_stats() {
let table = RoutingTable::new(0x1234);
table.record_in(1, 100);
table.record_in(1, 200);
table.record_in(2, 150);
table.record_out(1, 100);
table.record_drop(2);
let stats = table.aggregate_stats();
assert_eq!(stats.streams, 2);
assert_eq!(stats.packets_in, 3);
assert_eq!(stats.packets_out, 1);
assert_eq!(stats.packets_dropped, 1);
}
#[test]
fn test_add_route_with_metric_preserves_better_direct_route() {
let table = RoutingTable::new(0x1111);
let direct: SocketAddr = "127.0.0.1:2000".parse().unwrap();
let indirect: SocketAddr = "127.0.0.1:3000".parse().unwrap();
table.add_route(0x2222, direct);
assert_eq!(table.lookup(0x2222), Some(direct));
table.add_route_with_metric(0x2222, indirect, 5);
assert_eq!(
table.lookup(0x2222),
Some(direct),
"worse indirect route must not displace the direct route"
);
let better: SocketAddr = "127.0.0.1:4000".parse().unwrap();
table.add_route_with_metric(0x2222, better, 0);
assert_eq!(
table.lookup(0x2222),
Some(better),
"strictly-better metric update must replace next_hop"
);
}
#[test]
fn add_route_with_metric_equal_does_not_overwrite_next_hop() {
let table = RoutingTable::new(0x1111);
let real: SocketAddr = "127.0.0.1:2000".parse().unwrap();
let attacker: SocketAddr = "10.0.0.1:31337".parse().unwrap();
table.add_route(0x2222, real);
table.add_route_with_metric(0x2222, attacker, 1);
assert_eq!(
table.lookup(0x2222),
Some(real),
"equal-metric pingwave must not overwrite an installed \
route's next_hop (security: prevents address poisoning)"
);
}
#[test]
fn test_sweep_stale_and_staleness_aware_lookup() {
use std::time::Duration;
let table = RoutingTable::new(0x1111);
let addr_a: SocketAddr = "127.0.0.1:2000".parse().unwrap();
let addr_b: SocketAddr = "127.0.0.1:3000".parse().unwrap();
table.add_route(0x2222, addr_a);
table.add_route(0x3333, addr_b);
let stale_ts = Instant::now()
.checked_sub(Duration::from_millis(200))
.expect("test host uptime should exceed 200ms");
{
let mut e = table.routes.get_mut(&0x2222).unwrap();
e.updated_at = stale_ts;
}
table.set_max_route_age(Duration::from_millis(50));
assert_eq!(table.lookup(0x2222), None);
assert_eq!(table.lookup(0x3333), Some(addr_b));
let removed = table.sweep_stale(Duration::from_millis(50));
assert_eq!(removed, 1);
assert!(table.routes.get(&0x2222).is_none());
assert!(table.routes.get(&0x3333).is_some());
}
#[test]
fn test_regression_remove_route_if_next_hop_is() {
let table = RoutingTable::new(0x1111);
let original: SocketAddr = "127.0.0.1:2000".parse().unwrap();
let newer: SocketAddr = "127.0.0.1:3000".parse().unwrap();
table.add_route(0x4444, original);
table.add_route(0x4444, newer);
let removed = table.remove_route_if_next_hop_is(0x4444, original);
assert!(
!removed,
"rollback must not evict an entry whose next_hop changed under us"
);
assert_eq!(
table.lookup(0x4444),
Some(newer),
"newer route must survive a stale rollback attempt"
);
let removed = table.remove_route_if_next_hop_is(0x4444, newer);
assert!(removed);
assert!(table.lookup(0x4444).is_none());
assert!(!table.remove_route_if_next_hop_is(0x4444, newer));
}
#[test]
fn test_lookup_alternate() {
let table = RoutingTable::new(0x1);
let b: SocketAddr = "127.0.0.1:2000".parse().unwrap();
let c: SocketAddr = "127.0.0.1:3000".parse().unwrap();
assert!(table.lookup_alternate(0x4444, b).is_none());
table.add_route(0x4444, b);
assert_eq!(table.lookup_alternate(0x4444, b), None);
assert_eq!(table.lookup_alternate(0x4444, c), Some(b));
}
#[test]
fn test_lookup_alternate_respects_staleness() {
use std::time::Duration;
let table = RoutingTable::new(0x1);
let b: SocketAddr = "127.0.0.1:2000".parse().unwrap();
let c: SocketAddr = "127.0.0.1:3000".parse().unwrap();
table.add_route(0x4444, b);
let stale_ts = Instant::now()
.checked_sub(Duration::from_millis(200))
.expect("test host uptime should exceed 200ms");
{
let mut e = table.routes.get_mut(&0x4444).unwrap();
e.updated_at = stale_ts;
}
table.set_max_route_age(Duration::from_millis(50));
assert!(table.lookup_alternate(0x4444, c).is_none());
}
#[test]
fn concurrent_add_route_with_metric_converges_on_lowest_metric() {
use std::sync::{Arc, Barrier};
use std::thread;
let table = Arc::new(RoutingTable::new(0x1111));
let dest = 0x2222u64;
let start = Arc::new(Barrier::new(8));
let mut handles = Vec::new();
for metric in 1u16..=8 {
let table = table.clone();
let start = start.clone();
handles.push(thread::spawn(move || {
start.wait();
let next_hop: SocketAddr =
format!("127.0.0.1:{}", 10_000 + metric).parse().unwrap();
for _ in 0..500 {
table.add_route_with_metric(dest, next_hop, metric);
}
}));
}
for h in handles {
h.join().expect("thread panicked");
}
let entry = table
.routes
.get(&dest)
.expect("route must exist after all threads inserted");
assert_eq!(
entry.metric, 1,
"final metric must be the minimum (1) across all concurrent inserts — \
a metric > 1 indicates a lost update or a torn compare-and-swap",
);
let winner = table.lookup(dest).expect("dest must resolve");
assert_eq!(
winner,
"127.0.0.1:10001".parse::<SocketAddr>().unwrap(),
"lookup should return the next_hop paired with the winning metric",
);
}
#[test]
fn direct_route_survives_concurrent_worse_indirect_inserts() {
use std::sync::{Arc, Barrier};
use std::thread;
let table = Arc::new(RoutingTable::new(0x1111));
let dest = 0x2222u64;
let direct: SocketAddr = "127.0.0.1:2000".parse().unwrap();
table.add_route(dest, direct);
assert_eq!(table.lookup(dest), Some(direct));
let start = Arc::new(Barrier::new(9));
let mut handles = Vec::new();
for metric in 2u16..=10 {
let table = table.clone();
let start = start.clone();
handles.push(thread::spawn(move || {
start.wait();
let indirect: SocketAddr =
format!("127.0.0.1:{}", 20_000 + metric).parse().unwrap();
for _ in 0..500 {
table.add_route_with_metric(dest, indirect, metric);
}
}));
}
for h in handles {
h.join().expect("thread panicked");
}
assert_eq!(
table.lookup(dest),
Some(direct),
"direct route (metric=1) must not be displaced by any \
concurrent indirect insert with metric >= 2",
);
let entry = table.routes.get(&dest).unwrap();
assert_eq!(entry.metric, 1, "metric must still be 1 (direct)");
}
#[test]
fn record_in_stops_admitting_new_streams_at_cap() {
let table = RoutingTable::new(0xCAFE);
for i in 0..MAX_STREAM_STATS as u64 {
table.record_in(i, 1);
}
assert_eq!(
table.stream_count(),
MAX_STREAM_STATS,
"all initial entries must be admitted (we're at the cap)"
);
let novel = MAX_STREAM_STATS as u64 + 1;
table.record_in(novel, 1);
assert!(
!table.stream_stats.contains_key(&novel),
"novel stream_id at cap must NOT be admitted (pre-fix \
would have inserted unconditionally and grown the map \
unboundedly)"
);
assert_eq!(
table.stream_count(),
MAX_STREAM_STATS,
"stream_count must not grow past the cap"
);
table.record_in(0, 100);
let stats = table.stream_stats.get(&0).unwrap();
assert!(
stats.get_packets_in() >= 2,
"existing entry must continue to record despite the \
cap — fix is admit-side only"
);
}
#[test]
fn cap_admits_new_streams_after_cleanup_reclaims_slots() {
let table = RoutingTable::new(0xCAFE);
for i in 0..MAX_STREAM_STATS as u64 {
table.record_in(i, 1);
}
let removed = table.cleanup_idle_streams(0);
assert!(removed > 0, "cleanup must reclaim some entries");
let fresh: u64 = 0xDEAD_BEEF_CAFE_F00D;
table.record_in(fresh, 1);
assert!(
table.stream_stats.contains_key(&fresh),
"after cleanup_idle_streams reclaims slots, novel \
stream_ids must be admissible again"
);
}
}