use std::{
collections::{BTreeMap, BTreeSet},
fmt::Display,
future::poll_fn,
io,
net::{IpAddr, SocketAddr},
sync::{
Arc, Mutex, RwLock,
atomic::{AtomicBool, Ordering},
},
};
use iroh_base::{EndpointAddr, EndpointId, PublicKey, RelayUrl, SecretKey, TransportAddr};
use iroh_relay::{RelayConfig, RelayMap};
use mapped_addrs::MultipathMappedAddr;
use n0_error::{bail, e, stack_error};
use n0_future::{
MaybeFuture,
task::{self, AbortOnDropHandle},
time::{self, Duration, Instant},
};
use n0_watcher::{self, Watchable, Watcher};
use netwatch::netmon;
#[cfg(not(wasm_browser))]
use netwatch::{
interfaces::{IpNet, Ipv6AddrFlags},
ip::LocalAddresses,
};
use noq::{
NetworkChangeHint, WeakConnectionHandle,
crypto::rustls::{QuicClientConfig, QuicServerConfig},
};
use rand::RngExt;
use rustc_hash::FxHashSet;
use tokio::sync::{
Mutex as AsyncMutex,
mpsc::{self},
oneshot,
};
use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
use tracing::{Instrument, Level, Span, debug, error, event, info_span, instrument, trace, warn};
use transports::{LocalAddrsWatch, Transport, TransportConfig};
use url::Url;
use self::{
remote_map::{RemoteMap, RemoteStateMessage},
transports::{RelayActorConfig, Transports},
};
#[cfg(not(wasm_browser))]
use crate::dns::DnsResolver;
#[cfg(not(wasm_browser))]
use crate::net_report::QuicConfig;
use crate::{
address_lookup::{self, AddressLookupFailed, EndpointData, UserData},
defaults::timeouts::NET_REPORT_TIMEOUT,
endpoint::{hooks::EndpointHooksList, quic::QuicTransportConfig},
metrics::EndpointMetrics,
net_report::{self, IfStateDetails, Report},
portmapper,
runtime::Runtime,
socket::{
concurrent_read_map::ReadOnlyMap,
remote_map::{MappedAddrs, PathWatchable, RemoteInfo},
transports::{HomeRelayStatus, HomeRelayWatch, HomeRelayWatcher, TransportBiasMap},
},
tls::{
self,
misc::{Blake3HmacKey, RustlsTokenKey},
},
};
mod metrics;
pub(crate) mod concurrent_read_map;
pub(crate) mod mapped_addrs;
pub(crate) mod remote_map;
pub(crate) mod transports;
use self::mapped_addrs::{EndpointIdMappedAddr, MappedAddr};
pub use self::metrics::Metrics;
pub(crate) const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
pub(crate) const PATH_MAX_IDLE_TIMEOUT: Duration = Duration::from_secs(15);
pub(crate) const RELAY_PATH_MAX_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
pub(crate) const MAX_MULTIPATH_PATHS: u32 = 12;
#[stack_error(add_meta, derive)]
#[error("endpoint state actor stopped")]
#[derive(Clone)]
pub(crate) struct RemoteStateActorStoppedError;
impl From<mpsc::error::SendError<RemoteStateMessage>> for RemoteStateActorStoppedError {
#[track_caller]
fn from(_value: mpsc::error::SendError<RemoteStateMessage>) -> Self {
Self::new()
}
}
#[derive(derive_more::Debug)]
pub(crate) struct Options {
pub(crate) transports: Vec<TransportConfig>,
pub(crate) secret_key: SecretKey,
pub(crate) address_lookup_user_data: Option<UserData>,
#[cfg(not(wasm_browser))]
pub(crate) dns_resolver: DnsResolver,
pub(crate) proxy_url: Option<Url>,
pub(crate) tls_config: rustls::ClientConfig,
pub(crate) server_config: noq_proto::ServerConfig,
pub(crate) metrics: EndpointMetrics,
pub(crate) hooks: EndpointHooksList,
pub(crate) transport_bias: TransportBiasMap,
pub(crate) portmapper_config: portmapper::PortmapperConfig,
pub(crate) static_config: StaticConfig,
pub(crate) configured_addrs: BTreeSet<SocketAddr>,
}
#[derive(Debug, derive_more::Deref)]
pub(crate) struct EndpointInner {
#[deref(forward)]
sock: Arc<Socket>,
actor_task: Mutex<Option<AbortOnDropHandle<()>>>,
actor_sender: mpsc::Sender<ActorMessage>,
endpoint: noq::Endpoint,
runtime: Arc<Runtime>,
pub(crate) static_config: StaticConfig,
}
impl Drop for EndpointInner {
fn drop(&mut self) {
if self.sock.is_closed() {
return;
}
tracing::error!(
"Endpoint dropped without calling `Endpoint::close`. Aborting ungracefully."
);
self.abort();
}
}
#[derive(derive_more::Debug)]
pub(crate) struct StaticConfig {
pub(crate) tls_config: tls::TlsConfig,
#[debug("QuicServerConifg")]
pub(crate) server_config: QuicServerConfig,
#[debug("QuicClientConfig")]
pub(crate) client_config: QuicClientConfig,
#[debug("Arc<RustlsTokenKey>")]
pub(crate) token_key: Arc<RustlsTokenKey>,
pub(crate) transport_config: QuicTransportConfig,
}
impl StaticConfig {
pub(crate) fn create_server_config(
&self,
alpn_protocols: Vec<Vec<u8>>,
) -> noq_proto::ServerConfig {
let mut quic_server_config = self.server_config.clone();
quic_server_config.set_alpn_protocols(alpn_protocols);
let mut inner =
noq::ServerConfig::new(Arc::new(quic_server_config), self.token_key.clone());
inner.transport_config(self.transport_config.to_inner_arc());
inner
}
pub(crate) fn create_client_config(
&self,
alpn_protocols: Vec<Vec<u8>>,
transport_config: Arc<noq::TransportConfig>,
) -> noq_proto::ClientConfig {
let mut quic_client_config = self.client_config.clone();
quic_client_config.set_alpn_protocols(alpn_protocols);
let mut inner = noq::ClientConfig::new(Arc::new(quic_client_config));
inner.transport_config(transport_config);
inner
}
}
#[derive(Debug)]
struct ShutdownState {
at_close_start: CancellationToken,
at_endpoint_closed: CancellationToken,
closed: AtomicBool,
}
impl Default for ShutdownState {
fn default() -> Self {
Self {
at_close_start: CancellationToken::new(),
at_endpoint_closed: CancellationToken::new(),
closed: AtomicBool::new(false),
}
}
}
impl ShutdownState {
fn is_closing(&self) -> bool {
self.at_close_start.is_cancelled()
}
fn is_closed(&self) -> bool {
self.closed.load(Ordering::Relaxed)
}
}
#[derive(Debug)]
pub(crate) struct Socket {
remote_actors: ReadOnlyMap<EndpointId, mpsc::Sender<RemoteStateMessage>>,
public_key: PublicKey,
shutdown: ShutdownState,
direct_addrs: DiscoveredDirectAddrs,
net_report: Watchable<(Option<Report>, UpdateReason)>,
ipv6_reported: Arc<AtomicBool>,
mapped_addrs: MappedAddrs,
local_addrs_watch: LocalAddrsWatch,
home_relay_watch: HomeRelayWatcher,
#[cfg(not(wasm_browser))]
ip_bind_addrs: Vec<SocketAddr>,
#[cfg(not(wasm_browser))]
dns_resolver: DnsResolver,
relay_map: RelayMap,
address_lookup: address_lookup::AddressLookupServices,
address_lookup_user_data: RwLock<Option<UserData>>,
configured_addrs: RwLock<BTreeSet<SocketAddr>>,
pub(crate) tls_config: rustls::ClientConfig,
pub(crate) metrics: EndpointMetrics,
pub(crate) hooks: EndpointHooksList,
pub(crate) span: Span,
}
impl Socket {
pub(crate) fn my_relay(&self) -> Option<RelayUrl> {
self.local_addr().into_iter().find_map(|a| {
if let transports::Addr::Relay(url, _) = a {
Some(url)
} else {
None
}
})
}
pub(crate) fn is_closed(&self) -> bool {
self.shutdown.is_closed()
}
fn is_closing(&self) -> bool {
self.shutdown.is_closing()
}
pub(crate) fn closed(&self) -> WaitForCancellationFutureOwned {
self.shutdown.at_close_start.clone().cancelled_owned()
}
pub(crate) fn local_addr(&self) -> Vec<transports::Addr> {
self.local_addrs_watch.clone().get()
}
#[cfg(not(wasm_browser))]
fn ip_bind_addrs(&self) -> &[SocketAddr] {
&self.ip_bind_addrs
}
fn ip_local_addrs(&self) -> impl Iterator<Item = SocketAddr> + use<> {
self.local_addr()
.into_iter()
.filter_map(|addr| addr.into_socket_addr())
}
pub(crate) fn try_send_remote_state_msg(
&self,
endpoint_id: EndpointId,
message: RemoteStateMessage,
) -> Result<(), RemoteStateMessage> {
let Some(sender) = self.remote_actors.get(&endpoint_id) else {
return Err(message);
};
sender.try_send(message).map_err(|err| err.into_inner())
}
pub(crate) fn ip_addrs(&self) -> n0_watcher::Direct<BTreeSet<DirectAddr>> {
self.direct_addrs.addrs.watch()
}
pub(crate) fn net_report(&self) -> impl Watcher<Value = Option<Report>> + use<> {
self.net_report.watch().map(|(r, _)| r)
}
pub(crate) fn home_relay(&self) -> impl Watcher<Value = Vec<RelayUrl>> + use<> {
self.local_addrs_watch.clone().map(|addrs| {
addrs
.into_iter()
.filter_map(|addr| {
if let transports::Addr::Relay(url, _) = addr {
Some(url)
} else {
None
}
})
.collect()
})
}
pub(crate) fn home_relay_status(
&self,
) -> impl Watcher<Value = Vec<Option<(RelayUrl, HomeRelayStatus)>>> + use<> {
self.home_relay_watch.clone()
}
fn store_direct_addresses(&self, addrs: BTreeSet<DirectAddr>) {
let updated = self.direct_addrs.update(addrs);
if updated {
self.publish_my_addr();
}
}
#[cfg(not(wasm_browser))]
pub(crate) fn dns_resolver(&self) -> &DnsResolver {
&self.dns_resolver
}
pub(crate) fn to_transport_addr(&self, addr: SocketAddr) -> transports::Addr {
remote_map::to_transport_addr(
addr,
&self.mapped_addrs.relay_addrs,
&self.mapped_addrs.custom_addrs,
)
.unwrap_or(transports::Addr::Ip(addr))
}
pub(crate) fn address_lookup(&self) -> &address_lookup::AddressLookupServices {
&self.address_lookup
}
pub(crate) fn set_user_data_for_address_lookup(&self, user_data: Option<UserData>) {
let mut guard = self
.address_lookup_user_data
.write()
.expect("lock poisened");
if *guard != user_data {
*guard = user_data;
drop(guard);
self.publish_my_addr();
}
}
fn process_datagrams(
&self,
bufs: &mut [io::IoSliceMut<'_>],
metas: &mut [noq_udp::RecvMeta],
source_addrs: &[transports::Addr],
) {
debug_assert_eq!(bufs.len(), metas.len(), "non matching bufs & metas");
debug_assert_eq!(
bufs.len(),
source_addrs.len(),
"non matching bufs & source_addrs"
);
for i in 0..metas.len() {
let noq_meta = &mut metas[i];
let source_addr = &source_addrs[i];
let datagram_count = noq_meta.len.div_ceil(noq_meta.stride);
self.metrics
.socket
.recv_datagrams
.inc_by(datagram_count as _);
if noq_meta.len > noq_meta.stride {
trace!(
src = ?source_addr,
len = noq_meta.len,
stride = %noq_meta.stride,
datagram_count = noq_meta.len.div_ceil(noq_meta.stride),
"GRO datagram received",
);
self.metrics.socket.recv_gro_datagrams.inc();
} else {
trace!(src = ?source_addr, len = noq_meta.len, "datagram received");
}
match source_addr {
transports::Addr::Ip(SocketAddr::V4(..)) => {
self.metrics.socket.recv_data_ipv4.inc_by(noq_meta.len as _);
}
transports::Addr::Ip(SocketAddr::V6(..)) => {
self.metrics.socket.recv_data_ipv6.inc_by(noq_meta.len as _);
}
transports::Addr::Relay(src_url, src_node) => {
self.metrics
.socket
.recv_data_relay
.inc_by(noq_meta.len as _);
let mapped_addr = self
.mapped_addrs
.relay_addrs
.get(&(src_url.clone(), *src_node));
noq_meta.addr = mapped_addr.private_socket_addr();
}
transports::Addr::Custom(addr) => {
self.metrics
.socket
.recv_data_custom
.inc_by(noq_meta.len as _);
let mapped_addr = self.mapped_addrs.custom_addrs.get(addr);
noq_meta.addr = mapped_addr.private_socket_addr();
}
}
}
}
fn publish_my_addr(&self) {
let relay_url = self.my_relay();
let mut addrs: Vec<_> = self
.direct_addrs
.sockaddrs()
.map(TransportAddr::Ip)
.collect();
let user_data = self
.address_lookup_user_data
.read()
.expect("lock poisened")
.clone();
if relay_url.is_none() && addrs.is_empty() && user_data.is_none() {
return;
}
if let Some(url) = relay_url {
addrs.push(TransportAddr::Relay(url));
}
let mut data = EndpointData::new(addrs);
data.set_user_data(user_data);
self.address_lookup.publish(&data);
}
}
#[derive(Debug)]
struct DirectAddrUpdateState {
want_update: Option<UpdateReason>,
sock: Arc<Socket>,
port_mapper: portmapper::Client,
net_reporter: Arc<AsyncMutex<net_report::Client>>,
relay_map: RelayMap,
run_done: mpsc::Sender<()>,
shutdown_token: CancellationToken,
}
#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
enum UpdateReason {
#[default]
None,
Periodic,
PortmapUpdated,
LinkChangeMajor,
LinkChangeMinor,
RelayMapChange,
}
impl UpdateReason {
fn is_major(self) -> bool {
matches!(self, Self::LinkChangeMajor | Self::RelayMapChange)
}
}
impl DirectAddrUpdateState {
fn new(
sock: Arc<Socket>,
port_mapper: portmapper::Client,
net_reporter: Arc<AsyncMutex<net_report::Client>>,
relay_map: RelayMap,
run_done: mpsc::Sender<()>,
shutdown_token: CancellationToken,
) -> Self {
DirectAddrUpdateState {
want_update: Default::default(),
port_mapper,
net_reporter,
sock,
relay_map,
run_done,
shutdown_token,
}
}
fn schedule_run(&mut self, why: UpdateReason, if_state: IfStateDetails) {
match self.net_reporter.clone().try_lock_owned() {
Ok(net_reporter) => {
self.run(why, if_state, net_reporter);
}
Err(_) => {
let _ = self.want_update.insert(why);
}
}
}
fn try_run(&mut self, if_state: IfStateDetails) {
match self.net_reporter.clone().try_lock_owned() {
Ok(net_reporter) => {
if let Some(why) = self.want_update.take() {
self.run(why, if_state, net_reporter);
}
}
Err(_) => {
}
}
}
fn run(
&mut self,
why: UpdateReason,
if_state: IfStateDetails,
mut net_reporter: tokio::sync::OwnedMutexGuard<net_report::Client>,
) {
debug!("starting direct addr update ({:?})", why);
if self.shutdown_token.is_cancelled() {
debug!("skipping net_report, socket is shutting down");
self.port_mapper.deactivate();
return;
}
if self.relay_map.is_empty() {
debug!("skipping net_report, empty RelayMap");
self.sock.net_report.set((None, why)).ok();
return;
}
self.port_mapper.procure_mapping();
debug!("requesting net_report report");
let sock = self.sock.clone();
let run_done = self.run_done.clone();
let token = self.shutdown_token.child_token();
let inner_token = token.child_token();
task::spawn(
async move {
let fut = token.run_until_cancelled(time::timeout(
NET_REPORT_TIMEOUT,
net_reporter.get_report(if_state, why.is_major(), inner_token),
));
match fut.await {
Some(Ok(report)) => {
sock.net_report.set((Some(report), why)).ok();
}
Some(Err(time::Elapsed { .. })) => {
warn!("net_report report timed out");
}
None => {
trace!("net_report cancelled");
}
}
debug!("direct addr update done ({:?})", why);
run_done.send(()).await.ok();
}
.instrument(tracing::Span::current()),
);
}
}
#[allow(missing_docs)]
#[stack_error(derive, add_meta)]
#[non_exhaustive]
pub enum BindError {
#[error("Failed to bind sockets")]
Sockets { source: io::Error },
#[error("Failed to create internal QUIC endpoint")]
CreateQuicEndpoint { source: io::Error },
#[error("Failed to create netmon monitor")]
CreateNetmonMonitor { source: netmon::Error },
#[error("Invalid transport configuration")]
InvalidTransportConfig,
#[error("Invalid CA root configuration")]
InvalidCaRootConfig { source: io::Error },
#[error("Failed to create an address lookup service")]
AddressLookup {
#[error(from)]
source: crate::address_lookup::AddressLookupBuilderError,
},
#[error("Missing or incompatible rustls crypto provider configured")]
InvalidCryptoProvider,
#[error("Error constructing TLS configuration")]
TlsConfigError {
#[error(from)]
source: tls::TlsConfigError,
},
}
impl EndpointInner {
pub(crate) async fn bind(opts: Options) -> Result<Self, BindError> {
let span = tracing::Span::current();
let Options {
secret_key,
transports: transport_configs,
address_lookup_user_data,
#[cfg(not(wasm_browser))]
dns_resolver,
proxy_url,
server_config,
tls_config,
metrics,
hooks,
transport_bias,
portmapper_config,
static_config,
configured_addrs,
} = opts;
let address_lookup = address_lookup::AddressLookupServices::default();
let port_mapper = portmapper::create_client(&metrics, &portmapper_config);
let relay_transport_configs: Vec<_> = transport_configs
.iter()
.filter(|t| matches!(t, TransportConfig::Relay { .. }))
.collect();
if relay_transport_configs.len() > 1 {
bail!(BindError::InvalidTransportConfig);
}
let relay_map = relay_transport_configs
.iter()
.filter_map(|t| {
#[allow(irrefutable_let_patterns)]
if let TransportConfig::Relay { relay_map, .. } = t {
Some(relay_map.clone())
} else {
None
}
})
.next()
.unwrap_or_else(RelayMap::empty);
let ipv6_reported = Arc::new(AtomicBool::new(false));
let relay_actor_config = RelayActorConfig {
my_relay: HomeRelayWatch::default(),
secret_key: secret_key.clone(),
#[cfg(not(wasm_browser))]
dns_resolver: dns_resolver.clone(),
proxy_url: proxy_url.clone(),
ipv6_reported: ipv6_reported.clone(),
tls_config: tls_config.clone(),
metrics: metrics.socket.clone(),
};
let shutdown_state = ShutdownState::default();
let shutdown_token = shutdown_state.at_endpoint_closed.child_token();
let transports = Transports::bind(
&transport_configs,
relay_actor_config,
&metrics,
shutdown_token.child_token(),
)
.map_err(|err| e!(BindError::Sockets, err))?;
if let Some(v4_port) = transports.local_addrs().into_iter().find_map(|t| {
if let transports::Addr::Ip(SocketAddr::V4(addr)) = t {
Some(addr.port())
} else {
None
}
}) {
match v4_port.try_into() {
Ok(non_zero_port) => {
port_mapper.update_local_port(non_zero_port);
}
Err(_zero_port) => debug!("Skipping port mapping with zero local port"),
}
}
let (actor_sender, actor_receiver) = mpsc::channel(256);
#[cfg(not(wasm_browser))]
let has_ipv6_transport = transports
.ip_bind_addrs()
.into_iter()
.any(|addr| addr.is_ipv6());
#[cfg(not(wasm_browser))]
let has_ip_transports = !transports.ip_bind_addrs().is_empty();
let direct_addrs = DiscoveredDirectAddrs::default();
let remote_map = {
RemoteMap::new(
metrics.socket.clone(),
direct_addrs.addrs.watch(),
address_lookup.clone(),
shutdown_token.child_token(),
transport_bias,
span.clone(),
)
};
let home_relay_watch = transports.home_relay_watch();
let sock = Arc::new(Socket {
public_key: secret_key.public(),
remote_actors: remote_map.senders(),
shutdown: shutdown_state,
ipv6_reported,
mapped_addrs: remote_map.mapped_addrs.clone(),
address_lookup,
relay_map: relay_map.clone(),
address_lookup_user_data: RwLock::new(address_lookup_user_data),
configured_addrs: RwLock::new(configured_addrs),
direct_addrs,
net_report: Watchable::new((None, UpdateReason::None)),
#[cfg(not(wasm_browser))]
dns_resolver: dns_resolver.clone(),
metrics: metrics.clone(),
local_addrs_watch: transports.local_addrs_watch(),
home_relay_watch,
#[cfg(not(wasm_browser))]
ip_bind_addrs: transports.ip_bind_addrs(),
tls_config: tls_config.clone(),
hooks,
span: span.clone(),
});
let mut endpoint_config =
noq::EndpointConfig::new(Arc::new(Blake3HmacKey::new(&mut rand::rng())));
endpoint_config.grease_quic_bit(false);
let local_addrs_watch = transports.local_addrs_watch();
let transports_network_change = transports.create_network_change_sender();
let runtime = Arc::new(Runtime::new(secret_key.public()));
let endpoint = noq::Endpoint::new_with_abstract_socket(
endpoint_config,
Some(server_config),
Box::new(Transport::new(sock.clone(), transports)),
runtime.clone(),
)
.map_err(|err| e!(BindError::CreateQuicEndpoint, err))?;
let network_monitor = netmon::Monitor::new()
.await
.map_err(|err| e!(BindError::CreateNetmonMonitor, err))?;
#[cfg(not(wasm_browser))]
let net_report_config = {
let qad_config = has_ip_transports.then(|| QuicConfig {
ep: endpoint.clone(),
client_config: tls_config.clone(),
ipv4: true,
ipv6: has_ipv6_transport,
});
net_report::Options::new(tls_config.clone()).quic_config(qad_config)
};
#[cfg(wasm_browser)]
let net_report_config = net_report::Options::default();
let net_reporter = net_report::Client::new(
#[cfg(not(wasm_browser))]
dns_resolver,
relay_map.clone(),
net_report_config,
metrics.net_report.clone(),
);
let (direct_addr_done_tx, direct_addr_done_rx) = mpsc::channel(8);
let direct_addr_update_state = DirectAddrUpdateState::new(
sock.clone(),
port_mapper,
Arc::new(AsyncMutex::new(net_reporter)),
relay_map,
direct_addr_done_tx,
sock.shutdown.at_close_start.child_token(),
);
let local_interfaces_watcher = network_monitor.interface_state();
#[cfg_attr(not(wasm_browser), allow(unused_mut))]
let mut actor = Actor {
endpoint: endpoint.clone(),
sock: sock.clone(),
remote_map,
periodic_re_stun_timer: new_re_stun_timer(false),
network_monitor,
local_interfaces_watcher,
direct_addr_update_state,
transports_network_change,
direct_addr_done_rx,
call_notify_quic_network_change: None,
};
#[cfg(not(wasm_browser))]
actor.update_direct_addresses(None);
let actor_task = task::spawn(
actor
.run(
actor_receiver,
shutdown_token.child_token(),
local_addrs_watch,
)
.instrument(info_span!(parent: span, "actor")),
);
let actor_task = Mutex::new(Some(AbortOnDropHandle::new(actor_task)));
Ok(EndpointInner {
sock,
actor_sender,
actor_task,
endpoint,
runtime,
static_config,
})
}
pub(crate) fn noq_endpoint(&self) -> &noq::Endpoint {
&self.endpoint
}
#[instrument(skip_all, parent = self.sock.span.clone())]
pub(crate) async fn close(&self) {
if self.sock.is_closed() || self.sock.is_closing() {
return;
}
trace!(me = ?self.public_key, "socket closing...");
self.sock.shutdown.at_close_start.cancel();
self.sock.address_lookup().clear();
self.noq_endpoint().close(0u16.into(), b"");
trace!("wait_idle start");
self.noq_endpoint().wait_idle().await;
trace!("wait_idle done");
self.sock.shutdown.at_endpoint_closed.cancel();
let task = self.actor_task.lock().expect("poisoned").take();
if let Some(task) = task {
let shutdown_done = time::timeout(Duration::from_millis(100), async move {
if let Err(err) = task.await {
warn!("unexpected error in task shutdown: {:?}", err);
}
})
.await;
match shutdown_done {
Ok(_) => trace!("tasks finished in time, shutdown complete"),
Err(time::Elapsed { .. }) => {
warn!("tasks didn't finish in time, aborting");
}
}
}
self.runtime.shutdown().await;
self.sock.shutdown.closed.store(true, Ordering::SeqCst);
trace!("socket closed");
}
#[instrument(skip_all)]
pub(crate) fn abort(&self) {
if self.sock.is_closed() || self.sock.is_closing() {
return;
}
trace!(me = ?self.public_key, "aborting socket...");
self.sock.shutdown.at_close_start.cancel();
self.sock.address_lookup().clear();
self.sock.shutdown.at_endpoint_closed.cancel();
self.runtime.abort();
self.actor_task.lock().expect("poisoned").take();
self.sock.shutdown.closed.store(true, Ordering::SeqCst);
trace!("socket closed");
}
pub(crate) async fn insert_relay(
&self,
relay: RelayUrl,
endpoint: Arc<RelayConfig>,
) -> Option<Arc<RelayConfig>> {
let res = self.relay_map.insert(relay, endpoint);
self.actor_sender
.send(ActorMessage::RelayMapChange)
.await
.ok();
res
}
pub(crate) async fn remove_relay(&self, relay: &RelayUrl) -> Option<Arc<RelayConfig>> {
let res = self.relay_map.remove(relay);
self.actor_sender
.send(ActorMessage::RelayMapChange)
.await
.ok();
res
}
pub(crate) async fn add_external_addr(&self, addr: SocketAddr) {
self.sock
.configured_addrs
.write()
.expect("poisoned")
.insert(addr);
self.actor_sender
.send(ActorMessage::DirectAddrRefresh)
.await
.ok();
}
pub(crate) async fn remove_external_addr(&self, addr: &SocketAddr) -> bool {
let removed = self
.sock
.configured_addrs
.write()
.expect("poisoned")
.remove(addr);
if removed {
self.actor_sender
.send(ActorMessage::DirectAddrRefresh)
.await
.ok();
}
removed
}
pub(crate) async fn network_change(&self) {
self.actor_sender
.send(ActorMessage::NetworkChange)
.await
.ok();
}
#[cfg(all(test, with_crypto_provider))]
async fn force_network_change(&self, is_major: bool) {
self.actor_sender
.send(ActorMessage::ForceNetworkChange(is_major))
.await
.ok();
}
pub(crate) async fn resolve_remote(
&self,
addr: EndpointAddr,
) -> Result<Result<EndpointIdMappedAddr, AddressLookupFailed>, RemoteStateActorStoppedError>
{
let (tx, rx) = oneshot::channel();
self.actor_sender
.send(ActorMessage::ResolveRemote(addr, tx))
.await
.ok();
rx.await.map_err(|_| RemoteStateActorStoppedError::new())?
}
pub(crate) async fn remote_info(&self, id: EndpointId) -> Option<RemoteInfo> {
let (tx, rx) = oneshot::channel();
self.actor_sender
.send(ActorMessage::RemoteInfo(id, tx))
.await
.ok()?;
rx.await.ok()
}
pub(crate) fn register_connection(
&self,
remote: EndpointId,
conn: WeakConnectionHandle,
) -> impl Future<Output = Result<PathWatchable, RemoteStateActorStoppedError>> + Send + 'static
{
let (tx, rx) = oneshot::channel();
let sender = self.actor_sender.clone();
async move {
sender
.send(ActorMessage::AddConnection(remote, conn, tx))
.await
.map_err(|_| RemoteStateActorStoppedError::new())?;
rx.await.map_err(|_| RemoteStateActorStoppedError::new())
}
}
}
#[derive(derive_more::Debug)]
#[allow(clippy::enum_variant_names)]
enum ActorMessage {
NetworkChange,
RelayMapChange,
#[debug("ResolveRemote(..)")]
ResolveRemote(
EndpointAddr,
oneshot::Sender<
Result<Result<EndpointIdMappedAddr, AddressLookupFailed>, RemoteStateActorStoppedError>,
>,
),
#[debug("AddConnection(..)")]
AddConnection(
EndpointId,
WeakConnectionHandle,
oneshot::Sender<PathWatchable>,
),
#[debug("RemoteInfo(..)")]
RemoteInfo(EndpointId, oneshot::Sender<RemoteInfo>),
DirectAddrRefresh,
#[cfg(all(test, with_crypto_provider))]
ForceNetworkChange(bool),
}
struct PendingNetworkChangeNotify {
next_check: Instant,
interval: Duration,
is_major: bool,
started: Instant,
}
impl PendingNetworkChangeNotify {
const INITIAL_INTERVAL: Duration = Duration::from_millis(100);
const MAX_INTERVAL: Duration = Duration::from_secs(1);
const MAX_WAIT: Duration = Duration::from_secs(5);
fn new(is_major: bool) -> Self {
Self {
next_check: Instant::now() + Self::INITIAL_INTERVAL,
interval: Self::INITIAL_INTERVAL,
is_major,
started: Instant::now(),
}
}
fn advance(&mut self) {
self.interval = (self.interval * 2).min(Self::MAX_INTERVAL);
self.next_check = Instant::now() + self.interval;
}
fn expired(&self) -> bool {
self.started.elapsed() >= Self::MAX_WAIT
}
}
struct Actor {
endpoint: noq::Endpoint,
sock: Arc<Socket>,
remote_map: RemoteMap,
periodic_re_stun_timer: time::Interval,
network_monitor: netmon::Monitor,
local_interfaces_watcher: n0_watcher::Direct<netmon::State>,
transports_network_change: transports::NetworkChangeSender,
direct_addr_update_state: DirectAddrUpdateState,
direct_addr_done_rx: mpsc::Receiver<()>,
call_notify_quic_network_change: Option<PendingNetworkChangeNotify>,
}
impl Actor {
async fn run(
mut self,
mut msg_receiver: mpsc::Receiver<ActorMessage>,
shutdown_token: CancellationToken,
mut local_addrs_watcher: impl Watcher<Value = Vec<transports::Addr>> + Send + Sync,
) {
let mut current_netmon_state = self.local_interfaces_watcher.get();
let mut portmap_watcher = self
.direct_addr_update_state
.port_mapper
.watch_external_address();
let mut receiver_closed = false;
let mut portmap_watcher_closed = false;
let mut net_report_watcher = self.sock.net_report.watch();
self.sock.publish_my_addr();
while !shutdown_token.is_cancelled() {
self.sock.metrics.socket.actor_tick_main.inc();
let portmap_watcher_changed = portmap_watcher.changed();
let notify_quic_network_change = match &self.call_notify_quic_network_change {
Some(pending) => {
MaybeFuture::Some(n0_future::time::sleep_until(pending.next_check))
}
None => MaybeFuture::None,
};
n0_future::pin!(notify_quic_network_change);
tokio::select! {
_ = shutdown_token.cancelled() => {
debug!("tick: shutting down");
return;
}
msg = msg_receiver.recv(), if !receiver_closed => {
let Some(msg) = msg else {
trace!("tick: socket receiver closed");
self.sock.metrics.socket.actor_tick_other.inc();
receiver_closed = true;
continue;
};
trace!(?msg, "tick: msg");
self.sock.metrics.socket.actor_tick_msg.inc();
self.handle_actor_message(msg).await;
}
tick = self.periodic_re_stun_timer.tick() => {
trace!("tick: re_stun {:?}", tick);
self.sock.metrics.socket.actor_tick_re_stun.inc();
self.re_stun(UpdateReason::Periodic);
}
new_addr = local_addrs_watcher.updated() => {
match new_addr {
Ok(addrs) => {
if !addrs.is_empty() {
trace!(?addrs, "local addrs");
self.sock.publish_my_addr();
}
}
Err(_) => {
warn!("local addr watcher stopped");
}
}
}
report = net_report_watcher.updated() => {
match report {
Ok((report, _)) => {
self.handle_net_report_report(report);
#[cfg(not(wasm_browser))]
{
self.periodic_re_stun_timer = new_re_stun_timer(true);
}
}
Err(_) => {
warn!("net report watcher stopped");
}
}
}
reason = self.direct_addr_done_rx.recv() => {
match reason {
Some(()) => {
let state = self.local_interfaces_watcher.get();
self.direct_addr_update_state.try_run(state.into());
}
None => {
warn!("direct addr watcher died");
}
}
}
change = portmap_watcher_changed, if !portmap_watcher_closed => {
if change.is_err() {
trace!("tick: portmap watcher closed");
self.sock.metrics.socket.actor_tick_other.inc();
portmap_watcher_closed = true;
continue;
}
trace!("tick: portmap changed");
self.sock.metrics.socket.actor_tick_portmap_changed.inc();
let new_external_address = *portmap_watcher.borrow();
debug!("external address updated: {new_external_address:?}");
self.re_stun(UpdateReason::PortmapUpdated);
},
state = self.local_interfaces_watcher.updated() => {
let Ok(state) = state else {
trace!("tick: link change receiver closed");
self.sock.metrics.socket.actor_tick_other.inc();
continue;
};
let is_major = state.is_major_change(¤t_netmon_state);
event!(
target: "iroh::_events::link_change",
Level::DEBUG,
?state,
is_major
);
current_netmon_state = state;
self.sock.metrics.socket.actor_link_change.inc();
self.handle_network_change(is_major).await;
}
eid = poll_fn(|cx| self.remote_map.poll_cleanup(cx)) => {
trace!(%eid, "cleaned up RemoteStateActor");
}
_ = &mut notify_quic_network_change => {
let has_network = self.has_usable_network();
let Some(pending) = self.call_notify_quic_network_change.as_mut() else {
continue;
};
if has_network || pending.expired() {
let is_major = pending.is_major;
self.call_notify_quic_network_change = None;
self.notify_quic_network_change(is_major);
} else {
trace!(
interval = ?pending.interval,
elapsed = ?pending.started.elapsed(),
"no default route yet, retrying"
);
pending.advance();
}
}
else => {
trace!("tick: else");
}
}
}
}
fn has_usable_network(&mut self) -> bool {
#[cfg(target_family = "wasm")]
{
true
}
#[cfg(not(target_family = "wasm"))]
{
let interfaces = self.local_interfaces_watcher.get();
interfaces.default_route_interface.is_some()
&& (interfaces.have_v4 || interfaces.have_v6)
}
}
async fn handle_network_change(&mut self, is_major: bool) {
debug!(is_major, "link change detected");
if is_major {
if let Err(err) = self.transports_network_change.rebind() {
warn!("failed to rebind transports: {err:?}");
}
self.transports_network_change.check_relay_connection();
#[cfg(not(wasm_browser))]
self.sock.dns_resolver.reset().await;
self.re_stun(UpdateReason::LinkChangeMajor);
} else {
self.re_stun(UpdateReason::LinkChangeMinor);
}
if self.has_usable_network() {
self.call_notify_quic_network_change = None;
self.notify_quic_network_change(is_major);
} else {
match &mut self.call_notify_quic_network_change {
Some(pending) => {
pending.is_major |= is_major;
}
None => {
self.call_notify_quic_network_change =
Some(PendingNetworkChangeNotify::new(is_major));
}
}
}
}
fn notify_quic_network_change(&mut self, is_major: bool) {
#[derive(Debug)]
struct Hint {
local_addrs: FxHashSet<IpAddr>,
}
impl NetworkChangeHint for Hint {
fn is_path_recoverable(
&self,
_path_id: noq::PathId,
network_path: noq_proto::FourTuple,
) -> bool {
match MultipathMappedAddr::from(network_path.remote()) {
MultipathMappedAddr::Mixed(_) => {
error!("A mixed address can not be used for network changes");
false
}
MultipathMappedAddr::Relay(_) => {
true
}
MultipathMappedAddr::Ip(_) => {
match network_path.local_ip() {
Some(local_ip) => self.local_addrs.contains(&local_ip),
None => true,
}
}
MultipathMappedAddr::Custom(_) => {
false
}
}
}
}
let hint = Hint {
#[cfg(not(wasm_browser))]
local_addrs: {
let interfaces = self.local_interfaces_watcher.get();
interfaces
.local_addresses
.regular
.iter()
.chain(interfaces.local_addresses.loopback.iter())
.copied()
.collect()
},
#[cfg(wasm_browser)]
local_addrs: Default::default(),
};
self.endpoint.handle_network_change(Some(Arc::new(hint)));
self.remote_map.on_network_change(is_major);
}
fn handle_relay_map_change(&mut self) {
self.re_stun(UpdateReason::RelayMapChange);
}
fn re_stun(&mut self, why: UpdateReason) {
let state = self.local_interfaces_watcher.get();
self.direct_addr_update_state
.schedule_run(why, state.into());
}
async fn handle_actor_message(&mut self, msg: ActorMessage) {
match msg {
ActorMessage::NetworkChange => {
self.network_monitor.network_change().await.ok();
}
ActorMessage::RelayMapChange => {
self.handle_relay_map_change();
}
ActorMessage::ResolveRemote(addr, tx) => {
tx.send(self.remote_map.resolve_remote(addr).await).ok();
}
ActorMessage::RemoteInfo(id, tx) => {
if let Some(info) = self.remote_map.remote_info(id).await {
tx.send(info).ok();
}
}
ActorMessage::AddConnection(remote, conn, tx) => {
if let Some(watcher) = self.remote_map.add_connection(remote, conn).await {
tx.send(watcher).ok();
}
}
ActorMessage::DirectAddrRefresh => {
#[cfg(not(wasm_browser))]
{
let (report, _reason) = self.sock.net_report.get();
self.update_direct_addresses(report.as_ref());
}
}
#[cfg(all(test, with_crypto_provider))]
ActorMessage::ForceNetworkChange(is_major) => {
self.handle_network_change(is_major).await;
}
}
}
#[cfg(not(wasm_browser))]
fn update_direct_addresses(&mut self, net_report_report: Option<&net_report::Report>) {
let mut addrs: BTreeMap<SocketAddr, (DirectAddrType, Option<Ipv6AddrFlags>)> =
BTreeMap::new();
let portmap_watcher = self
.direct_addr_update_state
.port_mapper
.watch_external_address();
let maybe_port_mapped = *portmap_watcher.borrow();
if let Some(portmap_ext) = maybe_port_mapped.map(SocketAddr::V4) {
addrs
.entry(portmap_ext)
.or_insert((DirectAddrType::Portmapped, None));
}
if let Some(net_report_report) = net_report_report {
if let Some(global_v4) = net_report_report.global_v4 {
addrs
.entry(global_v4.into())
.or_insert((DirectAddrType::Qad, None));
let port = self.sock.ip_bind_addrs().iter().find_map(|addr| {
if addr.port() != 0 {
Some(addr.port())
} else {
None
}
});
if let Some(port) = port
&& net_report_report
.mapping_varies_by_dest()
.unwrap_or_default()
{
let mut addr = global_v4;
addr.set_port(port);
addrs
.entry(addr.into())
.or_insert((DirectAddrType::Qad4LocalPort, None));
}
}
if let Some(global_v6) = net_report_report.global_v6 {
addrs
.entry(global_v6.into())
.or_insert((DirectAddrType::Qad, None));
}
}
self.collect_local_addresses(&mut addrs);
for addr in self.sock.configured_addrs.read().expect("poisoned").iter() {
addrs.entry(*addr).or_insert((DirectAddrType::Config, None));
}
let stored_addrs = addrs
.into_iter()
.filter_map(|(addr, (typ, flags))| {
let is_deprecated = flags.map(|f| f.deprecated).unwrap_or(false);
if is_deprecated {
return None;
}
Some(DirectAddr { addr, typ })
})
.collect();
self.sock.store_direct_addresses(stored_addrs);
}
#[cfg(not(wasm_browser))]
fn collect_local_addresses(
&mut self,
addrs: &mut BTreeMap<SocketAddr, (DirectAddrType, Option<Ipv6AddrFlags>)>,
) {
let netmon_state = self.local_interfaces_watcher.get();
let local_addrs: Vec<(SocketAddr, SocketAddr)> = self
.sock
.ip_bind_addrs()
.iter()
.copied()
.zip(self.sock.ip_local_addrs())
.collect();
let has_ipv4_unspecified = local_addrs.iter().find_map(|(_, a)| {
if a.is_ipv4() && a.ip().is_unspecified() {
Some(a.port())
} else {
None
}
});
let has_ipv6_unspecified = local_addrs.iter().find_map(|(_, a)| {
if a.is_ipv6() && a.ip().is_unspecified() {
Some(a.port())
} else {
None
}
});
if local_addrs
.iter()
.any(|(_, local)| local.ip().is_unspecified())
{
let LocalAddresses {
regular: mut ips,
loopback,
} = self.local_interfaces_watcher.get().local_addresses;
if ips.is_empty() && addrs.is_empty() {
ips = loopback;
}
for ip in ips {
let port_if_unspecified = match ip {
IpAddr::V4(_) => has_ipv4_unspecified,
IpAddr::V6(_) => has_ipv6_unspecified,
};
if let Some(port) = port_if_unspecified {
let addr = SocketAddr::new(ip, port);
let flags = find_flags(&netmon_state, ip);
addrs.entry(addr).or_insert((DirectAddrType::Local, flags));
}
}
}
for (bound, local) in local_addrs {
if !bound.ip().is_unspecified() {
let flags = find_flags(&netmon_state, local.ip());
addrs.entry(local).or_insert((DirectAddrType::Local, flags));
}
}
}
fn handle_net_report_report(&mut self, mut report: Option<net_report::Report>) {
if let Some(ref mut r) = report {
self.sock.ipv6_reported.store(r.udp_v6, Ordering::Relaxed);
if r.preferred_relay.is_none()
&& let Some(my_relay) = self.sock.my_relay()
{
r.preferred_relay.replace(my_relay);
}
self.transports_network_change.on_network_change(r);
}
#[cfg(not(wasm_browser))]
self.update_direct_addresses(report.as_ref());
}
}
#[cfg(not(wasm_browser))]
fn find_flags(state: &netmon::State, ip: IpAddr) -> Option<Ipv6AddrFlags> {
if ip.is_ipv6() {
state
.interfaces
.values()
.flat_map(|i| i.addrs())
.find_map(|addr| match addr {
IpNet::V4(_) => None,
IpNet::V6 { net, flags, .. } => {
if net.addr() == ip {
Some(flags)
} else {
None
}
}
})
} else {
None
}
}
fn new_re_stun_timer(initial_delay: bool) -> time::Interval {
let mut rng = rand::rng();
let d: Duration = rng.random_range(Duration::from_secs(20)..=Duration::from_secs(26));
if initial_delay {
debug!("scheduling periodic_stun to run in {}s", d.as_secs());
time::interval_at(time::Instant::now() + d, d)
} else {
debug!(
"scheduling periodic_stun to run immediately and in {}s",
d.as_secs()
);
time::interval(d)
}
}
#[derive(derive_more::Debug, Clone, Default)]
struct DiscoveredDirectAddrs {
addrs: Watchable<BTreeSet<DirectAddr>>,
updated_at: Arc<RwLock<Option<Instant>>>,
}
impl DiscoveredDirectAddrs {
fn update(&self, addrs: BTreeSet<DirectAddr>) -> bool {
*self.updated_at.write().expect("poisoned") = Some(Instant::now());
let updated = self.addrs.set(addrs).is_ok();
if updated {
event!(
target: "iroh::_events::direct_addrs",
Level::DEBUG,
addrs = ?self.addrs.get(),
);
}
updated
}
fn sockaddrs(&self) -> impl Iterator<Item = SocketAddr> {
self.addrs.get().into_iter().map(|da| da.addr)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct DirectAddr {
pub addr: SocketAddr,
pub typ: DirectAddrType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[non_exhaustive]
pub enum DirectAddrType {
Unknown,
Local,
Qad,
Portmapped,
Qad4LocalPort,
Config,
}
impl Display for DirectAddrType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DirectAddrType::Unknown => write!(f, "?"),
DirectAddrType::Local => write!(f, "local"),
DirectAddrType::Qad => write!(f, "qad"),
DirectAddrType::Portmapped => write!(f, "portmap"),
DirectAddrType::Qad4LocalPort => write!(f, "qad4localport"),
DirectAddrType::Config => write!(f, "config"),
}
}
}
#[cfg(all(test, with_crypto_provider))]
mod tests {
use std::{net::SocketAddrV4, sync::Arc, time::Duration};
use data_encoding::HEXLOWER;
use iroh_base::{EndpointAddr, EndpointId, TransportAddr};
use iroh_relay::tls::{CaRootsConfig, default_provider};
use n0_error::{Result, StackResultExt, StdResultExt};
use n0_future::{MergeBounded, StreamExt, time};
use n0_tracing_test::traced_test;
use n0_watcher::Watcher;
use rand::{CryptoRng, Rng, RngExt, SeedableRng};
use tokio_util::task::AbortOnDropHandle;
use tracing::{Instrument, error, info, info_span, instrument};
use super::Options;
use crate::{
Endpoint, SecretKey,
address_lookup::memory::MemoryLookup,
dns::DnsResolver,
endpoint::{QuicTransportConfig, presets},
socket::{
EndpointInner, StaticConfig, TransportConfig,
mapped_addrs::{EndpointIdMappedAddr, MappedAddr},
},
tls::{self, DEFAULT_MAX_TLS_TICKETS, misc::RustlsTokenKey},
};
const ALPN: &[u8] = b"n0/test/1";
fn default_options(rng: &mut impl CryptoRng) -> Options {
let crypto_provider = default_provider();
let secret_key = SecretKey::from_bytes(&rng.random());
let tls_config = tls::TlsConfig::new(
secret_key.clone(),
DEFAULT_MAX_TLS_TICKETS,
crypto_provider.clone(),
);
let static_config = StaticConfig {
server_config: tls_config.make_server_config(false).unwrap(),
client_config: tls_config.make_client_config(false).unwrap(),
tls_config,
token_key: Arc::new(RustlsTokenKey::new(rng, &crypto_provider).unwrap()),
transport_config: QuicTransportConfig::default(),
};
let server_config = static_config.create_server_config(vec![]);
Options {
transports: vec![
TransportConfig::default_ipv4(),
TransportConfig::default_ipv6(),
],
secret_key,
proxy_url: None,
dns_resolver: DnsResolver::new(),
server_config,
tls_config: CaRootsConfig::default()
.client_config(crypto_provider.clone())
.unwrap(),
#[cfg(any(test, feature = "test-utils"))]
address_lookup_user_data: None,
metrics: Default::default(),
hooks: Default::default(),
transport_bias: Default::default(),
portmapper_config: Default::default(),
static_config,
configured_addrs: Default::default(),
}
}
#[instrument(skip_all, fields(me = %ep.id().fmt_short()))]
async fn echo_receiver(ep: Endpoint, loss: ExpectedLoss) -> Result {
info!("accepting conn");
let conn = ep.accept().await.expect("no conn");
info!("accepting");
let conn = conn.await.context("accepting")?;
info!("accepting bi");
let (mut send_bi, mut recv_bi) = conn.accept_bi().await.std_context("accept bi")?;
info!("reading");
let val = recv_bi
.read_to_end(usize::MAX)
.await
.std_context("read to end")?;
info!("replying");
for chunk in val.chunks(12) {
send_bi.write_all(chunk).await.std_context("write all")?;
}
info!("finishing");
send_bi.finish().std_context("finish")?;
send_bi.stopped().await.std_context("stopped")?;
let stats = conn.stats();
info!("stats: {:#?}", stats);
if matches!(loss, ExpectedLoss::AlmostNone) {
for info in conn.paths().get().iter() {
assert!(
info.stats().unwrap().lost_packets < 10,
"[receiver] path {:?} should not loose many packets",
info.remote_addr()
);
}
}
conn.closed().await;
info!("closed");
ep.inner()?.noq_endpoint().wait_idle().await;
info!("idle");
Ok(())
}
#[instrument(skip_all, fields(me = %ep.id().fmt_short()))]
async fn echo_sender(
ep: Endpoint,
dest_id: EndpointId,
msg: &[u8],
loss: ExpectedLoss,
) -> Result {
info!("connecting to {}", dest_id.fmt_short());
let dest = EndpointAddr::new(dest_id);
let conn = ep.connect(dest, ALPN).await?;
info!("opening bi");
let (mut send_bi, mut recv_bi) = conn.open_bi().await.std_context("open bi")?;
info!("writing message");
send_bi.write_all(msg).await.std_context("write all")?;
info!("finishing");
send_bi.finish().std_context("finish")?;
send_bi.stopped().await.std_context("stopped")?;
info!("reading_to_end");
let val = recv_bi
.read_to_end(usize::MAX)
.await
.std_context("read to end")?;
assert_eq!(
val,
msg,
"[sender] expected {}, got {}",
HEXLOWER.encode(msg),
HEXLOWER.encode(&val)
);
let stats = conn.stats();
info!("stats: {:#?}", stats);
if matches!(loss, ExpectedLoss::AlmostNone) {
for info in conn.paths().get().iter() {
assert!(
info.stats().unwrap().lost_packets < 10,
"[sender] path {:?} should not loose many packets",
info.remote_addr()
);
}
}
conn.close(0u32.into(), b"done");
info!("closed");
ep.inner()?.noq_endpoint().wait_idle().await;
info!("idle");
Ok(())
}
#[derive(Debug, Copy, Clone)]
enum ExpectedLoss {
AlmostNone,
YeahSure,
}
async fn run_roundtrip(
sender: Endpoint,
receiver: Endpoint,
payload: &[u8],
loss: ExpectedLoss,
) -> Result<()> {
tokio::time::timeout(Duration::from_secs(4), async move {
let send_endpoint_id = sender.id();
let recv_endpoint_id = receiver.id();
info!("\nroundtrip: {send_endpoint_id:#} -> {recv_endpoint_id:#}");
let receiver_task = AbortOnDropHandle::new(tokio::spawn(echo_receiver(receiver, loss)));
let sender_res = echo_sender(sender, recv_endpoint_id, payload, loss).await;
let sender_is_err = match sender_res {
Ok(()) => false,
Err(err) => {
error!("[sender] Error:\n{err:#?}");
true
}
};
let receiver_is_err = match receiver_task.await {
Ok(Ok(())) => false,
Ok(Err(err)) => {
error!("[receiver] Error:\n{err:#?}");
true
}
Err(joinerr) => {
if joinerr.is_panic() {
std::panic::resume_unwind(joinerr.into_panic());
} else {
error!("[receiver] Error:\n{joinerr:#?}");
}
true
}
};
if sender_is_err || receiver_is_err {
panic!("Sender or receiver errored");
}
})
.await
.std_context("timeout")?;
Ok(())
}
async fn endpoint_pair() -> (AbortOnDropHandle<()>, Endpoint, Endpoint) {
let address_lookup = MemoryLookup::new();
let ep1 = Endpoint::builder(presets::Minimal)
.alpns(vec![ALPN.to_vec()])
.address_lookup(address_lookup.clone())
.bind()
.await
.unwrap();
let ep2 = Endpoint::builder(presets::Minimal)
.alpns(vec![ALPN.to_vec()])
.address_lookup(address_lookup.clone())
.bind()
.await
.unwrap();
address_lookup.add_endpoint_info(ep1.addr());
address_lookup.add_endpoint_info(ep2.addr());
let ep1_addr_stream = ep1.watch_addr().stream();
let ep2_addr_stream = ep2.watch_addr().stream();
let mut addr_stream = MergeBounded::from_iter([ep1_addr_stream, ep2_addr_stream]);
let task = tokio::spawn(async move {
while let Some(addr) = addr_stream.next().await {
address_lookup.add_endpoint_info(addr);
}
});
(AbortOnDropHandle::new(task), ep1, ep2)
}
#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn test_two_devices_roundtrip_noq_small() -> Result {
let (_guard, m1, m2) = endpoint_pair().await;
run_roundtrip(
m1.clone(),
m2.clone(),
b"hello m1",
ExpectedLoss::AlmostNone,
)
.await?;
run_roundtrip(
m2.clone(),
m1.clone(),
b"hello m2",
ExpectedLoss::AlmostNone,
)
.await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn test_two_devices_roundtrip_noq_large() -> Result {
let (_guard, m1, m2) = endpoint_pair().await;
let mut data = vec![0u8; 10 * 1024];
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
rng.fill_bytes(&mut data);
run_roundtrip(m1.clone(), m2.clone(), &data, ExpectedLoss::AlmostNone).await?;
run_roundtrip(m2.clone(), m1.clone(), &data, ExpectedLoss::AlmostNone).await?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_regression_network_change_rebind_wakes_connection_driver() -> Result {
let (_guard, m1, m2) = endpoint_pair().await;
println!("Net change");
m1.inner()?.force_network_change(true).await;
tokio::time::sleep(Duration::from_secs(1)).await;
let _handle = AbortOnDropHandle::new(tokio::spawn({
let endpoint = m2.clone();
async move {
while let Some(incoming) = endpoint.accept().await {
println!("Incoming first conn!");
let conn = incoming.await.anyerr()?;
conn.closed().await;
}
n0_error::Ok(())
}
}));
println!("first conn!");
let conn = m1.connect(m2.addr(), ALPN).await?;
println!("Closing first conn");
conn.close(0u32.into(), b"bye lolz");
conn.closed().await;
println!("Closed first conn");
Ok(())
}
fn offset(rng: &mut rand_chacha::ChaCha8Rng) -> Duration {
let delay = rng.random_range(1..=5);
Duration::from_millis(delay * 50)
}
#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn test_two_devices_roundtrip_network_change_only_a() -> Result {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let (_guard, m1, m2) = endpoint_pair().await;
let _network_change_guard = {
let m1 = m1.clone();
let mut rng = rng.clone();
let task = tokio::spawn(async move {
loop {
info!("[m1] network change");
m1.inner()
.expect("haven't closed the endpoint yet")
.force_network_change(true)
.await;
time::sleep(offset(&mut rng)).await;
}
});
AbortOnDropHandle::new(task)
};
let mut data = vec![0u8; 10 * 1024];
rng.fill_bytes(&mut data);
run_roundtrip(m1.clone(), m2.clone(), &data, ExpectedLoss::YeahSure).await?;
run_roundtrip(m2.clone(), m1.clone(), &data, ExpectedLoss::YeahSure).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn test_two_devices_roundtrip_network_change_a_and_b() -> Result {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let (_guard, m1, m2) = endpoint_pair().await;
let _network_change_guard = {
let m1 = m1.clone();
let m2 = m2.clone();
let mut rng = rng.clone();
let task = tokio::spawn(async move {
info!("-- [m1] network change");
m1.inner()
.expect("haven't closed the endpoint yet")
.force_network_change(true)
.await;
info!("-- [m2] network change");
m2.inner()
.expect("haven't closed the endpoint yet")
.force_network_change(true)
.await;
time::sleep(offset(&mut rng)).await;
});
AbortOnDropHandle::new(task)
};
let mut data = vec![0u8; 10 * 1024];
rng.fill_bytes(&mut data);
run_roundtrip(m1.clone(), m2.clone(), &data, ExpectedLoss::YeahSure).await?;
run_roundtrip(m2.clone(), m1.clone(), &data, ExpectedLoss::YeahSure).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn test_two_devices_setup_teardown() -> Result {
for i in 0..10 {
info!("-- round {i}");
info!("setting up stack");
let (_guard, m1, m2) = endpoint_pair().await;
info!("closing endpoints");
let sock1 = m1.inner()?;
let sock2 = m2.inner()?;
m1.close().await;
m2.close().await;
assert!(sock1.is_closed());
assert!(sock2.is_closed());
}
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_direct_addresses() {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let sock = EndpointInner::bind(default_options(&mut rng))
.await
.unwrap();
let eps0 = sock.ip_addrs().get();
info!("{eps0:?}");
assert!(!eps0.is_empty());
let eps1 = sock.ip_addrs().get();
info!("{eps1:?}");
assert_eq!(eps0, eps1);
}
#[instrument(name = "ep", skip_all, fields(me = %secret_key.public().fmt_short()))]
async fn socket_ep(secret_key: SecretKey) -> Result<EndpointInner> {
let crypto_provider = default_provider();
let tls_config = tls::TlsConfig::new(
secret_key.clone(),
DEFAULT_MAX_TLS_TICKETS,
crypto_provider.clone(),
);
let keylog = true;
let static_config = StaticConfig {
server_config: tls_config.make_server_config(keylog).unwrap(),
client_config: tls_config.make_client_config(keylog).unwrap(),
tls_config,
token_key: Arc::new(RustlsTokenKey::new(&mut rand::rng(), &crypto_provider).unwrap()),
transport_config: QuicTransportConfig::default(),
};
let server_config = static_config.create_server_config(vec![ALPN.to_vec()]);
let dns_resolver = DnsResolver::new();
let opts = Options {
transports: vec![
TransportConfig::default_ipv4(),
TransportConfig::default_ipv6(),
],
secret_key: secret_key.clone(),
address_lookup_user_data: None,
dns_resolver,
proxy_url: None,
server_config,
tls_config: CaRootsConfig::default()
.client_config(crypto_provider.clone())
.unwrap(),
metrics: Default::default(),
hooks: Default::default(),
transport_bias: Default::default(),
portmapper_config: Default::default(),
static_config,
configured_addrs: Default::default(),
};
let sock = EndpointInner::bind(opts).await?;
Ok(sock)
}
#[instrument(name = "connect", skip_all, fields(me = %ep_secret_key.public().fmt_short()))]
async fn socket_connect(
ep: noq::Endpoint,
ep_secret_key: SecretKey,
addr: EndpointIdMappedAddr,
endpoint_id: EndpointId,
) -> Result<noq::Connection> {
let mut transport_config = noq::TransportConfig::default();
transport_config.keep_alive_interval(Some(Duration::from_secs(1)));
socket_connect_with_transport_config(
ep,
ep_secret_key,
addr,
endpoint_id,
Arc::new(transport_config),
)
.await
}
#[instrument(name = "connect", skip_all, fields(me = %ep_secret_key.public().fmt_short()))]
async fn socket_connect_with_transport_config(
ep: noq::Endpoint,
ep_secret_key: SecretKey,
mapped_addr: EndpointIdMappedAddr,
endpoint_id: EndpointId,
transport_config: Arc<noq::TransportConfig>,
) -> Result<noq::Connection> {
let mut quic_client_config = tls::TlsConfig::new(
ep_secret_key.clone(),
DEFAULT_MAX_TLS_TICKETS,
default_provider(),
)
.make_client_config(true)?;
quic_client_config.set_alpn_protocols(vec![ALPN.to_vec()]);
let mut client_config = noq::ClientConfig::new(Arc::new(quic_client_config));
client_config.transport_config(transport_config);
let connect = ep
.connect_with(
client_config,
mapped_addr.private_socket_addr(),
&tls::name::encode(endpoint_id),
)
.std_context("connect")?;
let connection = connect.await.anyerr()?;
Ok(connection)
}
#[tokio::test]
#[traced_test]
async fn test_try_send_no_send_addr() {
let secret_key_1 = SecretKey::from_bytes(&[1u8; 32]);
let secret_key_2 = SecretKey::from_bytes(&[2u8; 32]);
let endpoint_id_2 = secret_key_2.public();
let secret_key_missing_endpoint = SecretKey::from_bytes(&[255u8; 32]);
let endpoint_id_missing_endpoint = secret_key_missing_endpoint.public();
let sock_1 = socket_ep(secret_key_1.clone()).await.unwrap();
let bad_addr = EndpointIdMappedAddr::generate();
let res = tokio::time::timeout(
Duration::from_millis(500),
socket_connect(
sock_1.noq_endpoint().clone(),
secret_key_1.clone(),
bad_addr,
endpoint_id_missing_endpoint,
),
)
.await;
assert!(res.is_err(), "expecting timeout");
let sock_2 = socket_ep(secret_key_2.clone()).await.unwrap();
let accept_task = tokio::spawn({
async fn accept(ep: noq::Endpoint) -> Result<()> {
let incoming = ep.accept().await.std_context("no incoming")?;
let _conn = incoming
.accept()
.std_context("accept")?
.await
.std_context("accepting")?;
tokio::time::sleep(Duration::from_secs(10)).await;
info!("accept finished");
Ok(())
}
let ep = sock_2.noq_endpoint().clone();
async move {
if let Err(err) = accept(ep).await {
error!("{err:#}");
}
}
.instrument(info_span!("ep2.accept, me = endpoint_id_2.fmt_short()"))
});
let _accept_task = AbortOnDropHandle::new(accept_task);
let addrs = sock_2
.ip_addrs()
.get()
.into_iter()
.map(|x| TransportAddr::Ip(x.addr));
let endpoint_addr_2 = EndpointAddr::from_parts(endpoint_id_2, addrs);
let addr = sock_1
.resolve_remote(endpoint_addr_2)
.await
.unwrap()
.unwrap();
let res = tokio::time::timeout(
Duration::from_secs(10),
socket_connect(
sock_1.noq_endpoint().clone(),
secret_key_1.clone(),
addr,
endpoint_id_2,
),
)
.await
.expect("timeout while connecting");
res.unwrap();
}
#[tokio::test]
#[traced_test]
async fn test_try_send_no_udp_addr_or_relay_url() {
let secret_key_1 = SecretKey::from_bytes(&[1u8; 32]);
let secret_key_2 = SecretKey::from_bytes(&[2u8; 32]);
let endpoint_id_2 = secret_key_2.public();
let sock_1 = socket_ep(secret_key_1.clone()).await.unwrap();
let sock_2 = socket_ep(secret_key_2.clone()).await.unwrap();
let ep_2 = sock_2.noq_endpoint().clone();
let accept_task = tokio::spawn({
async fn accept(ep: noq::Endpoint) -> Result<()> {
let incoming = ep.accept().await.std_context("no incoming")?;
let conn = incoming
.accept()
.std_context("accept")?
.await
.std_context("connecting")?;
let mut stream = conn.accept_uni().await.std_context("accept uni")?;
stream
.read_to_end(1 << 16)
.await
.std_context("read to end")?;
info!("accept finished");
Ok(())
}
async move {
if let Err(err) = accept(ep_2).await {
error!("{err:#}");
}
}
.instrument(info_span!("ep2.accept", me = %endpoint_id_2.fmt_short()))
});
let _accept_task = AbortOnDropHandle::new(accept_task);
let empty_addr_2 = EndpointAddr::from_parts(
endpoint_id_2,
[TransportAddr::Ip(
SocketAddrV4::new([192, 0, 2, 1].into(), 12345).into(),
)],
);
let addr_2 = sock_1.resolve_remote(empty_addr_2).await.unwrap().unwrap();
let mut transport_config = noq::TransportConfig::default();
transport_config.max_idle_timeout(Some(Duration::from_millis(200).try_into().unwrap()));
let res = socket_connect_with_transport_config(
sock_1.noq_endpoint().clone(),
secret_key_1.clone(),
addr_2,
endpoint_id_2,
Arc::new(transport_config),
)
.await;
assert!(res.is_err(), "expected timeout");
info!("first connect timed out as expected");
let correct_addr_2 = EndpointAddr::from_parts(
endpoint_id_2,
sock_2
.ip_addrs()
.get()
.into_iter()
.map(|x| TransportAddr::Ip(x.addr)),
);
let addr_2a = sock_1
.resolve_remote(correct_addr_2)
.await
.unwrap()
.unwrap();
assert_eq!(addr_2, addr_2a);
tokio::time::timeout(Duration::from_secs(10), async move {
info!("establishing new connection");
let conn = socket_connect(
sock_1.noq_endpoint().clone(),
secret_key_1.clone(),
addr_2,
endpoint_id_2,
)
.await
.unwrap();
info!("have connection");
let mut stream = conn.open_uni().await.unwrap();
stream.write_all(b"hello").await.unwrap();
stream.finish().unwrap();
stream.stopped().await.unwrap();
info!("finished stream");
})
.await
.expect("connection timed out");
}
}