use std::{future::Future, path::PathBuf, sync::Arc, time::Duration};
use iroh::{
Endpoint, EndpointAddr, RelayMap, RelayMode, TransportAddr,
endpoint::{Connection, Path, PathEvent, presets},
tls::CaTlsConfig,
};
use iroh_metrics::MetricsGroupSet;
use n0_error::{Result, StackResultExt, StdResultExt, anyerr, ensure_any};
use n0_future::{StreamExt, boxed::BoxFuture, task::AbortOnDropHandle};
use noq::Side;
use patchbay::{Device, IpSupport, Lab, OutDir, TestGuard};
use tokio::sync::{Barrier, oneshot};
use tracing::{Instrument, debug, error, error_span, event, info};
use self::relay::run_relay_server;
const TEST_ALPN: &[u8] = b"test";
pub(crate) async fn lab_with_relay(
outdir: PathBuf,
) -> Result<(Lab, RelayMap, AbortOnDropHandle<()>, TestGuard)> {
let mut builder = Lab::builder().outdir(OutDir::Exact(outdir));
if let Some(name) = std::thread::current().name() {
builder = builder.label(name);
}
let lab = builder.build().await?;
let guard = lab.test_guard();
let (relay_map, relay_guard) = spawn_relay(&lab).await?;
Ok((lab, relay_map, relay_guard, guard))
}
async fn spawn_relay(lab: &Lab) -> Result<(RelayMap, AbortOnDropHandle<()>)> {
let dc = lab
.add_router("dc")
.ip_support(IpSupport::DualStack)
.build()
.await?;
let dev_relay = lab.add_device("relay").uplink(dc.id()).build().await?;
let relay_v4 = dev_relay.ip().expect("relay has IPv4");
let relay_v6 = dev_relay.ip6().expect("relay has IPv6");
let dns = lab.dns_server()?;
dns.set_host("relay.test", relay_v4.into())?;
dns.set_host("relay.test", relay_v6.into())?;
info!(%relay_v4, %relay_v6, "DNS entries for relay.test registered");
let (relay_map_tx, relay_map_rx) = oneshot::channel();
let task_relay = dev_relay.spawn(async move |_ctx| {
let (relay_map, _server) = run_relay_server().await.unwrap();
relay_map_tx.send(relay_map).unwrap();
std::future::pending::<()>().await;
})?;
let relay_map = relay_map_rx.await.unwrap();
Ok((relay_map, AbortOnDropHandle::new(task_relay)))
}
type RunFn = Box<dyn 'static + Send + FnOnce(Device, Endpoint, Connection) -> BoxFuture<Result>>;
fn box_fn<F, Fut>(f: F) -> RunFn
where
F: FnOnce(Device, Endpoint, Connection) -> Fut + Send + 'static,
Fut: Future<Output = Result> + Send + 'static,
{
Box::new(move |dev, ep, conn| Box::pin(f(dev, ep, conn)))
}
pub(crate) struct Pair {
relay_map: RelayMap,
server_dev: Option<Device>,
client_dev: Option<Device>,
server_fn: Option<RunFn>,
client_fn: Option<RunFn>,
}
impl Pair {
pub(crate) fn new(relay_map: RelayMap) -> Self {
Self {
relay_map,
server_dev: None,
client_dev: None,
server_fn: None,
client_fn: None,
}
}
pub(crate) fn left<F, Fut>(mut self, side: Side, device: Device, run_fn: F) -> Self
where
F: FnOnce(Device, Endpoint, Connection) -> Fut + Send + 'static,
Fut: Future<Output = Result> + Send + 'static,
{
let (dev_slot, fn_slot) = match side {
Side::Server => (&mut self.server_dev, &mut self.server_fn),
Side::Client => (&mut self.client_dev, &mut self.client_fn),
};
*dev_slot = Some(device);
*fn_slot = Some(box_fn(run_fn));
self
}
pub(crate) fn right<F, Fut>(self, device: Device, run_fn: F) -> Self
where
F: FnOnce(Device, Endpoint, Connection) -> Fut + Send + 'static,
Fut: Future<Output = Result> + Send + 'static,
{
let remaining = match (&self.server_dev, &self.client_dev) {
(Some(_), None) => Side::Client,
(None, Some(_)) => Side::Server,
(None, None) => panic!("call .left() before .right()"),
(Some(_), Some(_)) => panic!("both sides already assigned"),
};
self.left(remaining, device, run_fn)
}
pub(crate) fn server<F, Fut>(mut self, device: Device, run_fn: F) -> Self
where
F: FnOnce(Device, Endpoint, Connection) -> Fut + Send + 'static,
Fut: Future<Output = Result> + Send + 'static,
{
self.server_dev = Some(device);
self.server_fn = Some(box_fn(run_fn));
self
}
pub(crate) fn client<F, Fut>(mut self, device: Device, run_fn: F) -> Self
where
F: FnOnce(Device, Endpoint, Connection) -> Fut + Send + 'static,
Fut: Future<Output = Result> + Send + 'static,
{
self.client_dev = Some(device);
self.client_fn = Some(box_fn(run_fn));
self
}
pub(crate) async fn run(mut self) -> Result {
let server_device = self.server_dev.take().context("Missing server device")?;
let server_run = self
.server_fn
.take()
.context("Missing server run function")?;
let client_device = self.client_dev.take().context("Missing client device")?;
let client_run = self
.client_fn
.take()
.context("Missing client run function")?;
let (addr_tx, addr_rx) = oneshot::channel();
let relay_map2 = self.relay_map.clone();
let barrier_server = Arc::new(Barrier::new(2));
let barrier_client = barrier_server.clone();
let server_task = server_device.spawn(|dev| {
async move {
let endpoint = endpoint_builder(&dev, relay_map2)
.bind()
.await
.context("server endpoint bind")?;
info!(
id=%endpoint.id().fmt_short(),
bound_sockets=?endpoint.bound_sockets(),
"server endpoint bound",
);
endpoint.online().await;
info!("endpoint online");
addr_tx.send(addr_relay_only(endpoint.addr())).unwrap();
let incoming = endpoint.accept().await.context("server accept incoming")?;
let conn = incoming
.accept()
.anyerr()?
.await
.context("server accept handshake")?;
info!(remote=%conn.remote_id().fmt_short(), "accepted, executing run function");
watch_selected_path(&conn);
let res = server_run(dev.clone(), endpoint.clone(), conn).await;
match &res {
Ok(()) => info!("run function completed successfully"),
Err(err) => error!("run function failed: {err:#}"),
}
barrier_server.wait().await;
for group in endpoint.metrics().groups() {
dev.record_iroh_metrics(group);
}
res
}
.instrument(error_span!("ep-server"))
})?;
let client_task = client_device.spawn(move |dev| {
async move {
let endpoint = endpoint_builder(&dev, self.relay_map)
.bind()
.await
.context("client endpoint bind")?;
info!(
id=%endpoint.id().fmt_short(),
bound_sockets=?endpoint.bound_sockets(),
"client endpoint bound",
);
let addr = addr_rx
.await
.std_context("server did not send its address")?;
info!(?addr, "connecting to server");
let conn = endpoint
.connect(addr, TEST_ALPN)
.await
.context("client connect")?;
watch_selected_path(&conn);
info!(
remote=%conn.remote_id().fmt_short(),
"connected, executing run function",
);
let res = client_run(dev.clone(), endpoint.clone(), conn).await;
match &res {
Ok(()) => info!("run function completed successfully"),
Err(err) => error!("run function failed: {err:#}"),
}
barrier_client.wait().await;
for group in endpoint.metrics().groups() {
dev.record_iroh_metrics(group);
}
res
}
.instrument(error_span!("ep-client"))
})?;
let (server_res, client_res) = tokio::join!(server_task, client_task);
let [server_res, client_res] = [(&server_device, server_res), (&client_device, client_res)]
.map(|(dev, res)| {
let res = match res {
Err(err) => Err(anyerr!(err, "device {} panicked", dev.name())),
Ok(Err(err)) => Err(anyerr!(err, "device {} failed", dev.name())),
Ok(Ok(())) => Ok(()),
};
let res_str = res.as_ref().map_err(|err| format!("{err:#}")).cloned();
log_result_on_device(dev, res_str);
res
});
server_res?;
client_res?;
Ok(())
}
}
fn log_result_on_device<E: std::fmt::Display + Send + 'static>(dev: &Device, res: Result<(), E>) {
let _ = dev.run_sync(move || {
match res {
Ok(_) => event!(
target: "test::_events::pass",
tracing::Level::INFO,
msg = %"device passed"
),
Err(error) => event!(
target: "test::_events::fail",
tracing::Level::ERROR,
%error,
msg = %"device failed"
),
}
Ok(())
});
}
pub(crate) trait PathConnectionExt {
async fn wait_selected(
&self,
timeout: Duration,
f: impl FnMut(&Path<'_>) -> bool,
) -> Result<TransportAddr>;
async fn wait_ip(&self, timeout: Duration) -> Result<TransportAddr> {
self.wait_selected(timeout, |p| p.is_ip())
.await
.context("wait_ip")
}
}
impl PathConnectionExt for Connection {
async fn wait_selected(
&self,
timeout: Duration,
mut f: impl FnMut(&Path<'_>) -> bool,
) -> Result<TransportAddr> {
let mut stream = self.paths_stream();
tokio::time::timeout(timeout, async {
while let Some(paths) = stream.next().await {
let selected = paths
.iter()
.find(|p| p.is_selected())
.expect("no selected path");
if f(&selected) {
return Ok(selected.remote_addr().clone());
}
}
Err(anyerr!("path stream ended"))
})
.await
.with_std_context(|_| format!("wait_selected timed out after {timeout:?}"))?
}
}
pub(crate) fn is_relayed(conn: &iroh::endpoint::Connection) -> bool {
conn.paths()
.iter()
.find(|p| p.is_selected())
.expect("no selected path")
.is_relay()
}
pub(crate) async fn ping_open(conn: &Connection, timeout: Duration) -> Result {
tokio::time::timeout(timeout, async {
let data: [u8; 8] = rand::random();
debug!("open_bi");
let (mut send, mut recv) = conn.open_bi().await.anyerr()?;
debug!("write_all");
send.write_all(&data).await.anyerr()?;
send.finish().anyerr()?;
debug!("read_to_end");
let r = recv.read_to_end(8).await.anyerr()?;
ensure_any!(r == data, "reply matches");
debug!("done");
Ok(())
})
.instrument(error_span!("ping_open"))
.await
.with_std_context(|_| format!("ping_open timed out after {timeout:?}"))?
}
pub(crate) async fn ping_accept(conn: &Connection, timeout: Duration) -> Result {
tokio::time::timeout(timeout, async {
debug!("accept_bi");
let (mut send, mut recv) = conn.accept_bi().await.anyerr()?;
debug!("read_to_end");
let data = recv.read_to_end(8).await.anyerr()?;
debug!("write_all");
send.write_all(&data).await.anyerr()?;
send.finish().anyerr()?;
debug!("done");
Ok(())
})
.instrument(error_span!("ping_accept"))
.await
.with_std_context(|_| format!("ping_accept timed out after {timeout:?}"))?
}
fn watch_selected_path(conn: &Connection) {
let mut events = conn.path_events();
if let Some(path) = conn.paths().iter().find(|p| p.is_selected()) {
debug!("selected path: [{}] {}", path.id(), path.remote_addr());
}
tokio::spawn(
async move {
while let Some(event) = events.next().await {
if let PathEvent::Selected {
id, remote_addr, ..
} = event
{
debug!("selected path: [{id}] {remote_addr}");
}
}
}
.instrument(tracing::Span::current()),
);
}
fn endpoint_builder(device: &Device, relay_map: RelayMap) -> iroh::endpoint::Builder {
#[allow(unused_mut)]
let mut builder = Endpoint::builder(presets::Minimal)
.relay_mode(RelayMode::Custom(relay_map))
.ca_tls_config(CaTlsConfig::insecure_skip_verify())
.alpns(vec![TEST_ALPN.to_vec()]);
#[cfg(not(feature = "qlog"))]
let _ = device;
#[cfg(feature = "qlog")]
{
if let Some(path) = device.filepath("qlog") {
let prefix = path.file_name().unwrap().to_str().unwrap();
let directory = path.parent().unwrap();
let transport_config = iroh::endpoint::QuicTransportConfig::builder()
.qlog_from_path(directory, prefix)
.build();
builder = builder.transport_config(transport_config);
}
}
builder
}
fn addr_relay_only(addr: EndpointAddr) -> EndpointAddr {
EndpointAddr::from_parts(addr.id, addr.addrs.into_iter().filter(|a| a.is_relay()))
}
mod relay {
use std::{
net::{IpAddr, Ipv6Addr},
sync::Arc,
};
use iroh_base::RelayUrl;
use iroh_relay::{
RelayConfig, RelayMap, RelayQuicConfig,
server::{
AllowAll, CertConfig, QuicConfig, RelayConfig as RelayServerConfig, Server,
ServerConfig, SpawnError, TlsConfig, testing::self_signed_tls_certs_and_config,
},
};
pub(crate) async fn run_relay_server() -> Result<(RelayMap, Server), SpawnError> {
let bind_ip: IpAddr = Ipv6Addr::UNSPECIFIED.into();
let (_certs, server_config) = self_signed_tls_certs_and_config();
let tls = TlsConfig::new((bind_ip, 443), CertConfig::Manual { server_config });
let mut relay = RelayServerConfig::new((bind_ip, 80));
relay.tls = Some(tls);
relay.key_cache_capacity = Some(1024);
relay.access = Arc::new(AllowAll);
let mut config = ServerConfig::default();
config.relay = Some(relay);
config.quic = Some(QuicConfig::new((bind_ip, 7842)));
let server = Server::spawn(config).await?;
let url: RelayUrl = "https://relay.test".parse().expect("valid relay url");
let quic = server
.quic_addr()
.map(|addr| RelayQuicConfig::new(addr.port()));
let relay_map: RelayMap = RelayConfig::new(url, quic).into();
Ok((relay_map, server))
}
}