use std::{
collections::BTreeMap,
future::Future,
pin::Pin,
sync::{Arc, Mutex},
};
use iroh_base::EndpointId;
use n0_error::{AnyError, e, stack_error};
use n0_future::{
join_all,
task::{self, AbortOnDropHandle, JoinSet},
};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, field::Empty, info_span, trace, warn};
use crate::{
Endpoint,
endpoint::{Accepting, Connection, RemoteEndpointIdError, quic},
};
#[derive(Clone, Debug)]
pub struct Router {
endpoint: Endpoint,
task: Arc<Mutex<Option<AbortOnDropHandle<()>>>>,
cancel_token: CancellationToken,
}
#[derive(derive_more::Debug)]
pub struct RouterBuilder {
endpoint: Endpoint,
protocols: ProtocolMap,
#[debug(skip)]
incoming_filter: Option<IncomingFilter>,
}
#[allow(missing_docs)]
#[stack_error(derive, add_meta, from_sources, std_sources)]
#[non_exhaustive]
pub enum AcceptError {
#[error(transparent)]
Connecting {
source: crate::endpoint::ConnectingError,
},
#[error(transparent)]
Connection {
source: crate::endpoint::ConnectionError,
},
#[error(transparent)]
MissingRemoteEndpointId { source: RemoteEndpointIdError },
#[error("Not allowed.")]
NotAllowed {},
#[error(transparent)]
User { source: AnyError },
}
impl AcceptError {
#[track_caller]
pub fn from_err<T: std::error::Error + Send + Sync + 'static>(value: T) -> Self {
e!(AcceptError::User {
source: AnyError::from_std(value)
})
}
#[track_caller]
pub fn from_boxed(value: Box<dyn std::error::Error + Send + Sync>) -> Self {
e!(AcceptError::User {
source: AnyError::from_std_box(value)
})
}
}
impl From<std::io::Error> for AcceptError {
fn from(err: std::io::Error) -> Self {
Self::from_err(err)
}
}
impl From<quic::ClosedStream> for AcceptError {
fn from(err: quic::ClosedStream) -> Self {
Self::from_err(err)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum IncomingFilterOutcome {
Accept,
Retry,
Reject,
Ignore,
}
pub type IncomingFilter =
Arc<dyn Fn(&crate::endpoint::Incoming) -> IncomingFilterOutcome + Send + Sync + 'static>;
pub trait ProtocolHandler: Send + Sync + std::fmt::Debug + 'static {
fn on_accepting(
&self,
accepting: Accepting,
) -> impl Future<Output = Result<Connection, AcceptError>> + Send {
async move {
let conn = accepting.await?;
Ok(conn)
}
}
fn accept(
&self,
connection: Connection,
) -> impl Future<Output = Result<(), AcceptError>> + Send;
fn shutdown(&self) -> impl Future<Output = ()> + Send {
async move {}
}
}
impl<T: ProtocolHandler> ProtocolHandler for Arc<T> {
async fn on_accepting(&self, accepting: Accepting) -> Result<Connection, AcceptError> {
self.as_ref().on_accepting(accepting).await
}
async fn accept(&self, conn: Connection) -> Result<(), AcceptError> {
self.as_ref().accept(conn).await
}
async fn shutdown(&self) {
self.as_ref().shutdown().await
}
}
impl<T: ProtocolHandler> ProtocolHandler for Box<T> {
async fn on_accepting(&self, accepting: Accepting) -> Result<Connection, AcceptError> {
self.as_ref().on_accepting(accepting).await
}
async fn accept(&self, conn: Connection) -> Result<(), AcceptError> {
self.as_ref().accept(conn).await
}
async fn shutdown(&self) {
self.as_ref().shutdown().await
}
}
impl<T: ProtocolHandler> From<T> for Box<dyn DynProtocolHandler> {
fn from(value: T) -> Self {
Box::new(value)
}
}
pub trait DynProtocolHandler: Send + Sync + std::fmt::Debug + 'static {
fn on_accepting(
&self,
accepting: Accepting,
) -> Pin<Box<dyn Future<Output = Result<Connection, AcceptError>> + Send + '_>> {
Box::pin(async move {
let conn = accepting.await?;
Ok(conn)
})
}
fn accept(
&self,
connection: Connection,
) -> Pin<Box<dyn Future<Output = Result<(), AcceptError>> + Send + '_>>;
fn shutdown(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
Box::pin(async move {})
}
}
impl<P: ProtocolHandler> DynProtocolHandler for P {
fn accept(
&self,
connection: Connection,
) -> Pin<Box<dyn Future<Output = Result<(), AcceptError>> + Send + '_>> {
Box::pin(<Self as ProtocolHandler>::accept(self, connection))
}
fn on_accepting(
&self,
accepting: Accepting,
) -> Pin<Box<dyn Future<Output = Result<Connection, AcceptError>> + Send + '_>> {
Box::pin(<Self as ProtocolHandler>::on_accepting(self, accepting))
}
fn shutdown(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
Box::pin(<Self as ProtocolHandler>::shutdown(self))
}
}
#[derive(Debug, Default)]
pub(crate) struct ProtocolMap(BTreeMap<Vec<u8>, Box<dyn DynProtocolHandler>>);
impl ProtocolMap {
pub(crate) fn get(&self, alpn: &[u8]) -> Option<&dyn DynProtocolHandler> {
self.0.get(alpn).map(|p| &**p)
}
pub(crate) fn insert(&mut self, alpn: Vec<u8>, handler: Box<dyn DynProtocolHandler>) {
self.0.insert(alpn, handler);
}
pub(crate) fn alpns(&self) -> impl Iterator<Item = &Vec<u8>> {
self.0.keys()
}
pub(crate) async fn shutdown(&self) {
let handlers = self.0.values().map(|p| p.shutdown());
join_all(handlers).await;
}
}
impl Router {
pub fn builder(endpoint: Endpoint) -> RouterBuilder {
RouterBuilder::new(endpoint)
}
pub fn endpoint(&self) -> &Endpoint {
&self.endpoint
}
pub fn is_shutdown(&self) -> bool {
self.cancel_token.is_cancelled()
}
pub async fn shutdown(&self) -> Result<(), n0_future::task::JoinError> {
if self.is_shutdown() {
return Ok(());
}
self.cancel_token.cancel();
let task = self.task.lock().expect("poisoned").take();
if let Some(task) = task {
task.await?;
}
Ok(())
}
}
impl RouterBuilder {
pub fn new(endpoint: Endpoint) -> Self {
Self {
endpoint,
protocols: ProtocolMap::default(),
incoming_filter: None,
}
}
pub fn incoming_filter(mut self, filter: IncomingFilter) -> Self {
self.incoming_filter = Some(filter);
self
}
pub fn accept(
mut self,
alpn: impl AsRef<[u8]>,
handler: impl Into<Box<dyn DynProtocolHandler>>,
) -> Self {
self.protocols
.insert(alpn.as_ref().to_vec(), handler.into());
self
}
pub fn endpoint(&self) -> &Endpoint {
&self.endpoint
}
#[must_use = "Router aborts when dropped, use Router::shutdown to shut the router down cleanly"]
pub fn spawn(self) -> Router {
let alpns = self
.protocols
.alpns()
.map(|alpn| alpn.to_vec())
.collect::<Vec<_>>();
let protocols = Arc::new(self.protocols);
let incoming_filter = self.incoming_filter;
self.endpoint.set_alpns(alpns);
let mut join_set = JoinSet::new();
let endpoint = self.endpoint.clone();
let cancel = CancellationToken::new();
let cancel_token = cancel.clone();
let run_loop_fut = async move {
let _cancel_guard = cancel_token.clone().drop_guard();
let handler_cancel_token = CancellationToken::new();
loop {
tokio::select! {
biased;
_ = cancel_token.cancelled() => {
break;
},
Some(res) = join_set.join_next() => {
match res {
Err(outer) => {
if outer.is_panic() {
error!("Task panicked: {outer:?}");
break;
} else if outer.is_cancelled() {
trace!("Task cancelled: {outer:?}");
} else {
error!("Task failed: {outer:?}");
break;
}
}
Ok(Some(())) => {
trace!("Task finished");
}
Ok(None) => {
trace!("Task cancelled");
}
}
},
incoming = endpoint.accept() => {
let Some(incoming) = incoming else {
break; };
if let Some(filter) = &incoming_filter {
match filter(&incoming) {
IncomingFilterOutcome::Accept => {}
IncomingFilterOutcome::Retry => {
if !incoming.remote_addr_validated() {
warn!(
"filter returned Retry for an already validated connection",
);
}
if let Err(err) = incoming.retry() {
err.into_incoming().refuse();
}
continue;
}
IncomingFilterOutcome::Reject => {
incoming.refuse();
continue;
}
IncomingFilterOutcome::Ignore => {
incoming.ignore();
continue;
}
}
}
let protocols = protocols.clone();
let token = handler_cancel_token.child_token();
let span = info_span!("router.accept", me=%endpoint.id().fmt_short(), remote=Empty, alpn=Empty);
join_set.spawn(async move {
token.run_until_cancelled(handle_connection(incoming, protocols)).await
}.instrument(span));
},
}
}
protocols.shutdown().await;
handler_cancel_token.cancel();
endpoint.close().await;
tracing::debug!("Shutting down remaining tasks");
join_set.abort_all();
while let Some(res) = join_set.join_next().await {
match res {
Err(err) if err.is_panic() => error!("Task panicked: {err:?}"),
_ => {}
}
}
};
let task = task::spawn(run_loop_fut.instrument(tracing::Span::current()));
let task = AbortOnDropHandle::new(task);
Router {
endpoint: self.endpoint,
task: Arc::new(Mutex::new(Some(task))),
cancel_token: cancel,
}
}
}
async fn handle_connection(incoming: crate::endpoint::Incoming, protocols: Arc<ProtocolMap>) {
let mut accepting = match incoming.accept() {
Ok(conn) => conn,
Err(err) => {
warn!("Ignoring connection: accepting failed: {err:#}");
return;
}
};
let alpn = match accepting.alpn().await {
Ok(alpn) => alpn,
Err(err) => {
warn!("Ignoring connection: invalid handshake: {err:#}");
return;
}
};
tracing::Span::current().record("alpn", String::from_utf8_lossy(&alpn).to_string());
let Some(handler) = protocols.get(&alpn) else {
warn!("Ignoring connection: unsupported ALPN protocol");
return;
};
match handler.on_accepting(accepting).await {
Ok(connection) => {
tracing::Span::current().record(
"remote",
tracing::field::display(connection.remote_id().fmt_short()),
);
if let Err(err) = handler.accept(connection).await {
warn!("Handling incoming connection ended with error: {err}");
}
}
Err(err) => {
warn!("Accepting incoming connection ended with error: {err}");
}
}
}
#[derive(derive_more::Debug, Clone)]
pub struct AccessLimit<P: ProtocolHandler + Clone> {
proto: P,
#[debug("limiter")]
limiter: Arc<dyn Fn(EndpointId) -> bool + Send + Sync + 'static>,
}
impl<P: ProtocolHandler + Clone> AccessLimit<P> {
pub fn new<F>(proto: P, limiter: F) -> Self
where
F: Fn(EndpointId) -> bool + Send + Sync + 'static,
{
Self {
proto,
limiter: Arc::new(limiter),
}
}
}
impl<P: ProtocolHandler + Clone> ProtocolHandler for AccessLimit<P> {
fn on_accepting(
&self,
accepting: Accepting,
) -> impl Future<Output = Result<Connection, AcceptError>> + Send {
self.proto.on_accepting(accepting)
}
async fn accept(&self, conn: Connection) -> Result<(), AcceptError> {
let remote = conn.remote_id();
let is_allowed = (self.limiter)(remote);
if !is_allowed {
conn.close(0u32.into(), b"not allowed");
return Err(e!(AcceptError::NotAllowed));
}
self.proto.accept(conn).await?;
Ok(())
}
fn shutdown(&self) -> impl Future<Output = ()> + Send {
self.proto.shutdown()
}
}
#[cfg(all(test, with_crypto_provider))]
mod tests {
use std::{sync::Mutex, time::Duration};
use n0_error::{Result, StdResultExt};
use n0_tracing_test::traced_test;
use super::*;
use crate::endpoint::{
ApplicationClose, BeforeConnectOutcome, ConnectError, ConnectWithOptsError,
ConnectionError, EndpointHooks, presets,
};
#[tokio::test]
async fn test_shutdown() -> Result {
let endpoint = Endpoint::bind(presets::Minimal).await?;
let router = Router::builder(endpoint.clone()).spawn();
assert!(!router.is_shutdown());
assert!(!endpoint.is_closed());
router.shutdown().await.anyerr()?;
assert!(router.is_shutdown());
assert!(endpoint.is_closed());
Ok(())
}
#[derive(Debug, Clone)]
struct Echo;
const ECHO_ALPN: &[u8] = b"/iroh/echo/1";
impl ProtocolHandler for Echo {
async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
println!("accepting echo");
let (mut send, mut recv) = connection.accept_bi().await?;
let _bytes_sent = tokio::io::copy(&mut recv, &mut send).await?;
send.finish()?;
connection.closed().await;
Ok(())
}
}
#[tokio::test]
async fn test_limiter_router() -> Result {
let e1 = Endpoint::bind(presets::Minimal).await?;
let proto = AccessLimit::new(Echo, |_endpoint_id| false);
let r1 = Router::builder(e1.clone()).accept(ECHO_ALPN, proto).spawn();
let addr1 = r1.endpoint().addr();
dbg!(&addr1);
let e2 = Endpoint::bind(presets::Minimal).await?;
println!("connecting");
let conn = e2.connect(addr1, ECHO_ALPN).await?;
let (_send, mut recv) = conn.open_bi().await.anyerr()?;
let response = recv.read_to_end(1000).await.unwrap_err();
assert!(format!("{response:#?}").contains("not allowed"));
r1.shutdown().await.anyerr()?;
e2.close().await;
Ok(())
}
#[tokio::test]
async fn test_limiter_hook() -> Result {
#[derive(Debug, Default)]
struct LimitHook;
impl EndpointHooks for LimitHook {
async fn before_connect<'a>(
&'a self,
_remote_addr: &'a iroh_base::EndpointAddr,
alpn: &'a [u8],
) -> BeforeConnectOutcome {
assert_eq!(alpn, ECHO_ALPN);
BeforeConnectOutcome::Reject
}
}
let e1 = Endpoint::bind(presets::Minimal).await?;
let r1 = Router::builder(e1.clone()).accept(ECHO_ALPN, Echo).spawn();
let addr1 = r1.endpoint().addr();
dbg!(&addr1);
let e2 = Endpoint::builder(presets::Minimal)
.hooks(LimitHook)
.bind()
.await?;
println!("connecting");
let conn_err = e2.connect(addr1, ECHO_ALPN).await.unwrap_err();
assert!(matches!(
conn_err,
ConnectError::Connect {
source: ConnectWithOptsError::LocallyRejected { .. },
..
}
));
r1.shutdown().await.anyerr()?;
e2.close().await;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_accepting_remote_addr() -> Result {
use crate::endpoint::{IncomingAddr, presets};
let e1 = Endpoint::builder(presets::Minimal)
.alpns(vec![ECHO_ALPN.to_vec()])
.bind()
.await?;
let addr1 = e1.addr();
let e2 = Endpoint::bind(presets::Minimal).await?;
let connect_task = tokio::spawn({
let addr1 = addr1.clone();
let e2 = e2.clone();
async move { e2.connect(addr1, ECHO_ALPN).await }
});
let incoming = e1.accept().await.expect("accept");
let incoming_addr = incoming.remote_addr();
assert!(matches!(incoming_addr, IncomingAddr::Ip(_)));
let accepting = incoming.accept().anyerr()?;
assert_eq!(incoming_addr, accepting.remote_addr());
drop(accepting);
drop(connect_task);
e1.close().await;
e2.close().await;
Ok(())
}
mod incoming_filter {
use std::{
sync::{
Arc,
atomic::{AtomicBool, Ordering::Relaxed},
},
time::Duration,
};
use n0_error::{Result, StdResultExt};
use n0_tracing_test::traced_test;
use crate::{
Endpoint, EndpointAddr,
endpoint::presets,
protocol::{
IncomingFilterOutcome, Router,
tests::{ECHO_ALPN, Echo},
},
};
async fn direct_pair<F>(filter: F) -> Result<(Router, Endpoint, EndpointAddr)>
where
F: Fn(&crate::endpoint::Incoming) -> IncomingFilterOutcome + Send + Sync + 'static,
{
let e1 = Endpoint::builder(presets::Minimal)
.clear_ip_transports()
.bind_addr((std::net::Ipv4Addr::LOCALHOST, 0))
.anyerr()?
.bind()
.await?;
let r1 = Router::builder(e1.clone())
.incoming_filter(Arc::new(filter))
.accept(ECHO_ALPN, Echo)
.spawn();
let addr = r1.endpoint().addr();
let e2 = Endpoint::builder(presets::Minimal)
.clear_ip_transports()
.bind_addr((std::net::Ipv4Addr::LOCALHOST, 0))
.anyerr()?
.bind()
.await?;
Ok((r1, e2, addr))
}
async fn relay_pair<F>(
filter: F,
) -> Result<(Router, Endpoint, EndpointAddr, impl std::any::Any)>
where
F: Fn(&crate::endpoint::Incoming) -> IncomingFilterOutcome + Send + Sync + 'static,
{
let (_relay_map, relay_url, guard) =
crate::test_utils::run_relay_server().await.anyerr()?;
let relay_mode = crate::RelayMode::Custom(crate::RelayMap::from(relay_url.clone()));
let e1 = Endpoint::builder(presets::Minimal)
.relay_mode(relay_mode.clone())
.ca_roots_config(crate::tls::CaRootsConfig::insecure_skip_verify())
.bind()
.await?;
let r1 = Router::builder(e1.clone())
.incoming_filter(Arc::new(filter))
.accept(ECHO_ALPN, Echo)
.spawn();
let addr = EndpointAddr::new(e1.id()).with_relay_url(relay_url);
let e2 = Endpoint::builder(presets::Minimal)
.relay_mode(relay_mode)
.ca_roots_config(crate::tls::CaRootsConfig::insecure_skip_verify())
.bind()
.await?;
Ok((r1, e2, addr, guard))
}
#[tokio::test]
#[traced_test]
async fn addr_retry() -> Result {
let (r1, e2, addr) = direct_pair(|incoming| {
if !incoming.remote_addr_validated() {
IncomingFilterOutcome::Retry
} else {
IncomingFilterOutcome::Accept
}
})
.await?;
assert!(e2.connect(addr, ECHO_ALPN).await.is_ok());
r1.shutdown().await.anyerr()?;
e2.close().await;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn addr_reject() -> Result {
let (r1, e2, addr) = direct_pair(|_| IncomingFilterOutcome::Reject).await?;
assert!(e2.connect(addr, ECHO_ALPN).await.is_err());
r1.shutdown().await.anyerr()?;
e2.close().await;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn addr_ignore() -> Result {
let (r1, e2, addr) = direct_pair(|_| IncomingFilterOutcome::Ignore).await?;
let result =
tokio::time::timeout(Duration::from_millis(500), e2.connect(addr, ECHO_ALPN)).await;
assert!(result.is_err(), "expected timeout");
r1.shutdown().await.anyerr()?;
e2.close().await;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn relay_reject() -> Result {
let (r1, e2, addr, _guard) = relay_pair(|_| IncomingFilterOutcome::Reject).await?;
assert!(e2.connect(addr, ECHO_ALPN).await.is_err());
r1.shutdown().await.anyerr()?;
e2.close().await;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn relay_ignore() -> Result {
let (r1, e2, addr, _guard) = relay_pair(|_| IncomingFilterOutcome::Ignore).await?;
let result =
tokio::time::timeout(Duration::from_millis(500), e2.connect(addr, ECHO_ALPN)).await;
assert!(result.is_err(), "expected timeout");
r1.shutdown().await.anyerr()?;
e2.close().await;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn addr_retry_then_validated() -> Result {
let saw_validated = Arc::<AtomicBool>::default();
let saw_unvalidated = Arc::<AtomicBool>::default();
let (sv, su) = (saw_validated.clone(), saw_unvalidated.clone());
let (r1, e2, addr) = direct_pair(move |incoming| {
if incoming.remote_addr_validated() {
sv.store(true, Relaxed);
IncomingFilterOutcome::Accept
} else {
su.store(true, Relaxed);
IncomingFilterOutcome::Retry
}
})
.await?;
let _conn = e2.connect(addr, ECHO_ALPN).await?;
assert!(saw_unvalidated.load(Relaxed));
assert!(saw_validated.load(Relaxed));
r1.shutdown().await.anyerr()?;
e2.close().await;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn relay_retry_then_validated() -> Result {
let saw_validated = Arc::<AtomicBool>::default();
let saw_unvalidated = Arc::<AtomicBool>::default();
let (sv, su) = (saw_validated.clone(), saw_unvalidated.clone());
let (r1, e2, addr, _guard) = relay_pair(move |incoming| {
if incoming.remote_addr_validated() {
sv.store(true, Relaxed);
IncomingFilterOutcome::Accept
} else {
su.store(true, Relaxed);
IncomingFilterOutcome::Retry
}
})
.await?;
let _conn = e2.connect(addr, ECHO_ALPN).await?;
assert!(
saw_unvalidated.load(Relaxed),
"expected unvalidated incoming"
);
assert!(
saw_validated.load(Relaxed),
"expected validated incoming after retry"
);
r1.shutdown().await.anyerr()?;
e2.close().await;
Ok(())
}
}
#[tokio::test]
async fn test_graceful_shutdown() -> Result {
#[derive(Debug, Clone, Default)]
struct TestProtocol {
connections: Arc<Mutex<Vec<Connection>>>,
}
const TEST_ALPN: &[u8] = b"/iroh/test/1";
impl ProtocolHandler for TestProtocol {
async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
self.connections.lock().expect("poisoned").push(connection);
Ok(())
}
async fn shutdown(&self) {
tokio::time::sleep(Duration::from_millis(100)).await;
let mut connections = self.connections.lock().expect("poisoned");
for conn in connections.drain(..) {
conn.close(42u32.into(), b"shutdown");
}
}
}
eprintln!("creating ep1");
let endpoint = Endpoint::bind(presets::Minimal).await?;
let router = Router::builder(endpoint)
.accept(TEST_ALPN, TestProtocol::default())
.spawn();
eprintln!("waiting for endpoint addr");
let addr = router.endpoint().addr();
eprintln!("creating ep2");
let endpoint2 = Endpoint::bind(presets::Minimal).await?;
eprintln!("connecting to {addr:?}");
let conn = endpoint2.connect(addr, TEST_ALPN).await?;
eprintln!("starting shutdown");
router.shutdown().await.anyerr()?;
eprintln!("waiting for closed conn");
let reason = conn.closed().await;
assert_eq!(
reason,
ConnectionError::ApplicationClosed(ApplicationClose {
error_code: 42u32.into(),
reason: b"shutdown".to_vec().into()
})
);
Ok(())
}
}