use super::*;
use crate::config::TcpConfig;
use crate::transport::tcp::TcpTransport;
use crate::transport::{TransportAddr, TransportHandle, TransportId, packet_channel};
use spanning_tree::{
TestNode, cleanup_nodes, drain_all_packets, initiate_handshake, verify_tree_convergence,
};
use std::time::Duration;
async fn make_test_node_tcp() -> TestNode {
let mut node = make_node();
let transport_id = TransportId::new(1);
let config = TcpConfig {
bind_addr: Some("127.0.0.1:0".to_string()),
mtu: Some(1400),
..Default::default()
};
let (packet_tx, packet_rx) = packet_channel(256);
let mut transport = TcpTransport::new(transport_id, None, config, packet_tx);
transport.start_async().await.unwrap();
let local_addr = transport
.local_addr()
.expect("TCP transport should have local addr after start");
let addr = TransportAddr::from_string(&local_addr.to_string());
node.transports
.insert(transport_id, TransportHandle::Tcp(transport));
TestNode {
node,
transport_id,
packet_rx,
addr,
}
}
#[tokio::test]
async fn test_tcp_two_node_handshake() {
let mut nodes = vec![make_test_node_tcp().await, make_test_node_tcp().await];
initiate_handshake(&mut nodes, 0, 1).await;
let total = drain_all_packets(&mut nodes, false).await;
assert!(total > 0, "should have processed packets");
let addr_0 = *nodes[0].node.node_addr();
let addr_1 = *nodes[1].node.node_addr();
assert!(
nodes[0].node.get_peer(&addr_1).is_some(),
"node 0 should have node 1 as peer"
);
assert!(
nodes[1].node.get_peer(&addr_0).is_some(),
"node 1 should have node 0 as peer"
);
cleanup_nodes(&mut nodes).await;
}
#[tokio::test]
async fn test_tcp_three_node_chain() {
let mut nodes = vec![
make_test_node_tcp().await,
make_test_node_tcp().await,
make_test_node_tcp().await,
];
initiate_handshake(&mut nodes, 0, 1).await;
initiate_handshake(&mut nodes, 1, 2).await;
let total = drain_all_packets(&mut nodes, false).await;
assert!(total > 0, "should have processed packets");
verify_tree_convergence(&nodes);
let expected_root = nodes.iter().map(|tn| *tn.node.node_addr()).min().unwrap();
for tn in &nodes {
assert_eq!(*tn.node.tree_state().root(), expected_root);
}
assert_eq!(nodes[0].node.peer_count(), 1, "endpoint should have 1 peer");
assert_eq!(
nodes[1].node.peer_count(),
2,
"middle node should have 2 peers"
);
assert_eq!(nodes[2].node.peer_count(), 1, "endpoint should have 1 peer");
let addr_2 = *nodes[2].node.node_addr();
let reaches = nodes[0].node.peers().any(|p| p.may_reach(&addr_2));
assert!(
reaches,
"node 0 should see node 2 as reachable through bloom filters"
);
cleanup_nodes(&mut nodes).await;
}
#[tokio::test]
async fn test_tcp_mixed_transport_coexistence() {
use spanning_tree::{make_test_node, verify_tree_convergence_components};
let udp_0 = make_test_node().await;
let udp_1 = make_test_node().await;
let tcp_0 = make_test_node_tcp().await;
let tcp_1 = make_test_node_tcp().await;
let mut nodes = vec![udp_0, udp_1, tcp_0, tcp_1];
initiate_handshake(&mut nodes, 0, 1).await; initiate_handshake(&mut nodes, 2, 3).await;
let total = drain_all_packets(&mut nodes, false).await;
assert!(total > 0);
verify_tree_convergence_components(&nodes, &[vec![0, 1], vec![2, 3]]);
let tcp_root = std::cmp::min(*nodes[2].node.node_addr(), *nodes[3].node.node_addr());
assert_eq!(*nodes[2].node.tree_state().root(), tcp_root);
assert_eq!(*nodes[3].node.tree_state().root(), tcp_root);
cleanup_nodes(&mut nodes).await;
}
#[tokio::test]
async fn test_tcp_connection_loss_detection() {
let mut nodes = vec![make_test_node_tcp().await, make_test_node_tcp().await];
for tn in nodes.iter_mut() {
tn.node.config.node.heartbeat_interval_secs = 1;
tn.node.config.node.link_dead_timeout_secs = 3;
}
initiate_handshake(&mut nodes, 0, 1).await;
let total = drain_all_packets(&mut nodes, false).await;
assert!(total > 0);
let addr_0 = *nodes[0].node.node_addr();
let addr_1 = *nodes[1].node.node_addr();
assert!(nodes[0].node.get_peer(&addr_1).is_some());
assert!(nodes[1].node.get_peer(&addr_0).is_some());
let node1_listen_addr = nodes[1].addr.clone();
let transport = nodes[0]
.node
.transports
.get(&nodes[0].transport_id)
.unwrap();
transport.close_connection(&node1_listen_addr).await;
tokio::time::sleep(Duration::from_secs(4)).await;
nodes[0].node.check_link_heartbeats().await;
assert!(
nodes[0].node.get_peer(&addr_1).is_none(),
"node 0 should have removed dead peer node 1"
);
cleanup_nodes(&mut nodes).await;
}
#[tokio::test]
async fn test_tcp_reconnection_after_link_death() {
let mut nodes = vec![make_test_node_tcp().await, make_test_node_tcp().await];
for tn in nodes.iter_mut() {
tn.node.config.node.heartbeat_interval_secs = 1;
tn.node.config.node.link_dead_timeout_secs = 3;
}
initiate_handshake(&mut nodes, 0, 1).await;
let total = drain_all_packets(&mut nodes, false).await;
assert!(total > 0);
let addr_0 = *nodes[0].node.node_addr();
let addr_1 = *nodes[1].node.node_addr();
assert!(nodes[0].node.get_peer(&addr_1).is_some());
let node1_listen_addr = nodes[1].addr.clone();
let transport = nodes[0]
.node
.transports
.get(&nodes[0].transport_id)
.unwrap();
transport.close_connection(&node1_listen_addr).await;
tokio::time::sleep(Duration::from_secs(4)).await;
nodes[0].node.check_link_heartbeats().await;
nodes[1].node.check_link_heartbeats().await;
assert!(
nodes[0].node.get_peer(&addr_1).is_none(),
"node 0 should have removed node 1"
);
assert!(
nodes[1].node.get_peer(&addr_0).is_none(),
"node 1 should have removed node 0"
);
initiate_handshake(&mut nodes, 0, 1).await;
let total2 = drain_all_packets(&mut nodes, false).await;
assert!(total2 > 0, "should have processed reconnection packets");
assert!(
nodes[0].node.get_peer(&addr_1).is_some(),
"node 0 should have re-established peer node 1"
);
assert!(
nodes[1].node.get_peer(&addr_0).is_some(),
"node 1 should have re-established peer node 0"
);
cleanup_nodes(&mut nodes).await;
}