use crate::MultiAddr;
use crate::PeerId;
use crate::bgp_geo_provider::BgpGeoProvider;
use crate::error::{NetworkError, P2PError, P2pResult as Result};
use crate::identity::node_identity::NodeIdentity;
use crate::network::{
ConnectionStatus, MAX_ACTIVE_REQUESTS, MAX_REQUEST_TIMEOUT, MESSAGE_RECV_CHANNEL_CAPACITY,
NetworkSender, P2PEvent, ParsedMessage, PeerInfo, PeerResponse, PendingRequest,
RequestResponseEnvelope, WireMessage, broadcast_event, normalize_wildcard_to_loopback,
parse_protocol_message, register_new_channel,
};
use crate::transport::observed_address_cache::ObservedAddressCache;
use crate::transport::saorsa_transport_adapter::{ConnectionEvent, DualStackNetworkNode};
use crate::validation::{RateLimitConfig, RateLimiter};
use std::collections::hash_map::DefaultHasher;
use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use tokio::sync::{RwLock, broadcast};
use tokio::task::JoinHandle;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, trace, warn};
use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
const TEST_EVENT_CHANNEL_CAPACITY: usize = 16;
const TEST_MAX_REQUESTS: u32 = 100;
const TEST_BURST_SIZE: u32 = 100;
const TEST_RATE_LIMIT_WINDOW_SECS: u64 = 1;
const TEST_CONNECTION_TIMEOUT_SECS: u64 = 30;
const IDENTITY_ANNOUNCE_PROTOCOL: &str = "/saorsa/identity/1.0";
pub struct TransportConfig {
pub listen_addrs: Vec<MultiAddr>,
pub connection_timeout: Duration,
pub max_connections: usize,
pub event_channel_capacity: usize,
pub max_message_size: Option<usize>,
pub node_identity: Arc<NodeIdentity>,
pub user_agent: String,
pub allow_loopback: bool,
}
impl TransportConfig {
pub fn from_node_config(
config: &crate::network::NodeConfig,
event_channel_capacity: usize,
node_identity: Arc<NodeIdentity>,
) -> Self {
Self {
listen_addrs: config.listen_addrs(),
connection_timeout: config.connection_timeout,
max_connections: config.max_connections,
event_channel_capacity,
max_message_size: config.max_message_size,
node_identity,
user_agent: config.user_agent(),
allow_loopback: config.allow_loopback,
}
}
}
pub struct TransportHandle {
dual_node: Arc<DualStackNetworkNode>,
peers: Arc<RwLock<HashMap<String, PeerInfo>>>,
active_connections: Arc<RwLock<HashSet<String>>>,
event_tx: broadcast::Sender<P2PEvent>,
listen_addrs: RwLock<Vec<MultiAddr>>,
rate_limiter: Arc<RateLimiter>,
active_requests: Arc<RwLock<HashMap<String, PendingRequest>>>,
#[allow(dead_code)]
geo_provider: Arc<BgpGeoProvider>,
shutdown: CancellationToken,
peer_address_update_rx:
tokio::sync::Mutex<tokio::sync::mpsc::Receiver<(SocketAddr, SocketAddr)>>,
relay_established_rx: tokio::sync::Mutex<tokio::sync::mpsc::Receiver<SocketAddr>>,
observed_address_cache: Arc<parking_lot::Mutex<ObservedAddressCache>>,
connection_timeout: Duration,
connection_monitor_handle: Arc<RwLock<Option<JoinHandle<()>>>>,
recv_handles: Arc<RwLock<Vec<JoinHandle<()>>>>,
listener_handle: Arc<RwLock<Option<JoinHandle<()>>>>,
node_identity: Arc<NodeIdentity>,
user_agent: String,
peer_to_channel: Arc<DashMap<PeerId, HashSet<String>>>,
channel_to_peers: Arc<DashMap<String, HashSet<PeerId>>>,
peer_user_agents: Arc<RwLock<HashMap<PeerId, String>>>,
}
impl TransportHandle {
pub async fn new(config: TransportConfig) -> Result<Self> {
let (event_tx, _) = broadcast::channel(config.event_channel_capacity);
let mut v4_opt: Option<SocketAddr> = None;
let mut v6_opt: Option<SocketAddr> = None;
for addr in &config.listen_addrs {
if let Some(sa) = addr.dialable_socket_addr() {
match sa.ip() {
std::net::IpAddr::V4(_) if v4_opt.is_none() => v4_opt = Some(sa),
std::net::IpAddr::V6(_) if v6_opt.is_none() => v6_opt = Some(sa),
_ => {} }
}
}
let dual_node = Arc::new(
DualStackNetworkNode::new_with_options(
v6_opt,
v4_opt,
config.max_connections,
config.max_message_size,
config.allow_loopback,
)
.await
.map_err(|e| {
P2PError::Transport(crate::error::TransportError::SetupFailed(
format!("Failed to create dual-stack network nodes: {}", e).into(),
))
})?,
);
let rate_limiter = Arc::new(RateLimiter::new(RateLimitConfig::default()));
let active_connections = Arc::new(RwLock::new(HashSet::new()));
let geo_provider = Arc::new(BgpGeoProvider::new());
let peers = Arc::new(RwLock::new(HashMap::new()));
let shutdown = CancellationToken::new();
let observed_address_cache = Arc::new(parking_lot::Mutex::new(ObservedAddressCache::new()));
let (peer_addr_update_rx, relay_established_rx) =
dual_node.spawn_peer_address_update_forwarder(Arc::clone(&observed_address_cache));
let connection_event_rx = dual_node.subscribe_connection_events();
let peer_to_channel = Arc::new(DashMap::new());
let channel_to_peers = Arc::new(DashMap::new());
let peer_user_agents: Arc<RwLock<HashMap<PeerId, String>>> =
Arc::new(RwLock::new(HashMap::new()));
let connection_monitor_handle = {
let active_conns = Arc::clone(&active_connections);
let peers_map = Arc::clone(&peers);
let event_tx_clone = event_tx.clone();
let dual_node_clone = Arc::clone(&dual_node);
let geo_provider_clone = Arc::clone(&geo_provider);
let shutdown_token = shutdown.clone();
let p2c = Arc::clone(&peer_to_channel);
let c2p = Arc::clone(&channel_to_peers);
let pua = Arc::clone(&peer_user_agents);
let identity_clone = config.node_identity.clone();
let user_agent_clone = config.user_agent.clone();
let handle = tokio::spawn(async move {
Self::connection_lifecycle_monitor_with_rx(
dual_node_clone,
connection_event_rx,
active_conns,
peers_map,
event_tx_clone,
geo_provider_clone,
shutdown_token,
p2c,
c2p,
pua,
identity_clone,
user_agent_clone,
)
.await;
});
Arc::new(RwLock::new(Some(handle)))
};
Ok(Self {
dual_node,
peers,
active_connections,
event_tx,
listen_addrs: RwLock::new(Vec::new()),
rate_limiter,
active_requests: Arc::new(RwLock::new(HashMap::new())),
geo_provider,
shutdown,
peer_address_update_rx: tokio::sync::Mutex::new(peer_addr_update_rx),
relay_established_rx: tokio::sync::Mutex::new(relay_established_rx),
observed_address_cache,
connection_timeout: config.connection_timeout,
connection_monitor_handle,
recv_handles: Arc::new(RwLock::new(Vec::new())),
listener_handle: Arc::new(RwLock::new(None)),
node_identity: config.node_identity,
user_agent: config.user_agent,
peer_to_channel,
channel_to_peers,
peer_user_agents,
})
}
pub fn new_for_tests() -> Result<Self> {
let identity = Arc::new(NodeIdentity::generate().map_err(|e| {
P2PError::Network(NetworkError::BindError(
format!("Failed to generate test node identity: {}", e).into(),
))
})?);
let (event_tx, _) = broadcast::channel(TEST_EVENT_CHANNEL_CAPACITY);
let dual_node = {
let v6: Option<SocketAddr> = "[::1]:0"
.parse()
.ok()
.or(Some(SocketAddr::from(([0, 0, 0, 0], 0))));
let v4: Option<SocketAddr> = "127.0.0.1:0".parse().ok();
let handle = tokio::runtime::Handle::current();
let dual_attempt = handle.block_on(DualStackNetworkNode::new(v6, v4));
let dual = match dual_attempt {
Ok(d) => d,
Err(_e1) => {
let fallback = handle
.block_on(DualStackNetworkNode::new(None, "127.0.0.1:0".parse().ok()));
match fallback {
Ok(d) => d,
Err(e2) => {
return Err(P2PError::Network(NetworkError::BindError(
format!("Failed to create dual-stack network node: {}", e2).into(),
)));
}
}
}
};
Arc::new(dual)
};
Ok(Self {
dual_node,
peers: Arc::new(RwLock::new(HashMap::new())),
active_connections: Arc::new(RwLock::new(HashSet::new())),
event_tx,
listen_addrs: RwLock::new(Vec::new()),
rate_limiter: Arc::new(RateLimiter::new(RateLimitConfig {
max_requests: TEST_MAX_REQUESTS,
burst_size: TEST_BURST_SIZE,
window: std::time::Duration::from_secs(TEST_RATE_LIMIT_WINDOW_SECS),
..Default::default()
})),
active_requests: Arc::new(RwLock::new(HashMap::new())),
geo_provider: Arc::new(BgpGeoProvider::new()),
shutdown: CancellationToken::new(),
peer_address_update_rx: {
let (_tx, rx) = tokio::sync::mpsc::channel(
crate::transport::saorsa_transport_adapter::ADDRESS_EVENT_CHANNEL_CAPACITY,
);
tokio::sync::Mutex::new(rx)
},
relay_established_rx: {
let (_tx, rx) = tokio::sync::mpsc::channel(
crate::transport::saorsa_transport_adapter::ADDRESS_EVENT_CHANNEL_CAPACITY,
);
tokio::sync::Mutex::new(rx)
},
observed_address_cache: Arc::new(parking_lot::Mutex::new(ObservedAddressCache::new())),
connection_timeout: Duration::from_secs(TEST_CONNECTION_TIMEOUT_SECS),
connection_monitor_handle: Arc::new(RwLock::new(None)),
recv_handles: Arc::new(RwLock::new(Vec::new())),
listener_handle: Arc::new(RwLock::new(None)),
node_identity: identity,
user_agent: crate::network::user_agent_for_mode(crate::network::NodeMode::Node),
peer_to_channel: Arc::new(DashMap::new()),
channel_to_peers: Arc::new(DashMap::new()),
peer_user_agents: Arc::new(RwLock::new(HashMap::new())),
})
}
}
impl TransportHandle {
pub fn peer_id(&self) -> PeerId {
*self.node_identity.peer_id()
}
pub fn node_identity(&self) -> &Arc<NodeIdentity> {
&self.node_identity
}
pub fn local_addr(&self) -> Option<MultiAddr> {
self.listen_addrs
.try_read()
.ok()
.and_then(|addrs| addrs.first().cloned())
}
pub async fn listen_addrs(&self) -> Vec<MultiAddr> {
self.listen_addrs.read().await.clone()
}
pub fn observed_external_address(&self) -> Option<SocketAddr> {
self.observed_external_addresses().into_iter().next()
}
pub fn observed_external_addresses(&self) -> Vec<SocketAddr> {
let mut out: Vec<SocketAddr> = self.dual_node.get_observed_external_addresses();
let cached = self
.observed_address_cache
.lock()
.most_frequent_recent_per_local_bind();
for addr in cached {
if !out.contains(&addr) {
out.push(addr);
}
}
out
}
pub fn cached_observed_external_address(&self) -> Option<SocketAddr> {
self.observed_address_cache.lock().most_frequent_recent()
}
pub fn connection_timeout(&self) -> Duration {
self.connection_timeout
}
}
impl TransportHandle {
pub async fn connected_peers(&self) -> Vec<PeerId> {
self.peer_to_channel
.iter()
.map(|entry| *entry.key())
.collect()
}
pub async fn connected_peer_addresses(&self, limit: usize) -> Vec<(SocketAddr, PeerId)> {
let mut result = Vec::new();
for entry in self.peer_to_channel.iter() {
if result.len() >= limit {
break;
}
let peer_id = *entry.key();
for channel_id in entry.value() {
if let Ok(sa) = channel_id.parse::<SocketAddr>() {
result.push((sa, peer_id));
break; }
}
}
result
}
pub async fn peer_count(&self) -> usize {
self.peer_to_channel.len()
}
pub async fn peer_user_agent(&self, peer_id: &PeerId) -> Option<String> {
self.peer_user_agents.read().await.get(peer_id).cloned()
}
#[allow(dead_code)]
pub(crate) async fn active_channels(&self) -> Vec<String> {
self.active_connections
.read()
.await
.iter()
.cloned()
.collect()
}
pub async fn peer_info(&self, peer_id: &PeerId) -> Option<PeerInfo> {
let channel = {
let entry = self.peer_to_channel.get(peer_id)?;
entry.iter().next().cloned()?
};
let peers = self.peers.read().await;
peers.get(&channel).cloned()
}
#[allow(dead_code)]
pub(crate) async fn peer_info_by_channel(&self, channel_id: &str) -> Option<PeerInfo> {
self.peers.read().await.get(channel_id).cloned()
}
#[allow(dead_code)]
pub(crate) async fn get_channel_id_by_address(&self, addr: &MultiAddr) -> Option<String> {
let target = addr.socket_addr()?;
let peers = self.peers.read().await;
for (channel_id, peer_info) in peers.iter() {
for peer_addr in &peer_info.addresses {
if peer_addr.socket_addr() == Some(target) {
return Some(channel_id.clone());
}
}
}
None
}
#[allow(dead_code)]
pub(crate) async fn list_active_connections(&self) -> Vec<(String, Vec<MultiAddr>)> {
let active = self.active_connections.read().await;
let peers = self.peers.read().await;
active
.iter()
.map(|peer_id| {
let addresses = peers
.get(peer_id)
.map(|info| info.addresses.clone())
.unwrap_or_default();
(peer_id.clone(), addresses)
})
.collect()
}
pub(crate) async fn remove_channel(&self, channel_id: &str) -> bool {
self.active_connections.write().await.remove(channel_id);
self.remove_channel_mappings(channel_id).await;
self.peers.write().await.remove(channel_id).is_some()
}
pub(crate) async fn disconnect_channel(&self, channel_id: &str) {
match channel_id.parse::<SocketAddr>() {
Ok(addr) => self.dual_node.disconnect_peer_by_addr(&addr).await,
Err(e) => {
warn!(
channel = %channel_id,
error = %e,
"Failed to parse channel ID as SocketAddr — QUIC connection will not be closed",
);
}
}
self.active_connections.write().await.remove(channel_id);
self.remove_channel_mappings(channel_id).await;
self.peers.write().await.remove(channel_id);
}
pub async fn peer_id_for_addr(&self, addr: &SocketAddr) -> Option<PeerId> {
let channel_id = addr.to_string();
if let Some(peer_id) = self
.channel_to_peers
.get(&channel_id)
.and_then(|p| p.iter().next().copied())
{
return Some(peer_id);
}
let alt_addr = saorsa_transport::shared::dual_stack_alternate(addr)?;
let alt_channel_id = alt_addr.to_string();
self.channel_to_peers
.get(&alt_channel_id)
.and_then(|p| p.iter().next().copied())
}
pub async fn drain_peer_address_updates(&self) -> Vec<(SocketAddr, SocketAddr)> {
let mut rx = self.peer_address_update_rx.lock().await;
let mut updates = Vec::new();
while let Ok(update) = rx.try_recv() {
updates.push(update);
}
updates
}
pub async fn drain_relay_established(&self) -> Option<SocketAddr> {
let mut rx = self.relay_established_rx.lock().await;
rx.try_recv().ok()
}
pub async fn recv_peer_address_update(&self) -> Option<(SocketAddr, SocketAddr)> {
let mut rx = self.peer_address_update_rx.lock().await;
rx.recv().await
}
pub async fn recv_relay_established(&self) -> Option<SocketAddr> {
let mut rx = self.relay_established_rx.lock().await;
rx.recv().await
}
pub async fn is_peer_connected(&self, peer_id: &PeerId) -> bool {
self.peer_to_channel.contains_key(peer_id)
}
pub(crate) async fn is_connection_active(&self, channel_id: &str) -> bool {
self.active_connections.read().await.contains(channel_id)
}
async fn remove_channel_mappings(&self, channel_id: &str) {
Self::remove_channel_mappings_static(
channel_id,
&self.peer_to_channel,
&self.channel_to_peers,
&self.peer_user_agents,
&self.event_tx,
)
.await;
}
async fn remove_channel_mappings_static(
channel_id: &str,
peer_to_channel: &DashMap<PeerId, HashSet<String>>,
channel_to_peers: &DashMap<String, HashSet<PeerId>>,
peer_user_agents: &RwLock<HashMap<PeerId, String>>,
event_tx: &broadcast::Sender<P2PEvent>,
) {
let app_peers = match channel_to_peers.remove(channel_id) {
Some((_, peers)) => peers,
None => return,
};
let mut fully_disconnected: Vec<PeerId> = Vec::new();
for app_peer in &app_peers {
let became_empty = {
if let Some(mut channels_ref) = peer_to_channel.get_mut(app_peer) {
channels_ref.remove(channel_id);
channels_ref.is_empty()
} else {
false
}
}; if became_empty
&& peer_to_channel
.remove_if(app_peer, |_, v| v.is_empty())
.is_some()
{
fully_disconnected.push(*app_peer);
}
}
if !fully_disconnected.is_empty() {
let mut pua = peer_user_agents.write().await;
for app_peer in fully_disconnected {
pua.remove(&app_peer);
let _ = event_tx.send(P2PEvent::PeerDisconnected(app_peer));
}
}
}
}
impl TransportHandle {
pub async fn set_hole_punch_target_peer_id(&self, target: SocketAddr, peer_id: [u8; 32]) {
self.dual_node
.set_hole_punch_target_peer_id(target, peer_id)
.await;
}
pub async fn set_hole_punch_preferred_coordinator(
&self,
target: SocketAddr,
coordinator: SocketAddr,
) {
self.dual_node
.set_hole_punch_preferred_coordinator(target, coordinator)
.await;
}
pub async fn set_hole_punch_preferred_coordinators(
&self,
target: SocketAddr,
coordinators: Vec<SocketAddr>,
) {
self.dual_node
.set_hole_punch_preferred_coordinators(target, coordinators)
.await;
}
pub async fn connect_peer(&self, address: &MultiAddr) -> Result<String> {
let socket_addr = address.dialable_socket_addr().ok_or_else(|| {
P2PError::Network(NetworkError::InvalidAddress(
format!(
"only QUIC transport is supported for connect, got {}: {}",
address.transport().kind(),
address
)
.into(),
))
})?;
let normalized_addr = normalize_wildcard_to_loopback(socket_addr);
let addr_list = vec![normalized_addr];
let peer_id = match tokio::time::timeout(
self.connection_timeout,
self.dual_node.connect_happy_eyeballs(&addr_list),
)
.await
{
Ok(Ok(addr)) => {
let connected_peer_id = addr.to_string();
let is_self = {
let addrs = self.listen_addrs.read().await;
addrs.iter().any(|a| a.socket_addr() == Some(addr))
};
if is_self {
warn!(
"Detected self-connection to own address {} (channel_id: {}), rejecting",
address, connected_peer_id
);
self.dual_node.disconnect_peer_by_addr(&addr).await;
return Err(P2PError::Network(NetworkError::InvalidAddress(
format!("Cannot connect to self ({})", address).into(),
)));
}
info!("Successfully connected to channel: {}", connected_peer_id);
connected_peer_id
}
Ok(Err(e)) => {
warn!("connect_happy_eyeballs failed for {}: {}", address, e);
return Err(P2PError::Transport(
crate::error::TransportError::ConnectionFailed {
addr: normalized_addr,
reason: e.to_string().into(),
},
));
}
Err(_) => {
warn!(
"connect_happy_eyeballs timed out for {} after {:?}",
address, self.connection_timeout
);
return Err(P2PError::Timeout(self.connection_timeout));
}
};
let peer_info = PeerInfo {
channel_id: peer_id.clone(),
addresses: vec![address.clone()],
connected_at: Instant::now(),
last_seen: Instant::now(),
status: ConnectionStatus::Connected,
protocols: vec!["p2p-foundation/1.0".to_string()],
heartbeat_count: 0,
};
self.peers.write().await.insert(peer_id.clone(), peer_info);
self.active_connections
.write()
.await
.insert(peer_id.clone());
Ok(peer_id)
}
pub async fn disconnect_peer(&self, peer_id: &PeerId) -> Result<()> {
info!("Disconnecting from peer: {}", peer_id);
const MAX_DISCONNECT_ROUNDS: usize = 3;
let first_channels = match self.peer_to_channel.remove(peer_id) {
Some((_, chs)) => chs,
None => {
info!(
"Peer {} has no tracked channels, nothing to disconnect",
peer_id
);
return Ok(());
}
};
let mut all_orphaned: Vec<String> = Vec::new();
let mut to_scrub: HashSet<String> = first_channels;
let mut rounds_done: usize = 0;
loop {
for channel_id in &to_scrub {
let became_empty = {
if let Some(mut peers_ref) = self.channel_to_peers.get_mut(channel_id) {
peers_ref.remove(peer_id);
peers_ref.is_empty()
} else {
false
}
}; if became_empty
&& self
.channel_to_peers
.remove_if(channel_id, |_, v| v.is_empty())
.is_some()
{
all_orphaned.push(channel_id.clone());
}
}
rounds_done += 1;
if rounds_done >= MAX_DISCONNECT_ROUNDS {
break;
}
match self.peer_to_channel.remove(peer_id) {
Some((_, chs)) => to_scrub = chs,
None => break,
}
}
let still_present = self.peer_to_channel.contains_key(peer_id);
if !still_present {
self.peer_user_agents.write().await.remove(peer_id);
let _ = self.event_tx.send(P2PEvent::PeerDisconnected(*peer_id));
}
for channel_id in &all_orphaned {
match channel_id.parse::<SocketAddr>() {
Ok(addr) => self.dual_node.disconnect_peer_by_addr(&addr).await,
Err(e) => {
warn!(
peer = %peer_id,
channel = %channel_id,
error = %e,
"Failed to parse channel ID as SocketAddr — QUIC connection will not be closed",
);
}
}
self.active_connections.write().await.remove(channel_id);
self.peers.write().await.remove(channel_id);
}
if still_present {
warn!(
peer = %peer_id,
rounds = MAX_DISCONNECT_ROUNDS,
"disconnect_peer: peer kept being re-authenticated across drain rounds",
);
return Err(P2PError::Network(NetworkError::ProtocolError(
format!(
"disconnect_peer: peer {} remained mapped after {} drain rounds (concurrent re-authentication)",
peer_id, MAX_DISCONNECT_ROUNDS
)
.into(),
)));
}
info!("Disconnected from peer: {}", peer_id);
Ok(())
}
async fn disconnect_all_peers(&self) -> Result<()> {
let peer_ids: Vec<PeerId> = self
.peer_to_channel
.iter()
.map(|entry| *entry.key())
.collect();
let mut last_err: Option<P2PError> = None;
for peer_id in &peer_ids {
if let Err(e) = self.disconnect_peer(peer_id).await {
warn!(
peer = %peer_id,
error = %e,
"disconnect_all_peers: peer could not be fully drained, continuing",
);
last_err = Some(e);
}
}
match last_err {
Some(e) => Err(e),
None => Ok(()),
}
}
}
impl TransportHandle {
pub async fn send_message(
&self,
peer_id: &PeerId,
protocol: &str,
data: Vec<u8>,
) -> Result<()> {
let peer_hex = peer_id.to_hex();
let channels: Vec<String> = self
.peer_to_channel
.get(peer_id)
.map(|set| set.iter().cloned().collect())
.unwrap_or_default();
if channels.is_empty() {
return Err(P2PError::Network(NetworkError::PeerNotFound(
peer_hex.into(),
)));
}
let mut last_err = None;
for channel_id in &channels {
match self
.send_on_channel(channel_id, protocol, data.clone())
.await
{
Ok(()) => return Ok(()),
Err(e) => {
warn!(
peer = %peer_hex,
channel = %channel_id,
error = %e,
"Channel send failed, removing and trying next",
);
self.remove_channel(channel_id).await;
last_err = Some(e);
}
}
}
Err(last_err
.unwrap_or_else(|| P2PError::Network(NetworkError::PeerNotFound(peer_hex.into()))))
}
pub(crate) async fn send_on_channel(
&self,
channel_id: &str,
protocol: &str,
data: Vec<u8>,
) -> Result<()> {
debug!(
"Sending message to channel {} on protocol {}",
channel_id, protocol
);
{
let needs_insert = {
let peers = self.peers.read().await;
!peers.contains_key(channel_id)
};
if needs_insert {
let mut peers = self.peers.write().await;
peers.entry(channel_id.to_string()).or_insert_with(|| {
info!(
"send_on_channel: registering new channel {} on the fly",
channel_id
);
let addresses = channel_id
.parse::<std::net::SocketAddr>()
.map(|addr| vec![MultiAddr::quic(addr)])
.unwrap_or_default();
PeerInfo {
channel_id: channel_id.to_string(),
addresses,
status: ConnectionStatus::Connected,
last_seen: Instant::now(),
connected_at: Instant::now(),
protocols: Vec::new(),
heartbeat_count: 0,
}
});
}
}
if !self.is_connection_active(channel_id).await {
self.active_connections
.write()
.await
.insert(channel_id.to_string());
}
let raw_data_len = data.len();
let message_data = self.create_protocol_message(protocol, data)?;
info!(
"Sending {} bytes to channel {} on protocol {} (raw data: {} bytes)",
message_data.len(),
channel_id,
protocol,
raw_data_len
);
let addr: SocketAddr = channel_id.parse().map_err(|e: std::net::AddrParseError| {
P2PError::Network(NetworkError::PeerNotFound(
format!("Invalid channel ID address: {e}").into(),
))
})?;
let send_fut = self.dual_node.send_to_peer_optimized(&addr, &message_data);
let result = tokio::time::timeout(self.connection_timeout, send_fut)
.await
.map_err(|_| {
P2PError::Transport(crate::error::TransportError::StreamError(
"Timed out sending message".into(),
))
})?
.map_err(|e| {
P2PError::Transport(crate::error::TransportError::StreamError(
e.to_string().into(),
))
});
if result.is_ok() {
info!(
"Successfully sent {} bytes to channel {}",
message_data.len(),
channel_id
);
} else {
warn!("Failed to send message to channel {}", channel_id);
self.active_connections.write().await.remove(channel_id);
}
result
}
pub async fn channels_for_peer(&self, app_peer_id: &PeerId) -> Vec<String> {
self.peer_to_channel
.get(app_peer_id)
.map(|channels| channels.iter().cloned().collect())
.unwrap_or_default()
}
pub(crate) async fn peers_on_channel(&self, channel_id: &str) -> Vec<PeerId> {
self.channel_to_peers
.get(channel_id)
.map(|set| set.iter().cloned().collect())
.unwrap_or_default()
}
pub async fn is_known_app_peer_id(&self, peer_id: &PeerId) -> bool {
self.peer_to_channel.contains_key(peer_id)
}
pub async fn wait_for_peer_identity(
&self,
channel_id: &str,
timeout: Duration,
) -> Result<PeerId> {
let deadline = Instant::now() + timeout;
let poll_interval = Duration::from_millis(50);
loop {
let peers = self.peers_on_channel(channel_id).await;
if let Some(peer_id) = peers.into_iter().next() {
return Ok(peer_id);
}
if Instant::now() >= deadline {
return Err(P2PError::Timeout(timeout));
}
tokio::time::sleep(poll_interval).await;
}
}
pub async fn send_request(
&self,
peer_id: &PeerId,
protocol: &str,
data: Vec<u8>,
timeout: Duration,
) -> Result<PeerResponse> {
let timeout = timeout.min(MAX_REQUEST_TIMEOUT);
validate_protocol_name(protocol)?;
let message_id = uuid::Uuid::new_v4().to_string();
let (tx, rx) = tokio::sync::oneshot::channel();
let started_at = Instant::now();
{
let mut reqs = self.active_requests.write().await;
if reqs.len() >= MAX_ACTIVE_REQUESTS {
return Err(P2PError::Transport(
crate::error::TransportError::StreamError(
format!(
"Too many active requests ({MAX_ACTIVE_REQUESTS}); try again later"
)
.into(),
),
));
}
reqs.insert(
message_id.clone(),
PendingRequest {
response_tx: tx,
expected_peer: *peer_id,
},
);
}
let envelope = RequestResponseEnvelope {
message_id: message_id.clone(),
is_response: false,
payload: data,
};
let envelope_bytes = match postcard::to_allocvec(&envelope) {
Ok(bytes) => bytes,
Err(e) => {
self.active_requests.write().await.remove(&message_id);
return Err(P2PError::Serialization(
format!("Failed to serialize request envelope: {e}").into(),
));
}
};
let wire_protocol = format!("/rr/{}", protocol);
if let Err(e) = self
.send_message(peer_id, &wire_protocol, envelope_bytes)
.await
{
self.active_requests.write().await.remove(&message_id);
return Err(e);
}
let result = match tokio::time::timeout(timeout, rx).await {
Ok(Ok(response_bytes)) => {
let latency = started_at.elapsed();
Ok(PeerResponse {
peer_id: *peer_id,
data: response_bytes,
latency,
})
}
Ok(Err(_)) => Err(P2PError::Network(NetworkError::ConnectionClosed {
peer_id: peer_id.to_hex().into(),
})),
Err(_) => Err(P2PError::Transport(
crate::error::TransportError::StreamError(
format!(
"Request to {} on {} timed out after {:?}",
peer_id, protocol, timeout
)
.into(),
),
)),
};
self.active_requests.write().await.remove(&message_id);
result
}
pub async fn send_response(
&self,
peer_id: &PeerId,
protocol: &str,
message_id: &str,
data: Vec<u8>,
) -> Result<()> {
validate_protocol_name(protocol)?;
let envelope = RequestResponseEnvelope {
message_id: message_id.to_string(),
is_response: true,
payload: data,
};
let envelope_bytes = postcard::to_allocvec(&envelope).map_err(|e| {
P2PError::Serialization(format!("Failed to serialize response envelope: {e}").into())
})?;
let wire_protocol = format!("/rr/{}", protocol);
self.send_message(peer_id, &wire_protocol, envelope_bytes)
.await
}
pub fn parse_request_envelope(data: &[u8]) -> Option<(String, bool, Vec<u8>)> {
let envelope: RequestResponseEnvelope = postcard::from_bytes(data).ok()?;
Some((envelope.message_id, envelope.is_response, envelope.payload))
}
fn create_protocol_message(&self, protocol: &str, data: Vec<u8>) -> Result<Vec<u8>> {
let mut message = WireMessage {
protocol: protocol.to_string(),
data,
from: *self.node_identity.peer_id(),
timestamp: Self::current_timestamp_secs()?,
user_agent: self.user_agent.clone(),
public_key: Vec::new(),
signature: Vec::new(),
};
Self::sign_wire_message(&mut message, &self.node_identity)?;
Self::serialize_wire_message(&message)
}
fn create_identity_announce_bytes(
identity: &NodeIdentity,
user_agent: &str,
) -> Result<Vec<u8>> {
let mut message = WireMessage {
protocol: IDENTITY_ANNOUNCE_PROTOCOL.to_string(),
data: vec![],
from: *identity.peer_id(),
timestamp: Self::current_timestamp_secs()?,
user_agent: user_agent.to_owned(),
public_key: Vec::new(),
signature: Vec::new(),
};
Self::sign_wire_message(&mut message, identity)?;
Self::serialize_wire_message(&message)
}
fn current_timestamp_secs() -> Result<u64> {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.map_err(|e| {
P2PError::Network(NetworkError::ProtocolError(
format!("System time error: {e}").into(),
))
})
}
fn sign_wire_message(message: &mut WireMessage, identity: &NodeIdentity) -> Result<()> {
let signable = Self::compute_signable_bytes(
&message.protocol,
&message.data,
&message.from,
message.timestamp,
&message.user_agent,
)?;
let sig = identity.sign(&signable).map_err(|e| {
P2PError::Network(NetworkError::ProtocolError(
format!("Failed to sign message: {e}").into(),
))
})?;
message.public_key = identity.public_key().as_bytes().to_vec();
message.signature = sig.as_bytes().to_vec();
Ok(())
}
fn serialize_wire_message(message: &WireMessage) -> Result<Vec<u8>> {
postcard::to_stdvec(message).map_err(|e| {
P2PError::Transport(crate::error::TransportError::StreamError(
format!("Failed to serialize wire message: {e}").into(),
))
})
}
fn compute_signable_bytes(
protocol: &str,
data: &[u8],
from: &PeerId,
timestamp: u64,
user_agent: &str,
) -> Result<Vec<u8>> {
postcard::to_stdvec(&(protocol, data, from, timestamp, user_agent)).map_err(|e| {
P2PError::Network(NetworkError::ProtocolError(
format!("Failed to serialize signable bytes: {e}").into(),
))
})
}
}
impl TransportHandle {
pub async fn subscribe(&self, topic: &str) -> Result<()> {
info!("Subscribed to topic: {}", topic);
Ok(())
}
pub async fn publish(&self, topic: &str, data: &[u8]) -> Result<()> {
info!(
"Publishing message to topic: {} ({} bytes)",
topic,
data.len()
);
let mut peer_channel_groups: Vec<Vec<String>> = Vec::new();
let mut mapped_channels: HashSet<String> = HashSet::new();
for entry in self.peer_to_channel.iter() {
let chs: Vec<String> = entry.value().iter().cloned().collect();
mapped_channels.extend(chs.iter().cloned());
if !chs.is_empty() {
peer_channel_groups.push(chs);
}
}
{
let peers_guard = self.peers.read().await;
for channel_id in peers_guard.keys() {
if !mapped_channels.contains(channel_id) {
peer_channel_groups.push(vec![channel_id.clone()]);
}
}
}
if peer_channel_groups.is_empty() {
debug!("No peers connected, message will only be sent to local subscribers");
} else {
let mut send_count = 0;
let total = peer_channel_groups.len();
for channels in &peer_channel_groups {
let mut sent = false;
for channel_id in channels {
match self.send_on_channel(channel_id, topic, data.to_vec()).await {
Ok(()) => {
send_count += 1;
debug!("Published message via channel: {}", channel_id);
sent = true;
break;
}
Err(e) => {
warn!(
channel = %channel_id,
error = %e,
"Publish channel failed, removing and trying next",
);
self.remove_channel(channel_id).await;
}
}
}
if !sent {
warn!("All channels exhausted for one peer during publish");
}
}
info!(
"Published message to {}/{} connected peers",
send_count, total
);
}
self.send_event(P2PEvent::Message {
topic: topic.to_string(),
source: Some(*self.node_identity.peer_id()),
data: data.to_vec(),
});
Ok(())
}
}
impl TransportHandle {
pub fn subscribe_events(&self) -> broadcast::Receiver<P2PEvent> {
self.event_tx.subscribe()
}
pub(crate) fn send_event(&self, event: P2PEvent) {
if let Err(e) = self.event_tx.send(event) {
tracing::trace!("Event broadcast has no receivers: {e}");
}
}
}
impl TransportHandle {
pub async fn start_network_listeners(&self) -> Result<()> {
info!("Starting dual-stack listeners (saorsa-transport)...");
let socket_addrs = self.dual_node.local_addrs().await.map_err(|e| {
P2PError::Transport(crate::error::TransportError::SetupFailed(
format!("Failed to get local addresses: {}", e).into(),
))
})?;
let addrs: Vec<SocketAddr> = socket_addrs.clone();
{
let mut la = self.listen_addrs.write().await;
*la = socket_addrs.into_iter().map(MultiAddr::quic).collect();
}
let peers = self.peers.clone();
let active_connections = self.active_connections.clone();
let rate_limiter = self.rate_limiter.clone();
let dual = self.dual_node.clone();
let handle = tokio::spawn(async move {
loop {
let Some(remote_sock) = dual.accept_any().await else {
break;
};
if let Err(e) = rate_limiter.check_ip(&remote_sock.ip()) {
warn!(
"Rate-limited incoming connection from {}: {}",
remote_sock, e
);
continue;
}
let channel_id = remote_sock.to_string();
let remote_addr = MultiAddr::quic(remote_sock);
register_new_channel(&peers, &channel_id, &remote_addr).await;
active_connections.write().await.insert(channel_id);
}
});
*self.listener_handle.write().await = Some(handle);
self.start_message_receiving_system().await?;
info!("Dual-stack listeners active on: {:?}", addrs);
Ok(())
}
async fn start_message_receiving_system(&self) -> Result<()> {
info!(
"Starting message receiving system ({} dispatch shards)",
MESSAGE_DISPATCH_SHARDS
);
let (upstream_tx, mut upstream_rx) =
tokio::sync::mpsc::channel(MESSAGE_RECV_CHANNEL_CAPACITY);
let mut handles = self
.dual_node
.spawn_recv_tasks(upstream_tx.clone(), self.shutdown.clone());
drop(upstream_tx);
let per_shard_capacity = (MESSAGE_RECV_CHANNEL_CAPACITY / MESSAGE_DISPATCH_SHARDS)
.max(MIN_SHARD_CHANNEL_CAPACITY);
let mut shard_txs: Vec<tokio::sync::mpsc::Sender<(SocketAddr, Vec<u8>)>> =
Vec::with_capacity(MESSAGE_DISPATCH_SHARDS);
for shard_idx in 0..MESSAGE_DISPATCH_SHARDS {
let (shard_tx, shard_rx) = tokio::sync::mpsc::channel(per_shard_capacity);
shard_txs.push(shard_tx);
let event_tx = self.event_tx.clone();
let active_requests = Arc::clone(&self.active_requests);
let peer_to_channel = Arc::clone(&self.peer_to_channel);
let channel_to_peers = Arc::clone(&self.channel_to_peers);
let peer_user_agents = Arc::clone(&self.peer_user_agents);
let self_peer_id = *self.node_identity.peer_id();
let dual_node_for_peer_reg = Arc::clone(&self.dual_node);
handles.push(tokio::spawn(async move {
Self::run_shard_consumer(
shard_idx,
shard_rx,
event_tx,
active_requests,
peer_to_channel,
channel_to_peers,
peer_user_agents,
self_peer_id,
dual_node_for_peer_reg,
)
.await;
}));
}
let drop_counter = Arc::new(AtomicU64::new(0));
handles.push(tokio::spawn(async move {
info!(
"Message dispatcher loop started (sharded across {} consumers)",
MESSAGE_DISPATCH_SHARDS
);
while let Some((from_addr, bytes)) = upstream_rx.recv().await {
let shard_idx = shard_index_for_addr(&from_addr);
match shard_txs[shard_idx].try_send((from_addr, bytes)) {
Ok(()) => {}
Err(tokio::sync::mpsc::error::TrySendError::Full(_dropped)) => {
let prev = drop_counter.fetch_add(1, Ordering::Relaxed);
if prev.is_multiple_of(SHARD_DROP_LOG_INTERVAL) {
warn!(
shard = shard_idx,
from = %from_addr,
total_drops = prev + 1,
"Dispatcher dropped inbound message: shard channel full"
);
}
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_dropped)) => {
let prev = drop_counter.fetch_add(1, Ordering::Relaxed);
if prev.is_multiple_of(SHARD_DROP_LOG_INTERVAL) {
warn!(
shard = shard_idx,
from = %from_addr,
total_drops = prev + 1,
"Dispatcher dropped inbound message: shard consumer closed"
);
}
}
}
}
info!("Message dispatcher loop ended — upstream channel closed");
}));
*self.recv_handles.write().await = handles;
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn run_shard_consumer(
shard_idx: usize,
mut shard_rx: tokio::sync::mpsc::Receiver<(SocketAddr, Vec<u8>)>,
event_tx: broadcast::Sender<P2PEvent>,
active_requests: Arc<RwLock<HashMap<String, PendingRequest>>>,
peer_to_channel: Arc<DashMap<PeerId, HashSet<String>>>,
channel_to_peers: Arc<DashMap<String, HashSet<PeerId>>>,
peer_user_agents: Arc<RwLock<HashMap<PeerId, String>>>,
self_peer_id: PeerId,
dual_node_for_peer_reg: Arc<DualStackNetworkNode>,
) {
info!("Message dispatch shard {shard_idx} started");
while let Some((from_addr, bytes)) = shard_rx.recv().await {
let channel_id = from_addr.to_string();
trace!(
shard = shard_idx,
"Received {} bytes from channel {}",
bytes.len(),
channel_id
);
match parse_protocol_message(&bytes, &channel_id) {
Some(ParsedMessage {
event,
authenticated_node_id,
user_agent: peer_user_agent,
}) => {
if let Some(ref app_id) = authenticated_node_id
&& *app_id != self_peer_id
{
let is_new_peer = {
let is_new = match peer_to_channel.entry(*app_id) {
Entry::Vacant(vacant) => {
let mut set = HashSet::new();
set.insert(channel_id.clone());
vacant.insert(set);
true
}
Entry::Occupied(mut occupied) => {
occupied.get_mut().insert(channel_id.clone());
false
}
}; channel_to_peers
.entry(channel_id.clone())
.or_default()
.insert(*app_id);
is_new
};
dual_node_for_peer_reg
.register_connection_peer_id(from_addr, *app_id.to_bytes())
.await;
if is_new_peer {
peer_user_agents
.write()
.await
.insert(*app_id, peer_user_agent.clone());
broadcast_event(
&event_tx,
P2PEvent::PeerConnected(*app_id, peer_user_agent.clone()),
);
}
}
if let P2PEvent::Message { ref topic, .. } = event
&& topic == IDENTITY_ANNOUNCE_PROTOCOL
{
continue;
}
if let P2PEvent::Message {
ref topic,
ref data,
..
} = event
&& topic.starts_with("/rr/")
&& let Ok(envelope) = postcard::from_bytes::<RequestResponseEnvelope>(data)
&& envelope.is_response
{
let mut reqs = active_requests.write().await;
let expected_peer = match reqs.get(&envelope.message_id) {
Some(pending) => pending.expected_peer,
None => {
trace!(
message_id = %envelope.message_id,
"Unmatched /rr/ response (likely timed out) — suppressing"
);
continue;
}
};
if authenticated_node_id.as_ref() != Some(&expected_peer) {
warn!(
message_id = %envelope.message_id,
expected = %expected_peer,
actual_channel = %channel_id,
authenticated = ?authenticated_node_id,
"Response origin mismatch — ignoring"
);
continue;
}
if let Some(pending) = reqs.remove(&envelope.message_id) {
if pending.response_tx.send(envelope.payload).is_err() {
warn!(
message_id = %envelope.message_id,
"Response receiver dropped before delivery"
);
}
continue;
}
trace!(
message_id = %envelope.message_id,
"Unmatched /rr/ response (likely timed out) — suppressing"
);
continue;
}
broadcast_event(&event_tx, event);
}
None => {
warn!(
shard = shard_idx,
"Failed to parse protocol message ({} bytes)",
bytes.len()
);
}
}
}
info!("Message dispatch shard {shard_idx} ended — channel closed");
}
}
const MESSAGE_DISPATCH_SHARDS: usize = 8;
const MIN_SHARD_CHANNEL_CAPACITY: usize = 128;
const SHARD_DROP_LOG_INTERVAL: u64 = 64;
fn shard_index_for_addr(addr: &SocketAddr) -> usize {
let mut hasher = DefaultHasher::new();
addr.ip().hash(&mut hasher);
(hasher.finish() as usize) % MESSAGE_DISPATCH_SHARDS
}
impl TransportHandle {
pub async fn stop(&self) -> Result<()> {
info!("Stopping transport...");
self.shutdown.cancel();
self.dual_node.shutdown_endpoints().await;
let handles: Vec<_> = self.recv_handles.write().await.drain(..).collect();
Self::join_task_handles(handles, "recv").await;
Self::join_task_slot(&self.listener_handle, "listener").await;
Self::join_task_slot(&self.connection_monitor_handle, "connection monitor").await;
self.disconnect_all_peers().await?;
info!("Transport stopped");
Ok(())
}
async fn join_task_slot(handle_slot: &RwLock<Option<JoinHandle<()>>>, task_name: &str) {
let handle = handle_slot.write().await.take();
if let Some(handle) = handle {
Self::join_task_handle(handle, task_name).await;
}
}
async fn join_task_handles(handles: Vec<JoinHandle<()>>, task_name: &str) {
for handle in handles {
Self::join_task_handle(handle, task_name).await;
}
}
async fn join_task_handle(handle: JoinHandle<()>, task_name: &str) {
match handle.await {
Ok(()) => {}
Err(e) if e.is_cancelled() => {
tracing::debug!("{task_name} task was cancelled during shutdown");
}
Err(e) if e.is_panic() => {
tracing::error!("{task_name} task panicked during shutdown: {:?}", e);
}
Err(e) => {
tracing::warn!("{task_name} task join error during shutdown: {:?}", e);
}
}
}
}
impl TransportHandle {
#[allow(clippy::too_many_arguments)]
async fn connection_lifecycle_monitor_with_rx(
dual_node: Arc<DualStackNetworkNode>,
mut event_rx: broadcast::Receiver<
crate::transport::saorsa_transport_adapter::ConnectionEvent,
>,
active_connections: Arc<RwLock<HashSet<String>>>,
peers: Arc<RwLock<HashMap<String, PeerInfo>>>,
event_tx: broadcast::Sender<P2PEvent>,
_geo_provider: Arc<BgpGeoProvider>,
shutdown: CancellationToken,
peer_to_channel: Arc<DashMap<PeerId, HashSet<String>>>,
channel_to_peers: Arc<DashMap<String, HashSet<PeerId>>>,
peer_user_agents: Arc<RwLock<HashMap<PeerId, String>>>,
node_identity: Arc<NodeIdentity>,
user_agent: String,
) {
info!("Connection lifecycle monitor started (pre-subscribed receiver)");
loop {
tokio::select! {
() = shutdown.cancelled() => {
info!("Connection lifecycle monitor shutting down");
break;
}
recv = event_rx.recv() => {
match recv {
Ok(event) => match event {
ConnectionEvent::Established {
remote_address, ..
} => {
let channel_id = remote_address.to_string();
debug!(
"Connection established: channel={}, addr={}",
channel_id, remote_address
);
active_connections.write().await.insert(channel_id.clone());
let mut peers_lock = peers.write().await;
if let Some(peer_info) = peers_lock.get_mut(&channel_id) {
peer_info.status = ConnectionStatus::Connected;
peer_info.connected_at = Instant::now();
} else {
debug!("Registering new incoming channel: {}", channel_id);
peers_lock.insert(
channel_id.clone(),
PeerInfo {
channel_id: channel_id.clone(),
addresses: vec![MultiAddr::quic(remote_address)],
status: ConnectionStatus::Connected,
last_seen: Instant::now(),
connected_at: Instant::now(),
protocols: Vec::new(),
heartbeat_count: 0,
},
);
}
match Self::create_identity_announce_bytes(&node_identity, &user_agent) {
Ok(announce_bytes) => {
if let Err(e) = dual_node
.send_to_peer_optimized(&remote_address, &announce_bytes)
.await
{
warn!("Failed to send identity announce to {channel_id}: {e}");
}
}
Err(e) => {
warn!("Failed to create identity announce: {e}");
}
}
}
ConnectionEvent::Lost { remote_address, reason }
| ConnectionEvent::Failed { remote_address, reason } => {
let channel_id = remote_address.to_string();
debug!("Connection lost/failed: channel={channel_id}, reason={reason}");
active_connections.write().await.remove(&channel_id);
peers.write().await.remove(&channel_id);
Self::remove_channel_mappings_static(
&channel_id,
&peer_to_channel,
&channel_to_peers,
&peer_user_agents,
&event_tx,
).await;
}
ConnectionEvent::PeerAddressUpdated { .. } => {
}
},
Err(broadcast::error::RecvError::Lagged(skipped)) => {
warn!(
"Connection event receiver lagged, skipped {} events",
skipped
);
}
Err(broadcast::error::RecvError::Closed) => {
info!("Connection event channel closed, stopping lifecycle monitor");
break;
}
}
}
}
}
}
}
fn validate_protocol_name(protocol: &str) -> Result<()> {
if protocol.is_empty() || protocol.contains(&['/', '\\', '\0'][..]) {
return Err(P2PError::Transport(
crate::error::TransportError::StreamError(
format!("Invalid protocol name: {:?}", protocol).into(),
),
));
}
Ok(())
}
#[async_trait::async_trait]
impl NetworkSender for TransportHandle {
async fn send_message(&self, peer_id: &PeerId, protocol: &str, data: Vec<u8>) -> Result<()> {
TransportHandle::send_message(self, peer_id, protocol, data).await
}
fn local_peer_id(&self) -> PeerId {
self.peer_id()
}
}
#[cfg(test)]
impl TransportHandle {
pub(crate) async fn inject_peer(&self, peer_id: String, info: PeerInfo) {
self.peers.write().await.insert(peer_id, info);
}
pub(crate) async fn inject_active_connection(&self, channel_id: String) {
self.active_connections.write().await.insert(channel_id);
}
pub(crate) async fn inject_peer_to_channel(&self, peer_id: PeerId, channel_id: String) {
{
let mut channels = self.peer_to_channel.entry(peer_id).or_default();
channels.insert(channel_id.clone());
} self.channel_to_peers
.entry(channel_id)
.or_default()
.insert(peer_id);
}
}
#[cfg(test)]
mod peer_to_channel_concurrency_tests {
use super::*;
use std::time::Duration;
fn make_peer(b: u8) -> PeerId {
PeerId::from_bytes([b; 32])
}
async fn hot_path_insert(
peer_to_channel: &DashMap<PeerId, HashSet<String>>,
channel_to_peers: &DashMap<String, HashSet<PeerId>>,
app_id: &PeerId,
channel_id: &str,
) {
{
let _is_new = match peer_to_channel.entry(*app_id) {
Entry::Vacant(vacant) => {
let mut set = HashSet::new();
set.insert(channel_id.to_string());
vacant.insert(set);
true
}
Entry::Occupied(mut occupied) => {
occupied.get_mut().insert(channel_id.to_string());
false
}
};
channel_to_peers
.entry(channel_id.to_string())
.or_default()
.insert(*app_id);
} tokio::task::yield_now().await;
}
async fn remove_channel_mapping(
peer_to_channel: &DashMap<PeerId, HashSet<String>>,
channel_to_peers: &DashMap<String, HashSet<PeerId>>,
peer_user_agents: &RwLock<HashMap<PeerId, String>>,
channel_id: &str,
) {
let app_peers = match channel_to_peers.remove(channel_id) {
Some((_, peers)) => peers,
None => return,
};
let mut fully_disconnected: Vec<PeerId> = Vec::new();
for app_peer in &app_peers {
let became_empty = {
if let Some(mut channels_ref) = peer_to_channel.get_mut(app_peer) {
channels_ref.remove(channel_id);
channels_ref.is_empty()
} else {
false
}
};
if became_empty
&& peer_to_channel
.remove_if(app_peer, |_, v| v.is_empty())
.is_some()
{
fully_disconnected.push(*app_peer);
}
}
if !fully_disconnected.is_empty() {
let mut pua = peer_user_agents.write().await;
for app_peer in fully_disconnected {
pua.remove(&app_peer);
}
}
}
async fn disconnect_peer_pattern(
peer_to_channel: &DashMap<PeerId, HashSet<String>>,
channel_to_peers: &DashMap<String, HashSet<PeerId>>,
peer_id: &PeerId,
) {
let channel_ids = match peer_to_channel.remove(peer_id) {
Some((_, chs)) => chs,
None => return,
};
for channel_id in &channel_ids {
let became_empty = {
if let Some(mut peers_ref) = channel_to_peers.get_mut(channel_id) {
peers_ref.remove(peer_id);
peers_ref.is_empty()
} else {
false
}
};
if became_empty {
channel_to_peers.remove_if(channel_id, |_, v| v.is_empty());
}
}
}
#[test]
fn concurrent_peer_channel_stress_test() {
const NUM_TASKS: usize = 100;
const ITERATIONS_PER_TASK: usize = 50;
const PEER_POOL_SIZE: u8 = 20;
const CHANNEL_POOL_SIZE: usize = 10;
const WATCHDOG: Duration = Duration::from_secs(10);
let (done_tx, done_rx) = std::sync::mpsc::channel::<()>();
let _runtime_thread = std::thread::Builder::new()
.name("stress-test-runtime".into())
.spawn(move || {
let rt = tokio::runtime::Builder::new_multi_thread()
.worker_threads(8)
.enable_all()
.build()
.expect("build stress test runtime");
rt.block_on(async move {
let peer_to_channel: Arc<DashMap<PeerId, HashSet<String>>> =
Arc::new(DashMap::new());
let channel_to_peers: Arc<DashMap<String, HashSet<PeerId>>> =
Arc::new(DashMap::new());
let peer_user_agents: Arc<RwLock<HashMap<PeerId, String>>> =
Arc::new(RwLock::new(HashMap::new()));
let mut handles = Vec::new();
for task_idx in 0..NUM_TASKS {
let p2c = Arc::clone(&peer_to_channel);
let c2p = Arc::clone(&channel_to_peers);
let pua = Arc::clone(&peer_user_agents);
handles.push(tokio::spawn(async move {
for i in 0..ITERATIONS_PER_TASK {
let peer =
make_peer(((task_idx * 7 + i) % PEER_POOL_SIZE as usize) as u8);
let channel =
format!("127.0.0.1:{}", 10000 + (i % CHANNEL_POOL_SIZE));
match i % 6 {
0 => {
hot_path_insert(&p2c, &c2p, &peer, &channel).await;
pua.write()
.await
.entry(peer)
.or_insert_with(|| format!("agent-{task_idx}"));
}
1 => {
let _ = p2c.contains_key(&peer);
let _ = p2c.len();
let _ = p2c.get(&peer).map(|r| r.len());
}
2 => {
let count = p2c.iter().count();
assert!(count <= PEER_POOL_SIZE as usize);
for entry in p2c.iter() {
let _ = entry.value().len();
}
}
3 => {
remove_channel_mapping(&p2c, &c2p, &pua, &channel).await;
}
4 => {
disconnect_peer_pattern(&p2c, &c2p, &peer).await;
}
5 => {
let _peers: Vec<PeerId> =
p2c.iter().map(|e| *e.key()).collect();
}
_ => unreachable!(),
}
}
}));
}
for h in handles {
h.await.expect("stress task should not panic");
}
assert!(
peer_to_channel.len() <= PEER_POOL_SIZE as usize,
"peer count exceeds pool size: {}",
peer_to_channel.len()
);
});
let _ = done_tx.send(());
})
.expect("spawn stress test runtime thread");
if done_rx.recv_timeout(WATCHDOG).is_err() {
panic!(
"stress test deadlocked — tokio runtime wedged for {WATCHDOG:?}, \
likely a DashMap guard held across .await"
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remove_channel_mapping_does_not_hold_refmut_across_await() {
let peer_to_channel: Arc<DashMap<PeerId, HashSet<String>>> = Arc::new(DashMap::new());
let channel_to_peers: Arc<DashMap<String, HashSet<PeerId>>> = Arc::new(DashMap::new());
let peer_user_agents: Arc<RwLock<HashMap<PeerId, String>>> =
Arc::new(RwLock::new(HashMap::new()));
let peer = make_peer(1);
let channel_id = "127.0.0.1:10000".to_string();
peer_to_channel
.entry(peer)
.or_default()
.insert(channel_id.clone());
channel_to_peers
.entry(channel_id.clone())
.or_default()
.insert(peer);
peer_user_agents
.write()
.await
.insert(peer, "agent".to_string());
tokio::time::timeout(Duration::from_secs(2), async {
remove_channel_mapping(
&peer_to_channel,
&channel_to_peers,
&peer_user_agents,
&channel_id,
)
.await
})
.await
.expect("remove_channel_mapping timed out — RefMut likely held across .await");
assert!(!peer_to_channel.contains_key(&peer));
assert!(peer_user_agents.read().await.get(&peer).is_none());
}
}