use std::{
collections::BTreeSet,
future::poll_fn,
hash::Hash,
sync::Arc,
task::{Context, Poll, Waker, ready},
};
use iroh_base::{CustomAddr, EndpointAddr, EndpointId, RelayUrl};
use n0_future::task::JoinSet;
use serde::{Deserialize, Serialize};
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use tracing::{Span, error, trace};
pub(crate) use self::remote_state::PathStateReceiver;
use self::remote_state::RemoteStateActor;
pub(super) use self::remote_state::RemoteStateMessage;
pub use self::remote_state::{
Path, PathEvent, PathEventStream, PathList, PathListIter, PathListStream, RemoteInfo,
TransportAddrInfo, TransportAddrUsage,
};
#[cfg(feature = "unstable-custom-transports")]
pub use self::remote_state::{
PathSelection, PathSelectionContext, PathSelectionData, PathSelector,
};
#[cfg(not(feature = "unstable-custom-transports"))]
pub(crate) use self::remote_state::{
PathSelection, PathSelectionContext, PathSelectionData, PathSelector,
};
use super::{
DirectAddr, Metrics as SocketMetrics,
mapped_addrs::{
AddrMap, CustomMappedAddr, EndpointIdMappedAddr, MultipathMappedAddr, RelayMappedAddr,
},
transports,
};
use crate::{
address_lookup::{self, AddressLookupFailed},
socket::concurrent_read_map::{ConcurrentReadMap, ReadOnlyMap},
};
mod remote_state;
#[derive(Debug)]
pub(crate) struct RemoteMap {
pub(super) mapped_addrs: MappedAddrs,
senders: ConcurrentReadMap<EndpointId, mpsc::Sender<RemoteStateMessage>>,
tasks: Tasks,
}
#[derive(Clone, Debug, Default)]
pub(crate) struct MappedAddrs {
pub(super) endpoint_addrs: AddrMap<EndpointId, EndpointIdMappedAddr>,
pub(super) relay_addrs: AddrMap<(RelayUrl, EndpointId), RelayMappedAddr>,
pub(super) custom_addrs: AddrMap<CustomAddr, CustomMappedAddr>,
}
pub(super) fn to_transport_addr(
addr: impl Into<MultipathMappedAddr>,
relay_addrs: &AddrMap<(RelayUrl, EndpointId), RelayMappedAddr>,
custom_addrs: &AddrMap<CustomAddr, CustomMappedAddr>,
) -> Option<transports::Addr> {
match addr.into() {
MultipathMappedAddr::Mixed(_) => {
error!(
"Failed to convert addr to transport addr: Mixed mapped addr has no transport address"
);
None
}
MultipathMappedAddr::Relay(relay_mapped_addr) => {
match relay_addrs.lookup(&relay_mapped_addr) {
Some(parts) => Some(transports::Addr::from(parts)),
None => {
error!("Failed to convert addr to transport addr: Unknown relay mapped addr");
None
}
}
}
MultipathMappedAddr::Custom(custom_mapped_addr) => {
match custom_addrs.lookup(&custom_mapped_addr) {
Some(custom_addr) => Some(transports::Addr::Custom(custom_addr)),
None => {
error!("Failed to convert addr to transport addr: Unknown custom mapped addr");
None
}
}
}
MultipathMappedAddr::Ip(addr) => Some(transports::Addr::from(addr)),
}
}
#[derive(Debug)]
struct Tasks {
metrics: Arc<SocketMetrics>,
local_direct_addrs: n0_watcher::Direct<BTreeSet<DirectAddr>>,
address_lookup: address_lookup::AddressLookupServices,
shutdown_token: CancellationToken,
tasks: JoinSet<(EndpointId, Vec<RemoteStateMessage>)>,
poll_cleanup_waker: Option<Waker>,
path_selector: Arc<dyn PathSelector>,
span: Span,
}
impl RemoteMap {
pub(super) fn new(
metrics: Arc<SocketMetrics>,
local_direct_addrs: n0_watcher::Direct<BTreeSet<DirectAddr>>,
address_lookup: address_lookup::AddressLookupServices,
shutdown_token: CancellationToken,
path_selector: Arc<dyn PathSelector>,
span: Span,
) -> Self {
Self {
mapped_addrs: Default::default(),
senders: Default::default(),
tasks: Tasks {
metrics,
local_direct_addrs,
address_lookup,
shutdown_token,
tasks: Default::default(),
poll_cleanup_waker: None,
path_selector,
span,
},
}
}
pub(super) async fn cleanup(&mut self) -> EndpointId {
loop {
let (remote_id, leftover_messages) = poll_fn(|cx| self.poll_join_next(cx)).await;
if self.remove_or_restart_actor(remote_id, leftover_messages) {
return remote_id;
}
}
}
fn poll_join_next(
&mut self,
cx: &mut Context<'_>,
) -> Poll<(EndpointId, Vec<RemoteStateMessage>)> {
while let Some(result) = ready!(self.tasks.tasks.poll_join_next(cx)) {
match result {
Ok((remote_id, leftover_msgs)) => {
return Poll::Ready((remote_id, leftover_msgs));
}
Err(err) => {
if let Ok(panic) = err.try_into_panic() {
error!("RemoteStateActor panicked.");
std::panic::resume_unwind(panic);
}
}
}
}
self.tasks.poll_cleanup_waker.replace(cx.waker().clone());
Poll::Pending
}
fn remove_or_restart_actor(
&mut self,
remote_id: iroh_base::PublicKey,
leftover_msgs: Vec<RemoteStateMessage>,
) -> bool {
if leftover_msgs.is_empty() {
self.senders.remove(&remote_id);
trace!(%remote_id, "cleaned up RemoteStateActor");
true
} else {
trace!(%remote_id, "restarting terminated RemoteStateActor: messages received during shutdown");
let sender =
self.tasks
.start_remote_state_actor(remote_id, leftover_msgs, &self.mapped_addrs);
self.senders.insert(remote_id, sender);
false
}
}
pub(super) fn on_network_change(&mut self, is_major: bool) {
let read = self.senders.read_only();
let guard = read.guard();
for sender in read.values(&guard) {
sender
.try_send(RemoteStateMessage::NetworkChange { is_major })
.ok();
}
}
pub(super) async fn resolve_remote(
&mut self,
addr: EndpointAddr,
tx: oneshot::Sender<Result<(), AddressLookupFailed>>,
) {
let EndpointAddr { id, addrs } = addr;
self.send_to_actor(id, RemoteStateMessage::ResolveRemote(addrs, tx))
.await
}
pub(super) async fn add_connection(
&mut self,
remote: EndpointId,
conn: noq::Connection,
tx: oneshot::Sender<PathStateReceiver>,
) {
self.send_to_actor(remote, RemoteStateMessage::AddConnection(conn, tx))
.await
}
async fn send_to_actor(&mut self, remote_id: EndpointId, message: RemoteStateMessage) {
let sender = self.senders.get_or_insert_with(remote_id, || {
self.tasks
.start_remote_state_actor(remote_id, vec![], &self.mapped_addrs)
});
if let Err(mpsc::error::SendError(message)) = sender.send(message).await {
loop {
let (id, leftover_messages) = poll_fn(|cx| self.poll_join_next(cx)).await;
if id != remote_id {
self.remove_or_restart_actor(id, leftover_messages);
} else {
let mut messages = leftover_messages;
messages.push(message);
self.remove_or_restart_actor(id, messages);
break;
}
}
}
}
pub(super) fn senders(&self) -> ReadOnlyMap<EndpointId, mpsc::Sender<RemoteStateMessage>> {
self.senders.read_only()
}
}
impl Tasks {
fn start_remote_state_actor(
&mut self,
eid: EndpointId,
initial_msgs: Vec<RemoteStateMessage>,
mapped_addrs: &MappedAddrs,
) -> mpsc::Sender<RemoteStateMessage> {
mapped_addrs.endpoint_addrs.get(&eid);
let sender = RemoteStateActor::new(
eid,
self.local_direct_addrs.clone(),
mapped_addrs.relay_addrs.clone(),
mapped_addrs.custom_addrs.clone(),
self.metrics.clone(),
self.address_lookup.clone(),
self.path_selector.clone(),
)
.start(
initial_msgs,
&mut self.tasks,
self.shutdown_token.clone(),
self.span.clone(),
);
if let Some(waker) = self.poll_cleanup_waker.take() {
waker.wake();
}
sender
}
}
#[derive(Serialize, Deserialize, strum::Display, Debug, Clone, Eq, PartialEq, Hash)]
#[strum(serialize_all = "kebab-case")]
#[allow(private_interfaces)]
#[non_exhaustive]
pub(crate) enum Source {
App,
#[strum(serialize = "{name}")]
AddressLookup {
name: String,
},
Connection,
}
#[cfg(test)]
mod tests {
use std::{net::SocketAddr, time::Duration};
use iroh_base::{SecretKey, TransportAddr};
use n0_future::future::now_or_never;
use n0_tracing_test::traced_test;
use n0_watcher::Watchable;
use tokio::sync::oneshot;
use tracing::Span;
use super::*;
use crate::socket::biased_rtt_path_selector::BiasedRttPathSelector;
fn make_remote_map() -> (RemoteMap, CancellationToken, impl Sized) {
let metrics = Arc::new(SocketMetrics::default());
let watchable: Watchable<BTreeSet<DirectAddr>> = Watchable::new(BTreeSet::new());
let local_direct_addrs = watchable.watch();
let shutdown_token = CancellationToken::new();
let remote_map = RemoteMap::new(
metrics,
local_direct_addrs,
address_lookup::AddressLookupServices::default(),
shutdown_token.clone(),
Arc::new(BiasedRttPathSelector::default()),
Span::none(),
);
let guards = (watchable, shutdown_token.clone().drop_guard());
(remote_map, shutdown_token, guards)
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
#[traced_test]
async fn poll_cleanup_preserves_restarted_sender() {
let (mut remote_map, _shutdown_token, _guards) = make_remote_map();
let eid = SecretKey::from_bytes(&[0u8; 32]).public();
let addr_with_ip = |port: u16| {
EndpointAddr::from_parts(
eid,
[TransportAddr::Ip(SocketAddr::from(([127, 0, 0, 1], port)))],
)
};
let (tx1, rx1) = oneshot::channel();
remote_map.resolve_remote(addr_with_ip(1234), tx1).await;
assert!(
matches!(rx1.await, Ok(Ok(()))),
"First resolve completes Ok"
);
tokio::time::sleep(Duration::from_secs(65)).await;
tokio::time::resume();
let (tx2, rx2) = oneshot::channel();
remote_map.resolve_remote(addr_with_ip(5678), tx2).await;
now_or_never(remote_map.cleanup());
let (tx3, rx3) = oneshot::channel();
remote_map.resolve_remote(EndpointAddr::new(eid), tx3).await;
let outcome2 = rx2.await.expect("the resolve tx must be sent");
let outcome3 = rx3.await.expect("the resolve tx must be sent");
assert!(outcome2.is_ok(), "expected Ok, but got {outcome2:?}");
assert!(outcome3.is_ok(), "expected Ok, but got {outcome3:?}");
}
}