use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use bytes::Bytes;
use super::{Transport, TransportKind};
use crate::channel::fallback::FallbackController;
use crate::channel::health::HealthMonitor;
use crate::error::{SrxError, TransportError};
use crate::routing::RoutingMask;
use crate::seed::SeedRng;
const DEFAULT_BLOCK_THRESHOLD: u32 = 5;
pub struct TransportManager {
transports: HashMap<TransportKind, Arc<dyn Transport>>,
health: HashMap<TransportKind, HealthMonitor>,
fallback: Option<FallbackController>,
}
impl TransportManager {
pub fn new() -> Self {
Self {
transports: HashMap::new(),
health: HashMap::new(),
fallback: None,
}
}
pub fn add_transport(&mut self, transport: Arc<dyn Transport>) {
let kind = transport.kind();
self.transports.insert(kind, transport);
self.health
.entry(kind)
.or_insert_with(|| HealthMonitor::new(DEFAULT_BLOCK_THRESHOLD));
}
pub fn remove_transport(&mut self, kind: TransportKind) {
self.transports.remove(&kind);
self.health.remove(&kind);
}
pub fn get(&self, kind: TransportKind) -> Option<&Arc<dyn Transport>> {
self.transports.get(&kind)
}
pub fn active_kinds(&self) -> Vec<TransportKind> {
self.transports.keys().copied().collect()
}
pub fn healthy_kinds(&self) -> Vec<TransportKind> {
let mut scored: Vec<_> = self
.transports
.keys()
.filter_map(|&k| {
let monitor = self.health.get(&k)?;
if monitor.is_blocked() {
None
} else {
Some((k, monitor.score()))
}
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.into_iter().map(|(k, _)| k).collect()
}
pub fn health_monitor(&self, kind: TransportKind) -> Option<&HealthMonitor> {
self.health.get(&kind)
}
pub fn is_blocked(&self, kind: TransportKind) -> bool {
self.health.get(&kind).is_some_and(|m| m.is_blocked())
}
pub fn health_scores(&self) -> Vec<(TransportKind, f64)> {
self.transports
.keys()
.map(|&kind| {
let monitor = self.health.get(&kind);
let score = monitor
.filter(|m| !m.is_blocked())
.map(|m| m.score())
.unwrap_or(0.0);
(kind, score)
})
.collect()
}
pub fn set_fallback(&mut self, controller: FallbackController) {
self.fallback = Some(controller);
}
pub fn try_fallback(&mut self) -> Option<TransportKind> {
let fb = self.fallback.as_mut()?;
let current = fb.current()?;
let blocked = self.health.get(¤t).is_some_and(|m| m.is_blocked());
if blocked {
fb.fallback()
} else {
Some(current)
}
}
pub async fn send(&mut self, kind: TransportKind, data: Bytes) -> crate::error::Result<()> {
let t = self
.transports
.get(&kind)
.ok_or_else(|| SrxError::Transport(TransportError::NotRegistered(format!("{kind:?}"))))?
.clone();
let start = Instant::now();
match t.send(data).await {
Ok(()) => {
if let Some(m) = self.health.get_mut(&kind) {
m.record_success(start.elapsed());
}
Ok(())
}
Err(e) => {
if let Some(m) = self.health.get_mut(&kind) {
m.record_failure();
}
Err(e)
}
}
}
pub async fn recv(&mut self, kind: TransportKind) -> crate::error::Result<Bytes> {
let t = self
.transports
.get(&kind)
.ok_or_else(|| SrxError::Transport(TransportError::NotRegistered(format!("{kind:?}"))))?
.clone();
let start = Instant::now();
match t.recv().await {
Ok(data) => {
if let Some(m) = self.health.get_mut(&kind) {
m.record_success(start.elapsed());
}
Ok(data)
}
Err(e) => {
if let Some(m) = self.health.get_mut(&kind) {
m.record_failure();
}
Err(e)
}
}
}
pub async fn send_with_routing_mask(
&mut self,
rng: &mut SeedRng,
available: &[TransportKind],
frame_counter: u64,
data: Bytes,
) -> crate::error::Result<TransportKind> {
let mask = RoutingMask::generate(rng, available, frame_counter);
for k in &mask.transports {
if self.transports.contains_key(k) && !self.is_blocked(*k) {
self.send(*k, data).await?;
return Ok(*k);
}
}
for k in &mask.transports {
if self.transports.contains_key(k) {
self.send(*k, data).await?;
return Ok(*k);
}
}
Err(SrxError::Transport(TransportError::NotRegistered(
"no transport from routing mask is registered".into(),
)))
}
pub async fn recv_with_routing_mask(
&mut self,
rng: &mut SeedRng,
available: &[TransportKind],
frame_counter: u64,
) -> crate::error::Result<(TransportKind, Bytes)> {
let mask = RoutingMask::generate(rng, available, frame_counter);
for k in &mask.transports {
if self.transports.contains_key(k) {
let data = self.recv(*k).await?;
return Ok((*k, data));
}
}
Err(SrxError::Transport(TransportError::NotRegistered(
"no transport from routing mask is registered".into(),
)))
}
pub async fn health_check(&mut self) {
let mut to_remove = Vec::new();
for (kind, transport) in &self.transports {
if !transport.is_healthy().await {
to_remove.push(*kind);
}
}
for kind in to_remove {
self.transports.remove(&kind);
self.health.remove(&kind);
}
}
}
impl Default for TransportManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::{SrxError, TransportError};
use crate::seed::SeedRng;
use crate::transport::{TcpTransport, UdpTransport};
#[tokio::test]
async fn send_recv_dispatches_tcp() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let t = TcpTransport::from_stream(stream);
let got = t.recv().await.unwrap();
assert_eq!(got.as_ref(), b"mgr-ping");
t.send(Bytes::from_static(b"mgr-pong")).await.unwrap();
});
let client = TcpTransport::connect(addr).await.unwrap();
let mut mgr = TransportManager::new();
mgr.add_transport(Arc::new(client));
mgr.send(TransportKind::Tcp, Bytes::from_static(b"mgr-ping"))
.await
.unwrap();
let reply = mgr.recv(TransportKind::Tcp).await.unwrap();
assert_eq!(reply.as_ref(), b"mgr-pong");
server.await.unwrap();
}
#[tokio::test]
async fn send_with_routing_mask_uses_registered_kind() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let t = TcpTransport::from_stream(stream);
let got = t.recv().await.unwrap();
assert_eq!(got.as_ref(), b"mask-ping");
t.send(Bytes::from_static(b"mask-pong")).await.unwrap();
});
let client = TcpTransport::connect(addr).await.unwrap();
let mut mgr = TransportManager::new();
mgr.add_transport(Arc::new(client));
let mut rng = SeedRng::new([0xBBu8; 32]);
let available = [TransportKind::Tcp];
let used = mgr
.send_with_routing_mask(&mut rng, &available, 7, Bytes::from_static(b"mask-ping"))
.await
.unwrap();
assert_eq!(used, TransportKind::Tcp);
let reply = mgr.recv(used).await.unwrap();
assert_eq!(reply.as_ref(), b"mask-pong");
server.await.unwrap();
}
#[tokio::test]
async fn recv_with_routing_mask_uses_registered_kind() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut mgr = TransportManager::new();
mgr.add_transport(Arc::new(TcpTransport::from_stream(stream)));
let mut rng = SeedRng::new([0xCCu8; 32]);
mgr.recv_with_routing_mask(&mut rng, &[TransportKind::Tcp], 3)
.await
});
let client = TcpTransport::connect(addr).await.unwrap();
client
.send(Bytes::from_static(b"recv-mask-ping"))
.await
.unwrap();
let got = server.await.unwrap().unwrap();
assert_eq!(got.0, TransportKind::Tcp);
assert_eq!(got.1.as_ref(), b"recv-mask-ping");
}
#[tokio::test]
async fn missing_kind_returns_not_registered() {
let mut mgr = TransportManager::new();
let e = mgr
.send(TransportKind::Udp, Bytes::from_static(b"x"))
.await
.expect_err("expected error");
match e {
SrxError::Transport(TransportError::NotRegistered(name)) => {
assert!(name.contains("Udp"));
}
other => panic!("unexpected: {other:?}"),
}
}
#[tokio::test]
async fn send_records_health_success() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let t = TcpTransport::from_stream(stream);
t.recv().await.unwrap();
});
let client = TcpTransport::connect(addr).await.unwrap();
let mut mgr = TransportManager::new();
mgr.add_transport(Arc::new(client));
mgr.send(TransportKind::Tcp, Bytes::from_static(b"health-ok"))
.await
.unwrap();
let monitor = mgr.health_monitor(TransportKind::Tcp).unwrap();
assert!(!monitor.is_blocked());
assert!(monitor.score() > 0.0);
assert_eq!(monitor.consecutive_failures, 0);
server.await.unwrap();
}
#[test]
fn healthy_kinds_excludes_blocked() {
let mut mgr = TransportManager::new();
mgr.health.insert(TransportKind::Tcp, HealthMonitor::new(2));
mgr.health.insert(TransportKind::Udp, HealthMonitor::new(2));
assert!(!mgr.is_blocked(TransportKind::Tcp));
mgr.health
.get_mut(&TransportKind::Tcp)
.unwrap()
.record_failure();
mgr.health
.get_mut(&TransportKind::Tcp)
.unwrap()
.record_failure();
assert!(mgr.is_blocked(TransportKind::Tcp));
assert!(!mgr.is_blocked(TransportKind::Udp));
}
#[test]
fn fallback_advances_on_block() {
let mut mgr = TransportManager::new();
mgr.health.insert(TransportKind::Tcp, HealthMonitor::new(2));
mgr.health.insert(TransportKind::Udp, HealthMonitor::new(2));
mgr.set_fallback(FallbackController::new(vec![
TransportKind::Tcp,
TransportKind::Udp,
]));
assert_eq!(mgr.try_fallback(), Some(TransportKind::Tcp));
mgr.health
.get_mut(&TransportKind::Tcp)
.unwrap()
.record_failure();
mgr.health
.get_mut(&TransportKind::Tcp)
.unwrap()
.record_failure();
assert_eq!(mgr.try_fallback(), Some(TransportKind::Udp));
}
#[tokio::test]
async fn register_tcp_and_udp_simultaneously() {
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let tcp_addr = tcp_listener.local_addr().unwrap();
let udp_server_sock = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_client_sock = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_srv_addr = udp_server_sock.local_addr().unwrap();
let udp_cli_addr = udp_client_sock.local_addr().unwrap();
udp_server_sock.connect(udp_cli_addr).await.unwrap();
udp_client_sock.connect(udp_srv_addr).await.unwrap();
let server = tokio::spawn(async move {
let (tcp_stream, _) = tcp_listener.accept().await.unwrap();
let tcp = TcpTransport::from_stream(tcp_stream);
let udp = UdpTransport::from_socket(udp_server_sock);
let tcp_data = tcp.recv().await.unwrap();
tcp.send(Bytes::from_static(b"tcp-pong")).await.unwrap();
assert_eq!(tcp_data.as_ref(), b"tcp-ping");
let udp_data = udp.recv().await.unwrap();
udp.send(Bytes::from_static(b"udp-pong")).await.unwrap();
assert_eq!(udp_data.as_ref(), b"udp-ping");
});
let tcp_client = TcpTransport::connect(tcp_addr).await.unwrap();
let udp_client = UdpTransport::from_socket(udp_client_sock);
let mut mgr = TransportManager::new();
mgr.add_transport(Arc::new(tcp_client));
mgr.add_transport(Arc::new(udp_client));
assert_eq!(
mgr.active_kinds().len(),
2,
"should have 2 active transports"
);
assert!(mgr.active_kinds().contains(&TransportKind::Tcp));
assert!(mgr.active_kinds().contains(&TransportKind::Udp));
mgr.send(TransportKind::Tcp, Bytes::from_static(b"tcp-ping"))
.await
.unwrap();
let reply = mgr.recv(TransportKind::Tcp).await.unwrap();
assert_eq!(reply.as_ref(), b"tcp-pong");
mgr.send(TransportKind::Udp, Bytes::from_static(b"udp-ping"))
.await
.unwrap();
let reply = mgr.recv(TransportKind::Udp).await.unwrap();
assert_eq!(reply.as_ref(), b"udp-pong");
server.await.unwrap();
}
#[tokio::test]
async fn send_with_routing_mask_multiple_transports() {
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let tcp_addr = tcp_listener.local_addr().unwrap();
let udp_srv = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_cli = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_srv_addr = udp_srv.local_addr().unwrap();
let udp_cli_addr = udp_cli.local_addr().unwrap();
udp_srv.connect(udp_cli_addr).await.unwrap();
udp_cli.connect(udp_srv_addr).await.unwrap();
let server = tokio::spawn(async move {
let (tcp_stream, _) = tcp_listener.accept().await.unwrap();
let tcp = TcpTransport::from_stream(tcp_stream);
let udp = UdpTransport::from_socket(udp_srv);
tokio::select! {
result = tcp.recv() => {
let data = result.unwrap();
tcp.send(Bytes::from_static(b"tcp-ack")).await.unwrap();
assert_eq!(data.as_ref(), b"mask-tcp");
}
result = udp.recv() => {
let data = result.unwrap();
udp.send(Bytes::from_static(b"udp-ack")).await.unwrap();
assert_eq!(data.as_ref(), b"mask-udp");
}
}
});
let tcp_client = TcpTransport::connect(tcp_addr).await.unwrap();
let udp_client = UdpTransport::from_socket(udp_cli);
let mut mgr = TransportManager::new();
mgr.add_transport(Arc::new(tcp_client));
mgr.add_transport(Arc::new(udp_client));
let mut rng = SeedRng::new([0xDDu8; 32]);
let available = [TransportKind::Tcp, TransportKind::Udp];
let used = mgr
.send_with_routing_mask(&mut rng, &available, 42, Bytes::from_static(b"mask-tcp"))
.await
.unwrap();
assert!(
used == TransportKind::Tcp || used == TransportKind::Udp,
"routing mask selected unexpected kind {used:?}"
);
let reply = mgr.recv(used).await.unwrap();
assert!(
reply.as_ref() == b"tcp-ack" || reply.as_ref() == b"udp-ack",
"unexpected reply: {:?}",
reply
);
server.await.unwrap();
}
#[tokio::test]
async fn recv_with_routing_mask_multiple_transports() {
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let tcp_addr = tcp_listener.local_addr().unwrap();
let udp_srv = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_cli = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_srv_addr = udp_srv.local_addr().unwrap();
let udp_cli_addr = udp_cli.local_addr().unwrap();
udp_srv.connect(udp_cli_addr).await.unwrap();
udp_cli.connect(udp_srv_addr).await.unwrap();
let server = tokio::spawn(async move {
let (tcp_stream, _) = tcp_listener.accept().await.unwrap();
let mut mgr = TransportManager::new();
mgr.add_transport(Arc::new(TcpTransport::from_stream(tcp_stream)));
mgr.add_transport(Arc::new(UdpTransport::from_socket(udp_srv)));
let mut rng = SeedRng::new([0xEEu8; 32]);
mgr.recv_with_routing_mask(&mut rng, &[TransportKind::Tcp, TransportKind::Udp], 10)
.await
});
let tcp_client = TcpTransport::connect(tcp_addr).await.unwrap();
tcp_client
.send(Bytes::from_static(b"multi-recv-tcp"))
.await
.unwrap();
let (kind, data) = server.await.unwrap().unwrap();
assert!(
kind == TransportKind::Tcp || kind == TransportKind::Udp,
"unexpected recv kind {kind:?}"
);
assert_eq!(data.as_ref(), b"multi-recv-tcp");
}
#[tokio::test]
async fn health_check_removes_unhealthy_transport() {
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let tcp_addr = tcp_listener.local_addr().unwrap();
let udp_sock = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp = UdpTransport::from_socket(udp_sock);
udp.close().await.unwrap();
let server = tokio::spawn(async move {
let (stream, _) = tcp_listener.accept().await.unwrap();
let _tcp = TcpTransport::from_stream(stream);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
});
let tcp_client = TcpTransport::connect(tcp_addr).await.unwrap();
let mut mgr = TransportManager::new();
mgr.add_transport(Arc::new(tcp_client));
mgr.add_transport(Arc::new(udp));
assert_eq!(mgr.active_kinds().len(), 2);
mgr.health_check().await;
let kinds = mgr.active_kinds();
assert_eq!(kinds.len(), 1, "only TCP should remain after health check");
assert!(kinds.contains(&TransportKind::Tcp));
assert!(
!kinds.contains(&TransportKind::Udp),
"UDP should have been removed"
);
server.await.unwrap();
}
#[test]
fn healthy_kinds_sorts_by_score_descending() {
let mut tcp_mon = HealthMonitor::new(5);
let mut udp_mon = HealthMonitor::new(5);
let mut quic_mon = HealthMonitor::new(5);
tcp_mon.record_success(std::time::Duration::from_millis(5));
udp_mon.record_success(std::time::Duration::from_millis(50));
quic_mon.record_success(std::time::Duration::from_millis(200));
let tcp_score = tcp_mon.score();
let udp_score = udp_mon.score();
let quic_score = quic_mon.score();
assert!(
tcp_score > udp_score,
"TCP score {tcp_score} should be > UDP score {udp_score}"
);
assert!(
udp_score > quic_score,
"UDP score {udp_score} should be > QUIC score {quic_score}"
);
let mut blocked_mon = HealthMonitor::new(2);
blocked_mon.record_failure();
blocked_mon.record_failure();
assert!(blocked_mon.is_blocked());
assert_eq!(blocked_mon.score(), 0.0);
}
#[tokio::test]
async fn healthy_kinds_excludes_blocked_includes_healthy() {
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let tcp_addr = tcp_listener.local_addr().unwrap();
let udp_srv = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_cli = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_srv_addr = udp_srv.local_addr().unwrap();
let udp_cli_addr = udp_cli.local_addr().unwrap();
udp_srv.connect(udp_cli_addr).await.unwrap();
udp_cli.connect(udp_srv_addr).await.unwrap();
let server = tokio::spawn(async move {
let (tcp_stream, _) = tcp_listener.accept().await.unwrap();
let _tcp = TcpTransport::from_stream(tcp_stream);
let _udp = UdpTransport::from_socket(udp_srv);
});
let tcp_client = TcpTransport::connect(tcp_addr).await.unwrap();
let udp_client = UdpTransport::from_socket(udp_cli);
let mut mgr = TransportManager::new();
mgr.add_transport(Arc::new(tcp_client));
mgr.add_transport(Arc::new(udp_client));
for _ in 0..DEFAULT_BLOCK_THRESHOLD {
mgr.health
.get_mut(&TransportKind::Tcp)
.unwrap()
.record_failure();
}
let udp_mon = mgr.health.get_mut(&TransportKind::Udp).unwrap();
udp_mon.record_success(std::time::Duration::from_millis(10));
let kinds = mgr.healthy_kinds();
assert_eq!(kinds.len(), 1, "blocked TCP should be excluded");
assert!(!kinds.contains(&TransportKind::Tcp));
assert!(kinds.contains(&TransportKind::Udp));
server.await.unwrap();
}
#[tokio::test]
async fn fallback_when_one_transport_blocked() {
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let tcp_addr = tcp_listener.local_addr().unwrap();
let udp_srv = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_cli = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_srv_addr = udp_srv.local_addr().unwrap();
let udp_cli_addr = udp_cli.local_addr().unwrap();
udp_srv.connect(udp_cli_addr).await.unwrap();
udp_cli.connect(udp_srv_addr).await.unwrap();
let server = tokio::spawn(async move {
let (tcp_stream, _) = tcp_listener.accept().await.unwrap();
let tcp = TcpTransport::from_stream(tcp_stream);
let udp = UdpTransport::from_socket(udp_srv);
let _ = tcp.recv().await;
let data = udp.recv().await.unwrap();
assert_eq!(data.as_ref(), b"fallback-via-udp");
udp.send(Bytes::from_static(b"udp-fallback-ok"))
.await
.unwrap();
});
let tcp_client = TcpTransport::connect(tcp_addr).await.unwrap();
let udp_client = UdpTransport::from_socket(udp_cli);
let mut mgr = TransportManager::new();
mgr.add_transport(Arc::new(tcp_client));
mgr.add_transport(Arc::new(udp_client));
mgr.set_fallback(FallbackController::new(vec![
TransportKind::Tcp,
TransportKind::Udp,
]));
mgr.send(TransportKind::Tcp, Bytes::from_static(b"pre-block"))
.await
.unwrap();
let tcp_mon = mgr.health.get_mut(&TransportKind::Tcp).unwrap();
tcp_mon.record_failure();
tcp_mon.record_failure();
tcp_mon.record_failure();
tcp_mon.record_failure();
tcp_mon.record_failure();
assert!(mgr.is_blocked(TransportKind::Tcp));
let fallback_kind = mgr.try_fallback();
assert_eq!(
fallback_kind,
Some(TransportKind::Udp),
"fallback should switch to UDP when TCP is blocked"
);
let used = fallback_kind.unwrap();
mgr.send(used, Bytes::from_static(b"fallback-via-udp"))
.await
.unwrap();
let reply = mgr.recv(used).await.unwrap();
assert_eq!(reply.as_ref(), b"udp-fallback-ok");
server.await.unwrap();
}
#[tokio::test]
async fn send_with_routing_mask_skips_blocked_uses_healthy() {
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let tcp_addr = tcp_listener.local_addr().unwrap();
let udp_srv = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_cli = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_srv_addr = udp_srv.local_addr().unwrap();
let udp_cli_addr = udp_cli.local_addr().unwrap();
udp_srv.connect(udp_cli_addr).await.unwrap();
udp_cli.connect(udp_srv_addr).await.unwrap();
let server = tokio::spawn(async move {
let (tcp_stream, _) = tcp_listener.accept().await.unwrap();
let tcp = TcpTransport::from_stream(tcp_stream);
let udp = UdpTransport::from_socket(udp_srv);
tokio::select! {
result = tcp.recv() => {
let data = result.unwrap();
tcp.send(Bytes::from_static(b"tcp-ok")).await.unwrap();
assert_eq!(data.as_ref(), b"skip-blocked-tcp");
}
result = udp.recv() => {
let data = result.unwrap();
udp.send(Bytes::from_static(b"udp-ok")).await.unwrap();
assert_eq!(data.as_ref(), b"skip-blocked-tcp");
}
}
});
let tcp_client = TcpTransport::connect(tcp_addr).await.unwrap();
let udp_client = UdpTransport::from_socket(udp_cli);
let mut mgr = TransportManager::new();
mgr.add_transport(Arc::new(tcp_client));
mgr.add_transport(Arc::new(udp_client));
for _ in 0..DEFAULT_BLOCK_THRESHOLD {
mgr.health
.get_mut(&TransportKind::Tcp)
.unwrap()
.record_failure();
}
assert!(mgr.is_blocked(TransportKind::Tcp));
let mut found_seed = None;
for seed_byte in 0u8..255 {
let mut probe_rng = SeedRng::new([seed_byte; 32]);
let available = [TransportKind::Tcp, TransportKind::Udp];
let mask = RoutingMask::generate(&mut probe_rng, &available, 99);
if mask.transports.contains(&TransportKind::Udp) {
found_seed = Some(seed_byte);
break;
}
}
let seed_byte = found_seed.expect("no seed found where mask includes UDP");
let mut rng = SeedRng::new([seed_byte; 32]);
let available = [TransportKind::Tcp, TransportKind::Udp];
let used = mgr
.send_with_routing_mask(
&mut rng,
&available,
99,
Bytes::from_static(b"skip-blocked-tcp"),
)
.await
.unwrap();
assert_eq!(
used,
TransportKind::Udp,
"should have fallen back to UDP since TCP is blocked"
);
let reply = mgr.recv(used).await.unwrap();
assert_eq!(reply.as_ref(), b"udp-ok");
server.await.unwrap();
}
#[tokio::test]
async fn health_scores_reflect_multiple_transports() {
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let tcp_addr = tcp_listener.local_addr().unwrap();
let udp_srv = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_cli = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_srv_addr = udp_srv.local_addr().unwrap();
let udp_cli_addr = udp_cli.local_addr().unwrap();
udp_srv.connect(udp_cli_addr).await.unwrap();
udp_cli.connect(udp_srv_addr).await.unwrap();
let server = tokio::spawn(async move {
let (tcp_stream, _) = tcp_listener.accept().await.unwrap();
let tcp = TcpTransport::from_stream(tcp_stream);
let udp = UdpTransport::from_socket(udp_srv);
let _ = tcp.recv().await;
let _ = udp.recv().await;
});
let tcp_client = TcpTransport::connect(tcp_addr).await.unwrap();
let udp_client = UdpTransport::from_socket(udp_cli);
let mut mgr = TransportManager::new();
mgr.add_transport(Arc::new(tcp_client));
mgr.add_transport(Arc::new(udp_client));
mgr.send(TransportKind::Tcp, Bytes::from_static(b"health-tcp"))
.await
.unwrap();
mgr.send(TransportKind::Udp, Bytes::from_static(b"health-udp"))
.await
.unwrap();
let scores = mgr.health_scores();
assert_eq!(scores.len(), 2);
for (kind, score) in &scores {
assert!(
*score > 0.0,
"transport {kind:?} should have positive score, got {score}"
);
}
server.await.unwrap();
}
#[tokio::test]
async fn remove_transport_removes_health_too() {
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let tcp_addr = tcp_listener.local_addr().unwrap();
let udp_srv = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_cli = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let udp_srv_addr = udp_srv.local_addr().unwrap();
let udp_cli_addr = udp_cli.local_addr().unwrap();
udp_srv.connect(udp_cli_addr).await.unwrap();
udp_cli.connect(udp_srv_addr).await.unwrap();
let server = tokio::spawn(async move {
let (tcp_stream, _) = tcp_listener.accept().await.unwrap();
let _tcp = TcpTransport::from_stream(tcp_stream);
let _udp = UdpTransport::from_socket(udp_srv);
});
let tcp_client = TcpTransport::connect(tcp_addr).await.unwrap();
let udp_client = UdpTransport::from_socket(udp_cli);
let mut mgr = TransportManager::new();
mgr.add_transport(Arc::new(tcp_client));
mgr.add_transport(Arc::new(udp_client));
assert_eq!(mgr.active_kinds().len(), 2);
mgr.remove_transport(TransportKind::Tcp);
assert_eq!(mgr.active_kinds().len(), 1);
assert!(mgr.active_kinds().contains(&TransportKind::Udp));
assert!(
mgr.health_monitor(TransportKind::Tcp).is_none(),
"TCP health monitor should also be removed"
);
assert!(
mgr.health_monitor(TransportKind::Udp).is_some(),
"UDP health monitor should still exist"
);
server.await.unwrap();
}
}