use super::*;
use crate::protocol::TreeAnnounce;
use crate::tree::{CoordEntry, ParentDeclaration, TreeCoordinate};
static LARGE_NETWORK_TEST_LOCK: std::sync::LazyLock<tokio::sync::Mutex<()>> =
std::sync::LazyLock::new(|| tokio::sync::Mutex::new(()));
pub(super) async fn lock_large_network_test() -> tokio::sync::MutexGuard<'static, ()> {
LARGE_NETWORK_TEST_LOCK.lock().await
}
pub(super) struct TestNode {
pub(super) node: Node,
pub(super) transport_id: TransportId,
pub(super) packet_rx: PacketRx,
pub(super) addr: TransportAddr,
}
pub(super) async fn make_test_node() -> TestNode {
make_test_node_with_mtu(1280).await
}
pub(super) async fn make_test_node_with_mtu(mtu: u16) -> TestNode {
use crate::config::UdpConfig;
use crate::transport::udp::UdpTransport;
let mut node = make_node();
let transport_id = TransportId::new(1);
let udp_config = UdpConfig {
bind_addr: Some("127.0.0.1:0".to_string()),
mtu: Some(mtu),
..Default::default()
};
let (packet_tx, packet_rx) = packet_channel(256);
let mut transport = UdpTransport::new(transport_id, None, udp_config, packet_tx);
transport.start_async().await.unwrap();
let addr = TransportAddr::from_string(&transport.local_addr().unwrap().to_string());
node.transports
.insert(transport_id, TransportHandle::Udp(transport));
TestNode {
node,
transport_id,
packet_rx,
addr,
}
}
pub(super) async fn initiate_handshake(nodes: &mut [TestNode], i: usize, j: usize) {
use crate::node::wire::build_msg1;
let responder_addr = nodes[j].addr.clone();
let responder_pubkey_full = nodes[j].node.identity().pubkey_full();
let peer_identity = PeerIdentity::from_pubkey_full(responder_pubkey_full);
let initiator = &mut nodes[i];
let transport_id = initiator.transport_id;
let link_id = initiator.node.allocate_link_id();
let mut conn = PeerConnection::outbound(link_id, peer_identity, 1000);
let our_index = initiator.node.index_allocator.allocate().unwrap();
let our_keypair = initiator.node.identity().keypair();
let noise_msg1 = conn
.start_handshake(our_keypair, initiator.node.startup_epoch, 1000)
.unwrap();
conn.set_our_index(our_index);
conn.set_transport_id(transport_id);
conn.set_source_addr(responder_addr.clone());
let wire_msg1 = build_msg1(our_index, &noise_msg1);
let link = Link::connectionless(
link_id,
transport_id,
responder_addr.clone(),
LinkDirection::Outbound,
Duration::from_millis(100),
);
initiator.node.links.insert(link_id, link);
initiator
.node
.addr_to_link
.insert((transport_id, responder_addr.clone()), link_id);
initiator.node.connections.insert(link_id, conn);
initiator
.node
.pending_outbound
.insert((transport_id, our_index.as_u32()), link_id);
let transport = initiator.node.transports.get(&transport_id).unwrap();
transport
.send(&responder_addr, &wire_msg1)
.await
.expect("Failed to send msg1");
}
pub(super) fn print_tree_snapshot(label: &str, nodes: &[TestNode]) {
eprintln!("\n --- {} ---", label);
let expected_root = nodes.iter().map(|tn| *tn.node.node_addr()).min().unwrap();
let expected_root_idx = nodes
.iter()
.position(|tn| *tn.node.node_addr() == expected_root)
.unwrap();
let correct_root_count = nodes
.iter()
.filter(|tn| *tn.node.tree_state().root() == expected_root)
.count();
let total_pending: usize = nodes
.iter()
.map(|tn| {
tn.node
.peers
.values()
.filter(|p| p.has_pending_tree_announce())
.count()
})
.sum();
let mut depth_counts = std::collections::BTreeMap::new();
for tn in nodes {
*depth_counts
.entry(tn.node.tree_state().my_coords().depth())
.or_insert(0usize) += 1;
}
let depth_str: Vec<String> = depth_counts
.iter()
.map(|(d, c)| format!("d{}={}", d, c))
.collect();
let mut roots = std::collections::BTreeSet::new();
for tn in nodes {
roots.insert(*tn.node.tree_state().root());
}
eprintln!(
" converged={}/{} roots={} depths=[{}] pending={}",
correct_root_count,
nodes.len(),
roots.len(),
depth_str.join(" "),
total_pending,
);
if nodes.len() <= 20 {
for (i, tn) in nodes.iter().enumerate() {
let ts = tn.node.tree_state();
let parent_idx = if ts.is_root() {
"self".to_string()
} else {
nodes
.iter()
.position(|n| n.node.node_addr() == ts.my_declaration().parent_id())
.map(|p| format!("{}", p))
.unwrap_or_else(|| format!("?{}", ts.my_declaration().parent_id()))
};
let root_idx = nodes
.iter()
.position(|n| n.node.node_addr() == ts.root())
.map(|r| format!("{}", r))
.unwrap_or_else(|| format!("?{}", ts.root()));
let pending = tn
.node
.peers
.values()
.filter(|p| p.has_pending_tree_announce())
.count();
eprintln!(
" node[{}] root=node[{}] depth={} parent=node[{}] peers={} pending={}",
i,
root_idx,
ts.my_coords().depth(),
parent_idx,
tn.node.peer_count(),
pending,
);
}
} else if correct_root_count < nodes.len() {
let wrong: Vec<usize> = nodes
.iter()
.enumerate()
.filter(|(_, tn)| *tn.node.tree_state().root() != expected_root)
.map(|(i, _)| i)
.collect();
if wrong.len() <= 20 {
eprintln!(" unconverged nodes: {:?}", wrong);
} else {
eprintln!(" unconverged nodes: {} remaining", wrong.len());
}
}
let _ = expected_root_idx; }
pub(super) async fn process_available_packets(nodes: &mut [TestNode]) -> usize {
use crate::node::wire::{
COMMON_PREFIX_SIZE, CommonPrefix, FMP_VERSION, PHASE_ESTABLISHED, PHASE_MSG1, PHASE_MSG2,
};
let mut count = 0;
for node in nodes.iter_mut() {
while let Ok(packet) = node.packet_rx.try_recv() {
if packet.data.len() < COMMON_PREFIX_SIZE {
continue;
}
if let Some(prefix) = CommonPrefix::parse(&packet.data) {
if prefix.version != FMP_VERSION {
continue;
}
match prefix.phase {
PHASE_MSG1 => node.node.handle_msg1(packet).await,
PHASE_MSG2 => node.node.handle_msg2(packet).await,
PHASE_ESTABLISHED => node.node.handle_encrypted_frame(packet).await,
_ => {}
}
count += 1;
}
}
}
count
}
pub(super) async fn drain_all_packets(nodes: &mut [TestNode], verbose: bool) -> usize {
let mut total = 0;
for _round in 0..200 {
tokio::time::sleep(Duration::from_millis(10)).await;
let count = process_available_packets(nodes).await;
total += count;
if count == 0 {
break;
}
}
if verbose {
print_tree_snapshot(
&format!("After handshakes + initial announces ({} packets)", total),
nodes,
);
}
for flush in 0..20 {
tokio::time::sleep(Duration::from_millis(550)).await;
for tn in nodes.iter_mut() {
tn.node.send_pending_tree_announces().await;
tn.node.send_pending_filter_announces().await;
}
tokio::time::sleep(Duration::from_millis(20)).await;
let mut flush_total = process_available_packets(nodes).await;
for _sub in 0..20 {
tokio::time::sleep(Duration::from_millis(10)).await;
let count = process_available_packets(nodes).await;
flush_total += count;
if count == 0 {
break;
}
}
total += flush_total;
if flush_total == 0 {
break;
}
if verbose {
print_tree_snapshot(
&format!("After flush cycle {} ({} packets)", flush + 1, flush_total),
nodes,
);
}
}
total
}
async fn repair_missing_edge_handshakes(
nodes: &mut [TestNode],
edges: &[(usize, usize)],
verbose: bool,
) -> usize {
let mut retries = 0;
for attempt in 0..5 {
let mut missing = Vec::new();
for &(i, j) in edges {
let j_addr = *nodes[j].node.node_addr();
let i_addr = *nodes[i].node.node_addr();
let i_has_j = nodes[i].node.get_peer(&j_addr).is_some();
let j_has_i = nodes[j].node.get_peer(&i_addr).is_some();
if !i_has_j || !j_has_i {
missing.push((i, j, i_has_j, j_has_i));
}
}
if missing.is_empty() {
break;
}
if verbose {
eprintln!(
" Repairing {} missing synthetic edge handshake(s), attempt {}",
missing.len(),
attempt + 1
);
}
for (i, j, i_has_j, j_has_i) in missing {
if !i_has_j {
initiate_handshake(nodes, i, j).await;
retries += 1;
let _ = drain_all_packets(nodes, false).await;
}
let j_addr = *nodes[j].node.node_addr();
let i_addr = *nodes[i].node.node_addr();
let j_still_missing_i = nodes[j].node.get_peer(&i_addr).is_none();
let i_still_missing_j = nodes[i].node.get_peer(&j_addr).is_none();
if !j_has_i && j_still_missing_i {
initiate_handshake(nodes, j, i).await;
retries += 1;
let _ = drain_all_packets(nodes, false).await;
} else if i_still_missing_j {
initiate_handshake(nodes, i, j).await;
retries += 1;
let _ = drain_all_packets(nodes, false).await;
}
}
}
retries
}
pub(super) fn generate_random_edges(
n: usize,
target_edges: usize,
seed: u64,
) -> Vec<(usize, usize)> {
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
let mut rng = StdRng::seed_from_u64(seed);
let mut edges = Vec::new();
let mut adj = vec![vec![false; n]; n];
let mut connected = vec![false; n];
connected[0] = true;
let mut connected_count = 1;
while connected_count < n {
let from = rng.random_range(0..n);
if !connected[from] {
continue;
}
let to = rng.random_range(0..n);
if connected[to] || from == to {
continue;
}
edges.push((from, to));
adj[from][to] = true;
adj[to][from] = true;
connected[to] = true;
connected_count += 1;
}
let mut attempts = 0;
while edges.len() < target_edges && attempts < target_edges * 10 {
let a = rng.random_range(0..n);
let b = rng.random_range(0..n);
attempts += 1;
if a == b || adj[a][b] {
continue;
}
edges.push((a, b));
adj[a][b] = true;
adj[b][a] = true;
}
edges
}
pub(super) fn verify_tree_convergence(nodes: &[TestNode]) {
let n = nodes.len();
assert!(n > 0);
let expected_root = nodes.iter().map(|tn| *tn.node.node_addr()).min().unwrap();
for (i, tn) in nodes.iter().enumerate() {
let ts = tn.node.tree_state();
assert_eq!(
*ts.root(),
expected_root,
"Node {} (addr={}) has root {} but expected {}",
i,
tn.node.node_addr(),
ts.root(),
expected_root
);
}
let root_node = nodes
.iter()
.find(|tn| *tn.node.node_addr() == expected_root)
.unwrap();
assert!(
root_node.node.tree_state().is_root(),
"Expected root node should have is_root = true"
);
assert_eq!(
root_node.node.tree_state().my_coords().depth(),
0,
"Root node should have depth 0"
);
for (i, tn) in nodes.iter().enumerate() {
let ts = tn.node.tree_state();
if *tn.node.node_addr() != expected_root {
assert!(
ts.my_coords().depth() > 0,
"Non-root node {} should have depth > 0, got {}",
i,
ts.my_coords().depth()
);
}
}
for (i, tn) in nodes.iter().enumerate() {
let ts = tn.node.tree_state();
if ts.is_root() {
continue;
}
let parent_id = ts.my_declaration().parent_id();
assert!(
tn.node.get_peer(parent_id).is_some(),
"Node {}'s parent {} should be in its peer list",
i,
parent_id
);
}
for (i, tn) in nodes.iter().enumerate() {
let coords = tn.node.tree_state().my_coords();
assert_eq!(
*coords.root_id(),
expected_root,
"Node {}'s coordinate root {} should match expected root {}",
i,
coords.root_id(),
expected_root
);
}
for (i, tn) in nodes.iter().enumerate() {
let ts = tn.node.tree_state();
if ts.is_root() {
continue;
}
let my_depth = ts.my_coords().depth();
let parent_id = ts.my_declaration().parent_id();
if let Some(parent_node) = nodes.iter().find(|pn| pn.node.node_addr() == parent_id) {
let parent_depth = parent_node.node.tree_state().my_coords().depth();
assert_eq!(
my_depth,
parent_depth + 1,
"Node {}'s depth ({}) should be parent's depth ({}) + 1",
i,
my_depth,
parent_depth
);
}
}
}
pub(super) fn verify_tree_convergence_components(nodes: &[TestNode], components: &[Vec<usize>]) {
for component in components {
let component_nodes: Vec<&TestNode> = component.iter().map(|&i| &nodes[i]).collect();
let expected_root = component_nodes
.iter()
.map(|tn| *tn.node.node_addr())
.min()
.unwrap();
for &idx in component {
let ts = nodes[idx].node.tree_state();
assert_eq!(
*ts.root(),
expected_root,
"Node {} in component should have root {}",
idx,
expected_root
);
}
}
}
pub(super) async fn run_tree_test(
num_nodes: usize,
edges: &[(usize, usize)],
verbose: bool,
) -> Vec<TestNode> {
let mut nodes = Vec::new();
for _ in 0..num_nodes {
nodes.push(make_test_node().await);
}
if verbose {
eprintln!(
"\n === Spanning Tree Convergence ({} nodes, {} edges) ===",
num_nodes,
edges.len()
);
let expected_root = nodes.iter().map(|tn| *tn.node.node_addr()).min().unwrap();
let root_idx = nodes
.iter()
.position(|tn| *tn.node.node_addr() == expected_root)
.unwrap();
eprintln!(" Expected root: node[{}] = {}", root_idx, expected_root);
let mut degree = vec![0usize; num_nodes];
for &(i, j) in edges {
degree[i] += 1;
degree[j] += 1;
}
let avg_degree = degree.iter().sum::<usize>() as f64 / num_nodes as f64;
let max_degree = degree.iter().max().copied().unwrap_or(0);
let min_degree = degree.iter().min().copied().unwrap_or(0);
eprintln!(
" Degree: min={} max={} avg={:.1}",
min_degree, max_degree, avg_degree
);
if num_nodes <= 20 {
let mut sorted: Vec<(usize, NodeAddr)> = nodes
.iter()
.enumerate()
.map(|(i, tn)| (i, *tn.node.node_addr()))
.collect();
sorted.sort_by_key(|(_, addr)| *addr);
eprintln!(" Node addresses (sorted, smallest = expected root):");
for (i, addr) in &sorted {
let marker = if *i == sorted[0].0 { " <-- root" } else { "" };
eprintln!(" node[{}] = {}{}", i, addr, marker);
}
eprintln!(" Edges:");
for (idx, &(i, j)) in edges.iter().enumerate() {
eprintln!(" edge[{}]: node[{}] -- node[{}]", idx, i, j);
}
}
}
for &(i, j) in edges {
initiate_handshake(&mut nodes, i, j).await;
}
let total = drain_all_packets(&mut nodes, verbose).await;
assert!(total > 0, "Should have processed at least some packets");
let repaired = repair_missing_edge_handshakes(&mut nodes, edges, verbose).await;
if verbose {
eprintln!("\n Total packets processed: {}", total);
if repaired > 0 {
eprintln!(" Synthetic handshake retries: {}", repaired);
print_tree_snapshot("After synthetic handshake repair", &nodes);
}
}
for &(i, j) in edges {
let j_addr = *nodes[j].node.node_addr();
let i_addr = *nodes[i].node.node_addr();
assert!(
nodes[i].node.get_peer(&j_addr).is_some(),
"Node {} should have peer {} (node {})",
i,
j_addr,
j
);
assert!(
nodes[j].node.get_peer(&i_addr).is_some(),
"Node {} should have peer {} (node {})",
j,
i_addr,
i
);
}
nodes
}
pub(super) async fn run_tree_test_with_mtus(
mtus: &[u16],
edges: &[(usize, usize)],
) -> Vec<TestNode> {
let mut nodes = Vec::new();
for &mtu in mtus {
nodes.push(make_test_node_with_mtu(mtu).await);
}
for &(i, j) in edges {
initiate_handshake(&mut nodes, i, j).await;
}
let total = drain_all_packets(&mut nodes, false).await;
assert!(total > 0, "Should have processed at least some packets");
let _ = repair_missing_edge_handshakes(&mut nodes, edges, false).await;
for &(i, j) in edges {
let j_addr = *nodes[j].node.node_addr();
let i_addr = *nodes[i].node.node_addr();
assert!(
nodes[i].node.get_peer(&j_addr).is_some(),
"Node {} should have peer {} (node {})",
i,
j_addr,
j
);
assert!(
nodes[j].node.get_peer(&i_addr).is_some(),
"Node {} should have peer {} (node {})",
j,
i_addr,
i
);
}
nodes
}
pub(super) async fn cleanup_nodes(nodes: &mut [TestNode]) {
for tn in nodes.iter_mut() {
for (_, t) in tn.node.transports.iter_mut() {
t.stop().await.ok();
}
}
}
#[tokio::test]
async fn test_spanning_tree_convergence_100_nodes() {
let _guard = lock_large_network_test().await;
const NUM_NODES: usize = 100;
const TARGET_EDGES: usize = 250;
const SEED: u64 = 42;
let edges = generate_random_edges(NUM_NODES, TARGET_EDGES, SEED);
let mut nodes = run_tree_test(NUM_NODES, &edges, true).await;
verify_tree_convergence(&nodes);
cleanup_nodes(&mut nodes).await;
}
#[tokio::test]
async fn test_spanning_tree_ring() {
let edges: Vec<(usize, usize)> = vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)];
let mut nodes = run_tree_test(5, &edges, false).await;
verify_tree_convergence(&nodes);
cleanup_nodes(&mut nodes).await;
}
#[tokio::test]
async fn test_spanning_tree_star() {
let edges: Vec<(usize, usize)> = vec![(0, 1), (0, 2), (0, 3), (0, 4)];
let mut nodes = run_tree_test(5, &edges, false).await;
verify_tree_convergence(&nodes);
cleanup_nodes(&mut nodes).await;
}
#[tokio::test]
async fn test_spanning_tree_chain() {
let edges: Vec<(usize, usize)> = vec![(0, 1), (1, 2), (2, 3), (3, 4)];
let mut nodes = run_tree_test(5, &edges, false).await;
verify_tree_convergence(&nodes);
cleanup_nodes(&mut nodes).await;
}
#[tokio::test]
async fn test_spanning_tree_disconnected() {
let edges: Vec<(usize, usize)> = vec![
(0, 1),
(1, 2), (3, 4),
(4, 5), ];
let mut nodes = run_tree_test(6, &edges, false).await;
verify_tree_convergence_components(&nodes, &[vec![0, 1, 2], vec![3, 4, 5]]);
cleanup_nodes(&mut nodes).await;
}
#[tokio::test]
async fn test_rejects_tree_announce_with_inconsistent_root() {
let mut nodes = run_tree_test(2, &[(0, 1)], false).await;
let a_addr = *nodes[0].node.node_addr();
let current_root = *nodes[1].node.tree_state().root();
let current_depth = nodes[1].node.tree_state().my_coords().depth();
let peer_coords_before = nodes[1]
.node
.get_peer(&a_addr)
.unwrap()
.coords()
.unwrap()
.clone();
let accepted_before = nodes[1].node.stats().tree.accepted;
let fake_parent = NodeAddr::from_bytes([0u8; 16]);
let mut fake_root_bytes = [0u8; 16];
fake_root_bytes[15] = 1;
let fake_root = NodeAddr::from_bytes(fake_root_bytes);
let mut declaration = ParentDeclaration::new(a_addr, fake_parent, 99, 12345);
declaration.sign(nodes[0].node.identity()).unwrap();
let announce = TreeAnnounce::new(
declaration,
TreeCoordinate::new(vec![
CoordEntry::new(a_addr, 99, 12345),
CoordEntry::new(fake_parent, 98, 12344),
CoordEntry::new(fake_root, 97, 12343),
])
.unwrap(),
);
let encoded = announce.encode().unwrap();
nodes[1]
.node
.handle_tree_announce(&a_addr, &encoded[1..])
.await;
assert_eq!(*nodes[1].node.tree_state().root(), current_root);
assert_eq!(
nodes[1].node.tree_state().my_coords().depth(),
current_depth
);
assert_eq!(nodes[1].node.stats().tree.accepted, accepted_before);
assert_eq!(
nodes[1].node.get_peer(&a_addr).unwrap().coords().unwrap(),
&peer_coords_before
);
cleanup_nodes(&mut nodes).await;
}