use crate::MultiAddr;
use crate::PeerId;
use crate::bgp_geo_provider::BgpGeoProvider;
use crate::error::{NetworkError, P2PError, P2pResult as Result};
use crate::identity::node_identity::NodeIdentity;
use crate::network::{
ConnectionStatus, MAX_ACTIVE_REQUESTS, MAX_REQUEST_TIMEOUT, MESSAGE_RECV_CHANNEL_CAPACITY,
NetworkSender, P2PEvent, ParsedMessage, PeerInfo, PeerResponse, PendingRequest,
RequestResponseEnvelope, WireMessage, broadcast_event, normalize_wildcard_to_loopback,
parse_protocol_message, register_new_channel,
};
use crate::transport::saorsa_transport_adapter::{ConnectionEvent, DualStackNetworkNode};
use crate::validation::{RateLimitConfig, RateLimiter};
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{RwLock, broadcast};
use tokio::task::JoinHandle;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, trace, warn};
const TEST_EVENT_CHANNEL_CAPACITY: usize = 16;
const TEST_MAX_REQUESTS: u32 = 100;
const TEST_BURST_SIZE: u32 = 100;
const TEST_RATE_LIMIT_WINDOW_SECS: u64 = 1;
const TEST_CONNECTION_TIMEOUT_SECS: u64 = 30;
const IDENTITY_ANNOUNCE_PROTOCOL: &str = "/saorsa/identity/1.0";
pub struct TransportConfig {
pub listen_addrs: Vec<MultiAddr>,
pub connection_timeout: Duration,
pub max_connections: usize,
pub event_channel_capacity: usize,
pub max_message_size: Option<usize>,
pub node_identity: Arc<NodeIdentity>,
pub user_agent: String,
pub allow_loopback: bool,
}
impl TransportConfig {
pub fn from_node_config(
config: &crate::network::NodeConfig,
event_channel_capacity: usize,
node_identity: Arc<NodeIdentity>,
) -> Self {
Self {
listen_addrs: config.listen_addrs(),
connection_timeout: config.connection_timeout,
max_connections: config.max_connections,
event_channel_capacity,
max_message_size: config.max_message_size,
node_identity,
user_agent: config.user_agent(),
allow_loopback: config.allow_loopback,
}
}
}
pub struct TransportHandle {
dual_node: Arc<DualStackNetworkNode>,
peers: Arc<RwLock<HashMap<String, PeerInfo>>>,
active_connections: Arc<RwLock<HashSet<String>>>,
event_tx: broadcast::Sender<P2PEvent>,
listen_addrs: RwLock<Vec<MultiAddr>>,
rate_limiter: Arc<RateLimiter>,
active_requests: Arc<RwLock<HashMap<String, PendingRequest>>>,
#[allow(dead_code)]
geo_provider: Arc<BgpGeoProvider>,
shutdown: CancellationToken,
connection_timeout: Duration,
connection_monitor_handle: Arc<RwLock<Option<JoinHandle<()>>>>,
recv_handles: Arc<RwLock<Vec<JoinHandle<()>>>>,
listener_handle: Arc<RwLock<Option<JoinHandle<()>>>>,
node_identity: Arc<NodeIdentity>,
user_agent: String,
peer_to_channel: Arc<RwLock<HashMap<PeerId, HashSet<String>>>>,
channel_to_peers: Arc<RwLock<HashMap<String, HashSet<PeerId>>>>,
peer_user_agents: Arc<RwLock<HashMap<PeerId, String>>>,
}
impl TransportHandle {
pub async fn new(config: TransportConfig) -> Result<Self> {
let (event_tx, _) = broadcast::channel(config.event_channel_capacity);
let mut v4_opt: Option<SocketAddr> = None;
let mut v6_opt: Option<SocketAddr> = None;
for addr in &config.listen_addrs {
if let Some(sa) = addr.dialable_socket_addr() {
match sa.ip() {
std::net::IpAddr::V4(_) if v4_opt.is_none() => v4_opt = Some(sa),
std::net::IpAddr::V6(_) if v6_opt.is_none() => v6_opt = Some(sa),
_ => {} }
}
}
let dual_node = Arc::new(
DualStackNetworkNode::new_with_options(
v6_opt,
v4_opt,
config.max_connections,
config.max_message_size,
config.allow_loopback,
)
.await
.map_err(|e| {
P2PError::Transport(crate::error::TransportError::SetupFailed(
format!("Failed to create dual-stack network nodes: {}", e).into(),
))
})?,
);
let rate_limiter = Arc::new(RateLimiter::new(RateLimitConfig::default()));
let active_connections = Arc::new(RwLock::new(HashSet::new()));
let geo_provider = Arc::new(BgpGeoProvider::new());
let peers = Arc::new(RwLock::new(HashMap::new()));
let shutdown = CancellationToken::new();
let connection_event_rx = dual_node.subscribe_connection_events();
let peer_to_channel = Arc::new(RwLock::new(HashMap::new()));
let channel_to_peers = Arc::new(RwLock::new(HashMap::new()));
let peer_user_agents: Arc<RwLock<HashMap<PeerId, String>>> =
Arc::new(RwLock::new(HashMap::new()));
let connection_monitor_handle = {
let active_conns = Arc::clone(&active_connections);
let peers_map = Arc::clone(&peers);
let event_tx_clone = event_tx.clone();
let dual_node_clone = Arc::clone(&dual_node);
let geo_provider_clone = Arc::clone(&geo_provider);
let shutdown_token = shutdown.clone();
let p2c = Arc::clone(&peer_to_channel);
let c2p = Arc::clone(&channel_to_peers);
let pua = Arc::clone(&peer_user_agents);
let identity_clone = config.node_identity.clone();
let user_agent_clone = config.user_agent.clone();
let handle = tokio::spawn(async move {
Self::connection_lifecycle_monitor_with_rx(
dual_node_clone,
connection_event_rx,
active_conns,
peers_map,
event_tx_clone,
geo_provider_clone,
shutdown_token,
p2c,
c2p,
pua,
identity_clone,
user_agent_clone,
)
.await;
});
Arc::new(RwLock::new(Some(handle)))
};
Ok(Self {
dual_node,
peers,
active_connections,
event_tx,
listen_addrs: RwLock::new(Vec::new()),
rate_limiter,
active_requests: Arc::new(RwLock::new(HashMap::new())),
geo_provider,
shutdown,
connection_timeout: config.connection_timeout,
connection_monitor_handle,
recv_handles: Arc::new(RwLock::new(Vec::new())),
listener_handle: Arc::new(RwLock::new(None)),
node_identity: config.node_identity,
user_agent: config.user_agent,
peer_to_channel,
channel_to_peers,
peer_user_agents,
})
}
pub fn new_for_tests() -> Result<Self> {
let identity = Arc::new(NodeIdentity::generate().map_err(|e| {
P2PError::Network(NetworkError::BindError(
format!("Failed to generate test node identity: {}", e).into(),
))
})?);
let (event_tx, _) = broadcast::channel(TEST_EVENT_CHANNEL_CAPACITY);
let dual_node = {
let v6: Option<SocketAddr> = "[::1]:0"
.parse()
.ok()
.or(Some(SocketAddr::from(([0, 0, 0, 0], 0))));
let v4: Option<SocketAddr> = "127.0.0.1:0".parse().ok();
let handle = tokio::runtime::Handle::current();
let dual_attempt = handle.block_on(DualStackNetworkNode::new(v6, v4));
let dual = match dual_attempt {
Ok(d) => d,
Err(_e1) => {
let fallback = handle
.block_on(DualStackNetworkNode::new(None, "127.0.0.1:0".parse().ok()));
match fallback {
Ok(d) => d,
Err(e2) => {
return Err(P2PError::Network(NetworkError::BindError(
format!("Failed to create dual-stack network node: {}", e2).into(),
)));
}
}
}
};
Arc::new(dual)
};
Ok(Self {
dual_node,
peers: Arc::new(RwLock::new(HashMap::new())),
active_connections: Arc::new(RwLock::new(HashSet::new())),
event_tx,
listen_addrs: RwLock::new(Vec::new()),
rate_limiter: Arc::new(RateLimiter::new(RateLimitConfig {
max_requests: TEST_MAX_REQUESTS,
burst_size: TEST_BURST_SIZE,
window: std::time::Duration::from_secs(TEST_RATE_LIMIT_WINDOW_SECS),
..Default::default()
})),
active_requests: Arc::new(RwLock::new(HashMap::new())),
geo_provider: Arc::new(BgpGeoProvider::new()),
shutdown: CancellationToken::new(),
connection_timeout: Duration::from_secs(TEST_CONNECTION_TIMEOUT_SECS),
connection_monitor_handle: Arc::new(RwLock::new(None)),
recv_handles: Arc::new(RwLock::new(Vec::new())),
listener_handle: Arc::new(RwLock::new(None)),
node_identity: identity,
user_agent: crate::network::user_agent_for_mode(crate::network::NodeMode::Node),
peer_to_channel: Arc::new(RwLock::new(HashMap::new())),
channel_to_peers: Arc::new(RwLock::new(HashMap::new())),
peer_user_agents: Arc::new(RwLock::new(HashMap::new())),
})
}
}
impl TransportHandle {
pub fn peer_id(&self) -> PeerId {
*self.node_identity.peer_id()
}
pub fn node_identity(&self) -> &Arc<NodeIdentity> {
&self.node_identity
}
pub fn local_addr(&self) -> Option<MultiAddr> {
self.listen_addrs
.try_read()
.ok()
.and_then(|addrs| addrs.first().cloned())
}
pub async fn listen_addrs(&self) -> Vec<MultiAddr> {
self.listen_addrs.read().await.clone()
}
pub fn connection_timeout(&self) -> Duration {
self.connection_timeout
}
}
impl TransportHandle {
pub async fn connected_peers(&self) -> Vec<PeerId> {
self.peer_to_channel.read().await.keys().cloned().collect()
}
pub async fn peer_count(&self) -> usize {
self.peer_to_channel.read().await.len()
}
pub async fn peer_user_agent(&self, peer_id: &PeerId) -> Option<String> {
self.peer_user_agents.read().await.get(peer_id).cloned()
}
#[allow(dead_code)]
pub(crate) async fn active_channels(&self) -> Vec<String> {
self.active_connections
.read()
.await
.iter()
.cloned()
.collect()
}
pub async fn peer_info(&self, peer_id: &PeerId) -> Option<PeerInfo> {
let p2c = self.peer_to_channel.read().await;
let channel = p2c.get(peer_id).and_then(|chs| chs.iter().next())?;
let peers = self.peers.read().await;
peers.get(channel).cloned()
}
#[allow(dead_code)]
pub(crate) async fn peer_info_by_channel(&self, channel_id: &str) -> Option<PeerInfo> {
self.peers.read().await.get(channel_id).cloned()
}
#[allow(dead_code)]
pub(crate) async fn get_channel_id_by_address(&self, addr: &MultiAddr) -> Option<String> {
let target = addr.socket_addr()?;
let peers = self.peers.read().await;
for (channel_id, peer_info) in peers.iter() {
for peer_addr in &peer_info.addresses {
if peer_addr.socket_addr() == Some(target) {
return Some(channel_id.clone());
}
}
}
None
}
#[allow(dead_code)]
pub(crate) async fn list_active_connections(&self) -> Vec<(String, Vec<MultiAddr>)> {
let active = self.active_connections.read().await;
let peers = self.peers.read().await;
active
.iter()
.map(|peer_id| {
let addresses = peers
.get(peer_id)
.map(|info| info.addresses.clone())
.unwrap_or_default();
(peer_id.clone(), addresses)
})
.collect()
}
pub(crate) async fn remove_channel(&self, channel_id: &str) -> bool {
self.active_connections.write().await.remove(channel_id);
self.remove_channel_mappings(channel_id).await;
self.peers.write().await.remove(channel_id).is_some()
}
pub(crate) async fn disconnect_channel(&self, channel_id: &str) {
match channel_id.parse::<SocketAddr>() {
Ok(addr) => self.dual_node.disconnect_peer_by_addr(&addr).await,
Err(e) => {
warn!(
channel = %channel_id,
error = %e,
"Failed to parse channel ID as SocketAddr — QUIC connection will not be closed",
);
}
}
self.active_connections.write().await.remove(channel_id);
self.remove_channel_mappings(channel_id).await;
self.peers.write().await.remove(channel_id);
}
pub async fn is_peer_connected(&self, peer_id: &PeerId) -> bool {
self.peer_to_channel.read().await.contains_key(peer_id)
}
pub(crate) async fn is_connection_active(&self, channel_id: &str) -> bool {
self.active_connections.read().await.contains(channel_id)
}
async fn remove_channel_mappings(&self, channel_id: &str) {
Self::remove_channel_mappings_static(
channel_id,
&self.peer_to_channel,
&self.channel_to_peers,
&self.peer_user_agents,
&self.event_tx,
)
.await;
}
async fn remove_channel_mappings_static(
channel_id: &str,
peer_to_channel: &RwLock<HashMap<PeerId, HashSet<String>>>,
channel_to_peers: &RwLock<HashMap<String, HashSet<PeerId>>>,
peer_user_agents: &RwLock<HashMap<PeerId, String>>,
event_tx: &broadcast::Sender<P2PEvent>,
) {
let mut p2c = peer_to_channel.write().await;
let mut c2p = channel_to_peers.write().await;
if let Some(app_peers) = c2p.remove(channel_id) {
for app_peer in &app_peers {
if let Some(channels) = p2c.get_mut(app_peer) {
channels.remove(channel_id);
if channels.is_empty() {
p2c.remove(app_peer);
peer_user_agents.write().await.remove(app_peer);
let _ = event_tx.send(P2PEvent::PeerDisconnected(*app_peer));
}
}
}
}
}
}
impl TransportHandle {
pub async fn connect_peer(&self, address: &MultiAddr) -> Result<String> {
let socket_addr = address.dialable_socket_addr().ok_or_else(|| {
P2PError::Network(NetworkError::InvalidAddress(
format!(
"only QUIC transport is supported for connect, got {}: {}",
address.transport().kind(),
address
)
.into(),
))
})?;
let normalized_addr = normalize_wildcard_to_loopback(socket_addr);
let addr_list = vec![normalized_addr];
let peer_id = match tokio::time::timeout(
self.connection_timeout,
self.dual_node.connect_happy_eyeballs(&addr_list),
)
.await
{
Ok(Ok(addr)) => {
let connected_peer_id = addr.to_string();
let is_self = {
let addrs = self.listen_addrs.read().await;
addrs.iter().any(|a| a.socket_addr() == Some(addr))
};
if is_self {
warn!(
"Detected self-connection to own address {} (channel_id: {}), rejecting",
address, connected_peer_id
);
self.dual_node.disconnect_peer_by_addr(&addr).await;
return Err(P2PError::Network(NetworkError::InvalidAddress(
format!("Cannot connect to self ({})", address).into(),
)));
}
info!("Successfully connected to channel: {}", connected_peer_id);
connected_peer_id
}
Ok(Err(e)) => {
warn!("connect_happy_eyeballs failed for {}: {}", address, e);
return Err(P2PError::Transport(
crate::error::TransportError::ConnectionFailed {
addr: normalized_addr,
reason: e.to_string().into(),
},
));
}
Err(_) => {
warn!(
"connect_happy_eyeballs timed out for {} after {:?}",
address, self.connection_timeout
);
return Err(P2PError::Timeout(self.connection_timeout));
}
};
let peer_info = PeerInfo {
channel_id: peer_id.clone(),
addresses: vec![address.clone()],
connected_at: Instant::now(),
last_seen: Instant::now(),
status: ConnectionStatus::Connected,
protocols: vec!["p2p-foundation/1.0".to_string()],
heartbeat_count: 0,
};
self.peers.write().await.insert(peer_id.clone(), peer_info);
self.active_connections
.write()
.await
.insert(peer_id.clone());
Ok(peer_id)
}
pub async fn disconnect_peer(&self, peer_id: &PeerId) -> Result<()> {
info!("Disconnecting from peer: {}", peer_id);
let orphaned_channels = {
let mut p2c = self.peer_to_channel.write().await;
let mut c2p = self.channel_to_peers.write().await;
let channel_ids = match p2c.remove(peer_id) {
Some(chs) => chs,
None => {
info!(
"Peer {} has no tracked channels, nothing to disconnect",
peer_id
);
return Ok(());
}
};
let mut orphaned = Vec::new();
for channel_id in &channel_ids {
if let Some(peers) = c2p.get_mut(channel_id) {
peers.remove(peer_id);
if peers.is_empty() {
c2p.remove(channel_id);
orphaned.push(channel_id.clone());
}
}
}
orphaned
};
self.peer_user_agents.write().await.remove(peer_id);
let _ = self.event_tx.send(P2PEvent::PeerDisconnected(*peer_id));
for channel_id in &orphaned_channels {
match channel_id.parse::<SocketAddr>() {
Ok(addr) => self.dual_node.disconnect_peer_by_addr(&addr).await,
Err(e) => {
warn!(
peer = %peer_id,
channel = %channel_id,
error = %e,
"Failed to parse channel ID as SocketAddr — QUIC connection will not be closed",
);
}
}
self.active_connections.write().await.remove(channel_id);
self.peers.write().await.remove(channel_id);
}
info!("Disconnected from peer: {}", peer_id);
Ok(())
}
async fn disconnect_all_peers(&self) -> Result<()> {
let peer_ids: Vec<PeerId> = self.peer_to_channel.read().await.keys().cloned().collect();
for peer_id in &peer_ids {
self.disconnect_peer(peer_id).await?;
}
Ok(())
}
}
impl TransportHandle {
pub async fn send_message(
&self,
peer_id: &PeerId,
protocol: &str,
data: Vec<u8>,
) -> Result<()> {
let peer_hex = peer_id.to_hex();
let channels: Vec<String> = self
.peer_to_channel
.read()
.await
.get(peer_id)
.map(|set| set.iter().cloned().collect())
.unwrap_or_default();
if channels.is_empty() {
return Err(P2PError::Network(NetworkError::PeerNotFound(
peer_hex.into(),
)));
}
let mut last_err = None;
for channel_id in &channels {
match self
.send_on_channel(channel_id, protocol, data.clone())
.await
{
Ok(()) => return Ok(()),
Err(e) => {
warn!(
peer = %peer_hex,
channel = %channel_id,
error = %e,
"Channel send failed, removing and trying next",
);
self.remove_channel(channel_id).await;
last_err = Some(e);
}
}
}
Err(last_err
.unwrap_or_else(|| P2PError::Network(NetworkError::PeerNotFound(peer_hex.into()))))
}
pub(crate) async fn send_on_channel(
&self,
channel_id: &str,
protocol: &str,
data: Vec<u8>,
) -> Result<()> {
debug!(
"Sending message to channel {} on protocol {}",
channel_id, protocol
);
{
let needs_insert = {
let peers = self.peers.read().await;
!peers.contains_key(channel_id)
};
if needs_insert {
let mut peers = self.peers.write().await;
peers.entry(channel_id.to_string()).or_insert_with(|| {
info!(
"send_on_channel: registering new channel {} on the fly",
channel_id
);
let addresses = channel_id
.parse::<std::net::SocketAddr>()
.map(|addr| vec![MultiAddr::quic(addr)])
.unwrap_or_default();
PeerInfo {
channel_id: channel_id.to_string(),
addresses,
status: ConnectionStatus::Connected,
last_seen: Instant::now(),
connected_at: Instant::now(),
protocols: Vec::new(),
heartbeat_count: 0,
}
});
}
}
if !self.is_connection_active(channel_id).await {
self.active_connections
.write()
.await
.insert(channel_id.to_string());
}
let raw_data_len = data.len();
let message_data = self.create_protocol_message(protocol, data)?;
info!(
"Sending {} bytes to channel {} on protocol {} (raw data: {} bytes)",
message_data.len(),
channel_id,
protocol,
raw_data_len
);
let addr: SocketAddr = channel_id.parse().map_err(|e: std::net::AddrParseError| {
P2PError::Network(NetworkError::PeerNotFound(
format!("Invalid channel ID address: {e}").into(),
))
})?;
let send_fut = self.dual_node.send_to_peer_optimized(&addr, &message_data);
let result = tokio::time::timeout(self.connection_timeout, send_fut)
.await
.map_err(|_| {
P2PError::Transport(crate::error::TransportError::StreamError(
"Timed out sending message".into(),
))
})?
.map_err(|e| {
P2PError::Transport(crate::error::TransportError::StreamError(
e.to_string().into(),
))
});
if result.is_ok() {
info!(
"Successfully sent {} bytes to channel {}",
message_data.len(),
channel_id
);
} else {
warn!("Failed to send message to channel {}", channel_id);
self.active_connections.write().await.remove(channel_id);
}
result
}
pub async fn channels_for_peer(&self, app_peer_id: &PeerId) -> Vec<String> {
self.peer_to_channel
.read()
.await
.get(app_peer_id)
.map(|channels| channels.iter().cloned().collect())
.unwrap_or_default()
}
pub(crate) async fn peers_on_channel(&self, channel_id: &str) -> Vec<PeerId> {
self.channel_to_peers
.read()
.await
.get(channel_id)
.map(|set| set.iter().cloned().collect())
.unwrap_or_default()
}
pub async fn is_known_app_peer_id(&self, peer_id: &PeerId) -> bool {
self.peer_to_channel.read().await.contains_key(peer_id)
}
pub async fn wait_for_peer_identity(
&self,
channel_id: &str,
timeout: Duration,
) -> Result<PeerId> {
let deadline = Instant::now() + timeout;
let poll_interval = Duration::from_millis(50);
loop {
let peers = self.peers_on_channel(channel_id).await;
if let Some(peer_id) = peers.into_iter().next() {
return Ok(peer_id);
}
if Instant::now() >= deadline {
return Err(P2PError::Timeout(timeout));
}
tokio::time::sleep(poll_interval).await;
}
}
pub async fn send_request(
&self,
peer_id: &PeerId,
protocol: &str,
data: Vec<u8>,
timeout: Duration,
) -> Result<PeerResponse> {
let timeout = timeout.min(MAX_REQUEST_TIMEOUT);
validate_protocol_name(protocol)?;
let message_id = uuid::Uuid::new_v4().to_string();
let (tx, rx) = tokio::sync::oneshot::channel();
let started_at = Instant::now();
{
let mut reqs = self.active_requests.write().await;
if reqs.len() >= MAX_ACTIVE_REQUESTS {
return Err(P2PError::Transport(
crate::error::TransportError::StreamError(
format!(
"Too many active requests ({MAX_ACTIVE_REQUESTS}); try again later"
)
.into(),
),
));
}
reqs.insert(
message_id.clone(),
PendingRequest {
response_tx: tx,
expected_peer: *peer_id,
},
);
}
let envelope = RequestResponseEnvelope {
message_id: message_id.clone(),
is_response: false,
payload: data,
};
let envelope_bytes = match postcard::to_allocvec(&envelope) {
Ok(bytes) => bytes,
Err(e) => {
self.active_requests.write().await.remove(&message_id);
return Err(P2PError::Serialization(
format!("Failed to serialize request envelope: {e}").into(),
));
}
};
let wire_protocol = format!("/rr/{}", protocol);
if let Err(e) = self
.send_message(peer_id, &wire_protocol, envelope_bytes)
.await
{
self.active_requests.write().await.remove(&message_id);
return Err(e);
}
let result = match tokio::time::timeout(timeout, rx).await {
Ok(Ok(response_bytes)) => {
let latency = started_at.elapsed();
Ok(PeerResponse {
peer_id: *peer_id,
data: response_bytes,
latency,
})
}
Ok(Err(_)) => Err(P2PError::Network(NetworkError::ConnectionClosed {
peer_id: peer_id.to_hex().into(),
})),
Err(_) => Err(P2PError::Transport(
crate::error::TransportError::StreamError(
format!(
"Request to {} on {} timed out after {:?}",
peer_id, protocol, timeout
)
.into(),
),
)),
};
self.active_requests.write().await.remove(&message_id);
result
}
pub async fn send_response(
&self,
peer_id: &PeerId,
protocol: &str,
message_id: &str,
data: Vec<u8>,
) -> Result<()> {
validate_protocol_name(protocol)?;
let envelope = RequestResponseEnvelope {
message_id: message_id.to_string(),
is_response: true,
payload: data,
};
let envelope_bytes = postcard::to_allocvec(&envelope).map_err(|e| {
P2PError::Serialization(format!("Failed to serialize response envelope: {e}").into())
})?;
let wire_protocol = format!("/rr/{}", protocol);
self.send_message(peer_id, &wire_protocol, envelope_bytes)
.await
}
pub fn parse_request_envelope(data: &[u8]) -> Option<(String, bool, Vec<u8>)> {
let envelope: RequestResponseEnvelope = postcard::from_bytes(data).ok()?;
Some((envelope.message_id, envelope.is_response, envelope.payload))
}
fn create_protocol_message(&self, protocol: &str, data: Vec<u8>) -> Result<Vec<u8>> {
let mut message = WireMessage {
protocol: protocol.to_string(),
data,
from: *self.node_identity.peer_id(),
timestamp: Self::current_timestamp_secs()?,
user_agent: self.user_agent.clone(),
public_key: Vec::new(),
signature: Vec::new(),
};
Self::sign_wire_message(&mut message, &self.node_identity)?;
Self::serialize_wire_message(&message)
}
fn create_identity_announce_bytes(
identity: &NodeIdentity,
user_agent: &str,
) -> Result<Vec<u8>> {
let mut message = WireMessage {
protocol: IDENTITY_ANNOUNCE_PROTOCOL.to_string(),
data: vec![],
from: *identity.peer_id(),
timestamp: Self::current_timestamp_secs()?,
user_agent: user_agent.to_owned(),
public_key: Vec::new(),
signature: Vec::new(),
};
Self::sign_wire_message(&mut message, identity)?;
Self::serialize_wire_message(&message)
}
fn current_timestamp_secs() -> Result<u64> {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.map_err(|e| {
P2PError::Network(NetworkError::ProtocolError(
format!("System time error: {e}").into(),
))
})
}
fn sign_wire_message(message: &mut WireMessage, identity: &NodeIdentity) -> Result<()> {
let signable = Self::compute_signable_bytes(
&message.protocol,
&message.data,
&message.from,
message.timestamp,
&message.user_agent,
)?;
let sig = identity.sign(&signable).map_err(|e| {
P2PError::Network(NetworkError::ProtocolError(
format!("Failed to sign message: {e}").into(),
))
})?;
message.public_key = identity.public_key().as_bytes().to_vec();
message.signature = sig.as_bytes().to_vec();
Ok(())
}
fn serialize_wire_message(message: &WireMessage) -> Result<Vec<u8>> {
postcard::to_stdvec(message).map_err(|e| {
P2PError::Transport(crate::error::TransportError::StreamError(
format!("Failed to serialize wire message: {e}").into(),
))
})
}
fn compute_signable_bytes(
protocol: &str,
data: &[u8],
from: &PeerId,
timestamp: u64,
user_agent: &str,
) -> Result<Vec<u8>> {
postcard::to_stdvec(&(protocol, data, from, timestamp, user_agent)).map_err(|e| {
P2PError::Network(NetworkError::ProtocolError(
format!("Failed to serialize signable bytes: {e}").into(),
))
})
}
}
impl TransportHandle {
pub async fn subscribe(&self, topic: &str) -> Result<()> {
info!("Subscribed to topic: {}", topic);
Ok(())
}
pub async fn publish(&self, topic: &str, data: &[u8]) -> Result<()> {
info!(
"Publishing message to topic: {} ({} bytes)",
topic,
data.len()
);
let mut peer_channel_groups: Vec<Vec<String>> = Vec::new();
let mut mapped_channels: HashSet<String> = HashSet::new();
{
let p2c = self.peer_to_channel.read().await;
for channels in p2c.values() {
let chs: Vec<String> = channels.iter().cloned().collect();
mapped_channels.extend(chs.iter().cloned());
if !chs.is_empty() {
peer_channel_groups.push(chs);
}
}
}
{
let peers_guard = self.peers.read().await;
for channel_id in peers_guard.keys() {
if !mapped_channels.contains(channel_id) {
peer_channel_groups.push(vec![channel_id.clone()]);
}
}
}
if peer_channel_groups.is_empty() {
debug!("No peers connected, message will only be sent to local subscribers");
} else {
let mut send_count = 0;
let total = peer_channel_groups.len();
for channels in &peer_channel_groups {
let mut sent = false;
for channel_id in channels {
match self.send_on_channel(channel_id, topic, data.to_vec()).await {
Ok(()) => {
send_count += 1;
debug!("Published message via channel: {}", channel_id);
sent = true;
break;
}
Err(e) => {
warn!(
channel = %channel_id,
error = %e,
"Publish channel failed, removing and trying next",
);
self.remove_channel(channel_id).await;
}
}
}
if !sent {
warn!("All channels exhausted for one peer during publish");
}
}
info!(
"Published message to {}/{} connected peers",
send_count, total
);
}
self.send_event(P2PEvent::Message {
topic: topic.to_string(),
source: Some(*self.node_identity.peer_id()),
data: data.to_vec(),
});
Ok(())
}
}
impl TransportHandle {
pub fn subscribe_events(&self) -> broadcast::Receiver<P2PEvent> {
self.event_tx.subscribe()
}
pub(crate) fn send_event(&self, event: P2PEvent) {
if let Err(e) = self.event_tx.send(event) {
tracing::trace!("Event broadcast has no receivers: {e}");
}
}
}
impl TransportHandle {
pub async fn start_network_listeners(&self) -> Result<()> {
info!("Starting dual-stack listeners (saorsa-transport)...");
let socket_addrs = self.dual_node.local_addrs().await.map_err(|e| {
P2PError::Transport(crate::error::TransportError::SetupFailed(
format!("Failed to get local addresses: {}", e).into(),
))
})?;
let addrs: Vec<SocketAddr> = socket_addrs.clone();
{
let mut la = self.listen_addrs.write().await;
*la = socket_addrs.into_iter().map(MultiAddr::quic).collect();
}
let peers = self.peers.clone();
let active_connections = self.active_connections.clone();
let rate_limiter = self.rate_limiter.clone();
let dual = self.dual_node.clone();
let handle = tokio::spawn(async move {
loop {
let Some(remote_sock) = dual.accept_any().await else {
break;
};
if let Err(e) = rate_limiter.check_ip(&remote_sock.ip()) {
warn!(
"Rate-limited incoming connection from {}: {}",
remote_sock, e
);
continue;
}
let channel_id = remote_sock.to_string();
let remote_addr = MultiAddr::quic(remote_sock);
register_new_channel(&peers, &channel_id, &remote_addr).await;
active_connections.write().await.insert(channel_id);
}
});
*self.listener_handle.write().await = Some(handle);
self.start_message_receiving_system().await?;
info!("Dual-stack listeners active on: {:?}", addrs);
Ok(())
}
async fn start_message_receiving_system(&self) -> Result<()> {
info!("Starting message receiving system");
let (tx, mut rx) = tokio::sync::mpsc::channel(MESSAGE_RECV_CHANNEL_CAPACITY);
let mut handles = self
.dual_node
.spawn_recv_tasks(tx.clone(), self.shutdown.clone());
drop(tx);
let event_tx = self.event_tx.clone();
let active_requests = Arc::clone(&self.active_requests);
let peer_to_channel = Arc::clone(&self.peer_to_channel);
let channel_to_peers = Arc::clone(&self.channel_to_peers);
let peer_user_agents = Arc::clone(&self.peer_user_agents);
let self_peer_id = *self.node_identity.peer_id();
handles.push(tokio::spawn(async move {
info!("Message receive loop started");
while let Some((from_addr, bytes)) = rx.recv().await {
let channel_id = from_addr.to_string();
trace!("Received {} bytes from channel {}", bytes.len(), channel_id);
match parse_protocol_message(&bytes, &channel_id) {
Some(ParsedMessage {
event,
authenticated_node_id,
user_agent: peer_user_agent,
}) => {
if let Some(ref app_id) = authenticated_node_id
&& *app_id != self_peer_id
{
let mut p2c = peer_to_channel.write().await;
let is_new_peer = !p2c.contains_key(app_id);
let channels = p2c.entry(*app_id).or_default();
let inserted = channels.insert(channel_id.clone());
if inserted {
channel_to_peers
.write()
.await
.entry(channel_id.clone())
.or_default()
.insert(*app_id);
}
drop(p2c);
if is_new_peer {
peer_user_agents
.write()
.await
.insert(*app_id, peer_user_agent.clone());
broadcast_event(
&event_tx,
P2PEvent::PeerConnected(*app_id, peer_user_agent.clone()),
);
}
}
if let P2PEvent::Message { ref topic, .. } = event
&& topic == IDENTITY_ANNOUNCE_PROTOCOL
{
continue;
}
if let P2PEvent::Message {
ref topic,
ref data,
..
} = event
&& topic.starts_with("/rr/")
&& let Ok(envelope) =
postcard::from_bytes::<RequestResponseEnvelope>(data)
&& envelope.is_response
{
let mut reqs = active_requests.write().await;
let expected_peer = match reqs.get(&envelope.message_id) {
Some(pending) => pending.expected_peer,
None => {
trace!(
message_id = %envelope.message_id,
"Unmatched /rr/ response (likely timed out) — suppressing"
);
continue;
}
};
if authenticated_node_id.as_ref() != Some(&expected_peer) {
warn!(
message_id = %envelope.message_id,
expected = %expected_peer,
actual_channel = %channel_id,
authenticated = ?authenticated_node_id,
"Response origin mismatch — ignoring"
);
continue;
}
if let Some(pending) = reqs.remove(&envelope.message_id) {
if pending.response_tx.send(envelope.payload).is_err() {
warn!(
message_id = %envelope.message_id,
"Response receiver dropped before delivery"
);
}
continue;
}
trace!(
message_id = %envelope.message_id,
"Unmatched /rr/ response (likely timed out) — suppressing"
);
continue;
}
broadcast_event(&event_tx, event);
}
None => {
warn!("Failed to parse protocol message ({} bytes)", bytes.len());
}
}
}
info!("Message receive loop ended — channel closed");
}));
*self.recv_handles.write().await = handles;
Ok(())
}
}
impl TransportHandle {
pub async fn stop(&self) -> Result<()> {
info!("Stopping transport...");
self.shutdown.cancel();
self.dual_node.shutdown_endpoints().await;
let handles: Vec<_> = self.recv_handles.write().await.drain(..).collect();
Self::join_task_handles(handles, "recv").await;
Self::join_task_slot(&self.listener_handle, "listener").await;
Self::join_task_slot(&self.connection_monitor_handle, "connection monitor").await;
self.disconnect_all_peers().await?;
info!("Transport stopped");
Ok(())
}
async fn join_task_slot(handle_slot: &RwLock<Option<JoinHandle<()>>>, task_name: &str) {
let handle = handle_slot.write().await.take();
if let Some(handle) = handle {
Self::join_task_handle(handle, task_name).await;
}
}
async fn join_task_handles(handles: Vec<JoinHandle<()>>, task_name: &str) {
for handle in handles {
Self::join_task_handle(handle, task_name).await;
}
}
async fn join_task_handle(handle: JoinHandle<()>, task_name: &str) {
match handle.await {
Ok(()) => {}
Err(e) if e.is_cancelled() => {
tracing::debug!("{task_name} task was cancelled during shutdown");
}
Err(e) if e.is_panic() => {
tracing::error!("{task_name} task panicked during shutdown: {:?}", e);
}
Err(e) => {
tracing::warn!("{task_name} task join error during shutdown: {:?}", e);
}
}
}
}
impl TransportHandle {
#[allow(clippy::too_many_arguments)]
async fn connection_lifecycle_monitor_with_rx(
dual_node: Arc<DualStackNetworkNode>,
mut event_rx: broadcast::Receiver<
crate::transport::saorsa_transport_adapter::ConnectionEvent,
>,
active_connections: Arc<RwLock<HashSet<String>>>,
peers: Arc<RwLock<HashMap<String, PeerInfo>>>,
event_tx: broadcast::Sender<P2PEvent>,
_geo_provider: Arc<BgpGeoProvider>,
shutdown: CancellationToken,
peer_to_channel: Arc<RwLock<HashMap<PeerId, HashSet<String>>>>,
channel_to_peers: Arc<RwLock<HashMap<String, HashSet<PeerId>>>>,
peer_user_agents: Arc<RwLock<HashMap<PeerId, String>>>,
node_identity: Arc<NodeIdentity>,
user_agent: String,
) {
info!("Connection lifecycle monitor started (pre-subscribed receiver)");
loop {
tokio::select! {
() = shutdown.cancelled() => {
info!("Connection lifecycle monitor shutting down");
break;
}
recv = event_rx.recv() => {
match recv {
Ok(event) => match event {
ConnectionEvent::Established {
remote_address, ..
} => {
let channel_id = remote_address.to_string();
debug!(
"Connection established: channel={}, addr={}",
channel_id, remote_address
);
active_connections.write().await.insert(channel_id.clone());
let mut peers_lock = peers.write().await;
if let Some(peer_info) = peers_lock.get_mut(&channel_id) {
peer_info.status = ConnectionStatus::Connected;
peer_info.connected_at = Instant::now();
} else {
debug!("Registering new incoming channel: {}", channel_id);
peers_lock.insert(
channel_id.clone(),
PeerInfo {
channel_id: channel_id.clone(),
addresses: vec![MultiAddr::quic(remote_address)],
status: ConnectionStatus::Connected,
last_seen: Instant::now(),
connected_at: Instant::now(),
protocols: Vec::new(),
heartbeat_count: 0,
},
);
}
match Self::create_identity_announce_bytes(&node_identity, &user_agent) {
Ok(announce_bytes) => {
if let Err(e) = dual_node
.send_to_peer_optimized(&remote_address, &announce_bytes)
.await
{
warn!("Failed to send identity announce to {channel_id}: {e}");
}
}
Err(e) => {
warn!("Failed to create identity announce: {e}");
}
}
}
ConnectionEvent::Lost { remote_address, reason }
| ConnectionEvent::Failed { remote_address, reason } => {
let channel_id = remote_address.to_string();
debug!("Connection lost/failed: channel={channel_id}, reason={reason}");
active_connections.write().await.remove(&channel_id);
peers.write().await.remove(&channel_id);
Self::remove_channel_mappings_static(
&channel_id,
&peer_to_channel,
&channel_to_peers,
&peer_user_agents,
&event_tx,
).await;
}
},
Err(broadcast::error::RecvError::Lagged(skipped)) => {
warn!(
"Connection event receiver lagged, skipped {} events",
skipped
);
}
Err(broadcast::error::RecvError::Closed) => {
info!("Connection event channel closed, stopping lifecycle monitor");
break;
}
}
}
}
}
}
}
fn validate_protocol_name(protocol: &str) -> Result<()> {
if protocol.is_empty() || protocol.contains(&['/', '\\', '\0'][..]) {
return Err(P2PError::Transport(
crate::error::TransportError::StreamError(
format!("Invalid protocol name: {:?}", protocol).into(),
),
));
}
Ok(())
}
#[async_trait::async_trait]
impl NetworkSender for TransportHandle {
async fn send_message(&self, peer_id: &PeerId, protocol: &str, data: Vec<u8>) -> Result<()> {
TransportHandle::send_message(self, peer_id, protocol, data).await
}
fn local_peer_id(&self) -> PeerId {
self.peer_id()
}
}
#[cfg(test)]
impl TransportHandle {
pub(crate) async fn inject_peer(&self, peer_id: String, info: PeerInfo) {
self.peers.write().await.insert(peer_id, info);
}
pub(crate) async fn inject_active_connection(&self, channel_id: String) {
self.active_connections.write().await.insert(channel_id);
}
pub(crate) async fn inject_peer_to_channel(&self, peer_id: PeerId, channel_id: String) {
self.peer_to_channel
.write()
.await
.entry(peer_id)
.or_default()
.insert(channel_id.clone());
self.channel_to_peers
.write()
.await
.entry(channel_id)
.or_default()
.insert(peer_id);
}
}