use super::*;
use crate::config::BleConfig;
use crate::transport::ble::BleTransport;
use crate::transport::ble::addr::BleAddr;
use crate::transport::ble::io::{MockBleIo, MockBleStream};
use crate::transport::{Transport, TransportHandle, TransportId, packet_channel};
use spanning_tree::{
TestNode, cleanup_nodes, drain_all_packets, initiate_handshake, verify_tree_convergence,
};
use std::collections::HashMap;
use std::sync::{Arc, Mutex as StdMutex};
fn ble_addr(n: u8) -> BleAddr {
BleAddr {
adapter: "hci0".to_string(),
device: [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, n],
}
}
type StreamBank = Arc<StdMutex<HashMap<String, MockBleStream>>>;
async fn make_test_node_ble(node_num: u8) -> TestNode {
let mut node = make_node();
let transport_id = TransportId::new(1);
let addr = ble_addr(node_num);
let config = BleConfig {
adapter: Some("hci0".to_string()),
mtu: Some(2048),
accept_connections: Some(true),
scan: Some(false), advertise: Some(false), auto_connect: Some(false),
..Default::default()
};
let io = MockBleIo::new("hci0", addr.clone());
let (packet_tx, packet_rx) = packet_channel(256);
let mut transport = BleTransport::new(transport_id, None, config, io, packet_tx);
transport.start_async().await.unwrap();
let ta = addr.to_transport_addr();
node.transports
.insert(transport_id, TransportHandle::Ble(transport));
TestNode {
node,
transport_id,
packet_rx,
addr: ta,
}
}
fn node_ble_addr(node: &TestNode) -> BleAddr {
BleAddr::parse(node.addr.as_str().unwrap()).unwrap()
}
async fn wire_ble_connection(nodes: &[TestNode], i: usize, j: usize, bank: &StreamBank) {
let addr_i = node_ble_addr(&nodes[i]);
let addr_j = node_ble_addr(&nodes[j]);
let (stream_i, stream_j) = MockBleStream::pair(addr_j.clone(), addr_i.clone(), 2048);
let key = nodes[j].addr.to_string();
bank.lock().unwrap().insert(key, stream_i);
let transport_j = nodes[j]
.node
.transports
.get(&nodes[j].transport_id)
.unwrap();
match transport_j {
TransportHandle::Ble(t) => {
t.io().inject_inbound(stream_j).await;
}
_ => panic!("expected BLE transport"),
}
}
fn install_connect_handler(nodes: &[TestNode], i: usize, bank: &StreamBank) {
let bank = Arc::clone(bank);
let transport_i = nodes[i]
.node
.transports
.get(&nodes[i].transport_id)
.unwrap();
match transport_i {
TransportHandle::Ble(t) => {
t.io().set_connect_handler(move |addr, _psm| {
let key = addr.to_transport_addr().to_string();
let mut map = bank.lock().unwrap();
match map.remove(&key) {
Some(stream) => Ok(stream),
None => Err(crate::transport::TransportError::ConnectionRefused),
}
});
}
_ => panic!("expected BLE transport"),
}
}
async fn establish_ble_connection(nodes: &[TestNode], i: usize, j: usize) {
let transport = nodes[i]
.node
.transports
.get(&nodes[i].transport_id)
.unwrap();
transport.connect(&nodes[j].addr).await.unwrap();
tokio::task::yield_now().await;
}
#[tokio::test]
async fn test_ble_two_node_handshake() {
let mut nodes = vec![make_test_node_ble(1).await, make_test_node_ble(2).await];
let bank: StreamBank = Arc::new(StdMutex::new(HashMap::new()));
wire_ble_connection(&nodes, 0, 1, &bank).await;
install_connect_handler(&nodes, 0, &bank);
establish_ble_connection(&nodes, 0, 1).await;
initiate_handshake(&mut nodes, 0, 1).await;
let total = drain_all_packets(&mut nodes, false).await;
assert!(total > 0, "should have processed packets");
let addr_0 = *nodes[0].node.node_addr();
let addr_1 = *nodes[1].node.node_addr();
assert!(
nodes[0].node.get_peer(&addr_1).is_some(),
"node 0 should have node 1 as peer"
);
assert!(
nodes[1].node.get_peer(&addr_0).is_some(),
"node 1 should have node 0 as peer"
);
cleanup_nodes(&mut nodes).await;
}
#[tokio::test]
async fn test_ble_three_node_chain() {
let mut nodes = vec![
make_test_node_ble(1).await,
make_test_node_ble(2).await,
make_test_node_ble(3).await,
];
let bank: StreamBank = Arc::new(StdMutex::new(HashMap::new()));
wire_ble_connection(&nodes, 0, 1, &bank).await;
wire_ble_connection(&nodes, 1, 2, &bank).await;
install_connect_handler(&nodes, 0, &bank);
install_connect_handler(&nodes, 1, &bank);
establish_ble_connection(&nodes, 0, 1).await;
establish_ble_connection(&nodes, 1, 2).await;
initiate_handshake(&mut nodes, 0, 1).await;
initiate_handshake(&mut nodes, 1, 2).await;
let total = drain_all_packets(&mut nodes, false).await;
assert!(total > 0, "should have processed packets");
verify_tree_convergence(&nodes);
let expected_root = nodes.iter().map(|tn| *tn.node.node_addr()).min().unwrap();
for tn in &nodes {
assert_eq!(*tn.node.tree_state().root(), expected_root);
}
assert_eq!(nodes[0].node.peer_count(), 1);
assert_eq!(nodes[1].node.peer_count(), 2);
assert_eq!(nodes[2].node.peer_count(), 1);
let addr_2 = *nodes[2].node.node_addr();
let reaches = nodes[0].node.peers().any(|p| p.may_reach(&addr_2));
assert!(reaches, "node 0 should see node 2 as reachable");
cleanup_nodes(&mut nodes).await;
}
#[tokio::test]
async fn test_ble_mixed_transport() {
use spanning_tree::{make_test_node, verify_tree_convergence_components};
let udp_0 = make_test_node().await;
let udp_1 = make_test_node().await;
let ble_0 = make_test_node_ble(1).await;
let ble_1 = make_test_node_ble(2).await;
let mut nodes = vec![udp_0, udp_1, ble_0, ble_1];
let bank: StreamBank = Arc::new(StdMutex::new(HashMap::new()));
wire_ble_connection(&nodes, 2, 3, &bank).await;
install_connect_handler(&nodes, 2, &bank);
establish_ble_connection(&nodes, 2, 3).await;
initiate_handshake(&mut nodes, 0, 1).await; initiate_handshake(&mut nodes, 2, 3).await;
let total = drain_all_packets(&mut nodes, false).await;
assert!(total > 0);
verify_tree_convergence_components(&nodes, &[vec![0, 1], vec![2, 3]]);
let ble_root = std::cmp::min(*nodes[2].node.node_addr(), *nodes[3].node.node_addr());
assert_eq!(*nodes[2].node.tree_state().root(), ble_root);
assert_eq!(*nodes[3].node.tree_state().root(), ble_root);
cleanup_nodes(&mut nodes).await;
}
#[tokio::test(start_paused = true)]
async fn test_ble_discovery() {
let mut node = make_node();
let transport_id = TransportId::new(1);
let addr = ble_addr(1);
let config = BleConfig {
adapter: Some("hci0".to_string()),
mtu: Some(2048),
accept_connections: Some(true),
scan: Some(true),
advertise: Some(false),
auto_connect: Some(false),
..Default::default()
};
let io = MockBleIo::new("hci0", addr.clone());
let (packet_tx, packet_rx) = packet_channel(256);
let mut transport = BleTransport::new(transport_id, None, config, io, packet_tx);
transport.start_async().await.unwrap();
transport.io().inject_scan_result(ble_addr(2)).await;
transport.io().inject_scan_result(ble_addr(3)).await;
tokio::task::yield_now().await;
tokio::time::advance(std::time::Duration::from_secs(6)).await;
tokio::task::yield_now().await;
let peers = transport.discover().unwrap();
assert_eq!(peers.len(), 2);
let ta = addr.to_transport_addr();
node.transports
.insert(transport_id, TransportHandle::Ble(transport));
let mut nodes = vec![TestNode {
node,
transport_id,
packet_rx,
addr: ta,
}];
cleanup_nodes(&mut nodes).await;
}