use super::capabilities::{Transport, TransportCapabilities};
use super::health::{HealthMonitor, HeartbeatConfig};
use super::reconnection::{ReconnectionManager, ReconnectionPolicy};
use super::{
ConnectionHealth, DisconnectReason, MeshConnection, MeshTransport, NodeId, PeerEvent,
PeerEventReceiver, PeerEventSender, Result, TransportError, PEER_EVENT_CHANNEL_CAPACITY,
};
use crate::network::iroh_transport::IrohTransport;
use crate::network::peer_config::PeerConfig;
use async_trait::async_trait;
use iroh::endpoint::Connection;
use iroh::EndpointId;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tracing::{debug, info, warn};
pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(5);
pub struct IrohMeshTransport {
transport: Arc<IrohTransport>,
peer_config: Arc<RwLock<PeerConfig>>,
node_to_endpoint: Arc<RwLock<HashMap<NodeId, EndpointId>>>,
endpoint_to_node: Arc<RwLock<HashMap<EndpointId, NodeId>>>,
connections: Arc<RwLock<HashMap<NodeId, IrohMeshConnection>>>,
event_senders: Arc<RwLock<Vec<PeerEventSender>>>,
cleanup_running: Arc<AtomicBool>,
cleanup_interval: Duration,
reconnection: Arc<RwLock<ReconnectionManager>>,
static_peers: Arc<RwLock<std::collections::HashSet<NodeId>>>,
health_monitor: Arc<HealthMonitor>,
capabilities: TransportCapabilities,
}
impl IrohMeshTransport {
pub fn new(transport: Arc<IrohTransport>, peer_config: PeerConfig) -> Self {
Self::with_cleanup_interval(transport, peer_config, DEFAULT_CLEANUP_INTERVAL)
}
pub fn with_cleanup_interval(
transport: Arc<IrohTransport>,
peer_config: PeerConfig,
cleanup_interval: Duration,
) -> Self {
Self::with_reconnection_policy(
transport,
peer_config,
cleanup_interval,
ReconnectionPolicy::default(),
)
}
pub fn with_reconnection_policy(
transport: Arc<IrohTransport>,
peer_config: PeerConfig,
cleanup_interval: Duration,
reconnection_policy: ReconnectionPolicy,
) -> Self {
Self::with_full_config(
transport,
peer_config,
cleanup_interval,
reconnection_policy,
HeartbeatConfig::default(),
)
}
pub fn with_full_config(
transport: Arc<IrohTransport>,
peer_config: PeerConfig,
cleanup_interval: Duration,
reconnection_policy: ReconnectionPolicy,
heartbeat_config: HeartbeatConfig,
) -> Self {
Self {
transport,
peer_config: Arc::new(RwLock::new(peer_config)),
node_to_endpoint: Arc::new(RwLock::new(HashMap::new())),
endpoint_to_node: Arc::new(RwLock::new(HashMap::new())),
connections: Arc::new(RwLock::new(HashMap::new())),
event_senders: Arc::new(RwLock::new(Vec::new())),
cleanup_running: Arc::new(AtomicBool::new(false)),
cleanup_interval,
reconnection: Arc::new(RwLock::new(ReconnectionManager::new(reconnection_policy))),
static_peers: Arc::new(RwLock::new(std::collections::HashSet::new())),
health_monitor: Arc::new(HealthMonitor::new(heartbeat_config)),
capabilities: TransportCapabilities::quic(),
}
}
pub fn health_monitor(&self) -> &Arc<HealthMonitor> {
&self.health_monitor
}
fn emit_event(&self, event: PeerEvent) {
let mut senders = self
.event_senders
.write()
.expect("event_senders lock poisoned");
senders.retain(|sender| {
match sender.try_send(event.clone()) {
Ok(()) => true,
Err(mpsc::error::TrySendError::Full(_)) => {
warn!(
"Peer event channel full, dropping event for one subscriber: {:?}",
event
);
true }
Err(mpsc::error::TrySendError::Closed(_)) => {
debug!("Peer event subscriber disconnected, removing channel");
false }
}
});
}
pub fn cleanup_dead_connections(&self) {
let mut connections = self.connections.write().expect("connections lock poisoned");
let dead_peers: Vec<_> = connections
.iter()
.filter(|(_, conn)| !conn.is_alive())
.map(|(id, conn)| (id.clone(), conn.disconnect_reason(), conn.connected_at()))
.collect();
for (peer_id, reason, connected_at) in dead_peers {
connections.remove(&peer_id);
let event = PeerEvent::Disconnected {
peer_id: peer_id.clone(),
reason: reason.unwrap_or(DisconnectReason::Unknown),
connection_duration: connected_at.elapsed(),
};
debug!("Peer {} disconnected: {:?}", peer_id, event);
self.emit_event(event);
}
}
pub fn register_peer(&self, node_id: NodeId, endpoint_id: EndpointId) {
self.node_to_endpoint
.write()
.expect("node_to_endpoint lock poisoned")
.insert(node_id.clone(), endpoint_id);
self.endpoint_to_node
.write()
.expect("endpoint_to_node lock poisoned")
.insert(endpoint_id, node_id);
}
pub fn get_node_id(&self, endpoint_id: &EndpointId) -> Option<NodeId> {
self.endpoint_to_node
.read()
.expect("endpoint_to_node lock poisoned")
.get(endpoint_id)
.cloned()
}
pub fn get_endpoint_id(&self, node_id: &NodeId) -> Option<EndpointId> {
self.node_to_endpoint
.read()
.expect("node_to_endpoint lock poisoned")
.get(node_id)
.copied()
}
pub fn transport(&self) -> &Arc<IrohTransport> {
&self.transport
}
fn start_cleanup_task(&self) {
if self
.cleanup_running
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
return;
}
let connections = Arc::clone(&self.connections);
let event_senders = Arc::clone(&self.event_senders);
let cleanup_running = Arc::clone(&self.cleanup_running);
let interval = self.cleanup_interval;
let reconnection = Arc::clone(&self.reconnection);
let static_peers = Arc::clone(&self.static_peers);
let transport = Arc::clone(&self.transport);
let peer_config = Arc::clone(&self.peer_config);
let health_monitor = Arc::clone(&self.health_monitor);
tokio::spawn(async move {
info!(
"Started peer cleanup/reconnection task with interval {:?}",
interval
);
let emit_event = |event: PeerEvent, senders: &Arc<RwLock<Vec<PeerEventSender>>>| {
let mut senders = senders.write().expect("event_senders lock poisoned");
senders.retain(|sender| match sender.try_send(event.clone()) {
Ok(()) => true,
Err(mpsc::error::TrySendError::Full(_)) => true,
Err(mpsc::error::TrySendError::Closed(_)) => false,
});
};
while cleanup_running.load(Ordering::SeqCst) {
tokio::time::sleep(interval).await;
if !cleanup_running.load(Ordering::SeqCst) {
break;
}
let dead_peers: Vec<_> = {
let conns = connections.read().expect("connections lock poisoned");
conns
.iter()
.filter(|(_, conn)| !conn.is_alive())
.map(|(id, conn)| {
(id.clone(), conn.disconnect_reason(), conn.connected_at())
})
.collect()
};
if !dead_peers.is_empty() {
let mut conns = connections.write().expect("connections lock poisoned");
let static_set = static_peers.read().expect("static_peers lock poisoned");
let mut recon = reconnection.write().expect("reconnection lock poisoned");
for (peer_id, reason, connected_at) in dead_peers {
conns.remove(&peer_id);
health_monitor.stop_monitoring(&peer_id);
let event = PeerEvent::Disconnected {
peer_id: peer_id.clone(),
reason: reason.unwrap_or(DisconnectReason::Unknown),
connection_duration: connected_at.elapsed(),
};
debug!("Cleanup: Peer {} disconnected: {:?}", peer_id, event);
emit_event(event, &event_senders);
let is_static = static_set.contains(&peer_id);
if is_static {
recon.schedule_reconnect(peer_id.clone(), true);
debug!("Scheduled reconnection for static peer: {}", peer_id);
}
}
}
let newly_dead_from_health: Vec<NodeId> = health_monitor.check_timeouts();
for peer_id in newly_dead_from_health {
let is_static = static_peers
.read()
.expect("static_peers lock poisoned")
.contains(&peer_id);
if is_static {
let mut recon = reconnection.write().expect("reconnection lock poisoned");
recon.schedule_reconnect(peer_id.clone(), true);
debug!(
"Health monitor detected dead peer, scheduling reconnection: {}",
peer_id
);
}
}
let due_peers: Vec<NodeId> = {
let recon = reconnection.read().expect("reconnection lock poisoned");
if !recon.is_enabled() {
continue;
}
recon.due_reconnections()
};
for peer_id in due_peers {
let (attempt, max_attempts) = {
let recon = reconnection.read().expect("reconnection lock poisoned");
let state = recon.get_state(&peer_id);
let attempt = state.map(|s| s.attempts + 1).unwrap_or(1);
let max = recon.policy().max_retries;
(attempt, max)
};
emit_event(
PeerEvent::Reconnecting {
peer_id: peer_id.clone(),
attempt,
max_attempts,
},
&event_senders,
);
let peer_info_opt = {
let config = peer_config.read().expect("peer_config lock poisoned");
config
.peers
.iter()
.find(|p| p.name == peer_id.as_str())
.cloned()
};
let result = if let Some(peer_info) = peer_info_opt {
transport.connect_peer(&peer_info).await
} else {
Err(anyhow::anyhow!("Peer not found in config: {}", peer_id))
};
match result {
Ok(Some(conn)) => {
let connected_at = Instant::now();
let mesh_conn =
IrohMeshConnection::new(peer_id.clone(), conn, connected_at);
connections
.write()
.expect("connections lock poisoned")
.insert(peer_id.clone(), mesh_conn);
reconnection
.write()
.expect("reconnection lock poisoned")
.reconnected(&peer_id);
health_monitor.start_monitoring(peer_id.clone());
info!("Reconnected to peer: {} (attempt {})", peer_id, attempt);
emit_event(
PeerEvent::Connected {
peer_id: peer_id.clone(),
connected_at,
},
&event_senders,
);
}
Ok(None) => {
reconnection
.write()
.expect("reconnection lock poisoned")
.reconnected(&peer_id);
debug!("Reconnection to {} handled by accept path", peer_id);
}
Err(e) => {
let error_msg = e.to_string();
let will_retry = {
let mut recon =
reconnection.write().expect("reconnection lock poisoned");
let will_retry = recon.failed(&peer_id, error_msg.clone());
if !will_retry {
recon.remove(&peer_id);
}
will_retry
};
warn!(
"Reconnection to {} failed (attempt {}): {} - will_retry={}",
peer_id, attempt, error_msg, will_retry
);
emit_event(
PeerEvent::ReconnectFailed {
peer_id: peer_id.clone(),
attempt,
error: error_msg,
will_retry,
},
&event_senders,
);
}
}
}
}
info!("Stopped peer cleanup/reconnection task");
});
}
fn stop_cleanup_task(&self) {
self.cleanup_running.store(false, Ordering::SeqCst);
}
}
#[async_trait]
impl MeshTransport for IrohMeshTransport {
async fn start(&self) -> Result<()> {
self.transport
.start_accept_loop()
.map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
let config = self.peer_config.read().expect("peer_config lock poisoned");
let mut static_peers = self
.static_peers
.write()
.expect("static_peers lock poisoned");
for peer_info in &config.peers {
let node_id = NodeId::new(peer_info.name.clone());
if let Ok(endpoint_id) = peer_info.endpoint_id() {
self.register_peer(node_id.clone(), endpoint_id);
static_peers.insert(node_id);
}
}
drop(static_peers);
self.start_cleanup_task();
Ok(())
}
async fn stop(&self) -> Result<()> {
self.stop_cleanup_task();
self.transport
.stop_accept_loop()
.map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
self.health_monitor.clear();
let connections = self
.connections
.write()
.expect("connections lock poisoned")
.drain()
.collect::<Vec<_>>();
for (_node_id, _conn) in connections {
}
Ok(())
}
async fn connect(&self, peer_id: &NodeId) -> Result<Box<dyn MeshConnection>> {
if let Some(conn) = self.get_connection(peer_id) {
if conn.is_alive() {
return Ok(conn);
}
debug!("Existing connection to {} is dead, reconnecting", peer_id);
self.cleanup_dead_connections();
}
let _endpoint_id = self
.node_to_endpoint
.read()
.expect("node_to_endpoint lock poisoned")
.get(peer_id)
.copied()
.ok_or_else(|| TransportError::PeerNotFound(peer_id.as_str().to_string()))?;
let peer_info = {
let config = self.peer_config.read().expect("peer_config lock poisoned");
config
.peers
.iter()
.find(|p| p.name == peer_id.as_str())
.cloned()
.ok_or_else(|| TransportError::PeerNotFound(peer_id.as_str().to_string()))?
};
let conn_opt = self
.transport
.connect_peer(&peer_info)
.await
.map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
match conn_opt {
Some(conn) => {
let connected_at = Instant::now();
let mesh_conn = IrohMeshConnection::new(peer_id.clone(), conn, connected_at);
self.connections
.write()
.expect("connections lock poisoned")
.insert(peer_id.clone(), mesh_conn.clone());
self.health_monitor.start_monitoring(peer_id.clone());
self.emit_event(PeerEvent::Connected {
peer_id: peer_id.clone(),
connected_at,
});
debug!("Connected to peer: {}", peer_id);
Ok(Box::new(mesh_conn))
}
None => {
self.connections
.read()
.expect("connections lock poisoned")
.get(peer_id)
.cloned()
.map(|c| Box::new(c) as Box<dyn MeshConnection>)
.ok_or_else(|| {
TransportError::ConnectionFailed(
"Connection being handled by accept path".to_string(),
)
})
}
}
}
async fn disconnect(&self, peer_id: &NodeId) -> Result<()> {
if let Some(conn) = self
.connections
.write()
.expect("connections lock poisoned")
.remove(peer_id)
{
self.health_monitor.stop_monitoring(peer_id);
let event = PeerEvent::Disconnected {
peer_id: peer_id.clone(),
reason: DisconnectReason::LocalClosed,
connection_duration: conn.connected_at().elapsed(),
};
debug!("Disconnected from peer: {}", peer_id);
self.emit_event(event);
}
Ok(())
}
fn get_connection(&self, peer_id: &NodeId) -> Option<Box<dyn MeshConnection>> {
self.connections
.read()
.expect("connections lock poisoned")
.get(peer_id)
.cloned()
.map(|c| Box::new(c) as Box<dyn MeshConnection>)
}
fn peer_count(&self) -> usize {
self.connections
.read()
.expect("connections lock poisoned")
.len()
}
fn connected_peers(&self) -> Vec<NodeId> {
self.connections
.read()
.expect("connections lock poisoned")
.keys()
.cloned()
.collect()
}
fn subscribe_peer_events(&self) -> PeerEventReceiver {
let (tx, rx) = mpsc::channel(PEER_EVENT_CHANNEL_CAPACITY);
self.event_senders
.write()
.expect("event_senders lock poisoned")
.push(tx);
rx
}
fn get_peer_health(&self, peer_id: &NodeId) -> Option<ConnectionHealth> {
self.health_monitor.get_health(peer_id)
}
}
impl Transport for IrohMeshTransport {
fn capabilities(&self) -> &TransportCapabilities {
&self.capabilities
}
fn is_available(&self) -> bool {
self.cleanup_running.load(Ordering::SeqCst)
}
fn signal_quality(&self) -> Option<u8> {
None
}
fn can_reach(&self, peer_id: &NodeId) -> bool {
self.node_to_endpoint
.read()
.expect("node_to_endpoint lock poisoned")
.contains_key(peer_id)
}
}
#[derive(Clone)]
pub struct IrohMeshConnection {
peer_id: NodeId,
connection: Connection,
connected_at: Instant,
}
impl IrohMeshConnection {
pub fn new(peer_id: NodeId, connection: Connection, connected_at: Instant) -> Self {
Self {
peer_id,
connection,
connected_at,
}
}
pub fn connection(&self) -> &Connection {
&self.connection
}
fn parse_close_reason(&self) -> Option<DisconnectReason> {
self.connection.close_reason().map(|reason| {
let reason_str = reason.to_string();
if reason_str.contains("timeout") || reason_str.contains("idle") {
DisconnectReason::IdleTimeout
} else if reason_str.contains("reset") || reason_str.contains("closed") {
DisconnectReason::RemoteClosed
} else if reason_str.contains("application") {
DisconnectReason::ApplicationError(reason_str)
} else {
DisconnectReason::NetworkError(reason_str)
}
})
}
}
impl MeshConnection for IrohMeshConnection {
fn peer_id(&self) -> &NodeId {
&self.peer_id
}
fn is_alive(&self) -> bool {
self.connection.close_reason().is_none()
}
fn connected_at(&self) -> Instant {
self.connected_at
}
fn disconnect_reason(&self) -> Option<DisconnectReason> {
self.parse_close_reason()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::network::peer_config::{LocalConfig, PeerInfo};
use std::net::SocketAddr;
#[tokio::test]
async fn test_iroh_mesh_transport_creation() {
let transport = Arc::new(IrohTransport::new().await.unwrap());
let peer_config = PeerConfig::empty();
let mesh_transport = IrohMeshTransport::new(transport, peer_config);
assert_eq!(mesh_transport.peer_count(), 0);
}
#[tokio::test]
async fn test_peer_registration() {
let transport = Arc::new(IrohTransport::new().await.unwrap());
let peer_config = PeerConfig::empty();
let mesh_transport = IrohMeshTransport::new(transport.clone(), peer_config);
let node_id = NodeId::new("test-node".to_string());
let endpoint_id = transport.endpoint_id();
mesh_transport.register_peer(node_id.clone(), endpoint_id);
assert_eq!(mesh_transport.get_endpoint_id(&node_id), Some(endpoint_id));
assert_eq!(mesh_transport.get_node_id(&endpoint_id), Some(node_id));
}
#[tokio::test]
async fn test_start_stop_lifecycle() {
let transport = Arc::new(IrohTransport::new().await.unwrap());
let peer_config = PeerConfig::empty();
let mesh_transport = IrohMeshTransport::new(transport.clone(), peer_config);
mesh_transport.start().await.unwrap();
assert!(transport.is_accept_loop_running());
mesh_transport.stop().await.unwrap();
assert!(!transport.is_accept_loop_running());
}
#[tokio::test]
async fn test_connect_to_unknown_peer() {
let transport = Arc::new(IrohTransport::new().await.unwrap());
let peer_config = PeerConfig::empty();
let mesh_transport = IrohMeshTransport::new(transport, peer_config);
mesh_transport.start().await.unwrap();
let unknown_peer = NodeId::new("unknown".to_string());
let result = mesh_transport.connect(&unknown_peer).await;
assert!(result.is_err());
match result {
Err(TransportError::PeerNotFound(_)) => {}
_ => panic!("Expected PeerNotFound error"),
}
}
#[tokio::test]
async fn test_disconnect() {
let transport = Arc::new(IrohTransport::new().await.unwrap());
let peer_config = PeerConfig::empty();
let mesh_transport = IrohMeshTransport::new(transport, peer_config);
mesh_transport.start().await.unwrap();
let peer_id = NodeId::new("test".to_string());
let result = mesh_transport.disconnect(&peer_id).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_static_config_peer_registration() {
let bind_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let transport = Arc::new(IrohTransport::bind(bind_addr).await.unwrap());
let endpoint_id = transport.endpoint_id();
let peer_config = PeerConfig {
local: LocalConfig::default(),
formation: None,
peers: vec![PeerInfo {
name: "test-peer".to_string(),
node_id: hex::encode(endpoint_id.as_bytes()),
addresses: vec!["127.0.0.1:9999".to_string()],
relay_url: None,
}],
};
let mesh_transport = IrohMeshTransport::new(transport, peer_config);
mesh_transport.start().await.unwrap();
let node_id = NodeId::new("test-peer".to_string());
assert_eq!(mesh_transport.get_endpoint_id(&node_id), Some(endpoint_id));
}
}