use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::sync::broadcast::error::RecvError;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, trace, warn};
use crate::dht::AddressType;
use crate::dht_network_manager::{DhtNetworkEvent, DhtNetworkManager};
use crate::reachability::session::{RelayAcquisitionOutcome, run_relay_acquisition};
use crate::transport_handle::TransportHandle;
use crate::{MultiAddr, PeerId};
const HEALTH_POLL_INTERVAL: Duration = Duration::from_secs(5);
const BACKOFF_INITIAL: Duration = Duration::from_secs(30);
const BACKOFF_MAX: Duration = Duration::from_secs(300);
const BACKOFF_FACTOR: u32 = 2;
pub(crate) fn spawn_acquisition_driver(
dht: Arc<DhtNetworkManager>,
transport: Arc<TransportHandle>,
relayer_peer_id: Arc<RwLock<Option<PeerId>>>,
relay_address: Arc<RwLock<Option<SocketAddr>>>,
shutdown: CancellationToken,
) {
tokio::spawn(async move {
let mut driver = AcquisitionDriver {
dht,
transport,
relayer_peer_id,
relay_address,
shutdown,
current_backoff: BACKOFF_INITIAL,
};
driver.run().await;
});
}
struct AcquisitionDriver {
dht: Arc<DhtNetworkManager>,
transport: Arc<TransportHandle>,
relayer_peer_id: Arc<RwLock<Option<PeerId>>>,
relay_address: Arc<RwLock<Option<SocketAddr>>>,
shutdown: CancellationToken,
current_backoff: Duration,
}
impl AcquisitionDriver {
async fn run(&mut self) {
info!("relay acquisition driver starting");
loop {
if self.shutdown.is_cancelled() {
debug!("relay acquisition driver: shutdown, exiting");
return;
}
let outcome = run_relay_acquisition(self.dht.as_ref(), &self.transport).await;
match outcome {
RelayAcquisitionOutcome::Acquired(relay) => {
self.current_backoff = BACKOFF_INITIAL;
*self.relayer_peer_id.write().await = Some(relay.relayer);
*self.relay_address.write().await = Some(relay.allocated_public_addr);
self.transport
.set_relay_address(relay.allocated_public_addr);
self.publish_typed_set(Some(relay.allocated_public_addr))
.await;
info!(
relayer = ?relay.relayer,
allocated = %relay.allocated_public_addr,
"driver: relay acquired and published"
);
if self.hold_until_lost().await {
return;
}
self.lose_relay_and_republish().await;
}
RelayAcquisitionOutcome::Failed(reason) => {
warn!(reason, "driver: acquisition failed, entering backoff");
*self.relayer_peer_id.write().await = None;
*self.relay_address.write().await = None;
self.transport.clear_relay_address();
self.publish_typed_set(None).await;
if self.wait_backoff_or_event().await {
return; }
self.advance_backoff();
}
}
}
}
async fn publish_typed_set(&self, relay: Option<SocketAddr>) {
let listen = self.transport.listen_addrs().await;
let observed = self.transport.direct_external_addresses();
let direct_verified = self.transport.direct_reachability_observed();
let direct_tag: Option<AddressType> = if direct_verified {
Some(AddressType::Direct)
} else if relay.is_none() {
Some(AddressType::Unverified)
} else {
None
};
let mut typed: Vec<(MultiAddr, AddressType)> = Vec::new();
let mut seen: HashSet<SocketAddr> = HashSet::new();
if let Some(tag) = direct_tag {
if !observed.is_empty() {
for sa in observed {
if sa.ip().is_unspecified() {
continue;
}
let normalized = saorsa_transport::shared::normalize_socket_addr(sa);
if seen.insert(normalized) {
typed.push((MultiAddr::quic(normalized), tag));
}
}
}
for addr in listen {
let Some(sa) = addr.dialable_socket_addr() else {
continue;
};
if sa.ip().is_unspecified() {
continue;
}
let normalized = saorsa_transport::shared::normalize_socket_addr(sa);
if seen.insert(normalized) {
typed.push((MultiAddr::quic(normalized), tag));
}
}
}
if let Some(relay_addr) = relay {
let normalized = saorsa_transport::shared::normalize_socket_addr(relay_addr);
typed.push((MultiAddr::quic(normalized), AddressType::Relay));
}
if typed.is_empty() {
debug!("driver: publish skipped, no dialable self addresses");
return;
}
let own_key = *self.dht.peer_id().to_bytes();
let all_peers = self
.dht
.find_closest_nodes_local(&own_key, self.dht.k_value())
.await;
trace!(
peers = all_peers.len(),
addrs = typed.len(),
relay = ?relay,
"driver: publishing typed address set to all routing table peers"
);
self.dht
.publish_address_set_to_peers(typed, &all_peers)
.await;
}
async fn hold_until_lost(&self) -> bool {
let mut events = self.dht.subscribe_events();
let mut health = tokio::time::interval(HEALTH_POLL_INTERVAL);
health.tick().await;
loop {
tokio::select! {
biased;
_ = self.shutdown.cancelled() => {
return true;
}
lost = self.transport.recv_relay_lost() => {
match lost {
Some(addr) => {
info!(
relay = %addr,
"driver: RelayLost event received, rebinding"
);
return false;
}
None => {
return true;
}
}
}
event = events.recv() => {
match event {
Ok(DhtNetworkEvent::KClosestPeersChanged { ref new, .. }) => {
if self.relayer_evicted_from_k_closest(new).await {
info!("driver: relayer evicted from K-closest, rebinding");
return false;
}
}
Ok(_) => continue,
Err(RecvError::Closed) => return true,
Err(_) => continue,
}
}
_ = health.tick() => {
if !self.transport.is_relay_healthy() {
info!("driver: relay tunnel unhealthy, rebinding");
return false;
}
}
}
}
}
async fn relayer_evicted_from_k_closest(&self, new_k_closest: &[PeerId]) -> bool {
let guard = self.relayer_peer_id.read().await;
let Some(relayer) = guard.as_ref() else {
return false;
};
!new_k_closest.contains(relayer)
}
async fn lose_relay_and_republish(&self) {
*self.relayer_peer_id.write().await = None;
*self.relay_address.write().await = None;
self.transport.clear_relay_address();
self.publish_typed_set(None).await;
}
async fn wait_backoff_or_event(&self) -> bool {
let mut events = self.dht.subscribe_events();
let sleep = tokio::time::sleep(self.current_backoff);
tokio::pin!(sleep);
loop {
tokio::select! {
biased;
_ = self.shutdown.cancelled() => return true,
_ = &mut sleep => {
trace!(window = ?self.current_backoff, "driver: backoff window expired");
return false;
}
event = events.recv() => {
match event {
Ok(DhtNetworkEvent::KClosestPeersChanged { .. }) => {
debug!("driver: K-closest changed, retrying early");
return false;
}
Ok(_) => continue,
Err(RecvError::Closed) => return true,
Err(_) => continue,
}
}
}
}
}
fn advance_backoff(&mut self) {
let next = self.current_backoff.saturating_mul(BACKOFF_FACTOR);
self.current_backoff = next.min(BACKOFF_MAX);
}
}