#[cfg(feature = "automerge-backend")]
use super::peer_config::PeerInfo;
#[cfg(feature = "automerge-backend")]
use anyhow::{Context, Result};
#[cfg(feature = "automerge-backend")]
use iroh::address_lookup::mdns::MdnsAddressLookup;
#[cfg(feature = "automerge-backend")]
use iroh::endpoint::{Connection, Endpoint};
#[cfg(feature = "automerge-backend")]
use iroh::{EndpointAddr, EndpointId};
#[cfg(feature = "automerge-backend")]
use std::collections::HashMap;
#[cfg(feature = "automerge-backend")]
use std::net::SocketAddr;
#[cfg(feature = "automerge-backend")]
use std::sync::atomic::{AtomicBool, Ordering};
#[cfg(feature = "automerge-backend")]
use std::sync::{Arc, RwLock};
#[cfg(feature = "automerge-backend")]
use std::time::Duration;
#[cfg(feature = "automerge-backend")]
use tokio::sync::mpsc;
#[cfg(feature = "automerge-backend")]
use tokio::task::JoinHandle;
#[cfg(feature = "automerge-backend")]
#[derive(Debug, Clone)]
pub enum TransportPeerEvent {
Connected {
endpoint_id: EndpointId,
connected_at: std::time::Instant,
},
Disconnected {
endpoint_id: EndpointId,
reason: String,
},
}
#[cfg(feature = "automerge-backend")]
pub const TRANSPORT_EVENT_CHANNEL_CAPACITY: usize = 256;
#[cfg(feature = "automerge-backend")]
pub type TransportEventReceiver = mpsc::Receiver<TransportPeerEvent>;
#[cfg(feature = "automerge-backend")]
pub type TransportEventSender = mpsc::Sender<TransportPeerEvent>;
#[cfg(feature = "automerge-backend")]
pub const CAP_AUTOMERGE_ALPN: &[u8] = b"cap/automerge/1";
#[cfg(feature = "automerge-backend")]
pub const QUIC_MAX_IDLE_TIMEOUT_SECS: u64 = 5;
#[cfg(feature = "automerge-backend")]
pub const QUIC_KEEP_ALIVE_INTERVAL_SECS: u64 = 1;
#[cfg(feature = "automerge-backend")]
fn create_tactical_transport_config() -> iroh::endpoint::QuicTransportConfig {
let config = iroh::endpoint::QuicTransportConfig::builder()
.max_idle_timeout(Some(
Duration::from_secs(QUIC_MAX_IDLE_TIMEOUT_SECS)
.try_into()
.expect("valid idle timeout duration"),
))
.keep_alive_interval(Duration::from_secs(QUIC_KEEP_ALIVE_INTERVAL_SECS))
.build();
tracing::debug!(
max_idle_timeout_secs = QUIC_MAX_IDLE_TIMEOUT_SECS,
keep_alive_interval_secs = QUIC_KEEP_ALIVE_INTERVAL_SECS,
"Created tactical QUIC transport config (Issue #315)"
);
config
}
#[cfg(feature = "automerge-backend")]
pub const CONNECTION_RECYCLE_INTERVAL_SECS: u64 = 60;
#[cfg(feature = "automerge-backend")]
pub struct IrohTransport {
endpoint: Endpoint,
connections: Arc<RwLock<HashMap<EndpointId, Connection>>>,
connection_timestamps: Arc<RwLock<HashMap<EndpointId, std::time::Instant>>>,
accept_running: Arc<AtomicBool>,
accept_task: Arc<RwLock<Option<JoinHandle<()>>>>,
mdns_discovery: Arc<RwLock<Option<MdnsAddressLookup>>>,
event_senders: Arc<RwLock<Vec<TransportEventSender>>>,
runtime_handle: tokio::runtime::Handle,
}
#[cfg(feature = "automerge-backend")]
impl IrohTransport {
pub async fn new() -> Result<Self> {
let endpoint = Endpoint::builder(iroh::endpoint::presets::N0)
.alpns(vec![CAP_AUTOMERGE_ALPN.to_vec()])
.transport_config(create_tactical_transport_config())
.bind()
.await
.context("Failed to create Iroh endpoint")?;
Ok(Self {
endpoint,
connections: Arc::new(RwLock::new(HashMap::new())),
connection_timestamps: Arc::new(RwLock::new(HashMap::new())),
accept_running: Arc::new(AtomicBool::new(false)),
accept_task: Arc::new(RwLock::new(None)),
mdns_discovery: Arc::new(RwLock::new(None)),
event_senders: Arc::new(RwLock::new(Vec::new())),
runtime_handle: tokio::runtime::Handle::current(),
})
}
pub async fn with_discovery(_node_name: &str) -> Result<Self> {
let mut rng = rand::rng();
let secret_key = iroh::SecretKey::generate(&mut rng);
let endpoint_id = secret_key.public();
let discovery = MdnsAddressLookup::builder()
.build(endpoint_id)
.context("Failed to create mDNS discovery")?;
let endpoint = Endpoint::builder(iroh::endpoint::presets::N0)
.alpns(vec![CAP_AUTOMERGE_ALPN.to_vec()])
.secret_key(secret_key)
.address_lookup(discovery.clone())
.transport_config(create_tactical_transport_config())
.bind()
.await
.context("Failed to create Iroh endpoint with mDNS discovery")?;
tracing::info!(
endpoint_id = %endpoint.id(),
"Created IrohTransport with mDNS discovery"
);
Ok(Self {
endpoint,
connections: Arc::new(RwLock::new(HashMap::new())),
connection_timestamps: Arc::new(RwLock::new(HashMap::new())),
accept_running: Arc::new(AtomicBool::new(false)),
accept_task: Arc::new(RwLock::new(None)),
mdns_discovery: Arc::new(RwLock::new(Some(discovery))),
event_senders: Arc::new(RwLock::new(Vec::new())),
runtime_handle: tokio::runtime::Handle::current(),
})
}
pub async fn from_seed(seed: &str) -> Result<Self> {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(b"peat-iroh-key-v1:"); hasher.update(seed.as_bytes());
let hash = hasher.finalize();
let mut seed_bytes = [0u8; 32];
seed_bytes.copy_from_slice(&hash);
let secret_key = iroh::SecretKey::from_bytes(&seed_bytes);
tracing::info!(
seed = seed,
endpoint_id = %secret_key.public(),
"Created IrohTransport with deterministic key from seed"
);
let endpoint = Endpoint::builder(iroh::endpoint::presets::N0)
.alpns(vec![CAP_AUTOMERGE_ALPN.to_vec()])
.secret_key(secret_key)
.transport_config(create_tactical_transport_config())
.bind()
.await
.context("Failed to create Iroh endpoint from seed")?;
Ok(Self {
endpoint,
connections: Arc::new(RwLock::new(HashMap::new())),
connection_timestamps: Arc::new(RwLock::new(HashMap::new())),
accept_running: Arc::new(AtomicBool::new(false)),
accept_task: Arc::new(RwLock::new(None)),
mdns_discovery: Arc::new(RwLock::new(None)),
event_senders: Arc::new(RwLock::new(Vec::new())),
runtime_handle: tokio::runtime::Handle::current(),
})
}
pub async fn from_seed_with_discovery(seed: &str) -> Result<Self> {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(b"peat-iroh-key-v1:"); hasher.update(seed.as_bytes());
let hash = hasher.finalize();
let mut seed_bytes = [0u8; 32];
seed_bytes.copy_from_slice(&hash);
let secret_key = iroh::SecretKey::from_bytes(&seed_bytes);
let endpoint_id = secret_key.public();
let discovery = MdnsAddressLookup::builder()
.build(endpoint_id)
.context("Failed to create mDNS discovery")?;
tracing::info!(
seed = seed,
endpoint_id = %endpoint_id,
"Created IrohTransport with deterministic key and mDNS discovery"
);
let endpoint = Endpoint::builder(iroh::endpoint::presets::N0)
.alpns(vec![CAP_AUTOMERGE_ALPN.to_vec()])
.secret_key(secret_key)
.address_lookup(discovery.clone())
.transport_config(create_tactical_transport_config())
.bind()
.await
.context("Failed to create Iroh endpoint from seed with discovery")?;
Ok(Self {
endpoint,
connections: Arc::new(RwLock::new(HashMap::new())),
connection_timestamps: Arc::new(RwLock::new(HashMap::new())),
accept_running: Arc::new(AtomicBool::new(false)),
accept_task: Arc::new(RwLock::new(None)),
mdns_discovery: Arc::new(RwLock::new(Some(discovery))),
event_senders: Arc::new(RwLock::new(Vec::new())),
runtime_handle: tokio::runtime::Handle::current(),
})
}
pub async fn from_seed_with_discovery_at_addr(
seed: &str,
bind_addr: SocketAddr,
) -> Result<Self> {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(b"peat-iroh-key-v1:"); hasher.update(seed.as_bytes());
let hash = hasher.finalize();
let mut seed_bytes = [0u8; 32];
seed_bytes.copy_from_slice(&hash);
let secret_key = iroh::SecretKey::from_bytes(&seed_bytes);
let endpoint_id = secret_key.public();
let discovery = MdnsAddressLookup::builder()
.build(endpoint_id)
.context("Failed to create mDNS discovery")?;
tracing::info!(
seed = seed,
endpoint_id = %endpoint_id,
bind_addr = %bind_addr,
"Created IrohTransport with deterministic key, mDNS discovery, and bind address"
);
let endpoint = Endpoint::builder(iroh::endpoint::presets::N0)
.alpns(vec![CAP_AUTOMERGE_ALPN.to_vec()])
.secret_key(secret_key)
.address_lookup(discovery.clone())
.bind_addr(bind_addr)
.context("Invalid bind address")?
.transport_config(create_tactical_transport_config())
.bind()
.await
.context("Failed to create Iroh endpoint from seed with discovery at addr")?;
Ok(Self {
endpoint,
connections: Arc::new(RwLock::new(HashMap::new())),
connection_timestamps: Arc::new(RwLock::new(HashMap::new())),
accept_running: Arc::new(AtomicBool::new(false)),
accept_task: Arc::new(RwLock::new(None)),
mdns_discovery: Arc::new(RwLock::new(Some(discovery))),
event_senders: Arc::new(RwLock::new(Vec::new())),
runtime_handle: tokio::runtime::Handle::current(),
})
}
pub async fn from_seed_at_addr(seed: &str, bind_addr: SocketAddr) -> Result<Self> {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(b"peat-iroh-key-v1:"); hasher.update(seed.as_bytes());
let hash = hasher.finalize();
let mut seed_bytes = [0u8; 32];
seed_bytes.copy_from_slice(&hash);
let secret_key = iroh::SecretKey::from_bytes(&seed_bytes);
let endpoint_id = secret_key.public();
tracing::info!(
seed = seed,
endpoint_id = %endpoint_id,
bind_addr = %bind_addr,
"Created IrohTransport with deterministic key (NO mDNS discovery - fast startup)"
);
let endpoint = Endpoint::builder(iroh::endpoint::presets::N0)
.alpns(vec![CAP_AUTOMERGE_ALPN.to_vec()])
.secret_key(secret_key)
.bind_addr(bind_addr)
.context("Invalid bind address")?
.transport_config(create_tactical_transport_config())
.bind()
.await
.context("Failed to create Iroh endpoint from seed at addr")?;
Ok(Self {
endpoint,
connections: Arc::new(RwLock::new(HashMap::new())),
connection_timestamps: Arc::new(RwLock::new(HashMap::new())),
accept_running: Arc::new(AtomicBool::new(false)),
accept_task: Arc::new(RwLock::new(None)),
mdns_discovery: Arc::new(RwLock::new(None)),
event_senders: Arc::new(RwLock::new(Vec::new())),
runtime_handle: tokio::runtime::Handle::current(),
})
}
pub async fn enable_mdns_discovery(&self) -> Result<()> {
{
let guard = self
.mdns_discovery
.read()
.expect("mdns_discovery lock poisoned");
if guard.is_some() {
anyhow::bail!("mDNS discovery is already enabled");
}
}
let endpoint_id = self.endpoint.id();
let discovery = MdnsAddressLookup::builder()
.build(endpoint_id)
.context("Failed to create mDNS discovery")?;
tracing::info!(
endpoint_id = %endpoint_id,
"Enabled mDNS discovery (deferred initialization)"
);
*self
.mdns_discovery
.write()
.expect("mdns_discovery lock poisoned") = Some(discovery);
Ok(())
}
pub fn endpoint_id_from_seed(seed: &str) -> EndpointId {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(b"peat-iroh-key-v1:"); hasher.update(seed.as_bytes());
let hash = hasher.finalize();
let mut seed_bytes = [0u8; 32];
seed_bytes.copy_from_slice(&hash);
let secret_key = iroh::SecretKey::from_bytes(&seed_bytes);
secret_key.public()
}
pub async fn bind(bind_addr: SocketAddr) -> Result<Self> {
let endpoint = Endpoint::builder(iroh::endpoint::presets::N0)
.alpns(vec![CAP_AUTOMERGE_ALPN.to_vec()])
.bind_addr(bind_addr)
.context("Invalid bind address")?
.transport_config(create_tactical_transport_config())
.bind()
.await
.context("Failed to create Iroh endpoint with bind address")?;
Ok(Self {
endpoint,
connections: Arc::new(RwLock::new(HashMap::new())),
connection_timestamps: Arc::new(RwLock::new(HashMap::new())),
accept_running: Arc::new(AtomicBool::new(false)),
accept_task: Arc::new(RwLock::new(None)),
mdns_discovery: Arc::new(RwLock::new(None)),
event_senders: Arc::new(RwLock::new(Vec::new())),
runtime_handle: tokio::runtime::Handle::current(),
})
}
pub fn endpoint_id(&self) -> EndpointId {
self.endpoint.id()
}
pub fn endpoint_addr(&self) -> EndpointAddr {
self.endpoint.addr()
}
pub fn endpoint(&self) -> &Endpoint {
&self.endpoint
}
pub async fn connect(&self, addr: EndpointAddr) -> Result<Option<Connection>> {
let remote_id = addr.id;
let our_id = self.endpoint_id();
{
let connections = self.connections.read().expect("connections lock poisoned");
if let Some(existing) = connections.get(&remote_id) {
if existing.close_reason().is_none() {
tracing::debug!(
"Already have live connection to {:?}, accept path handling",
remote_id
);
return Ok(None);
}
}
}
tracing::debug!(
our_id = %our_id,
remote_id = %remote_id,
"Connecting to peer (conflict resolution on detection)"
);
let conn = self
.endpoint
.connect(addr, CAP_AUTOMERGE_ALPN)
.await
.context("Failed to connect to peer")?;
let mut connections = self.connections.write().expect("connections lock poisoned");
if let Some(existing) = connections.get(&remote_id) {
if existing.close_reason().is_none() {
let we_are_lower = our_id.as_bytes() < remote_id.as_bytes();
if we_are_lower {
tracing::info!(
remote_id = %remote_id,
our_id = %our_id,
"Conflict resolved in connect(): we have lower ID, closing their incoming connection"
);
if let Some(old) = connections.remove(&remote_id) {
old.close(100u32.into(), b"conflict_connect_lower_wins");
}
} else {
tracing::info!(
remote_id = %remote_id,
our_id = %our_id,
"Conflict resolved in connect(): they have lower ID, closing our outgoing connection"
);
conn.close(101u32.into(), b"conflict_connect_yield");
return Ok(None);
}
}
}
connections.insert(remote_id, conn.clone());
drop(connections);
self.connection_timestamps
.write()
.expect("connection_timestamps lock poisoned")
.insert(remote_id, std::time::Instant::now());
self.spawn_connection_monitor(remote_id, conn.clone());
Ok(Some(conn))
}
pub fn emit_peer_connected(&self, endpoint_id: EndpointId) {
self.emit_event(TransportPeerEvent::Connected {
endpoint_id,
connected_at: std::time::Instant::now(),
});
}
pub async fn connect_peer(&self, peer: &PeerInfo) -> Result<Option<Connection>> {
let endpoint_id = peer.endpoint_id()?;
let socket_addrs = peer.socket_addrs()?;
let mut addr = EndpointAddr::new(endpoint_id);
for socket_addr in socket_addrs {
addr = addr.with_ip_addr(socket_addr);
}
self.connect(addr).await
}
pub async fn connect_by_id(&self, endpoint_id: EndpointId) -> Result<Option<Connection>> {
let addr = EndpointAddr::new(endpoint_id);
tracing::debug!(
peer_id = %endpoint_id,
"Connecting to peer by ID (using discovery-resolved addresses)"
);
self.connect(addr).await
}
pub fn has_discovery(&self) -> bool {
self.mdns_discovery
.read()
.expect("mdns_discovery lock poisoned")
.is_some()
}
pub fn mdns_discovery(&self) -> Option<MdnsAddressLookup> {
self.mdns_discovery
.read()
.expect("mdns_discovery lock poisoned")
.clone()
}
pub async fn accept(&self) -> Result<Option<Connection>> {
let incoming = self
.endpoint
.accept()
.await
.context("Endpoint closed - no more incoming connections")?;
let conn = match incoming.await {
Ok(conn) => conn,
Err(e) => {
tracing::warn!(
error = %e,
"Incoming connection failed during QUIC handshake (transient, continuing)"
);
return Ok(None);
}
};
let remote_id = conn.remote_id();
let our_id = self.endpoint_id();
let mut connections = self.connections.write().expect("connections lock poisoned");
if let Some(existing) = connections.get(&remote_id) {
if existing.close_reason().is_none() {
let they_are_lower = remote_id.as_bytes() < our_id.as_bytes();
if they_are_lower {
tracing::info!(
our_id = %our_id,
remote_id = %remote_id,
"Conflict resolved in accept(): they have lower ID, closing our outgoing connection"
);
if let Some(old) = connections.remove(&remote_id) {
old.close(102u32.into(), b"conflict_accept_closing_outgoing");
}
} else {
tracing::info!(
our_id = %our_id,
remote_id = %remote_id,
"Conflict resolved in accept(): we have lower ID, rejecting incoming connection"
);
conn.close(103u32.into(), b"conflict_accept_reject_incoming");
drop(connections);
return Ok(None);
}
}
}
connections.insert(remote_id, conn.clone());
drop(connections);
self.connection_timestamps
.write()
.expect("connection_timestamps lock poisoned")
.insert(remote_id, std::time::Instant::now());
self.spawn_connection_monitor(remote_id, conn.clone());
Ok(Some(conn))
}
pub fn get_connection(&self, endpoint_id: &EndpointId) -> Option<Connection> {
self.connections
.read()
.expect("connections lock poisoned")
.get(endpoint_id)
.cloned()
}
pub fn disconnect(&self, endpoint_id: &EndpointId) -> Result<()> {
self.connection_timestamps
.write()
.expect("connection_timestamps lock poisoned")
.remove(endpoint_id);
if let Some(conn) = self
.connections
.write()
.expect("connections lock poisoned")
.remove(endpoint_id)
{
conn.close(0u32.into(), b"disconnecting");
self.emit_event(TransportPeerEvent::Disconnected {
endpoint_id: *endpoint_id,
reason: "local disconnect".to_string(),
});
}
Ok(())
}
pub fn connections_older_than(&self, max_age: Duration) -> Vec<EndpointId> {
let now = std::time::Instant::now();
let timestamps = self
.connection_timestamps
.read()
.expect("connection_timestamps lock poisoned");
timestamps
.iter()
.filter_map(|(id, &connected_at)| {
if now.duration_since(connected_at) > max_age {
Some(*id)
} else {
None
}
})
.collect()
}
pub fn recycle_old_connections(&self, max_age: Duration) -> usize {
let old_connections = self.connections_older_than(max_age);
let count = old_connections.len();
for endpoint_id in old_connections {
tracing::info!(
peer_id = %endpoint_id,
"Recycling connection to mitigate memory leak (Issue #435)"
);
let _ = self.disconnect(&endpoint_id);
}
if count > 0 {
tracing::info!(
count = count,
max_age_secs = max_age.as_secs(),
"Recycled old connections (Issue #435 memory leak workaround)"
);
}
count
}
pub fn subscribe_peer_events(&self) -> TransportEventReceiver {
let (tx, rx) = mpsc::channel(TRANSPORT_EVENT_CHANNEL_CAPACITY);
self.event_senders
.write()
.expect("event_senders lock poisoned")
.push(tx);
rx
}
fn emit_event(&self, event: TransportPeerEvent) {
let senders = self
.event_senders
.read()
.expect("event_senders lock poisoned");
for sender in senders.iter() {
let _ = sender.try_send(event.clone());
}
}
fn spawn_connection_monitor(&self, endpoint_id: EndpointId, conn: Connection) {
let connections = Arc::clone(&self.connections);
let event_senders = Arc::clone(&self.event_senders);
let monitored_stable_id = conn.stable_id();
self.runtime_handle.spawn(async move {
let close_reason = conn.closed().await;
tracing::info!(
?endpoint_id,
?close_reason,
"Connection closed, emitting disconnect event"
);
let should_emit_disconnect;
{
let mut conns = connections.write().expect("connections lock poisoned");
if let Some(current_conn) = conns.get(&endpoint_id) {
if current_conn.stable_id() == monitored_stable_id {
conns.remove(&endpoint_id);
should_emit_disconnect = true;
} else {
tracing::debug!(
?endpoint_id,
monitored_id = monitored_stable_id,
current_id = current_conn.stable_id(),
"Connection was replaced, not removing"
);
should_emit_disconnect = false;
}
} else {
should_emit_disconnect = false;
}
}
if should_emit_disconnect {
let reason = format!("{:?}", close_reason);
let event = TransportPeerEvent::Disconnected {
endpoint_id,
reason,
};
let senders = event_senders.read().expect("event_senders lock poisoned");
for sender in senders.iter() {
let _ = sender.try_send(event.clone());
}
}
});
}
pub fn peer_count(&self) -> usize {
self.cleanup_closed_connections();
self.connections
.read()
.expect("connections lock poisoned")
.len()
}
pub fn connected_peers(&self) -> Vec<EndpointId> {
self.cleanup_closed_connections();
self.connections
.read()
.expect("connections lock poisoned")
.keys()
.copied()
.collect()
}
pub fn cleanup_closed_connections(&self) {
let closed_peers: Vec<(EndpointId, String)> = {
let mut connections = self.connections.write().expect("connections lock poisoned");
let mut closed = Vec::new();
connections.retain(|endpoint_id, conn| {
if let Some(reason) = conn.close_reason() {
let reason_str = format!("{:?}", reason);
let close_source = if reason_str.contains("100")
|| reason_str.contains("conflict_connect_lower_wins")
{
"connect() conflict resolution (we had lower ID)"
} else if reason_str.contains("101")
|| reason_str.contains("conflict_connect_yield")
{
"connect() yielding to accept path"
} else if reason_str.contains("102")
|| reason_str.contains("conflict_accept_closing_outgoing")
{
"accept() closing our outgoing connection"
} else if reason_str.contains("103")
|| reason_str.contains("conflict_accept_reject_incoming")
{
"accept() rejecting incoming connection"
} else if reason_str.contains("authentication") {
"authentication failure"
} else if reason_str.contains("TimedOut") {
"QUIC idle timeout (no keep-alives received)"
} else if reason_str.contains("LocallyClosed") {
"local close (unknown source)"
} else {
"other"
};
tracing::warn!(
endpoint_id = %endpoint_id,
reason = %reason_str,
close_source = %close_source,
"[CLEANUP] Removing closed connection"
);
closed.push((*endpoint_id, reason_str));
false
} else {
true
}
});
closed
};
if !closed_peers.is_empty() {
let mut timestamps = self
.connection_timestamps
.write()
.expect("connection_timestamps lock poisoned");
for (endpoint_id, _) in &closed_peers {
timestamps.remove(endpoint_id);
}
}
for (endpoint_id, reason) in closed_peers {
self.emit_event(TransportPeerEvent::Disconnected {
endpoint_id,
reason,
});
}
}
pub fn start_accept_loop(self: &Arc<Self>) -> Result<()> {
if self
.accept_running
.compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
.is_err()
{
anyhow::bail!("Accept loop already running");
}
let transport = Arc::clone(self);
let accept_running = Arc::clone(&self.accept_running);
let task = tokio::spawn(async move {
while accept_running.load(Ordering::Relaxed) {
match transport.accept().await {
Ok(Some(conn)) => {
tracing::debug!("Accepted connection from: {:?}", conn.remote_id());
}
Ok(None) => {
tracing::debug!("Duplicate connection closed, using existing");
}
Err(e) => {
tracing::debug!("Accept loop ended: {}", e);
break;
}
}
}
tracing::debug!("Accept loop stopped");
});
*self.accept_task.write().expect("accept_task lock poisoned") = Some(task);
Ok(())
}
pub fn stop_accept_loop(&self) -> Result<()> {
if !self.accept_running.swap(false, Ordering::SeqCst) {
anyhow::bail!("Accept loop is not running");
}
Ok(())
}
pub fn is_accept_loop_running(&self) -> bool {
self.accept_running.load(Ordering::Relaxed)
}
pub fn mark_accept_loop_managed(&self) -> Result<()> {
if self
.accept_running
.compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
.is_err()
{
anyhow::bail!("Accept loop already running");
}
Ok(())
}
pub async fn close(&self) -> Result<()> {
if self.accept_running.load(Ordering::Relaxed) {
let _ = self.stop_accept_loop();
}
for (_endpoint_id, conn) in self
.connections
.write()
.expect("connections lock poisoned")
.drain()
{
conn.close(0u32.into(), b"shutdown");
}
self.endpoint.close().await;
Ok(())
}
}
#[cfg(feature = "automerge-backend")]
#[async_trait::async_trait]
impl peat_mesh::storage::sync_transport::SyncTransport for IrohTransport {
fn get_connection(&self, peer_id: &EndpointId) -> Option<Connection> {
self.get_connection(peer_id)
}
fn connected_peers(&self) -> Vec<EndpointId> {
self.connected_peers()
}
}
#[cfg(all(test, feature = "automerge-backend"))]
impl IrohTransport {
pub(crate) async fn new_local() -> Result<Self> {
let endpoint = Endpoint::empty_builder()
.alpns(vec![CAP_AUTOMERGE_ALPN.to_vec()])
.transport_config(create_tactical_transport_config())
.bind()
.await
.context("Failed to create local-only Iroh endpoint")?;
Ok(Self {
endpoint,
connections: Arc::new(RwLock::new(HashMap::new())),
connection_timestamps: Arc::new(RwLock::new(HashMap::new())),
accept_running: Arc::new(AtomicBool::new(false)),
accept_task: Arc::new(RwLock::new(None)),
mdns_discovery: Arc::new(RwLock::new(None)),
event_senders: Arc::new(RwLock::new(Vec::new())),
runtime_handle: tokio::runtime::Handle::current(),
})
}
pub(crate) async fn from_seed_local(seed: &str) -> Result<Self> {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(b"peat-iroh-key-v1:");
hasher.update(seed.as_bytes());
let hash = hasher.finalize();
let mut seed_bytes = [0u8; 32];
seed_bytes.copy_from_slice(&hash);
let secret_key = iroh::SecretKey::from_bytes(&seed_bytes);
let endpoint = Endpoint::empty_builder()
.alpns(vec![CAP_AUTOMERGE_ALPN.to_vec()])
.secret_key(secret_key)
.transport_config(create_tactical_transport_config())
.bind()
.await
.context("Failed to create local-only Iroh endpoint from seed")?;
Ok(Self {
endpoint,
connections: Arc::new(RwLock::new(HashMap::new())),
connection_timestamps: Arc::new(RwLock::new(HashMap::new())),
accept_running: Arc::new(AtomicBool::new(false)),
accept_task: Arc::new(RwLock::new(None)),
mdns_discovery: Arc::new(RwLock::new(None)),
event_senders: Arc::new(RwLock::new(Vec::new())),
runtime_handle: tokio::runtime::Handle::current(),
})
}
}
#[cfg(all(test, feature = "automerge-backend"))]
mod tests {
use super::*;
use serial_test::serial;
#[tokio::test]
async fn test_transport_creation() {
let transport = IrohTransport::new().await.unwrap();
let endpoint_id = transport.endpoint_id();
assert_ne!(endpoint_id.as_bytes(), &[0u8; 32]);
transport.close().await.unwrap();
}
#[tokio::test]
async fn test_transport_endpoint_addr() {
let transport = IrohTransport::new().await.unwrap();
let addr = transport.endpoint_addr();
assert_eq!(addr.id, transport.endpoint_id());
transport.close().await.unwrap();
}
#[tokio::test]
async fn test_peer_count_initially_zero() {
let transport = IrohTransport::new().await.unwrap();
assert_eq!(transport.peer_count(), 0);
transport.close().await.unwrap();
}
#[tokio::test]
async fn test_connected_peers_initially_empty() {
let transport = IrohTransport::new().await.unwrap();
assert!(transport.connected_peers().is_empty());
transport.close().await.unwrap();
}
#[tokio::test]
async fn test_transport_with_discovery() {
let transport = IrohTransport::with_discovery("test-node").await.unwrap();
let endpoint_id = transport.endpoint_id();
assert_ne!(endpoint_id.as_bytes(), &[0u8; 32]);
assert!(transport.has_discovery());
assert_eq!(transport.peer_count(), 0);
assert!(transport.connected_peers().is_empty());
transport.close().await.unwrap();
}
#[tokio::test]
async fn test_transport_without_discovery() {
let transport = IrohTransport::new().await.unwrap();
assert!(!transport.has_discovery());
transport.close().await.unwrap();
}
#[tokio::test]
async fn test_from_seed_deterministic() {
let seed = "test-formation/node-1";
let transport1 = IrohTransport::from_seed(seed).await.unwrap();
let id1 = transport1.endpoint_id();
transport1.close().await.unwrap();
let transport2 = IrohTransport::from_seed(seed).await.unwrap();
let id2 = transport2.endpoint_id();
transport2.close().await.unwrap();
assert_eq!(id1, id2, "Same seed should produce same EndpointId");
}
#[tokio::test]
async fn test_from_seed_different_seeds() {
let transport1 = IrohTransport::from_seed("formation/node-1").await.unwrap();
let id1 = transport1.endpoint_id();
let transport2 = IrohTransport::from_seed("formation/node-2").await.unwrap();
let id2 = transport2.endpoint_id();
assert_ne!(
id1, id2,
"Different seeds should produce different EndpointIds"
);
transport1.close().await.unwrap();
transport2.close().await.unwrap();
}
#[test]
fn test_endpoint_id_from_seed() {
let seed = "alpha-formation/node-1";
let id1 = IrohTransport::endpoint_id_from_seed(seed);
let id2 = IrohTransport::endpoint_id_from_seed(seed);
assert_eq!(id1, id2);
let id3 = IrohTransport::endpoint_id_from_seed("alpha-formation/node-2");
assert_ne!(id1, id3);
}
#[tokio::test]
async fn test_from_seed_matches_static_computation() {
let seed = "containerlab/mesh-node-1";
let computed_id = IrohTransport::endpoint_id_from_seed(seed);
let transport = IrohTransport::from_seed(seed).await.unwrap();
let transport_id = transport.endpoint_id();
assert_eq!(
computed_id, transport_id,
"Static and dynamic computation should match"
);
transport.close().await.unwrap();
}
#[tokio::test]
async fn test_from_seed_with_discovery() {
let seed = "test-formation/discovery-node";
let transport = IrohTransport::from_seed_with_discovery(seed).await.unwrap();
assert!(transport.has_discovery());
let expected_id = IrohTransport::endpoint_id_from_seed(seed);
assert_eq!(transport.endpoint_id(), expected_id);
transport.close().await.unwrap();
}
#[tokio::test]
#[serial]
async fn test_stale_peer_cleanup_issue_244() {
use std::sync::Arc;
let transport_a = Arc::new(IrohTransport::from_seed_local("test/node-a").await.unwrap());
let transport_b = Arc::new(IrohTransport::from_seed_local("test/node-b").await.unwrap());
let acceptor_addr = transport_b.endpoint_addr();
assert_eq!(transport_a.peer_count(), 0);
assert_eq!(transport_b.peer_count(), 0);
transport_b.start_accept_loop().unwrap();
let _conn = transport_a.connect(acceptor_addr).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
assert_eq!(transport_a.peer_count(), 1);
let _ = transport_b.stop_accept_loop();
for (_id, conn) in transport_b.connections.write().unwrap().drain() {
conn.close(0u32.into(), b"test_close");
}
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
assert_eq!(
transport_a.peer_count(),
0,
"Closed connections should be removed from the map"
);
assert!(
transport_a.connected_peers().is_empty(),
"connected_peers() should not include closed connections"
);
let _ = transport_a.close().await;
let _ = transport_b.close().await;
}
#[tokio::test]
async fn test_peer_event_subscription() {
let transport = IrohTransport::new().await.unwrap();
let mut rx = transport.subscribe_peer_events();
let result = tokio::time::timeout(std::time::Duration::from_millis(50), rx.recv()).await;
assert!(result.is_err(), "Should timeout when no events");
transport.close().await.unwrap();
}
#[tokio::test]
#[serial]
async fn test_peer_event_on_connect() {
use std::sync::Arc;
let transport = Arc::new(
IrohTransport::from_seed_local("test-event/node-a")
.await
.unwrap(),
);
let transport2 = Arc::new(
IrohTransport::from_seed_local("test-event/node-b")
.await
.unwrap(),
);
let transport2_id = transport2.endpoint_id();
let transport2_addr = transport2.endpoint_addr();
let mut rx = transport.subscribe_peer_events();
transport2.start_accept_loop().unwrap();
let conn = transport.connect(transport2_addr).await.unwrap();
assert!(conn.is_some(), "Should get connection in asymmetric case");
transport.emit_peer_connected(transport2_id);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let event = tokio::time::timeout(std::time::Duration::from_millis(500), rx.recv()).await;
assert!(event.is_ok(), "Should receive connect event");
if let Ok(Some(TransportPeerEvent::Connected { endpoint_id, .. })) = event {
assert_eq!(
endpoint_id, transport2_id,
"Event should be for connected peer"
);
} else {
panic!("Expected Connected event");
}
let _ = transport.close().await;
let _ = transport2.close().await;
}
#[tokio::test]
async fn test_multiple_event_subscribers() {
let transport = IrohTransport::new().await.unwrap();
let mut rx1 = transport.subscribe_peer_events();
let mut rx2 = transport.subscribe_peer_events();
let result1 = tokio::time::timeout(std::time::Duration::from_millis(50), rx1.recv()).await;
let result2 = tokio::time::timeout(std::time::Duration::from_millis(50), rx2.recv()).await;
assert!(result1.is_err(), "Subscriber 1 should timeout");
assert!(result2.is_err(), "Subscriber 2 should timeout");
transport.close().await.unwrap();
}
#[test]
fn test_tactical_transport_config() {
let _config = create_tactical_transport_config();
}
#[tokio::test]
#[serial]
async fn test_fast_disconnect_detection_issue_315() {
use std::sync::Arc;
let transport_a = Arc::new(
IrohTransport::from_seed_local("test-315/node-a")
.await
.unwrap(),
);
let transport_b = Arc::new(
IrohTransport::from_seed_local("test-315/node-b")
.await
.unwrap(),
);
let transport_b_id = transport_b.endpoint_id();
let acceptor_addr = transport_b.endpoint_addr();
let mut events = transport_a.subscribe_peer_events();
transport_b.start_accept_loop().unwrap();
let conn = transport_a.connect(acceptor_addr).await.unwrap();
assert!(conn.is_some(), "Should get connection in asymmetric case");
transport_a.emit_peer_connected(transport_b_id);
let event = tokio::time::timeout(std::time::Duration::from_secs(1), events.recv()).await;
assert!(event.is_ok(), "Should receive connect event");
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
assert_eq!(
transport_a.peer_count(),
1,
"Should have 1 peer before disconnect"
);
let _ = transport_b.stop_accept_loop();
for (_id, conn) in transport_b.connections.write().unwrap().drain() {
conn.close(0u32.into(), b"crash");
}
drop(transport_b);
let start = std::time::Instant::now();
let disconnect_timeout = std::time::Duration::from_secs(QUIC_MAX_IDLE_TIMEOUT_SECS + 2);
let event = tokio::time::timeout(disconnect_timeout, events.recv()).await;
let elapsed = start.elapsed();
assert!(
event.is_ok(),
"Should receive disconnect event within timeout"
);
if let Ok(Some(TransportPeerEvent::Disconnected { reason, .. })) = event {
tracing::info!(
elapsed_secs = elapsed.as_secs_f64(),
reason = %reason,
"Disconnect detected (Issue #315)"
);
assert!(
elapsed.as_secs() <= QUIC_MAX_IDLE_TIMEOUT_SECS + 2,
"Disconnect should be detected within {} seconds, took {:.1}s (Issue #315)",
QUIC_MAX_IDLE_TIMEOUT_SECS + 2,
elapsed.as_secs_f64()
);
}
assert_eq!(
transport_a.peer_count(),
0,
"Peer count should be 0 after disconnect"
);
let _ = transport_a.close().await;
}
}