use std::collections::HashSet;
use std::fmt;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::time::{Duration, SystemTime};
use async_trait::async_trait;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::sync::broadcast;
use crate::nat_traversal_api::PeerId;
use crate::reachability::ReachabilityScope;
use crate::transport::TransportAddr;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[repr(u8)]
pub enum StreamType {
Membership = 0x00,
PubSub = 0x01,
GossipBulk = 0x02,
DhtQuery = 0x10,
DhtStore = 0x11,
DhtWitness = 0x12,
DhtReplication = 0x13,
WebRtcSignal = 0x20,
WebRtcMedia = 0x21,
WebRtcData = 0x22,
Reserved = 0xF0,
}
impl StreamType {
#[inline]
pub fn from_byte(byte: u8) -> Option<Self> {
match byte {
0x00 => Some(Self::Membership),
0x01 => Some(Self::PubSub),
0x02 => Some(Self::GossipBulk),
0x10 => Some(Self::DhtQuery),
0x11 => Some(Self::DhtStore),
0x12 => Some(Self::DhtWitness),
0x13 => Some(Self::DhtReplication),
0x20 => Some(Self::WebRtcSignal),
0x21 => Some(Self::WebRtcMedia),
0x22 => Some(Self::WebRtcData),
0xF0 => Some(Self::Reserved),
_ => None,
}
}
#[inline]
pub const fn as_byte(self) -> u8 {
self as u8
}
#[inline]
pub const fn family(self) -> StreamTypeFamily {
match self as u8 {
0x00..=0x0F => StreamTypeFamily::Gossip,
0x10..=0x1F => StreamTypeFamily::Dht,
0x20..=0x2F => StreamTypeFamily::WebRtc,
_ => StreamTypeFamily::Reserved,
}
}
#[inline]
pub const fn is_gossip(self) -> bool {
matches!(self.family(), StreamTypeFamily::Gossip)
}
#[inline]
pub const fn is_dht(self) -> bool {
matches!(self.family(), StreamTypeFamily::Dht)
}
#[inline]
pub const fn is_webrtc(self) -> bool {
matches!(self.family(), StreamTypeFamily::WebRtc)
}
pub const fn gossip_types() -> &'static [StreamType] {
&[Self::Membership, Self::PubSub, Self::GossipBulk]
}
pub const fn dht_types() -> &'static [StreamType] {
&[
Self::DhtQuery,
Self::DhtStore,
Self::DhtWitness,
Self::DhtReplication,
]
}
pub const fn webrtc_types() -> &'static [StreamType] {
&[Self::WebRtcSignal, Self::WebRtcMedia, Self::WebRtcData]
}
pub const fn all_types() -> &'static [StreamType] {
&[
Self::Membership,
Self::PubSub,
Self::GossipBulk,
Self::DhtQuery,
Self::DhtStore,
Self::DhtWitness,
Self::DhtReplication,
Self::WebRtcSignal,
Self::WebRtcMedia,
Self::WebRtcData,
Self::Reserved,
]
}
}
impl fmt::Display for StreamType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Membership => write!(f, "Membership"),
Self::PubSub => write!(f, "PubSub"),
Self::GossipBulk => write!(f, "GossipBulk"),
Self::DhtQuery => write!(f, "DhtQuery"),
Self::DhtStore => write!(f, "DhtStore"),
Self::DhtWitness => write!(f, "DhtWitness"),
Self::DhtReplication => write!(f, "DhtReplication"),
Self::WebRtcSignal => write!(f, "WebRtcSignal"),
Self::WebRtcMedia => write!(f, "WebRtcMedia"),
Self::WebRtcData => write!(f, "WebRtcData"),
Self::Reserved => write!(f, "Reserved"),
}
}
}
impl From<StreamType> for u8 {
fn from(st: StreamType) -> Self {
st as u8
}
}
impl TryFrom<u8> for StreamType {
type Error = LinkError;
fn try_from(byte: u8) -> Result<Self, Self::Error> {
Self::from_byte(byte).ok_or(LinkError::InvalidStreamType(byte))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StreamTypeFamily {
Gossip,
Dht,
WebRtc,
Reserved,
}
impl StreamTypeFamily {
pub const fn byte_range(self) -> (u8, u8) {
match self {
Self::Gossip => (0x00, 0x0F),
Self::Dht => (0x10, 0x1F),
Self::WebRtc => (0x20, 0x2F),
Self::Reserved => (0xF0, 0xFF),
}
}
pub const fn contains(self, byte: u8) -> bool {
let (start, end) = self.byte_range();
byte >= start && byte <= end
}
}
#[derive(Debug, Clone, Default)]
pub struct StreamFilter {
allowed: HashSet<StreamType>,
}
impl StreamFilter {
pub fn new() -> Self {
Self::default()
}
pub fn accept_all() -> Self {
let mut filter = Self::new();
for st in StreamType::all_types() {
filter.allowed.insert(*st);
}
filter
}
pub fn gossip_only() -> Self {
Self::new().with_types(StreamType::gossip_types())
}
pub fn dht_only() -> Self {
Self::new().with_types(StreamType::dht_types())
}
pub fn webrtc_only() -> Self {
Self::new().with_types(StreamType::webrtc_types())
}
pub fn with_type(mut self, stream_type: StreamType) -> Self {
self.allowed.insert(stream_type);
self
}
pub fn with_types(mut self, stream_types: &[StreamType]) -> Self {
for st in stream_types {
self.allowed.insert(*st);
}
self
}
pub fn accepts(&self, stream_type: StreamType) -> bool {
self.allowed.is_empty() || self.allowed.contains(&stream_type)
}
pub fn accepts_all(&self) -> bool {
self.allowed.is_empty()
}
pub fn allowed_types(&self) -> &HashSet<StreamType> {
&self.allowed
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct ProtocolId(pub [u8; 16]);
impl ProtocolId {
#[inline]
pub const fn new(bytes: [u8; 16]) -> Self {
Self(bytes)
}
#[inline]
pub const fn from_static(s: &[u8]) -> Self {
let mut bytes = [0u8; 16];
let len = if s.len() < 16 { s.len() } else { 16 };
let mut i = 0;
while i < len {
bytes[i] = s[i];
i += 1;
}
Self(bytes)
}
#[inline]
pub const fn as_bytes(&self) -> &[u8; 16] {
&self.0
}
pub const DEFAULT: Self = Self::from_static(b"ant-quic/default");
pub const NAT_TRAVERSAL: Self = Self::from_static(b"ant-quic/nat");
pub const RELAY: Self = Self::from_static(b"ant-quic/relay");
}
impl Default for ProtocolId {
fn default() -> Self {
Self::DEFAULT
}
}
impl fmt::Debug for ProtocolId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let end = self.0.iter().position(|&b| b == 0).unwrap_or(16);
if let Ok(s) = std::str::from_utf8(&self.0[..end]) {
write!(f, "ProtocolId({:?})", s)
} else {
write!(f, "ProtocolId({:?})", hex::encode(self.0))
}
}
}
impl fmt::Display for ProtocolId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let end = self.0.iter().position(|&b| b == 0).unwrap_or(16);
if let Ok(s) = std::str::from_utf8(&self.0[..end]) {
write!(f, "{}", s)
} else {
write!(f, "{}", hex::encode(self.0))
}
}
}
impl From<&str> for ProtocolId {
fn from(s: &str) -> Self {
Self::from_static(s.as_bytes())
}
}
impl From<[u8; 16]> for ProtocolId {
fn from(bytes: [u8; 16]) -> Self {
Self(bytes)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum NatHint {
None,
FullCone,
AddressRestrictedCone,
PortRestrictedCone,
Symmetric,
#[default]
Unknown,
}
#[derive(Debug, Clone)]
pub struct Capabilities {
pub supports_relay: bool,
pub supports_coordination: bool,
pub observed_addrs: Vec<SocketAddr>,
pub direct_reachability_scope: Option<ReachabilityScope>,
pub protocols: Vec<ProtocolId>,
pub last_seen: SystemTime,
pub rtt_ms_p50: u32,
pub rtt_jitter_ms: u32,
pub packet_loss: f32,
pub nat_type_hint: Option<NatHint>,
pub bandwidth_limit: Option<u64>,
pub successful_connections: u32,
pub failed_connections: u32,
pub is_connected: bool,
}
impl Default for Capabilities {
fn default() -> Self {
Self {
supports_relay: false,
supports_coordination: false,
observed_addrs: Vec::new(),
direct_reachability_scope: None,
protocols: Vec::new(),
last_seen: SystemTime::UNIX_EPOCH,
rtt_ms_p50: 0,
rtt_jitter_ms: 0,
packet_loss: 0.0,
nat_type_hint: None,
bandwidth_limit: None,
successful_connections: 0,
failed_connections: 0,
is_connected: false,
}
}
}
impl Capabilities {
pub fn new_connected(addr: SocketAddr) -> Self {
Self {
observed_addrs: vec![addr],
direct_reachability_scope: None,
last_seen: SystemTime::now(),
is_connected: true,
..Default::default()
}
}
pub fn quality_score(&self) -> f32 {
let mut score = 0.5;
let rtt_score = 1.0 - (self.rtt_ms_p50 as f32 / 300.0).min(1.0);
score += rtt_score * 0.3;
let loss_score = 1.0 - self.packet_loss;
score += loss_score * 0.2;
let total = self.successful_connections + self.failed_connections;
if total > 0 {
let success_rate = self.successful_connections as f32 / total as f32;
score += success_rate * 0.2;
}
if self.supports_relay {
score += 0.05;
}
if self.supports_coordination {
score += 0.05;
}
if let Some(nat) = self.nat_type_hint {
match nat {
NatHint::None | NatHint::FullCone => {}
NatHint::AddressRestrictedCone | NatHint::PortRestrictedCone => {
score -= 0.05;
}
NatHint::Symmetric => {
score -= 0.15;
}
NatHint::Unknown => {
score -= 0.02;
}
}
}
score.clamp(0.0, 1.0)
}
pub fn supports_protocol(&self, proto: &ProtocolId) -> bool {
self.protocols.contains(proto)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DisconnectReason {
LocalClose,
RemoteClose,
Timeout,
TransportError(String),
ApplicationError(u64),
Reset,
}
#[derive(Debug, Clone)]
pub enum LinkEvent {
PeerConnected {
peer: PeerId,
caps: Capabilities,
},
PeerDisconnected {
peer: PeerId,
reason: DisconnectReason,
},
ExternalAddressUpdated {
addr: TransportAddr,
},
CapabilityUpdated {
peer: PeerId,
caps: Capabilities,
},
RelayRequest {
from: PeerId,
to: PeerId,
budget_bytes: u64,
},
CoordinationRequest {
peer_a: PeerId,
peer_b: PeerId,
round: u64,
},
BootstrapCacheUpdated {
peer_count: usize,
},
}
#[async_trait]
pub trait ProtocolHandler: Send + Sync {
fn stream_types(&self) -> &[StreamType];
async fn handle_stream(
&self,
peer: PeerId,
stream_type: StreamType,
data: Bytes,
) -> LinkResult<Option<Bytes>>;
async fn handle_datagram(
&self,
_peer: PeerId,
_stream_type: StreamType,
_data: Bytes,
) -> LinkResult<()> {
Ok(())
}
async fn shutdown(&self) -> LinkResult<()> {
Ok(())
}
fn name(&self) -> &str {
"ProtocolHandler"
}
}
pub type BoxedHandler = Box<dyn ProtocolHandler>;
pub trait ProtocolHandlerExt: ProtocolHandler + Sized + 'static {
fn boxed(self) -> BoxedHandler {
Box::new(self)
}
}
impl<T: ProtocolHandler + 'static> ProtocolHandlerExt for T {}
#[derive(Debug, Error, Clone)]
pub enum LinkError {
#[error("connection closed")]
ConnectionClosed,
#[error("connection failed: {0}")]
ConnectionFailed(String),
#[error("peer not found: {0}")]
PeerNotFound(String),
#[error("protocol not supported: {0}")]
ProtocolNotSupported(ProtocolId),
#[error("operation timed out")]
Timeout,
#[error("stream reset: error code {0}")]
StreamReset(u64),
#[error("I/O error: {0}")]
Io(String),
#[error("transport shutdown")]
Shutdown,
#[error("rate limit exceeded")]
RateLimited,
#[error("internal error: {0}")]
Internal(String),
#[error("invalid stream type byte: 0x{0:02x}")]
InvalidStreamType(u8),
#[error("stream type {0} not accepted")]
StreamTypeFiltered(StreamType),
#[error("handler already exists for stream type: {0}")]
HandlerExists(StreamType),
#[error("no handler for stream type: {0}")]
NoHandler(StreamType),
#[error("transport not running")]
NotRunning,
#[error("transport already running")]
AlreadyRunning,
}
impl From<std::io::Error> for LinkError {
fn from(e: std::io::Error) -> Self {
Self::Io(e.to_string())
}
}
pub type LinkResult<T> = Result<T, LinkError>;
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub type BoxStream<'a, T> = Pin<Box<dyn futures_util::Stream<Item = T> + Send + 'a>>;
pub trait LinkConn: Send + Sync {
fn peer(&self) -> PeerId;
fn remote_addr(&self) -> SocketAddr;
fn open_uni(&self) -> BoxFuture<'_, LinkResult<Box<dyn LinkSendStream>>>;
fn open_bi(
&self,
) -> BoxFuture<'_, LinkResult<(Box<dyn LinkSendStream>, Box<dyn LinkRecvStream>)>>;
fn open_uni_typed(
&self,
stream_type: StreamType,
) -> BoxFuture<'_, LinkResult<Box<dyn LinkSendStream>>>;
fn open_bi_typed(
&self,
stream_type: StreamType,
) -> BoxFuture<'_, LinkResult<(Box<dyn LinkSendStream>, Box<dyn LinkRecvStream>)>>;
fn accept_uni_typed(
&self,
filter: StreamFilter,
) -> BoxStream<'_, LinkResult<(StreamType, Box<dyn LinkRecvStream>)>>;
fn accept_bi_typed(
&self,
filter: StreamFilter,
) -> BoxStream<'_, LinkResult<(StreamType, Box<dyn LinkSendStream>, Box<dyn LinkRecvStream>)>>;
fn send_datagram(&self, data: Bytes) -> LinkResult<()>;
fn recv_datagrams(&self) -> BoxStream<'_, Bytes>;
fn close(&self, error_code: u64, reason: &str);
fn is_open(&self) -> bool;
fn stats(&self) -> ConnectionStats;
}
#[derive(Debug, Clone, Default)]
pub struct ConnectionStats {
pub bytes_sent: u64,
pub bytes_received: u64,
pub rtt: Duration,
pub connected_duration: Duration,
pub streams_opened: u64,
pub packets_lost: u64,
}
pub trait LinkSendStream: Send + Sync {
fn write<'a>(&'a mut self, data: &'a [u8]) -> BoxFuture<'a, LinkResult<usize>>;
fn write_all<'a>(&'a mut self, data: &'a [u8]) -> BoxFuture<'a, LinkResult<()>>;
fn finish(&mut self) -> LinkResult<()>;
fn reset(&mut self, error_code: u64) -> LinkResult<()>;
fn id(&self) -> u64;
}
pub trait LinkRecvStream: Send + Sync {
fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> BoxFuture<'a, LinkResult<Option<usize>>>;
fn read_to_end(&mut self, size_limit: usize) -> BoxFuture<'_, LinkResult<Vec<u8>>>;
fn stop(&mut self, error_code: u64) -> LinkResult<()>;
fn id(&self) -> u64;
}
pub type Incoming<C> = BoxStream<'static, LinkResult<C>>;
pub trait LinkTransport: Send + Sync + 'static {
type Conn: LinkConn + 'static;
fn local_peer(&self) -> PeerId;
fn external_address(&self) -> Option<SocketAddr>;
fn peer_table(&self) -> Vec<(PeerId, Capabilities)>;
fn peer_capabilities(&self, peer: &PeerId) -> Option<Capabilities>;
fn subscribe(&self) -> broadcast::Receiver<LinkEvent>;
fn accept(&self, proto: ProtocolId) -> Incoming<Self::Conn>;
fn dial(&self, peer: PeerId, proto: ProtocolId) -> BoxFuture<'_, LinkResult<Self::Conn>>;
fn dial_addr(
&self,
addr: SocketAddr,
proto: ProtocolId,
) -> BoxFuture<'_, LinkResult<Self::Conn>>;
fn supported_protocols(&self) -> Vec<ProtocolId>;
fn register_protocol(&self, proto: ProtocolId);
fn unregister_protocol(&self, proto: ProtocolId);
fn is_connected(&self, peer: &PeerId) -> bool;
fn active_connections(&self) -> usize;
fn shutdown(&self) -> BoxFuture<'_, ()>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_protocol_id_from_string() {
let proto = ProtocolId::from("saorsa-dht/1.0");
assert_eq!(&proto.0[..14], b"saorsa-dht/1.0");
assert_eq!(proto.0[14], 0);
assert_eq!(proto.0[15], 0);
}
#[test]
fn test_protocol_id_truncation() {
let proto = ProtocolId::from("this-is-a-very-long-protocol-name");
assert_eq!(&proto.0, b"this-is-a-very-l");
}
#[test]
fn test_protocol_id_display() {
let proto = ProtocolId::from("test/1.0");
assert_eq!(format!("{}", proto), "test/1.0");
}
#[test]
fn test_capabilities_quality_score() {
let mut caps = Capabilities::default();
let base_score = caps.quality_score();
assert!(
(0.9..=1.0).contains(&base_score),
"base_score = {}",
base_score
);
caps.rtt_ms_p50 = 150; let worse_rtt_score = caps.quality_score();
assert!(
worse_rtt_score < base_score,
"worse RTT should reduce score"
);
caps.rtt_ms_p50 = 500;
let bad_rtt_score = caps.quality_score();
assert!(
bad_rtt_score < worse_rtt_score,
"bad RTT should reduce score more"
);
caps.rtt_ms_p50 = 50;
caps.nat_type_hint = Some(NatHint::Symmetric);
let nat_score = caps.quality_score();
caps.nat_type_hint = None;
caps.rtt_ms_p50 = 50;
let no_nat_score = caps.quality_score();
assert!(
nat_score < no_nat_score,
"symmetric NAT should reduce score"
);
}
#[test]
fn test_capabilities_supports_protocol() {
let mut caps = Capabilities::default();
let dht = ProtocolId::from("dht/1.0");
let gossip = ProtocolId::from("gossip/1.0");
caps.protocols.push(dht);
assert!(caps.supports_protocol(&dht));
assert!(!caps.supports_protocol(&gossip));
}
#[test]
fn test_stream_type_bytes() {
assert_eq!(StreamType::Membership.as_byte(), 0x00);
assert_eq!(StreamType::PubSub.as_byte(), 0x01);
assert_eq!(StreamType::GossipBulk.as_byte(), 0x02);
assert_eq!(StreamType::DhtQuery.as_byte(), 0x10);
assert_eq!(StreamType::DhtStore.as_byte(), 0x11);
assert_eq!(StreamType::DhtWitness.as_byte(), 0x12);
assert_eq!(StreamType::DhtReplication.as_byte(), 0x13);
assert_eq!(StreamType::WebRtcSignal.as_byte(), 0x20);
assert_eq!(StreamType::WebRtcMedia.as_byte(), 0x21);
assert_eq!(StreamType::WebRtcData.as_byte(), 0x22);
assert_eq!(StreamType::Reserved.as_byte(), 0xF0);
}
#[test]
fn test_stream_type_from_byte() {
assert_eq!(StreamType::from_byte(0x00), Some(StreamType::Membership));
assert_eq!(StreamType::from_byte(0x10), Some(StreamType::DhtQuery));
assert_eq!(StreamType::from_byte(0x20), Some(StreamType::WebRtcSignal));
assert_eq!(StreamType::from_byte(0xF0), Some(StreamType::Reserved));
assert_eq!(StreamType::from_byte(0x99), None); assert_eq!(StreamType::from_byte(0xFF), None); }
#[test]
fn test_stream_type_families() {
assert!(StreamType::Membership.is_gossip());
assert!(StreamType::PubSub.is_gossip());
assert!(StreamType::GossipBulk.is_gossip());
assert!(StreamType::DhtQuery.is_dht());
assert!(StreamType::DhtStore.is_dht());
assert!(StreamType::DhtWitness.is_dht());
assert!(StreamType::DhtReplication.is_dht());
assert!(StreamType::WebRtcSignal.is_webrtc());
assert!(StreamType::WebRtcMedia.is_webrtc());
assert!(StreamType::WebRtcData.is_webrtc());
}
#[test]
fn test_stream_type_family_ranges() {
assert!(StreamTypeFamily::Gossip.contains(0x00));
assert!(StreamTypeFamily::Gossip.contains(0x0F));
assert!(!StreamTypeFamily::Gossip.contains(0x10));
assert!(StreamTypeFamily::Dht.contains(0x10));
assert!(StreamTypeFamily::Dht.contains(0x1F));
assert!(!StreamTypeFamily::Dht.contains(0x20));
assert!(StreamTypeFamily::WebRtc.contains(0x20));
assert!(StreamTypeFamily::WebRtc.contains(0x2F));
assert!(!StreamTypeFamily::WebRtc.contains(0x30));
}
#[test]
fn test_stream_filter_accepts() {
let filter = StreamFilter::new()
.with_type(StreamType::Membership)
.with_type(StreamType::DhtQuery);
assert!(filter.accepts(StreamType::Membership));
assert!(filter.accepts(StreamType::DhtQuery));
assert!(!filter.accepts(StreamType::PubSub));
assert!(!filter.accepts(StreamType::WebRtcMedia));
}
#[test]
fn test_stream_filter_empty_accepts_all() {
let filter = StreamFilter::new();
assert!(filter.accepts_all());
assert!(filter.accepts(StreamType::Membership));
assert!(filter.accepts(StreamType::DhtQuery));
assert!(filter.accepts(StreamType::WebRtcMedia));
}
#[test]
fn test_stream_filter_presets() {
let gossip = StreamFilter::gossip_only();
assert!(gossip.accepts(StreamType::Membership));
assert!(gossip.accepts(StreamType::PubSub));
assert!(gossip.accepts(StreamType::GossipBulk));
assert!(!gossip.accepts(StreamType::DhtQuery));
let dht = StreamFilter::dht_only();
assert!(dht.accepts(StreamType::DhtQuery));
assert!(dht.accepts(StreamType::DhtStore));
assert!(!dht.accepts(StreamType::Membership));
let webrtc = StreamFilter::webrtc_only();
assert!(webrtc.accepts(StreamType::WebRtcSignal));
assert!(webrtc.accepts(StreamType::WebRtcMedia));
assert!(!webrtc.accepts(StreamType::DhtQuery));
}
#[test]
fn test_stream_type_display() {
assert_eq!(format!("{}", StreamType::Membership), "Membership");
assert_eq!(format!("{}", StreamType::DhtQuery), "DhtQuery");
assert_eq!(format!("{}", StreamType::WebRtcMedia), "WebRtcMedia");
}
mod protocol_handler_tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
struct TestHandler {
types: Vec<StreamType>,
call_count: Arc<AtomicUsize>,
}
impl TestHandler {
fn new(types: Vec<StreamType>) -> Self {
Self {
types,
call_count: Arc::new(AtomicUsize::new(0)),
}
}
fn with_counter(types: Vec<StreamType>, counter: Arc<AtomicUsize>) -> Self {
Self {
types,
call_count: counter,
}
}
}
#[async_trait]
impl ProtocolHandler for TestHandler {
fn stream_types(&self) -> &[StreamType] {
&self.types
}
async fn handle_stream(
&self,
_peer: PeerId,
_stream_type: StreamType,
data: Bytes,
) -> LinkResult<Option<Bytes>> {
self.call_count.fetch_add(1, Ordering::SeqCst);
Ok(Some(data)) }
fn name(&self) -> &str {
"TestHandler"
}
}
#[test]
fn test_handler_stream_types() {
let handler = TestHandler::new(vec![StreamType::Membership, StreamType::PubSub]);
assert_eq!(handler.stream_types().len(), 2);
assert!(handler.stream_types().contains(&StreamType::Membership));
assert!(handler.stream_types().contains(&StreamType::PubSub));
}
#[tokio::test]
async fn test_handler_returns_response() {
let handler = TestHandler::new(vec![StreamType::DhtQuery]);
let peer = PeerId::from([0u8; 32]);
let result = handler
.handle_stream(peer, StreamType::DhtQuery, Bytes::from_static(b"test"))
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), Some(Bytes::from_static(b"test")));
}
#[tokio::test]
async fn test_handler_no_response() {
struct SinkHandler;
#[async_trait]
impl ProtocolHandler for SinkHandler {
fn stream_types(&self) -> &[StreamType] {
&[StreamType::GossipBulk]
}
async fn handle_stream(
&self,
_peer: PeerId,
_stream_type: StreamType,
_data: Bytes,
) -> LinkResult<Option<Bytes>> {
Ok(None)
}
}
let handler = SinkHandler;
let peer = PeerId::from([0u8; 32]);
let result = handler
.handle_stream(peer, StreamType::GossipBulk, Bytes::from_static(b"data"))
.await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[tokio::test]
async fn test_handler_tracks_calls() {
let count = Arc::new(AtomicUsize::new(0));
let handler = TestHandler::with_counter(vec![StreamType::Membership], count.clone());
let peer = PeerId::from([0u8; 32]);
assert_eq!(handler.name(), "TestHandler");
assert_eq!(count.load(Ordering::SeqCst), 0);
let _ = handler
.handle_stream(peer, StreamType::Membership, Bytes::new())
.await;
assert_eq!(count.load(Ordering::SeqCst), 1);
let _ = handler
.handle_stream(peer, StreamType::Membership, Bytes::new())
.await;
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[test]
fn test_boxed_handler() {
let handler: BoxedHandler = TestHandler::new(vec![StreamType::DhtStore]).boxed();
assert_eq!(handler.stream_types(), &[StreamType::DhtStore]);
assert_eq!(handler.name(), "TestHandler");
}
#[tokio::test]
async fn test_default_datagram_handler() {
let handler = TestHandler::new(vec![StreamType::Membership]);
let peer = PeerId::from([0u8; 32]);
let result = handler
.handle_datagram(peer, StreamType::Membership, Bytes::from_static(b"dgram"))
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_default_shutdown() {
let handler = TestHandler::new(vec![StreamType::Membership]);
let result = handler.shutdown().await;
assert!(result.is_ok());
}
}
mod handler_error_tests {
use super::*;
#[test]
fn test_handler_exists_error() {
let err = LinkError::HandlerExists(StreamType::Membership);
let msg = err.to_string();
assert!(msg.contains("Membership"), "Error message: {}", msg);
assert!(
msg.to_lowercase().contains("handler"),
"Error message: {}",
msg
);
}
#[test]
fn test_no_handler_error() {
let err = LinkError::NoHandler(StreamType::DhtQuery);
let msg = err.to_string();
assert!(msg.contains("DhtQuery"), "Error message: {}", msg);
}
#[test]
fn test_not_running_error() {
let err = LinkError::NotRunning;
let msg = err.to_string();
assert!(
msg.to_lowercase().contains("not running"),
"Error message: {}",
msg
);
}
#[test]
fn test_already_running_error() {
let err = LinkError::AlreadyRunning;
let msg = err.to_string();
assert!(
msg.to_lowercase().contains("already running"),
"Error message: {}",
msg
);
}
}
}