use bytes::{Buf, BufMut, Bytes, BytesMut};
use dashmap::DashMap;
use std::collections::{HashMap, HashSet, VecDeque};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
pub const PINGWAVE_SIZE: usize = 24;
#[allow(dead_code)]
pub const MAX_CAPABILITIES_LEN: usize = 256;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub struct Pingwave {
pub origin_id: u64,
pub seq: u64,
pub ttl: u8,
pub hop_count: u8,
pub _reserved: [u8; 6],
}
impl Pingwave {
pub fn new(origin_id: u64, seq: u64, ttl: u8) -> Self {
Self {
origin_id,
seq,
ttl,
hop_count: 0,
_reserved: [0; 6],
}
}
pub fn to_bytes(&self) -> [u8; PINGWAVE_SIZE] {
let mut buf = [0u8; PINGWAVE_SIZE];
buf[0..8].copy_from_slice(&self.origin_id.to_le_bytes());
buf[8..16].copy_from_slice(&self.seq.to_le_bytes());
buf[16] = self.ttl;
buf[17] = self.hop_count;
buf
}
pub fn from_bytes(buf: &[u8]) -> Option<Self> {
if buf.len() < PINGWAVE_SIZE {
return None;
}
Some(Self {
origin_id: u64::from_le_bytes(buf[0..8].try_into().ok()?),
seq: u64::from_le_bytes(buf[8..16].try_into().ok()?),
ttl: buf[16],
hop_count: buf[17],
_reserved: [0; 6],
})
}
pub fn write_to(&self, buf: &mut BytesMut) {
buf.put_u64_le(self.origin_id);
buf.put_u64_le(self.seq);
buf.put_u8(self.ttl);
buf.put_u8(self.hop_count);
buf.put_slice(&[0u8; 6]); }
pub fn read_from(buf: &mut Bytes) -> Option<Self> {
if buf.remaining() < PINGWAVE_SIZE {
return None;
}
Some(Self {
origin_id: buf.get_u64_le(),
seq: buf.get_u64_le(),
ttl: buf.get_u8(),
hop_count: buf.get_u8(),
_reserved: {
buf.advance(6);
[0; 6]
},
})
}
#[inline]
pub fn is_expired(&self) -> bool {
self.ttl == 0
}
#[inline]
pub fn forward(&mut self) -> bool {
if self.ttl == 0 {
return false;
}
self.ttl -= 1;
self.hop_count = self.hop_count.saturating_add(1);
true
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct Capabilities {
pub gpu: bool,
pub tools: Vec<String>,
pub memory_gb: u32,
pub model_slots: u8,
pub tags: Vec<String>,
}
impl Capabilities {
pub fn new() -> Self {
Self::default()
}
pub fn with_gpu(mut self, gpu: bool) -> Self {
self.gpu = gpu;
self
}
pub fn with_tool(mut self, tool: impl Into<String>) -> Self {
self.tools.push(tool.into());
self
}
pub fn with_memory(mut self, memory_gb: u32) -> Self {
self.memory_gb = memory_gb;
self
}
pub fn with_model_slots(mut self, slots: u8) -> Self {
self.model_slots = slots;
self
}
pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
pub fn has_tool(&self, tool: &str) -> bool {
self.tools.iter().any(|t| t == tool)
}
pub fn has_tag(&self, tag: &str) -> bool {
self.tags.iter().any(|t| t == tag)
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(64);
let flags = if self.gpu { 0x01 } else { 0x00 };
buf.push(flags);
buf.extend_from_slice(&self.memory_gb.to_le_bytes());
buf.push(self.model_slots);
let tool_count = self.tools.len().min(255);
buf.push(tool_count as u8);
for tool in &self.tools[..tool_count] {
let bytes = tool.as_bytes();
let len = bytes.len().min(255);
buf.push(len as u8);
buf.extend_from_slice(&bytes[..len]);
}
let tag_count = self.tags.len().min(255);
buf.push(tag_count as u8);
for tag in &self.tags[..tag_count] {
let bytes = tag.as_bytes();
let len = bytes.len().min(255);
buf.push(len as u8);
buf.extend_from_slice(&bytes[..len]);
}
buf
}
pub fn from_bytes(mut buf: &[u8]) -> Option<Self> {
if buf.len() < 7 {
return None;
}
let flags = buf[0];
let gpu = (flags & 0x01) != 0;
let memory_gb = u32::from_le_bytes(buf[1..5].try_into().ok()?);
let model_slots = buf[5];
let tool_count = buf[6] as usize;
buf = &buf[7..];
let mut tools = Vec::with_capacity(tool_count);
for _ in 0..tool_count {
if buf.is_empty() {
return None;
}
let len = buf[0] as usize;
buf = &buf[1..];
if buf.len() < len {
return None;
}
let tool = std::str::from_utf8(&buf[..len]).ok()?.to_string();
tools.push(tool);
buf = &buf[len..];
}
if buf.is_empty() {
return None;
}
let tag_count = buf[0] as usize;
buf = &buf[1..];
let mut tags = Vec::with_capacity(tag_count);
for _ in 0..tag_count {
if buf.is_empty() {
return None;
}
let len = buf[0] as usize;
buf = &buf[1..];
if buf.len() < len {
return None;
}
let tag = std::str::from_utf8(&buf[..len]).ok()?.to_string();
tags.push(tag);
buf = &buf[len..];
}
Some(Self {
gpu,
tools,
memory_gb,
model_slots,
tags,
})
}
}
#[derive(Debug, Clone)]
pub struct CapabilityAd {
pub node_id: u64,
pub version: u32,
pub capabilities: Capabilities,
}
impl CapabilityAd {
pub fn new(node_id: u64, version: u32, capabilities: Capabilities) -> Self {
Self {
node_id,
version,
capabilities,
}
}
pub fn to_bytes(&self) -> Vec<u8> {
let cap_bytes = self.capabilities.to_bytes();
let mut buf = Vec::with_capacity(12 + cap_bytes.len());
buf.extend_from_slice(&self.node_id.to_le_bytes());
buf.extend_from_slice(&self.version.to_le_bytes());
buf.extend_from_slice(&cap_bytes);
buf
}
pub fn from_bytes(buf: &[u8]) -> Option<Self> {
if buf.len() < 12 {
return None;
}
let node_id = u64::from_le_bytes(buf[0..8].try_into().ok()?);
let version = u32::from_le_bytes(buf[8..12].try_into().ok()?);
let capabilities = Capabilities::from_bytes(&buf[12..])?;
Some(Self {
node_id,
version,
capabilities,
})
}
}
#[derive(Debug, Clone)]
pub struct NodeInfo {
pub node_id: u64,
pub addr: SocketAddr,
pub hops: u8,
pub last_seen: Instant,
pub last_seq: u64,
pub capabilities: Option<Capabilities>,
pub cap_version: u32,
}
impl NodeInfo {
pub fn new(node_id: u64, addr: SocketAddr, hops: u8) -> Self {
Self {
node_id,
addr,
hops,
last_seen: Instant::now(),
last_seq: 0,
capabilities: None,
cap_version: 0,
}
}
pub fn touch(&mut self) {
self.last_seen = Instant::now();
}
pub fn is_stale(&self, timeout: Duration) -> bool {
self.last_seen.elapsed() > timeout
}
pub fn update_capabilities(&mut self, version: u32, caps: Capabilities) -> bool {
if version > self.cap_version {
self.capabilities = Some(caps);
self.cap_version = version;
true
} else {
false
}
}
}
#[derive(Debug, Clone)]
pub struct EdgeInfo {
pub from: u64,
pub to: u64,
pub latency_us: u32,
pub last_updated: Instant,
}
impl EdgeInfo {
pub fn new(from: u64, to: u64) -> Self {
Self {
from,
to,
latency_us: 0,
last_updated: Instant::now(),
}
}
pub fn with_latency(from: u64, to: u64, latency_us: u32) -> Self {
Self {
from,
to,
latency_us,
last_updated: Instant::now(),
}
}
}
pub const MAX_GRAPH_NODES: usize = 65_536;
pub const MAX_SEEN_PINGWAVES: usize = 262_144;
pub struct LocalGraph {
my_id: u64,
radius: u8,
nodes: DashMap<u64, NodeInfo>,
edges: DashMap<(u64, u64), EdgeInfo>,
seen_pingwaves: DashMap<(u64, u64), Instant>,
next_seq: AtomicU64,
node_timeout: Duration,
pingwave_cache_timeout: Duration,
}
impl LocalGraph {
pub fn new(my_id: u64, radius: u8) -> Self {
Self {
my_id,
radius,
nodes: DashMap::new(),
edges: DashMap::new(),
seen_pingwaves: DashMap::new(),
next_seq: AtomicU64::new(1),
node_timeout: Duration::from_secs(30),
pingwave_cache_timeout: Duration::from_secs(10),
}
}
pub fn with_node_timeout(mut self, timeout: Duration) -> Self {
self.node_timeout = timeout;
self
}
pub fn my_id(&self) -> u64 {
self.my_id
}
pub fn radius(&self) -> u8 {
self.radius
}
pub fn create_pingwave(&self) -> Pingwave {
let seq = self.next_seq.fetch_add(1, Ordering::Relaxed);
Pingwave::new(self.my_id, seq, self.radius)
}
pub fn on_pingwave(&self, mut pw: Pingwave, from: SocketAddr) -> Option<Pingwave> {
if pw.origin_id == self.my_id {
return None;
}
let key = (pw.origin_id, pw.seq);
if self.seen_pingwaves.contains_key(&key) {
return None;
}
if self.seen_pingwaves.len() >= MAX_SEEN_PINGWAVES {
return None;
}
self.seen_pingwaves.insert(key, Instant::now());
let hops = pw.hop_count.saturating_add(1);
if !self.nodes.contains_key(&pw.origin_id) && self.nodes.len() >= MAX_GRAPH_NODES {
return None;
}
self.nodes
.entry(pw.origin_id)
.and_modify(|n| {
let likely_restart = n.last_seq > 1 && pw.seq < n.last_seq.saturating_div(2);
let strict_progress = pw.seq >= n.last_seq && hops <= n.hops;
let strict_progress = strict_progress && (pw.seq > n.last_seq || hops < n.hops);
if strict_progress {
n.last_seq = pw.seq;
n.hops = hops;
n.addr = from;
n.touch();
} else if likely_restart {
n.touch();
}
})
.or_insert_with(|| {
let mut info = NodeInfo::new(pw.origin_id, from, hops);
info.last_seq = pw.seq;
info
});
if pw.is_expired() {
return None;
}
pw.forward();
Some(pw)
}
pub fn on_capability(&self, ad: CapabilityAd, from: SocketAddr) {
self.nodes
.entry(ad.node_id)
.and_modify(|n| {
n.update_capabilities(ad.version, ad.capabilities.clone());
})
.or_insert_with(|| {
let mut info = NodeInfo::new(ad.node_id, from, 0);
info.update_capabilities(ad.version, ad.capabilities.clone());
info
});
}
pub fn add_edge(&self, from: u64, to: u64, latency_us: u32) {
let key = (from, to);
self.edges
.entry(key)
.and_modify(|e| {
e.latency_us = latency_us;
e.last_updated = Instant::now();
})
.or_insert_with(|| EdgeInfo::with_latency(from, to, latency_us));
}
pub fn get_node(&self, node_id: u64) -> Option<NodeInfo> {
self.nodes.get(&node_id).map(|r| r.clone())
}
pub fn all_nodes(&self) -> Vec<NodeInfo> {
self.nodes.iter().map(|r| r.value().clone()).collect()
}
pub fn nodes_within_hops(&self, max_hops: u8) -> Vec<NodeInfo> {
self.nodes
.iter()
.filter(|r| r.hops <= max_hops)
.map(|r| r.value().clone())
.collect()
}
pub fn find_by_tool(&self, tool: &str) -> Vec<NodeInfo> {
self.nodes
.iter()
.filter(|r| {
r.capabilities
.as_ref()
.map(|c| c.has_tool(tool))
.unwrap_or(false)
})
.map(|r| r.value().clone())
.collect()
}
pub fn find_by_tag(&self, tag: &str) -> Vec<NodeInfo> {
self.nodes
.iter()
.filter(|r| {
r.capabilities
.as_ref()
.map(|c| c.has_tag(tag))
.unwrap_or(false)
})
.map(|r| r.value().clone())
.collect()
}
pub fn find_with_gpu(&self) -> Vec<NodeInfo> {
self.nodes
.iter()
.filter(|r| r.capabilities.as_ref().map(|c| c.gpu).unwrap_or(false))
.map(|r| r.value().clone())
.collect()
}
pub fn path_to(&self, dest: u64) -> Option<Vec<u64>> {
if dest == self.my_id {
return Some(vec![self.my_id]);
}
let mut adjacency: HashMap<u64, Vec<u64>> = HashMap::new();
for edge in self.edges.iter() {
adjacency.entry(edge.from).or_default().push(edge.to);
}
let mut parent: HashMap<u64, u64> = HashMap::new();
let mut visited: HashSet<u64> = HashSet::new();
let mut queue: VecDeque<u64> = VecDeque::new();
queue.push_back(self.my_id);
visited.insert(self.my_id);
while let Some(current) = queue.pop_front() {
if current == dest {
let mut path = vec![current];
let mut node = current;
while node != self.my_id {
node = *parent.get(&node)?;
path.push(node);
}
path.reverse();
return Some(path);
}
if let Some(neighbors) = adjacency.get(¤t) {
for &neighbor in neighbors {
if visited.insert(neighbor) {
parent.insert(neighbor, current);
queue.push_back(neighbor);
}
}
}
}
None
}
pub fn cleanup(&self) -> (usize, usize) {
let mut removed_nodes = 0;
let mut removed_pingwaves = 0;
self.nodes.retain(|_, node| {
if node.is_stale(self.node_timeout) {
removed_nodes += 1;
false
} else {
true
}
});
self.seen_pingwaves.retain(|_, instant| {
if instant.elapsed() > self.pingwave_cache_timeout {
removed_pingwaves += 1;
false
} else {
true
}
});
(removed_nodes, removed_pingwaves)
}
pub fn stats(&self) -> GraphStats {
GraphStats {
node_count: self.nodes.len(),
edge_count: self.edges.len(),
pingwave_cache_size: self.seen_pingwaves.len(),
}
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
}
impl std::fmt::Debug for LocalGraph {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalGraph")
.field("my_id", &format!("{:016x}", self.my_id))
.field("radius", &self.radius)
.field("nodes", &self.nodes.len())
.field("edges", &self.edges.len())
.finish()
}
}
#[derive(Debug, Clone, Default)]
pub struct GraphStats {
pub node_count: usize,
pub edge_count: usize,
pub pingwave_cache_size: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pingwave_roundtrip() {
let pw = Pingwave::new(0x123456789ABCDEF0, 42, 3);
let bytes = pw.to_bytes();
let parsed = Pingwave::from_bytes(&bytes).unwrap();
assert_eq!(pw, parsed);
}
#[test]
fn test_pingwave_forward() {
let mut pw = Pingwave::new(0x1234, 1, 3);
assert_eq!(pw.ttl, 3);
assert_eq!(pw.hop_count, 0);
assert!(pw.forward());
assert_eq!(pw.ttl, 2);
assert_eq!(pw.hop_count, 1);
assert!(pw.forward());
assert!(pw.forward());
assert_eq!(pw.ttl, 0);
assert_eq!(pw.hop_count, 3);
assert!(!pw.forward());
}
#[test]
fn test_capabilities_roundtrip() {
let caps = Capabilities::new()
.with_gpu(true)
.with_memory(16)
.with_model_slots(4)
.with_tool("python")
.with_tool("rust")
.with_tag("inference")
.with_tag("training");
let bytes = caps.to_bytes();
let parsed = Capabilities::from_bytes(&bytes).unwrap();
assert_eq!(caps.gpu, parsed.gpu);
assert_eq!(caps.memory_gb, parsed.memory_gb);
assert_eq!(caps.model_slots, parsed.model_slots);
assert_eq!(caps.tools, parsed.tools);
assert_eq!(caps.tags, parsed.tags);
}
#[test]
fn on_pingwave_drops_novel_entries_when_seen_pingwaves_is_at_cap() {
let graph = LocalGraph::new(0x1, 8);
let from: SocketAddr = "127.0.0.1:9000".parse().unwrap();
for i in 0..MAX_SEEN_PINGWAVES as u64 {
graph
.seen_pingwaves
.insert((0xDEAD_BEEF_0000 + i, 0), Instant::now());
}
assert_eq!(graph.seen_pingwaves.len(), MAX_SEEN_PINGWAVES);
let novel_pw = Pingwave::new(0xCAFE, 1, 3);
let result = graph.on_pingwave(novel_pw, from);
assert!(
result.is_none(),
"novel pingwave at cap must NOT be forwarded"
);
assert!(
!graph.seen_pingwaves.contains_key(&(0xCAFE, 1)),
"novel pingwave must NOT be inserted at cap"
);
assert!(
!graph.nodes.contains_key(&0xCAFE),
"novel origin must NOT be inserted at cap"
);
}
#[test]
fn on_pingwave_likely_restart_only_touches_does_not_lower_last_seq() {
let graph = LocalGraph::new(0x1, 8);
let from: SocketAddr = "127.0.0.1:9000".parse().unwrap();
for seq in [100u64, 200, 500, 1000].iter() {
let pw = Pingwave::new(0xCAFE, *seq, 3);
graph.on_pingwave(pw, from);
}
let pre_restart_last_seq = graph.nodes.get(&0xCAFE).map(|n| n.last_seq).unwrap();
assert_eq!(pre_restart_last_seq, 1000);
let restart_from: SocketAddr = "127.0.0.1:9001".parse().unwrap();
let pw = Pingwave::new(0xCAFE, 1, 3);
graph.on_pingwave(pw, restart_from);
let (post_restart_last_seq, post_restart_addr) = graph
.nodes
.get(&0xCAFE)
.map(|n| (n.last_seq, n.addr))
.unwrap();
assert_eq!(
post_restart_last_seq, 1000,
"Cubic P1: unauthenticated restart-only path must NOT lower \
last_seq; the high-water mark is the only credential blocking \
a spoofed seq=2 from looking like strict progress"
);
assert_eq!(
post_restart_addr, from,
"CR-6: address must NOT auto-update on the restart-only path"
);
let pw2 = Pingwave::new(0xCAFE, 1001, 3);
graph.on_pingwave(pw2, restart_from);
let (final_last_seq, final_addr) = graph
.nodes
.get(&0xCAFE)
.map(|n| (n.last_seq, n.addr))
.unwrap();
assert_eq!(final_last_seq, 1001);
assert_eq!(
final_addr, restart_from,
"strict-progress pingwave (seq > prior high-water) updates addr"
);
}
#[test]
fn on_pingwave_likely_restart_must_not_overwrite_addr() {
let graph = LocalGraph::new(0x1, 8);
let legit: SocketAddr = "10.0.0.5:9000".parse().unwrap();
for seq in [100u64, 500, 1000].iter() {
let pw = Pingwave::new(0xBEEF, *seq, 3);
graph.on_pingwave(pw, legit);
}
assert_eq!(
graph.nodes.get(&0xBEEF).map(|n| n.addr).unwrap(),
legit,
"sanity: legit addr is recorded"
);
let attacker: SocketAddr = "192.0.2.99:31337".parse().unwrap();
let spoof = Pingwave::new(0xBEEF, 1, 3);
graph.on_pingwave(spoof, attacker);
let (recorded_addr, recorded_last_seq) = graph
.nodes
.get(&0xBEEF)
.map(|n| (n.addr, n.last_seq))
.unwrap();
assert_eq!(
recorded_addr, legit,
"CR-6: spoofed restart MUST NOT repoint the recorded address \
to the attacker; got {:?}",
recorded_addr
);
assert_eq!(
recorded_last_seq, 1000,
"Cubic P1: last_seq must STAY at the pre-spoof high-water mark; \
a lowered last_seq lets a follow-up seq=2 spoof masquerade as \
strict progress and overwrite n.addr"
);
}
#[test]
fn on_pingwave_below_last_seq_with_shorter_hops_does_not_overwrite_addr() {
let graph = LocalGraph::new(0x1, 8);
let legit: SocketAddr = "10.0.0.5:9000".parse().unwrap();
for seq in [100u64, 500, 1000].iter() {
let mut pw = Pingwave::new(0xBEEF, *seq, 8);
pw.hop_count = 2;
graph.on_pingwave(pw, legit);
}
assert_eq!(
graph
.nodes
.get(&0xBEEF)
.map(|n| (n.addr, n.last_seq, n.hops))
.unwrap(),
(legit, 1000, 3),
);
let attacker: SocketAddr = "192.0.2.99:31337".parse().unwrap();
let spoof = Pingwave::new(0xBEEF, 800, 8);
graph.on_pingwave(spoof, attacker);
let (recorded_addr, recorded_last_seq, recorded_hops) = graph
.nodes
.get(&0xBEEF)
.map(|n| (n.addr, n.last_seq, n.hops))
.unwrap();
assert_eq!(
recorded_addr, legit,
"stale-seq + shorter-hops spoof must NOT repoint addr; \
got {:?}",
recorded_addr,
);
assert_eq!(recorded_last_seq, 1000, "last_seq must not regress");
assert_eq!(recorded_hops, 3, "hops must not be lowered by stale seq");
}
#[test]
fn on_pingwave_ignores_small_seq_regression_without_restart_signal() {
let graph = LocalGraph::new(0x1, 8);
let from: SocketAddr = "127.0.0.1:9000".parse().unwrap();
for seq in [10u64, 20].iter() {
let pw = Pingwave::new(0xCAFE, *seq, 3);
graph.on_pingwave(pw, from);
}
let pw = Pingwave::new(0xCAFE, 15, 3);
graph.on_pingwave(pw, from);
let last_seq = graph.nodes.get(&0xCAFE).map(|n| n.last_seq).unwrap();
assert_eq!(
last_seq, 20,
"small seq regression (out-of-order) must NOT update last_seq"
);
}
#[test]
fn on_pingwave_drops_novel_origin_when_nodes_is_at_cap() {
let graph = LocalGraph::new(0x1, 8);
let from: SocketAddr = "127.0.0.1:9000".parse().unwrap();
for i in 0..MAX_GRAPH_NODES as u64 {
let id = 0xDEAD_BEEF_0000 + i;
graph.nodes.insert(id, NodeInfo::new(id, from, 1));
}
assert_eq!(graph.nodes.len(), MAX_GRAPH_NODES);
let novel_pw = Pingwave::new(0xFACE, 1, 3);
graph.on_pingwave(novel_pw, from);
assert!(
!graph.nodes.contains_key(&0xFACE),
"novel origin at cap must NOT be inserted"
);
let existing_id = 0xDEAD_BEEF_0000u64;
let existing_pw = Pingwave::new(existing_id, 99, 3);
let pre_seq = graph.nodes.get(&existing_id).unwrap().last_seq;
graph.on_pingwave(existing_pw, from);
let post_seq = graph.nodes.get(&existing_id).unwrap().last_seq;
assert!(
post_seq > pre_seq,
"already-known origin must keep updating despite cap"
);
}
#[test]
fn test_capability_ad_roundtrip() {
let caps = Capabilities::new().with_gpu(true).with_tool("test");
let ad = CapabilityAd::new(0x1234, 5, caps);
let bytes = ad.to_bytes();
let parsed = CapabilityAd::from_bytes(&bytes).unwrap();
assert_eq!(ad.node_id, parsed.node_id);
assert_eq!(ad.version, parsed.version);
assert_eq!(ad.capabilities.gpu, parsed.capabilities.gpu);
}
#[test]
fn test_capabilities_large_strings_capped() {
let long_tool = "x".repeat(300);
let caps = Capabilities::new().with_tool(&long_tool);
let bytes = caps.to_bytes();
let parsed = Capabilities::from_bytes(&bytes).unwrap();
assert_eq!(parsed.tools.len(), 1);
assert_eq!(parsed.tools[0].len(), 255);
}
#[test]
fn test_capabilities_many_items_capped() {
let mut caps = Capabilities::new();
for i in 0..300 {
caps = caps.with_tool(format!("t{}", i));
}
let bytes = caps.to_bytes();
let parsed = Capabilities::from_bytes(&bytes).unwrap();
assert_eq!(parsed.tools.len(), 255);
}
#[test]
fn test_local_graph_pingwave_processing() {
let graph = LocalGraph::new(0x1111, 3);
let pw = Pingwave::new(0x2222, 1, 3);
let from: SocketAddr = "127.0.0.1:9000".parse().unwrap();
let forwarded = graph.on_pingwave(pw, from);
assert!(forwarded.is_some());
let node = graph.get_node(0x2222).unwrap();
assert_eq!(node.hops, 1);
assert_eq!(node.addr, from);
let pw2 = Pingwave::new(0x2222, 1, 3);
assert!(graph.on_pingwave(pw2, from).is_none());
let pw3 = Pingwave::new(0x2222, 2, 3);
assert!(graph.on_pingwave(pw3, from).is_some());
}
#[test]
fn test_local_graph_capability_search() {
let graph = LocalGraph::new(0x1111, 3);
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
graph.nodes.insert(0x2222, NodeInfo::new(0x2222, addr, 1));
graph.nodes.insert(0x3333, NodeInfo::new(0x3333, addr, 2));
let caps1 = Capabilities::new().with_gpu(true).with_tool("python");
let caps2 = Capabilities::new().with_gpu(false).with_tool("rust");
graph.on_capability(CapabilityAd::new(0x2222, 1, caps1), addr);
graph.on_capability(CapabilityAd::new(0x3333, 1, caps2), addr);
let python_nodes = graph.find_by_tool("python");
assert_eq!(python_nodes.len(), 1);
assert_eq!(python_nodes[0].node_id, 0x2222);
let gpu_nodes = graph.find_with_gpu();
assert_eq!(gpu_nodes.len(), 1);
assert_eq!(gpu_nodes[0].node_id, 0x2222);
}
#[test]
fn test_capability_ad_creates_unknown_node() {
let graph = LocalGraph::new(0x1111, 3);
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
assert!(graph.get_node(0x2222).is_none());
let caps = Capabilities::new().with_gpu(true).with_tool("python");
graph.on_capability(CapabilityAd::new(0x2222, 1, caps), addr);
let node = graph.get_node(0x2222);
assert!(node.is_some(), "node should be created from capability ad");
let node = node.unwrap();
assert!(node.capabilities.is_some());
let gpu_nodes = graph.find_with_gpu();
assert_eq!(gpu_nodes.len(), 1);
assert_eq!(gpu_nodes[0].node_id, 0x2222);
}
#[test]
fn test_local_graph_path_finding() {
let graph = LocalGraph::new(0x1111, 3);
graph.add_edge(0x1111, 0x2222, 100);
graph.add_edge(0x2222, 0x3333, 100);
graph.add_edge(0x3333, 0x4444, 100);
let path = graph.path_to(0x4444);
assert!(path.is_some());
let path = path.unwrap();
assert_eq!(path, vec![0x1111, 0x2222, 0x3333, 0x4444]);
let no_path = graph.path_to(0x9999);
assert!(no_path.is_none());
}
#[test]
fn pingwave_write_to_and_read_from_roundtrip() {
let pw = Pingwave::new(0xDEADBEEF_CAFEBABE, 12345, 7);
let mut buf = BytesMut::new();
pw.write_to(&mut buf);
assert_eq!(buf.len(), PINGWAVE_SIZE);
let mut bytes = buf.freeze();
let parsed = Pingwave::read_from(&mut bytes).expect("roundtrip parse must succeed");
assert_eq!(parsed.origin_id, pw.origin_id);
assert_eq!(parsed.seq, pw.seq);
assert_eq!(parsed.ttl, pw.ttl);
assert_eq!(parsed.hop_count, pw.hop_count);
assert_eq!(bytes.remaining(), 0);
}
#[test]
fn pingwave_read_from_returns_none_on_truncated_buffer() {
let mut buf = Bytes::from(vec![0u8; PINGWAVE_SIZE - 1]);
assert!(Pingwave::read_from(&mut buf).is_none());
}
fn populate_graph_for_filter_tests() -> (LocalGraph, SocketAddr) {
let graph = LocalGraph::new(0x1111, 3);
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
graph.nodes.insert(0x2222, NodeInfo::new(0x2222, addr, 1));
graph.nodes.insert(0x3333, NodeInfo::new(0x3333, addr, 2));
graph.nodes.insert(0x4444, NodeInfo::new(0x4444, addr, 3));
let caps_gpu = Capabilities::new().with_gpu(true).with_tag("inference");
let caps_cpu = Capabilities::new().with_gpu(false).with_tag("training");
graph.on_capability(CapabilityAd::new(0x2222, 1, caps_gpu), addr);
graph.on_capability(CapabilityAd::new(0x3333, 1, caps_cpu), addr);
(graph, addr)
}
#[test]
fn find_by_tag_returns_only_matching_nodes() {
let (graph, _) = populate_graph_for_filter_tests();
let inference = graph.find_by_tag("inference");
assert_eq!(inference.len(), 1);
assert_eq!(inference[0].node_id, 0x2222);
let training = graph.find_by_tag("training");
assert_eq!(training.len(), 1);
assert_eq!(training[0].node_id, 0x3333);
assert!(graph.find_by_tag("nonexistent").is_empty());
}
#[test]
fn nodes_within_hops_filters_by_hop_distance() {
let (graph, _) = populate_graph_for_filter_tests();
let within_1 = graph.nodes_within_hops(1);
assert_eq!(within_1.len(), 1);
assert_eq!(within_1[0].node_id, 0x2222);
let within_2 = graph.nodes_within_hops(2);
assert_eq!(within_2.len(), 2);
let within_3 = graph.nodes_within_hops(3);
assert_eq!(within_3.len(), 3);
let within_0 = graph.nodes_within_hops(0);
assert!(within_0.is_empty());
}
#[test]
fn local_graph_debug_format_includes_id_radius_and_counts() {
let graph = LocalGraph::new(0xABCD_1234_5678_9000, 5);
let s = format!("{:?}", graph);
assert!(s.contains("LocalGraph"));
assert!(s.contains("radius: 5"));
assert!(s.contains("nodes: 0"));
assert!(s.contains("edges: 0"));
assert!(s.contains("abcd123456789000"), "got: {s}");
}
#[test]
fn edge_info_new_and_with_latency_set_fields_correctly() {
let e = EdgeInfo::new(1, 2);
assert_eq!(e.from, 1);
assert_eq!(e.to, 2);
assert_eq!(e.latency_us, 0);
let e = EdgeInfo::with_latency(7, 9, 500);
assert_eq!(e.from, 7);
assert_eq!(e.to, 9);
assert_eq!(e.latency_us, 500);
}
}