use async_trait::async_trait;
use std::fmt;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use super::addr::{TransportAddr, TransportType};
use super::capabilities::TransportCapabilities;
#[derive(Debug, Clone)]
pub enum TransportError {
AddressMismatch {
expected: TransportType,
actual: TransportType,
},
MessageTooLarge {
size: usize,
mtu: usize,
},
Offline,
ShuttingDown,
SendFailed {
reason: String,
},
ReceiveFailed {
reason: String,
},
BroadcastNotSupported,
NoProviderForAddress {
addr_type: TransportType,
},
Other {
message: String,
},
}
impl fmt::Display for TransportError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::AddressMismatch { expected, actual } => {
write!(
f,
"address type mismatch: expected {expected}, got {actual}"
)
}
Self::MessageTooLarge { size, mtu } => {
write!(f, "message too large: {size} bytes exceeds MTU of {mtu}")
}
Self::Offline => write!(f, "transport is offline"),
Self::ShuttingDown => write!(f, "transport is shutting down"),
Self::SendFailed { reason } => write!(f, "send failed: {reason}"),
Self::ReceiveFailed { reason } => write!(f, "receive failed: {reason}"),
Self::BroadcastNotSupported => write!(f, "broadcast not supported"),
Self::NoProviderForAddress { addr_type } => {
write!(f, "no provider registered for address type: {addr_type}")
}
Self::Other { message } => write!(f, "{message}"),
}
}
}
impl std::error::Error for TransportError {}
#[derive(Debug, Clone)]
pub struct InboundDatagram {
pub data: Vec<u8>,
pub source: TransportAddr,
pub received_at: std::time::Instant,
pub link_quality: Option<LinkQuality>,
}
#[derive(Debug, Clone, Default)]
pub struct LinkQuality {
pub rssi: Option<i16>,
pub snr: Option<f32>,
pub hop_count: Option<u8>,
pub rtt: Option<Duration>,
}
#[derive(Debug, Clone, Default)]
pub struct TransportStats {
pub datagrams_sent: u64,
pub datagrams_received: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
pub send_errors: u64,
pub receive_errors: u64,
pub current_rtt: Option<Duration>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProtocolEngine {
Quic,
Constrained,
}
impl ProtocolEngine {
pub fn for_transport(caps: &TransportCapabilities) -> Self {
if caps.supports_full_quic() {
Self::Quic
} else {
Self::Constrained
}
}
}
impl fmt::Display for ProtocolEngine {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Quic => write!(f, "QUIC"),
Self::Constrained => write!(f, "Constrained"),
}
}
}
#[async_trait]
pub trait TransportProvider: Send + Sync + 'static {
fn name(&self) -> &str;
fn transport_type(&self) -> TransportType;
fn capabilities(&self) -> &TransportCapabilities;
fn local_addr(&self) -> Option<TransportAddr>;
async fn send(&self, data: &[u8], dest: &TransportAddr) -> Result<(), TransportError>;
fn inbound(&self) -> mpsc::Receiver<InboundDatagram>;
fn is_online(&self) -> bool;
async fn shutdown(&self) -> Result<(), TransportError>;
async fn broadcast(&self, _data: &[u8]) -> Result<(), TransportError> {
if !self.capabilities().broadcast {
return Err(TransportError::BroadcastNotSupported);
}
Err(TransportError::BroadcastNotSupported)
}
async fn link_quality(&self, _peer: &TransportAddr) -> Option<LinkQuality> {
None
}
fn stats(&self) -> TransportStats {
TransportStats::default()
}
fn protocol_engine(&self) -> ProtocolEngine {
ProtocolEngine::for_transport(self.capabilities())
}
fn socket(&self) -> Option<&Arc<tokio::net::UdpSocket>> {
None }
}
#[derive(Debug, Clone)]
pub struct TransportDiagnostics {
pub name: String,
pub transport_type: TransportType,
pub protocol_engine: ProtocolEngine,
pub bandwidth_class: super::capabilities::BandwidthClass,
pub current_rtt: Option<Duration>,
pub is_online: bool,
pub stats: TransportStats,
pub local_addr: Option<TransportAddr>,
}
impl TransportDiagnostics {
pub fn from_provider(provider: &dyn TransportProvider) -> Self {
let caps = provider.capabilities();
Self {
name: provider.name().to_string(),
transport_type: provider.transport_type(),
protocol_engine: provider.protocol_engine(),
bandwidth_class: caps.bandwidth_class(),
current_rtt: provider.stats().current_rtt,
is_online: provider.is_online(),
stats: provider.stats(),
local_addr: provider.local_addr(),
}
}
}
#[derive(Default, Clone)]
pub struct TransportRegistry {
providers: Vec<Arc<dyn TransportProvider>>,
}
impl TransportRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, provider: Arc<dyn TransportProvider>) {
self.providers.push(provider);
}
pub fn providers(&self) -> &[Arc<dyn TransportProvider>] {
&self.providers
}
pub fn providers_by_type(
&self,
transport_type: TransportType,
) -> Vec<Arc<dyn TransportProvider>> {
self.providers
.iter()
.filter(|p| p.transport_type() == transport_type)
.cloned()
.collect()
}
pub fn has_transport_type(&self, transport_type: TransportType) -> bool {
self.providers
.iter()
.any(|provider| provider.transport_type() == transport_type)
}
pub fn provider_for_addr(&self, addr: &TransportAddr) -> Option<Arc<dyn TransportProvider>> {
let target_type = addr.transport_type();
self.providers
.iter()
.find(|p| p.transport_type() == target_type && p.is_online())
.cloned()
}
pub fn online_providers(&self) -> impl Iterator<Item = Arc<dyn TransportProvider>> + '_ {
self.providers.iter().filter(|p| p.is_online()).cloned()
}
pub fn diagnostics(&self) -> Vec<TransportDiagnostics> {
self.providers
.iter()
.map(|p| TransportDiagnostics::from_provider(p.as_ref()))
.collect()
}
pub fn has_quic_capable_transport(&self) -> bool {
self.providers
.iter()
.any(|p| p.is_online() && p.capabilities().supports_full_quic())
}
pub fn len(&self) -> usize {
self.providers.len()
}
pub fn is_empty(&self) -> bool {
self.providers.is_empty()
}
pub fn get_udp_socket(&self) -> Option<Arc<tokio::net::UdpSocket>> {
for provider in &self.providers {
if provider.transport_type() == TransportType::Udp && provider.is_online() {
if let Some(socket) = provider.socket() {
return Some(socket.clone());
}
}
}
None
}
pub fn get_udp_local_addr(&self) -> Option<std::net::SocketAddr> {
for provider in &self.providers {
if provider.transport_type() == TransportType::Udp && provider.is_online() {
if let Some(TransportAddr::Udp(addr)) = provider.local_addr() {
return Some(addr);
}
}
}
None
}
pub async fn send(&self, data: &[u8], dest: &TransportAddr) -> Result<(), TransportError> {
let provider =
self.provider_for_addr(dest)
.ok_or(TransportError::NoProviderForAddress {
addr_type: dest.transport_type(),
})?;
provider.send(data, dest).await
}
}
impl fmt::Debug for TransportRegistry {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TransportRegistry")
.field("providers", &self.providers.len())
.field("online", &self.online_providers().count())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering};
#[allow(dead_code)]
struct MockTransport {
name: String,
transport_type: TransportType,
capabilities: TransportCapabilities,
online: AtomicBool,
local_addr: Option<TransportAddr>,
inbound_rx: tokio::sync::Mutex<Option<mpsc::Receiver<InboundDatagram>>>,
}
impl MockTransport {
fn new_udp() -> Self {
let (_, rx) = mpsc::channel(16);
Self {
name: "MockUDP".to_string(),
transport_type: TransportType::Udp,
capabilities: TransportCapabilities::broadband(),
online: AtomicBool::new(true),
local_addr: Some(TransportAddr::Udp("127.0.0.1:9000".parse().unwrap())),
inbound_rx: tokio::sync::Mutex::new(Some(rx)),
}
}
fn new_ble() -> Self {
let (_, rx) = mpsc::channel(16);
Self {
name: "MockBLE".to_string(),
transport_type: TransportType::Ble,
capabilities: TransportCapabilities::ble(),
online: AtomicBool::new(true),
local_addr: Some(TransportAddr::ble(
[0x00, 0x11, 0x22, 0x33, 0x44, 0x55],
None,
)),
inbound_rx: tokio::sync::Mutex::new(Some(rx)),
}
}
}
#[async_trait]
impl TransportProvider for MockTransport {
fn name(&self) -> &str {
&self.name
}
fn transport_type(&self) -> TransportType {
self.transport_type
}
fn capabilities(&self) -> &TransportCapabilities {
&self.capabilities
}
fn local_addr(&self) -> Option<TransportAddr> {
self.local_addr.clone()
}
async fn send(&self, data: &[u8], dest: &TransportAddr) -> Result<(), TransportError> {
if !self.online.load(Ordering::SeqCst) {
return Err(TransportError::Offline);
}
if dest.transport_type() != self.transport_type {
return Err(TransportError::AddressMismatch {
expected: self.transport_type,
actual: dest.transport_type(),
});
}
if data.len() > self.capabilities.mtu {
return Err(TransportError::MessageTooLarge {
size: data.len(),
mtu: self.capabilities.mtu,
});
}
Ok(())
}
fn inbound(&self) -> mpsc::Receiver<InboundDatagram> {
let (_, rx) = mpsc::channel(16);
rx
}
fn is_online(&self) -> bool {
self.online.load(Ordering::SeqCst)
}
async fn shutdown(&self) -> Result<(), TransportError> {
self.online.store(false, Ordering::SeqCst);
Ok(())
}
}
#[test]
fn test_protocol_engine_selection() {
let broadband = TransportCapabilities::broadband();
assert_eq!(
ProtocolEngine::for_transport(&broadband),
ProtocolEngine::Quic
);
let ble = TransportCapabilities::ble();
assert_eq!(
ProtocolEngine::for_transport(&ble),
ProtocolEngine::Constrained
);
let lora = TransportCapabilities::lora_long_range();
assert_eq!(
ProtocolEngine::for_transport(&lora),
ProtocolEngine::Constrained
);
}
#[tokio::test]
async fn test_mock_transport_send() {
let transport = MockTransport::new_udp();
let dest: SocketAddr = "192.168.1.1:9000".parse().unwrap();
let result = transport.send(b"hello", &TransportAddr::Udp(dest)).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_transport_address_mismatch() {
let transport = MockTransport::new_udp();
let dest = TransportAddr::ble([0x00, 0x11, 0x22, 0x33, 0x44, 0x55], None);
let result = transport.send(b"hello", &dest).await;
match result {
Err(TransportError::AddressMismatch { expected, actual }) => {
assert_eq!(expected, TransportType::Udp);
assert_eq!(actual, TransportType::Ble);
}
_ => panic!("expected AddressMismatch error"),
}
}
#[tokio::test]
async fn test_message_too_large() {
let transport = MockTransport::new_ble();
let large_data = vec![0u8; 500];
let dest = TransportAddr::ble([0x00, 0x11, 0x22, 0x33, 0x44, 0x55], None);
let result = transport.send(&large_data, &dest).await;
match result {
Err(TransportError::MessageTooLarge { size, mtu }) => {
assert_eq!(size, 500);
assert_eq!(mtu, 244);
}
_ => panic!("expected MessageTooLarge error"),
}
}
#[tokio::test]
async fn test_offline_transport() {
let transport = MockTransport::new_udp();
transport.shutdown().await.unwrap();
let dest: SocketAddr = "192.168.1.1:9000".parse().unwrap();
let result = transport.send(b"hello", &TransportAddr::Udp(dest)).await;
assert!(matches!(result, Err(TransportError::Offline)));
assert!(!transport.is_online());
}
#[test]
fn test_transport_registry() {
let mut registry = TransportRegistry::new();
assert!(registry.is_empty());
registry.register(Arc::new(MockTransport::new_udp()));
registry.register(Arc::new(MockTransport::new_ble()));
assert_eq!(registry.len(), 2);
assert!(!registry.is_empty());
let udp_providers = registry.providers_by_type(TransportType::Udp);
assert_eq!(udp_providers.len(), 1);
let ble_providers = registry.providers_by_type(TransportType::Ble);
assert_eq!(ble_providers.len(), 1);
let lora_providers = registry.providers_by_type(TransportType::LoRa);
assert!(lora_providers.is_empty());
}
#[test]
fn test_provider_for_addr() {
let mut registry = TransportRegistry::new();
registry.register(Arc::new(MockTransport::new_udp()));
registry.register(Arc::new(MockTransport::new_ble()));
let udp_addr: SocketAddr = "192.168.1.1:9000".parse().unwrap();
let provider = registry.provider_for_addr(&TransportAddr::Udp(udp_addr));
assert!(provider.is_some());
assert_eq!(provider.unwrap().transport_type(), TransportType::Udp);
let ble_addr = TransportAddr::ble([0x00, 0x11, 0x22, 0x33, 0x44, 0x55], None);
let provider = registry.provider_for_addr(&ble_addr);
assert!(provider.is_some());
assert_eq!(provider.unwrap().transport_type(), TransportType::Ble);
let lora_addr = TransportAddr::lora([0xDE, 0xAD, 0xBE, 0xEF]);
let provider = registry.provider_for_addr(&lora_addr);
assert!(provider.is_none());
}
#[test]
fn test_quic_capable_check() {
let mut registry = TransportRegistry::new();
registry.register(Arc::new(MockTransport::new_udp()));
assert!(registry.has_quic_capable_transport());
let mut ble_only = TransportRegistry::new();
ble_only.register(Arc::new(MockTransport::new_ble()));
assert!(!ble_only.has_quic_capable_transport());
}
#[test]
fn test_transport_diagnostics() {
let transport = MockTransport::new_udp();
let diag = TransportDiagnostics::from_provider(&transport);
assert_eq!(diag.name, "MockUDP");
assert_eq!(diag.transport_type, TransportType::Udp);
assert_eq!(diag.protocol_engine, ProtocolEngine::Quic);
assert!(diag.is_online);
assert!(diag.local_addr.is_some());
}
#[test]
fn test_transport_error_display() {
let err = TransportError::AddressMismatch {
expected: TransportType::Udp,
actual: TransportType::Ble,
};
assert!(format!("{err}").contains("UDP"));
assert!(format!("{err}").contains("BLE"));
let err = TransportError::MessageTooLarge {
size: 1000,
mtu: 500,
};
assert!(format!("{err}").contains("1000"));
assert!(format!("{err}").contains("500"));
}
#[test]
fn test_link_quality_default() {
let quality = LinkQuality::default();
assert!(quality.rssi.is_none());
assert!(quality.snr.is_none());
assert!(quality.hop_count.is_none());
assert!(quality.rtt.is_none());
}
#[test]
fn test_online_providers_filters_offline() {
let mut registry = TransportRegistry::new();
let udp_online = Arc::new(MockTransport::new_udp());
let ble_online = Arc::new(MockTransport::new_ble());
let udp_offline = Arc::new(MockTransport::new_udp());
udp_offline.online.store(false, Ordering::SeqCst);
registry.register(udp_online.clone());
registry.register(ble_online.clone());
registry.register(udp_offline);
assert_eq!(registry.len(), 3);
let online: Vec<_> = registry.online_providers().collect();
assert_eq!(online.len(), 2);
let online_types: Vec<_> = online.iter().map(|p| p.transport_type()).collect();
assert!(online_types.contains(&TransportType::Udp));
assert!(online_types.contains(&TransportType::Ble));
}
#[test]
fn test_online_providers_empty_when_all_offline() {
let mut registry = TransportRegistry::new();
let udp_provider = Arc::new(MockTransport::new_udp());
let ble_provider = Arc::new(MockTransport::new_ble());
udp_provider.online.store(false, Ordering::SeqCst);
ble_provider.online.store(false, Ordering::SeqCst);
registry.register(udp_provider);
registry.register(ble_provider);
assert_eq!(registry.len(), 2);
let online: Vec<_> = registry.online_providers().collect();
assert_eq!(online.len(), 0);
}
#[test]
fn test_get_provider_by_type() {
let mut registry = TransportRegistry::new();
registry.register(Arc::new(MockTransport::new_udp()));
registry.register(Arc::new(MockTransport::new_ble()));
let udp_providers = registry.providers_by_type(TransportType::Udp);
assert_eq!(udp_providers.len(), 1);
assert_eq!(udp_providers[0].transport_type(), TransportType::Udp);
assert_eq!(udp_providers[0].name(), "MockUDP");
let ble_providers = registry.providers_by_type(TransportType::Ble);
assert_eq!(ble_providers.len(), 1);
assert_eq!(ble_providers[0].transport_type(), TransportType::Ble);
assert_eq!(ble_providers[0].name(), "MockBLE");
let lora_providers = registry.providers_by_type(TransportType::LoRa);
assert_eq!(lora_providers.len(), 0);
}
#[test]
fn test_registry_default_includes_udp() {
let mut registry = TransportRegistry::new();
registry.register(Arc::new(MockTransport::new_udp()));
assert_eq!(registry.len(), 1);
let udp_providers = registry.providers_by_type(TransportType::Udp);
assert_eq!(udp_providers.len(), 1);
let provider = &udp_providers[0];
assert!(provider.is_online());
assert_eq!(provider.transport_type(), TransportType::Udp);
assert!(provider.capabilities().supports_full_quic());
}
}