use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::sync::broadcast::error::RecvError;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, trace, warn};
use crate::dht::AddressType;
use crate::dht_network_manager::{DhtNetworkEvent, DhtNetworkManager};
use crate::reachability::session::{RelayAcquisitionOutcome, run_relay_acquisition};
use crate::self_address::build_typed_self_address_set;
use crate::transport_handle::TransportHandle;
use crate::{MultiAddr, PeerId};
const HEALTH_POLL_INTERVAL: Duration = Duration::from_secs(5);
const BACKOFF_INITIAL: Duration = Duration::from_secs(30);
const BACKOFF_MAX: Duration = Duration::from_secs(300);
const BACKOFF_FACTOR: u32 = 2;
pub(crate) fn spawn_acquisition_driver(
dht: Arc<DhtNetworkManager>,
transport: Arc<TransportHandle>,
relayer_peer_id: Arc<RwLock<Option<PeerId>>>,
relay_address: Arc<RwLock<Option<SocketAddr>>>,
shutdown: CancellationToken,
) {
tokio::spawn(async move {
let mut driver = AcquisitionDriver {
dht,
transport,
relayer_peer_id,
relay_address,
shutdown,
current_backoff: BACKOFF_INITIAL,
last_published_typed_set: None,
};
driver.run().await;
});
}
struct AcquisitionDriver {
dht: Arc<DhtNetworkManager>,
transport: Arc<TransportHandle>,
relayer_peer_id: Arc<RwLock<Option<PeerId>>>,
relay_address: Arc<RwLock<Option<SocketAddr>>>,
shutdown: CancellationToken,
current_backoff: Duration,
last_published_typed_set: Option<PublishedTypedSet>,
}
#[derive(Clone, Debug, PartialEq)]
struct PublishedTypedSet {
typed_addresses: Vec<(MultiAddr, AddressType)>,
peers: Vec<PeerId>,
}
impl AcquisitionDriver {
async fn run(&mut self) {
info!("relay acquisition driver starting");
loop {
if self.shutdown.is_cancelled() {
debug!("relay acquisition driver: shutdown, exiting");
return;
}
let outcome = run_relay_acquisition(self.dht.as_ref(), &self.transport).await;
match outcome {
RelayAcquisitionOutcome::Acquired(relay) => {
self.current_backoff = BACKOFF_INITIAL;
*self.relayer_peer_id.write().await = Some(relay.relayer);
*self.relay_address.write().await = Some(relay.allocated_public_addr);
self.transport
.set_relay_address(relay.allocated_public_addr);
self.force_publish_typed_set(Some(relay.allocated_public_addr))
.await;
info!(
relayer = ?relay.relayer,
allocated = %relay.allocated_public_addr,
"driver: relay acquired and published"
);
if self.hold_until_lost().await {
return;
}
self.lose_relay_and_republish().await;
}
RelayAcquisitionOutcome::Failed(reason) => {
warn!(reason, "driver: acquisition failed, entering backoff");
*self.relayer_peer_id.write().await = None;
*self.relay_address.write().await = None;
self.transport.clear_relay_address();
self.publish_typed_set(None).await;
if self.wait_backoff_or_event().await {
return; }
self.advance_backoff();
}
}
}
}
async fn publish_typed_set(&mut self, relay: Option<SocketAddr>) {
self.publish_typed_set_with_policy(relay, false).await;
}
async fn force_publish_typed_set(&mut self, relay: Option<SocketAddr>) {
self.publish_typed_set_with_policy(relay, true).await;
}
async fn publish_typed_set_with_policy(&mut self, relay: Option<SocketAddr>, force: bool) {
let listen = self.transport.listen_addrs().await;
let observed = self.transport.non_relay_external_addresses();
debug!(
relay = ?relay,
observed = ?observed,
listen = ?listen,
"driver: preparing typed self address set"
);
let typed = build_typed_self_address_set(observed, listen, relay, |sa| {
self.transport.is_external_proven(sa)
});
if typed.is_empty() {
debug!("driver: publish skipped, no dialable self addresses");
return;
}
let own_key = *self.dht.peer_id().to_bytes();
let all_peers = self
.dht
.find_closest_nodes_local(&own_key, self.dht.k_value())
.await;
let peers = all_peers.iter().map(|node| node.peer_id).collect();
let publish_snapshot = PublishedTypedSet {
typed_addresses: typed.clone(),
peers,
};
if !force && self.last_published_typed_set.as_ref() == Some(&publish_snapshot) {
debug!(
peers = all_peers.len(),
typed_addresses = ?typed,
relay = ?relay,
"driver: publish skipped, typed self address set unchanged"
);
return;
}
debug!(
peers = all_peers.len(),
typed_addresses = ?typed,
relay = ?relay,
"driver: publishing typed self address set"
);
trace!(
peers = all_peers.len(),
addrs = typed.len(),
relay = ?relay,
"driver: publishing typed address set to all routing table peers"
);
self.dht
.publish_address_set_to_peers(typed, &all_peers)
.await;
self.last_published_typed_set = Some(publish_snapshot);
}
async fn hold_until_lost(&mut self) -> bool {
let mut events = self.dht.subscribe_events();
let mut health = tokio::time::interval(HEALTH_POLL_INTERVAL);
health.tick().await;
loop {
tokio::select! {
biased;
_ = self.shutdown.cancelled() => {
return true;
}
lost = self.transport.recv_relay_lost() => {
match lost {
Some(addr) => {
info!(
relay = %addr,
"driver: RelayLost event received, rebinding"
);
return false;
}
None => {
return true;
}
}
}
promoted = self.transport.recv_direct_address_promoted() => {
match promoted {
Some(addr) => {
let relay = *self.relay_address.read().await;
info!(
address = %addr,
relay = ?relay,
"driver: direct address promoted, republishing typed self address set"
);
self.publish_typed_set(relay).await;
}
None => {
return true;
}
}
}
updated = self.transport.recv_self_address_updated() => {
match updated {
Some(addr) => {
let relay = *self.relay_address.read().await;
debug!(
address = %addr,
relay = ?relay,
"driver: self address updated, refreshing typed self address set"
);
self.publish_typed_set(relay).await;
}
None => {
return true;
}
}
}
event = events.recv() => {
match event {
Ok(DhtNetworkEvent::KClosestPeersChanged { ref new, .. }) => {
if self.relayer_evicted_from_k_closest(new).await {
info!("driver: relayer evicted from K-closest, rebinding");
return false;
}
}
Ok(_) => continue,
Err(RecvError::Closed) => return true,
Err(_) => continue,
}
}
_ = health.tick() => {
if !self.transport.is_relay_healthy() {
info!("driver: relay tunnel unhealthy, rebinding");
return false;
}
}
}
}
}
async fn relayer_evicted_from_k_closest(&self, new_k_closest: &[PeerId]) -> bool {
let guard = self.relayer_peer_id.read().await;
let Some(relayer) = guard.as_ref() else {
return false;
};
!new_k_closest.contains(relayer)
}
async fn lose_relay_and_republish(&mut self) {
*self.relayer_peer_id.write().await = None;
*self.relay_address.write().await = None;
self.transport.clear_relay_address();
self.force_publish_typed_set(None).await;
}
async fn wait_backoff_or_event(&mut self) -> bool {
let mut events = self.dht.subscribe_events();
let sleep = tokio::time::sleep(self.current_backoff);
tokio::pin!(sleep);
loop {
tokio::select! {
biased;
_ = self.shutdown.cancelled() => return true,
_ = &mut sleep => {
trace!(window = ?self.current_backoff, "driver: backoff window expired");
return false;
}
promoted = self.transport.recv_direct_address_promoted() => {
match promoted {
Some(addr) => {
info!(
address = %addr,
"driver: direct address promoted during relay backoff, republishing typed self address set"
);
self.publish_typed_set(None).await;
}
None => {
return true;
}
}
}
updated = self.transport.recv_self_address_updated() => {
match updated {
Some(addr) => {
debug!(
address = %addr,
"driver: self address updated during relay backoff, refreshing typed self address set"
);
self.publish_typed_set(None).await;
}
None => {
return true;
}
}
}
event = events.recv() => {
match event {
Ok(DhtNetworkEvent::KClosestPeersChanged { .. }) => {
debug!("driver: K-closest changed, retrying early");
return false;
}
Ok(_) => continue,
Err(RecvError::Closed) => return true,
Err(_) => continue,
}
}
}
}
}
fn advance_backoff(&mut self) {
let next = self.current_backoff.saturating_mul(BACKOFF_FACTOR);
self.current_backoff = next.min(BACKOFF_MAX);
}
}