1use std::{
36 collections::BTreeMap,
37 future::Future,
38 pin::Pin,
39 sync::{Arc, Mutex},
40};
41
42use iroh_base::NodeId;
43use n0_future::{
44 join_all,
45 task::{self, AbortOnDropHandle, JoinSet},
46};
47use snafu::{Backtrace, Snafu};
48use tokio_util::sync::CancellationToken;
49use tracing::{Instrument, error, field::Empty, info_span, trace, warn};
50
51use crate::{
52 Endpoint,
53 endpoint::{Connecting, Connection, RemoteNodeIdError},
54};
55
56#[derive(Clone, Debug)]
87pub struct Router {
88 endpoint: Endpoint,
89 task: Arc<Mutex<Option<AbortOnDropHandle<()>>>>,
91 cancel_token: CancellationToken,
92}
93
94#[derive(Debug)]
96pub struct RouterBuilder {
97 endpoint: Endpoint,
98 protocols: ProtocolMap,
99}
100
101#[allow(missing_docs)]
102#[derive(Debug, Snafu)]
103#[non_exhaustive]
104pub enum AcceptError {
105 #[snafu(transparent)]
106 Connection {
107 source: crate::endpoint::ConnectionError,
108 backtrace: Option<Backtrace>,
109 #[snafu(implicit)]
110 span_trace: n0_snafu::SpanTrace,
111 },
112 #[snafu(transparent)]
113 MissingRemoteNodeId { source: RemoteNodeIdError },
114 #[snafu(display("Not allowed."))]
115 NotAllowed {},
116
117 #[snafu(transparent)]
118 User {
119 source: Box<dyn std::error::Error + Send + Sync + 'static>,
120 },
121}
122
123impl AcceptError {
124 pub fn from_err<T: std::error::Error + Send + Sync + 'static>(value: T) -> Self {
126 Self::User {
127 source: Box::new(value),
128 }
129 }
130}
131
132impl From<std::io::Error> for AcceptError {
133 fn from(err: std::io::Error) -> Self {
134 Self::from_err(err)
135 }
136}
137
138impl From<quinn::ClosedStream> for AcceptError {
139 fn from(err: quinn::ClosedStream) -> Self {
140 Self::from_err(err)
141 }
142}
143
144pub trait ProtocolHandler: Send + Sync + std::fmt::Debug + 'static {
156 fn on_connecting(
162 &self,
163 connecting: Connecting,
164 ) -> impl Future<Output = Result<Connection, AcceptError>> + Send {
165 async move {
166 let conn = connecting.await?;
167 Ok(conn)
168 }
169 }
170
171 fn accept(
181 &self,
182 connection: Connection,
183 ) -> impl Future<Output = Result<(), AcceptError>> + Send;
184
185 fn shutdown(&self) -> impl Future<Output = ()> + Send {
192 async move {}
193 }
194}
195
196impl<T: ProtocolHandler> ProtocolHandler for Arc<T> {
197 async fn on_connecting(&self, conn: Connecting) -> Result<Connection, AcceptError> {
198 self.as_ref().on_connecting(conn).await
199 }
200
201 async fn accept(&self, conn: Connection) -> Result<(), AcceptError> {
202 self.as_ref().accept(conn).await
203 }
204
205 async fn shutdown(&self) {
206 self.as_ref().shutdown().await
207 }
208}
209
210impl<T: ProtocolHandler> ProtocolHandler for Box<T> {
211 async fn on_connecting(&self, conn: Connecting) -> Result<Connection, AcceptError> {
212 self.as_ref().on_connecting(conn).await
213 }
214
215 async fn accept(&self, conn: Connection) -> Result<(), AcceptError> {
216 self.as_ref().accept(conn).await
217 }
218
219 async fn shutdown(&self) {
220 self.as_ref().shutdown().await
221 }
222}
223
224impl<T: ProtocolHandler> From<T> for Box<dyn DynProtocolHandler> {
225 fn from(value: T) -> Self {
226 Box::new(value)
227 }
228}
229
230pub trait DynProtocolHandler: Send + Sync + std::fmt::Debug + 'static {
239 fn on_connecting(
241 &self,
242 connecting: Connecting,
243 ) -> Pin<Box<dyn Future<Output = Result<Connection, AcceptError>> + Send + '_>> {
244 Box::pin(async move {
245 let conn = connecting.await?;
246 Ok(conn)
247 })
248 }
249
250 fn accept(
252 &self,
253 connection: Connection,
254 ) -> Pin<Box<dyn Future<Output = Result<(), AcceptError>> + Send + '_>>;
255
256 fn shutdown(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
258 Box::pin(async move {})
259 }
260}
261
262impl<P: ProtocolHandler> DynProtocolHandler for P {
263 fn accept(
264 &self,
265 connection: Connection,
266 ) -> Pin<Box<dyn Future<Output = Result<(), AcceptError>> + Send + '_>> {
267 Box::pin(<Self as ProtocolHandler>::accept(self, connection))
268 }
269
270 fn on_connecting(
271 &self,
272 connecting: Connecting,
273 ) -> Pin<Box<dyn Future<Output = Result<Connection, AcceptError>> + Send + '_>> {
274 Box::pin(<Self as ProtocolHandler>::on_connecting(self, connecting))
275 }
276
277 fn shutdown(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
278 Box::pin(<Self as ProtocolHandler>::shutdown(self))
279 }
280}
281
282#[derive(Debug, Default)]
284pub(crate) struct ProtocolMap(BTreeMap<Vec<u8>, Box<dyn DynProtocolHandler>>);
285
286impl ProtocolMap {
287 pub(crate) fn get(&self, alpn: &[u8]) -> Option<&dyn DynProtocolHandler> {
289 self.0.get(alpn).map(|p| &**p)
290 }
291
292 pub(crate) fn insert(&mut self, alpn: Vec<u8>, handler: Box<dyn DynProtocolHandler>) {
294 self.0.insert(alpn, handler);
295 }
296
297 pub(crate) fn alpns(&self) -> impl Iterator<Item = &Vec<u8>> {
299 self.0.keys()
300 }
301
302 pub(crate) async fn shutdown(&self) {
306 let handlers = self.0.values().map(|p| p.shutdown());
307 join_all(handlers).await;
308 }
309}
310
311impl Router {
312 pub fn builder(endpoint: Endpoint) -> RouterBuilder {
314 RouterBuilder::new(endpoint)
315 }
316
317 pub fn endpoint(&self) -> &Endpoint {
319 &self.endpoint
320 }
321
322 pub fn is_shutdown(&self) -> bool {
324 self.cancel_token.is_cancelled()
325 }
326
327 pub async fn shutdown(&self) -> Result<(), n0_future::task::JoinError> {
337 if self.is_shutdown() {
338 return Ok(());
339 }
340
341 self.cancel_token.cancel();
343
344 let task = self.task.lock().expect("poisoned").take();
348 if let Some(task) = task {
349 task.await?;
350 }
351
352 Ok(())
353 }
354}
355
356impl RouterBuilder {
357 pub fn new(endpoint: Endpoint) -> Self {
359 Self {
360 endpoint,
361 protocols: ProtocolMap::default(),
362 }
363 }
364
365 pub fn accept(
373 mut self,
374 alpn: impl AsRef<[u8]>,
375 handler: impl Into<Box<dyn DynProtocolHandler>>,
376 ) -> Self {
377 self.protocols
378 .insert(alpn.as_ref().to_vec(), handler.into());
379 self
380 }
381
382 pub fn endpoint(&self) -> &Endpoint {
384 &self.endpoint
385 }
386
387 pub fn spawn(self) -> Router {
389 let alpns = self
391 .protocols
392 .alpns()
393 .map(|alpn| alpn.to_vec())
394 .collect::<Vec<_>>();
395
396 let protocols = Arc::new(self.protocols);
397 self.endpoint.set_alpns(alpns);
398
399 let mut join_set = JoinSet::new();
400 let endpoint = self.endpoint.clone();
401
402 let cancel = CancellationToken::new();
404 let cancel_token = cancel.clone();
405
406 let run_loop_fut = async move {
407 let _cancel_guard = cancel_token.clone().drop_guard();
409 let handler_cancel_token = CancellationToken::new();
412
413 loop {
414 tokio::select! {
415 biased;
416 _ = cancel_token.cancelled() => {
417 break;
418 },
419 Some(res) = join_set.join_next() => {
421 match res {
422 Err(outer) => {
423 if outer.is_panic() {
424 error!("Task panicked: {outer:?}");
425 break;
426 } else if outer.is_cancelled() {
427 trace!("Task cancelled: {outer:?}");
428 } else {
429 error!("Task failed: {outer:?}");
430 break;
431 }
432 }
433 Ok(Some(())) => {
434 trace!("Task finished");
435 }
436 Ok(None) => {
437 trace!("Task cancelled");
438 }
439 }
440 },
441
442 incoming = endpoint.accept() => {
444 let Some(incoming) = incoming else {
445 break; };
447
448 let protocols = protocols.clone();
449 let token = handler_cancel_token.child_token();
450 let span = info_span!("router.accept", me=%endpoint.node_id().fmt_short(), remote=Empty, alpn=Empty);
451 join_set.spawn(async move {
452 token.run_until_cancelled(handle_connection(incoming, protocols)).await
453 }.instrument(span));
454 },
455 }
456 }
457
458 protocols.shutdown().await;
460 handler_cancel_token.cancel();
462 endpoint.close().await;
464 tracing::debug!("Shutting down remaining tasks");
467 join_set.abort_all();
468 while let Some(res) = join_set.join_next().await {
469 match res {
470 Err(err) if err.is_panic() => error!("Task panicked: {err:?}"),
471 _ => {}
472 }
473 }
474 };
475 let task = task::spawn(run_loop_fut.instrument(tracing::Span::current()));
476 let task = AbortOnDropHandle::new(task);
477
478 Router {
479 endpoint: self.endpoint,
480 task: Arc::new(Mutex::new(Some(task))),
481 cancel_token: cancel,
482 }
483 }
484}
485
486async fn handle_connection(incoming: crate::endpoint::Incoming, protocols: Arc<ProtocolMap>) {
487 let mut connecting = match incoming.accept() {
488 Ok(conn) => conn,
489 Err(err) => {
490 warn!("Ignoring connection: accepting failed: {err:#}");
491 return;
492 }
493 };
494 let alpn = match connecting.alpn().await {
495 Ok(alpn) => alpn,
496 Err(err) => {
497 warn!("Ignoring connection: invalid handshake: {err:#}");
498 return;
499 }
500 };
501 tracing::Span::current().record("alpn", String::from_utf8_lossy(&alpn).to_string());
502 let Some(handler) = protocols.get(&alpn) else {
503 warn!("Ignoring connection: unsupported ALPN protocol");
504 return;
505 };
506 match handler.on_connecting(connecting).await {
507 Ok(connection) => {
508 if let Ok(remote) = connection.remote_node_id() {
509 tracing::Span::current()
510 .record("remote", tracing::field::display(remote.fmt_short()));
511 };
512 if let Err(err) = handler.accept(connection).await {
513 warn!("Handling incoming connection ended with error: {err}");
514 }
515 }
516 Err(err) => {
517 warn!("Handling incoming connecting ended with error: {err}");
518 }
519 }
520}
521
522#[derive(derive_more::Debug, Clone)]
527pub struct AccessLimit<P: ProtocolHandler + Clone> {
528 proto: P,
529 #[debug("limiter")]
530 limiter: Arc<dyn Fn(NodeId) -> bool + Send + Sync + 'static>,
531}
532
533impl<P: ProtocolHandler + Clone> AccessLimit<P> {
534 pub fn new<F>(proto: P, limiter: F) -> Self
539 where
540 F: Fn(NodeId) -> bool + Send + Sync + 'static,
541 {
542 Self {
543 proto,
544 limiter: Arc::new(limiter),
545 }
546 }
547}
548
549impl<P: ProtocolHandler + Clone> ProtocolHandler for AccessLimit<P> {
550 fn on_connecting(
551 &self,
552 conn: Connecting,
553 ) -> impl Future<Output = Result<Connection, AcceptError>> + Send {
554 self.proto.on_connecting(conn)
555 }
556
557 async fn accept(&self, conn: Connection) -> Result<(), AcceptError> {
558 let remote = conn.remote_node_id()?;
559 let is_allowed = (self.limiter)(remote);
560 if !is_allowed {
561 conn.close(0u32.into(), b"not allowed");
562 return Err(NotAllowedSnafu.build());
563 }
564 self.proto.accept(conn).await?;
565 Ok(())
566 }
567
568 fn shutdown(&self) -> impl Future<Output = ()> + Send {
569 self.proto.shutdown()
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use std::{sync::Mutex, time::Duration};
576
577 use n0_snafu::{Result, ResultExt};
578 use quinn::ApplicationClose;
579
580 use super::*;
581 use crate::{RelayMode, endpoint::ConnectionError};
582
583 #[tokio::test]
584 async fn test_shutdown() -> Result {
585 let endpoint = Endpoint::builder().bind().await?;
586 let router = Router::builder(endpoint.clone()).spawn();
587
588 assert!(!router.is_shutdown());
589 assert!(!endpoint.is_closed());
590
591 router.shutdown().await.e()?;
592
593 assert!(router.is_shutdown());
594 assert!(endpoint.is_closed());
595
596 Ok(())
597 }
598
599 #[derive(Debug, Clone)]
601 struct Echo;
602
603 const ECHO_ALPN: &[u8] = b"/iroh/echo/1";
604
605 impl ProtocolHandler for Echo {
606 async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
607 println!("accepting echo");
608 let (mut send, mut recv) = connection.accept_bi().await?;
609
610 let _bytes_sent = tokio::io::copy(&mut recv, &mut send).await?;
612
613 send.finish()?;
614 connection.closed().await;
615
616 Ok(())
617 }
618 }
619
620 #[tokio::test]
621 async fn test_limiter() -> Result {
622 let e1 = Endpoint::builder()
624 .relay_mode(RelayMode::Disabled)
625 .bind()
626 .await?;
627 let proto = AccessLimit::new(Echo, |_node_id| false);
629 let r1 = Router::builder(e1.clone()).accept(ECHO_ALPN, proto).spawn();
630
631 let addr1 = r1.endpoint().node_addr();
632 dbg!(&addr1);
633 let e2 = Endpoint::builder()
634 .relay_mode(RelayMode::Disabled)
635 .bind()
636 .await?;
637
638 println!("connecting");
639 let conn = e2.connect(addr1, ECHO_ALPN).await?;
640
641 let (_send, mut recv) = conn.open_bi().await.e()?;
642 let response = recv.read_to_end(1000).await.unwrap_err();
643 assert!(format!("{response:#?}").contains("not allowed"));
644
645 r1.shutdown().await.e()?;
646 e2.close().await;
647
648 Ok(())
649 }
650
651 #[tokio::test]
652 async fn test_graceful_shutdown() -> Result {
653 #[derive(Debug, Clone, Default)]
654 struct TestProtocol {
655 connections: Arc<Mutex<Vec<Connection>>>,
656 }
657
658 const TEST_ALPN: &[u8] = b"/iroh/test/1";
659
660 impl ProtocolHandler for TestProtocol {
661 async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
662 self.connections.lock().expect("poisoned").push(connection);
663 Ok(())
664 }
665
666 async fn shutdown(&self) {
667 tokio::time::sleep(Duration::from_millis(100)).await;
668 let mut connections = self.connections.lock().expect("poisoned");
669 for conn in connections.drain(..) {
670 conn.close(42u32.into(), b"shutdown");
671 }
672 }
673 }
674
675 eprintln!("creating ep1");
676 let endpoint = Endpoint::builder()
677 .relay_mode(RelayMode::Disabled)
678 .bind()
679 .await?;
680 let router = Router::builder(endpoint)
681 .accept(TEST_ALPN, TestProtocol::default())
682 .spawn();
683 eprintln!("waiting for node addr");
684 let addr = router.endpoint().node_addr();
685
686 eprintln!("creating ep2");
687 let endpoint2 = Endpoint::builder()
688 .relay_mode(RelayMode::Disabled)
689 .bind()
690 .await?;
691 eprintln!("connecting to {addr:?}");
692 let conn = endpoint2.connect(addr, TEST_ALPN).await?;
693
694 eprintln!("starting shutdown");
695 router.shutdown().await.e()?;
696
697 eprintln!("waiting for closed conn");
698 let reason = conn.closed().await;
699 assert_eq!(
700 reason,
701 ConnectionError::ApplicationClosed(ApplicationClose {
702 error_code: 42u32.into(),
703 reason: b"shutdown".to_vec().into()
704 })
705 );
706 Ok(())
707 }
708}