use std::{
collections::BTreeSet,
hash::Hash,
net::{IpAddr, SocketAddr},
sync::Arc,
task::{Context, Poll, Waker, ready},
};
use iroh_base::{CustomAddr, EndpointAddr, EndpointId, RelayUrl};
use n0_future::task::JoinSet;
use serde::{Deserialize, Serialize};
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use tracing::{Span, debug, error};
pub(crate) use self::remote_state::PathWatchable;
use self::remote_state::RemoteStateActor;
pub(super) use self::remote_state::RemoteStateMessage;
pub use self::remote_state::{
PathInfo, PathInfoList, PathInfoListIter, PathWatcher, RemoteInfo, TransportAddrInfo,
TransportAddrUsage,
};
use super::{
DirectAddr, Metrics as SocketMetrics,
mapped_addrs::{
AddrMap, CustomMappedAddr, EndpointIdMappedAddr, MultipathMappedAddr, RelayMappedAddr,
},
transports,
};
use crate::{
address_lookup::{self, AddressLookupFailed},
socket::{
RemoteStateActorStoppedError,
concurrent_read_map::{ConcurrentReadMap, ReadOnlyMap},
transports::TransportBiasMap,
},
};
mod remote_state;
#[derive(Debug)]
pub(crate) struct RemoteMap {
pub(crate) mapped_addrs: MappedAddrs,
senders: ConcurrentReadMap<EndpointId, mpsc::Sender<RemoteStateMessage>>,
tasks: Tasks,
}
#[derive(Clone, Debug, Default)]
pub(crate) struct MappedAddrs {
pub(super) endpoint_addrs: AddrMap<EndpointId, EndpointIdMappedAddr>,
pub(super) relay_addrs: AddrMap<(RelayUrl, EndpointId), RelayMappedAddr>,
pub(super) custom_addrs: AddrMap<CustomAddr, CustomMappedAddr>,
}
pub(super) fn to_transport_addr(
addr: impl Into<MultipathMappedAddr>,
relay_addrs: &AddrMap<(RelayUrl, EndpointId), RelayMappedAddr>,
custom_addrs: &AddrMap<CustomAddr, CustomMappedAddr>,
) -> Option<transports::Addr> {
match addr.into() {
MultipathMappedAddr::Mixed(_) => {
error!(
"Failed to convert addr to transport addr: Mixed mapped addr has no transport address"
);
None
}
MultipathMappedAddr::Relay(relay_mapped_addr) => {
match relay_addrs.lookup(&relay_mapped_addr) {
Some(parts) => Some(transports::Addr::from(parts)),
None => {
error!("Failed to convert addr to transport addr: Unknown relay mapped addr");
None
}
}
}
MultipathMappedAddr::Custom(custom_mapped_addr) => {
match custom_addrs.lookup(&custom_mapped_addr) {
Some(custom_addr) => Some(transports::Addr::Custom(custom_addr)),
None => {
error!("Failed to convert addr to transport addr: Unknown custom mapped addr");
None
}
}
}
MultipathMappedAddr::Ip(addr) => Some(transports::Addr::from(addr)),
}
}
#[derive(Debug)]
struct Tasks {
metrics: Arc<SocketMetrics>,
local_direct_addrs: n0_watcher::Direct<BTreeSet<DirectAddr>>,
address_lookup: address_lookup::AddressLookupServices,
shutdown_token: CancellationToken,
tasks: JoinSet<(EndpointId, Vec<RemoteStateMessage>)>,
poll_cleanup_waker: Option<Waker>,
transport_bias: TransportBiasMap,
span: Span,
}
impl RemoteMap {
pub(super) fn new(
metrics: Arc<SocketMetrics>,
local_direct_addrs: n0_watcher::Direct<BTreeSet<DirectAddr>>,
address_lookup: address_lookup::AddressLookupServices,
shutdown_token: CancellationToken,
transport_bias: TransportBiasMap,
span: Span,
) -> Self {
Self {
mapped_addrs: Default::default(),
senders: Default::default(),
tasks: Tasks {
metrics,
local_direct_addrs,
address_lookup,
shutdown_token,
tasks: Default::default(),
poll_cleanup_waker: None,
transport_bias,
span,
},
}
}
pub(super) fn poll_cleanup(&mut self, cx: &mut Context<'_>) -> Poll<EndpointId> {
while let Some(result) = ready!(self.tasks.tasks.poll_join_next(cx)) {
match result {
Ok((eid, leftover_msgs)) => {
if leftover_msgs.is_empty() {
self.senders.remove(&eid);
return Poll::Ready(eid);
}
debug!(%eid, "restarting terminated remote state actor: messages received during shutdown");
let sender =
self.tasks
.start_remote_state_actor(eid, leftover_msgs, &self.mapped_addrs);
self.senders.insert(eid, sender);
}
Err(err) => {
if let Ok(panic) = err.try_into_panic() {
error!("RemoteStateActor panicked.");
std::panic::resume_unwind(panic);
}
}
}
}
self.tasks.poll_cleanup_waker.replace(cx.waker().clone());
Poll::Pending
}
pub(super) fn on_network_change(&mut self, is_major: bool) {
let read = self.senders.read_only();
let guard = read.guard();
for sender in read.values(&guard) {
sender
.try_send(RemoteStateMessage::NetworkChange { is_major })
.ok();
}
}
pub(super) async fn resolve_remote(
&mut self,
addr: EndpointAddr,
) -> Result<Result<EndpointIdMappedAddr, AddressLookupFailed>, RemoteStateActorStoppedError>
{
let EndpointAddr { id, addrs } = addr;
let actor = self.remote_state_actor(id);
let (tx, rx) = oneshot::channel();
actor
.send(RemoteStateMessage::ResolveRemote(addrs, tx))
.await?;
match rx.await {
Ok(Ok(())) => Ok(Ok(self.mapped_addrs.endpoint_addrs.get(&id))),
Ok(Err(err)) => Ok(Err(err)),
Err(_) => Err(RemoteStateActorStoppedError::new()),
}
}
pub(super) async fn remote_info(&mut self, id: EndpointId) -> Option<RemoteInfo> {
let actor = self.remote_state_actor_if_exists(id)?;
let (tx, rx) = oneshot::channel();
actor.send(RemoteStateMessage::RemoteInfo(tx)).await.ok()?;
rx.await.ok()
}
pub(super) async fn add_connection(
&mut self,
remote: EndpointId,
conn: noq::WeakConnectionHandle,
) -> Option<PathWatchable> {
let actor = self.remote_state_actor(remote);
let (tx, rx) = oneshot::channel();
actor
.send(RemoteStateMessage::AddConnection(conn, tx))
.await
.ok()?;
rx.await.ok()
}
pub(super) fn remote_state_actor(
&mut self,
eid: EndpointId,
) -> mpsc::Sender<RemoteStateMessage> {
let sender = self.senders.get_or_insert_with(eid, || {
self.tasks
.start_remote_state_actor(eid, vec![], &self.mapped_addrs)
});
if sender.is_closed() {
let sender = self
.tasks
.start_remote_state_actor(eid, vec![], &self.mapped_addrs);
self.senders.insert(eid, sender.clone());
sender
} else {
sender.clone()
}
}
pub(super) fn remote_state_actor_if_exists(
&self,
eid: EndpointId,
) -> Option<mpsc::Sender<RemoteStateMessage>> {
self.senders.get(&eid)
}
pub(super) fn senders(&self) -> ReadOnlyMap<EndpointId, mpsc::Sender<RemoteStateMessage>> {
self.senders.read_only()
}
}
impl Tasks {
fn start_remote_state_actor(
&mut self,
eid: EndpointId,
initial_msgs: Vec<RemoteStateMessage>,
mapped_addrs: &MappedAddrs,
) -> mpsc::Sender<RemoteStateMessage> {
mapped_addrs.endpoint_addrs.get(&eid);
let sender = RemoteStateActor::new(
eid,
self.local_direct_addrs.clone(),
mapped_addrs.relay_addrs.clone(),
mapped_addrs.custom_addrs.clone(),
self.metrics.clone(),
self.address_lookup.clone(),
self.transport_bias.clone(),
)
.start(
initial_msgs,
&mut self.tasks,
self.shutdown_token.clone(),
self.span.clone(),
);
if let Some(waker) = self.poll_cleanup_waker.take() {
waker.wake();
}
sender
}
}
#[derive(Serialize, Deserialize, strum::Display, Debug, Clone, Eq, PartialEq, Hash)]
#[strum(serialize_all = "kebab-case")]
#[allow(private_interfaces)]
pub enum Source {
Udp,
Relay,
App,
#[strum(serialize = "{name}")]
AddressLookup {
name: String,
},
#[strum(serialize = "{name}")]
NamedApp {
name: String,
},
#[strum(serialize = "CallMeMaybe")]
CallMeMaybe {
_0: Private,
},
#[strum(serialize = "Ping")]
Ping {
_0: Private,
},
#[strum(serialize = "Connection")]
Connection {
_0: Private,
},
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Eq, PartialEq, Hash)]
struct Private;
#[derive(Debug, derive_more::Display, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
#[display("{}", SocketAddr::from(*self))]
pub struct IpPort {
ip: IpAddr,
port: u16,
}
impl From<SocketAddr> for IpPort {
fn from(socket_addr: SocketAddr) -> Self {
Self {
ip: socket_addr.ip(),
port: socket_addr.port(),
}
}
}
impl From<IpPort> for SocketAddr {
fn from(ip_port: IpPort) -> Self {
let IpPort { ip, port } = ip_port;
(ip, port).into()
}
}
impl IpPort {
pub fn ip(&self) -> &IpAddr {
&self.ip
}
pub fn port(&self) -> u16 {
self.port
}
}