iroh/
protocol.rs

1//! Tools for spawning an accept loop that routes incoming requests to the right protocol.
2//!
3//! ## Example
4//!
5//! ```no_run
6//! # use iroh::{endpoint::{Connection, BindError}, protocol::{AcceptError, ProtocolHandler, Router}, Endpoint, NodeAddr};
7//! #
8//! # async fn test_compile() -> Result<(), BindError> {
9//! let endpoint = Endpoint::builder().discovery_n0().bind().await?;
10//!
11//! let router = Router::builder(endpoint)
12//!     .accept(b"/my/alpn", Echo)
13//!     .spawn();
14//! # Ok(())
15//! # }
16//!
17//! // The protocol definition:
18//! #[derive(Debug, Clone)]
19//! struct Echo;
20//!
21//! impl ProtocolHandler for Echo {
22//!     async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
23//!         let (mut send, mut recv) = connection.accept_bi().await?;
24//!
25//!         // Echo any bytes received back directly.
26//!         let bytes_sent = tokio::io::copy(&mut recv, &mut send).await?;
27//!
28//!         send.finish()?;
29//!         connection.closed().await;
30//!
31//!         Ok(())
32//!     }
33//! }
34//! ```
35use std::{
36    collections::BTreeMap,
37    future::Future,
38    pin::Pin,
39    sync::{Arc, Mutex},
40};
41
42use iroh_base::NodeId;
43use n0_future::{
44    join_all,
45    task::{self, AbortOnDropHandle, JoinSet},
46};
47use snafu::{Backtrace, Snafu};
48use tokio_util::sync::CancellationToken;
49use tracing::{Instrument, error, field::Empty, info_span, trace, warn};
50
51use crate::{
52    Endpoint,
53    endpoint::{Connecting, Connection, RemoteNodeIdError},
54};
55
56/// The built router.
57///
58/// Construct this using [`Router::builder`].
59///
60/// When dropped, this will abort listening the tasks, so make sure to store it.
61///
62/// Even with this abort-on-drop behaviour, it's recommended to call and await
63/// [`Router::shutdown`] before ending the process.
64///
65/// As an example for graceful shutdown, e.g. for tests or CLI tools,
66/// wait for [`tokio::signal::ctrl_c()`]:
67///
68/// ```no_run
69/// # use std::sync::Arc;
70/// # use n0_snafu::ResultExt;
71/// # use iroh::{endpoint::Connecting, protocol::{ProtocolHandler, Router}, Endpoint, NodeAddr};
72/// #
73/// # async fn test_compile() -> n0_snafu::Result<()> {
74/// let endpoint = Endpoint::builder().discovery_n0().bind().await?;
75///
76/// let router = Router::builder(endpoint)
77///     // .accept(&ALPN, <something>)
78///     .spawn();
79///
80/// // wait until the user wants to
81/// tokio::signal::ctrl_c().await.context("ctrl+c")?;
82/// router.shutdown().await.context("shutdown")?;
83/// # Ok(())
84/// # }
85/// ```
86#[derive(Clone, Debug)]
87pub struct Router {
88    endpoint: Endpoint,
89    // `Router` needs to be `Clone + Send`, and we need to `task.await` in its `shutdown()` impl.
90    task: Arc<Mutex<Option<AbortOnDropHandle<()>>>>,
91    cancel_token: CancellationToken,
92}
93
94/// Builder for creating a [`Router`] for accepting protocols.
95#[derive(Debug)]
96pub struct RouterBuilder {
97    endpoint: Endpoint,
98    protocols: ProtocolMap,
99}
100
101#[allow(missing_docs)]
102#[derive(Debug, Snafu)]
103#[non_exhaustive]
104pub enum AcceptError {
105    #[snafu(transparent)]
106    Connection {
107        source: crate::endpoint::ConnectionError,
108        backtrace: Option<Backtrace>,
109        #[snafu(implicit)]
110        span_trace: n0_snafu::SpanTrace,
111    },
112    #[snafu(transparent)]
113    MissingRemoteNodeId { source: RemoteNodeIdError },
114    #[snafu(display("Not allowed."))]
115    NotAllowed {},
116
117    #[snafu(transparent)]
118    User {
119        source: Box<dyn std::error::Error + Send + Sync + 'static>,
120    },
121}
122
123impl AcceptError {
124    /// Creates a new user error from an arbitrary error type.
125    pub fn from_err<T: std::error::Error + Send + Sync + 'static>(value: T) -> Self {
126        Self::User {
127            source: Box::new(value),
128        }
129    }
130}
131
132impl From<std::io::Error> for AcceptError {
133    fn from(err: std::io::Error) -> Self {
134        Self::from_err(err)
135    }
136}
137
138impl From<quinn::ClosedStream> for AcceptError {
139    fn from(err: quinn::ClosedStream) -> Self {
140        Self::from_err(err)
141    }
142}
143
144/// Handler for incoming connections.
145///
146/// A router accepts connections for arbitrary ALPN protocols.
147///
148/// With this trait, you can handle incoming connections for any protocol.
149///
150/// Implement this trait on a struct that should handle incoming connections.
151/// The protocol handler must then be registered on the node for an ALPN protocol with
152/// [`crate::protocol::RouterBuilder::accept`].
153///
154/// See the [module documentation](crate::protocol) for an example.
155pub trait ProtocolHandler: Send + Sync + std::fmt::Debug + 'static {
156    /// Optional interception point to handle the `Connecting` state.
157    ///
158    /// Can be implemented as `async fn on_connecting(&self, connecting: Connecting) -> Result<Connection>`.
159    ///
160    /// This enables accepting 0-RTT data from clients, among other things.
161    fn on_connecting(
162        &self,
163        connecting: Connecting,
164    ) -> impl Future<Output = Result<Connection, AcceptError>> + Send {
165        async move {
166            let conn = connecting.await?;
167            Ok(conn)
168        }
169    }
170
171    /// Handle an incoming connection.
172    ///
173    /// Can be implemented as `async fn accept(&self, connection: Connection) -> Result<Connection>`.
174    ///
175    /// The returned future runs on a freshly spawned tokio task so it can be long-running.
176    ///
177    /// When [`Router::shutdown`] is called, no further connections will be accepted, and
178    /// the futures returned by [`Self::accept`] will be aborted after the future returned
179    /// from [`ProtocolHandler::shutdown`] completes.
180    fn accept(
181        &self,
182        connection: Connection,
183    ) -> impl Future<Output = Result<(), AcceptError>> + Send;
184
185    /// Called when the router shuts down.
186    ///
187    /// Can be implemented as `async fn shutdown(&self)`.
188    ///
189    /// This is called from [`Router::shutdown`]. The returned future is awaited before
190    /// the router closes the endpoint.
191    fn shutdown(&self) -> impl Future<Output = ()> + Send {
192        async move {}
193    }
194}
195
196impl<T: ProtocolHandler> ProtocolHandler for Arc<T> {
197    async fn on_connecting(&self, conn: Connecting) -> Result<Connection, AcceptError> {
198        self.as_ref().on_connecting(conn).await
199    }
200
201    async fn accept(&self, conn: Connection) -> Result<(), AcceptError> {
202        self.as_ref().accept(conn).await
203    }
204
205    async fn shutdown(&self) {
206        self.as_ref().shutdown().await
207    }
208}
209
210impl<T: ProtocolHandler> ProtocolHandler for Box<T> {
211    async fn on_connecting(&self, conn: Connecting) -> Result<Connection, AcceptError> {
212        self.as_ref().on_connecting(conn).await
213    }
214
215    async fn accept(&self, conn: Connection) -> Result<(), AcceptError> {
216        self.as_ref().accept(conn).await
217    }
218
219    async fn shutdown(&self) {
220        self.as_ref().shutdown().await
221    }
222}
223
224impl<T: ProtocolHandler> From<T> for Box<dyn DynProtocolHandler> {
225    fn from(value: T) -> Self {
226        Box::new(value)
227    }
228}
229
230/// A dyn-compatible version of [`ProtocolHandler`] that returns boxed futures.
231///
232/// Any type that implements [`ProtocolHandler`] automatically also implements [`DynProtocolHandler`].
233/// There is a also [`From`] impl to turn any type that implements [`ProtocolHandler`] into a
234/// `Box<dyn DynProtocolHandler>`.
235//
236// We are not using [`n0_future::boxed::BoxFuture] because we don't need a `'static` bound
237// on these futures.
238pub trait DynProtocolHandler: Send + Sync + std::fmt::Debug + 'static {
239    /// See [`ProtocolHandler::on_connecting`].
240    fn on_connecting(
241        &self,
242        connecting: Connecting,
243    ) -> Pin<Box<dyn Future<Output = Result<Connection, AcceptError>> + Send + '_>> {
244        Box::pin(async move {
245            let conn = connecting.await?;
246            Ok(conn)
247        })
248    }
249
250    /// See [`ProtocolHandler::accept`].
251    fn accept(
252        &self,
253        connection: Connection,
254    ) -> Pin<Box<dyn Future<Output = Result<(), AcceptError>> + Send + '_>>;
255
256    /// See [`ProtocolHandler::shutdown`].
257    fn shutdown(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
258        Box::pin(async move {})
259    }
260}
261
262impl<P: ProtocolHandler> DynProtocolHandler for P {
263    fn accept(
264        &self,
265        connection: Connection,
266    ) -> Pin<Box<dyn Future<Output = Result<(), AcceptError>> + Send + '_>> {
267        Box::pin(<Self as ProtocolHandler>::accept(self, connection))
268    }
269
270    fn on_connecting(
271        &self,
272        connecting: Connecting,
273    ) -> Pin<Box<dyn Future<Output = Result<Connection, AcceptError>> + Send + '_>> {
274        Box::pin(<Self as ProtocolHandler>::on_connecting(self, connecting))
275    }
276
277    fn shutdown(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
278        Box::pin(<Self as ProtocolHandler>::shutdown(self))
279    }
280}
281
282/// A typed map of protocol handlers, mapping them from ALPNs.
283#[derive(Debug, Default)]
284pub(crate) struct ProtocolMap(BTreeMap<Vec<u8>, Box<dyn DynProtocolHandler>>);
285
286impl ProtocolMap {
287    /// Returns the registered protocol handler for an ALPN as a [`Arc<dyn ProtocolHandler>`].
288    pub(crate) fn get(&self, alpn: &[u8]) -> Option<&dyn DynProtocolHandler> {
289        self.0.get(alpn).map(|p| &**p)
290    }
291
292    /// Inserts a protocol handler.
293    pub(crate) fn insert(&mut self, alpn: Vec<u8>, handler: Box<dyn DynProtocolHandler>) {
294        self.0.insert(alpn, handler);
295    }
296
297    /// Returns an iterator of all registered ALPN protocol identifiers.
298    pub(crate) fn alpns(&self) -> impl Iterator<Item = &Vec<u8>> {
299        self.0.keys()
300    }
301
302    /// Shuts down all protocol handlers.
303    ///
304    /// Calls and awaits [`ProtocolHandler::shutdown`] for all registered handlers concurrently.
305    pub(crate) async fn shutdown(&self) {
306        let handlers = self.0.values().map(|p| p.shutdown());
307        join_all(handlers).await;
308    }
309}
310
311impl Router {
312    /// Creates a new [`Router`] using given [`Endpoint`].
313    pub fn builder(endpoint: Endpoint) -> RouterBuilder {
314        RouterBuilder::new(endpoint)
315    }
316
317    /// Returns the [`Endpoint`] stored in this router.
318    pub fn endpoint(&self) -> &Endpoint {
319        &self.endpoint
320    }
321
322    /// Checks if the router is already shutdown.
323    pub fn is_shutdown(&self) -> bool {
324        self.cancel_token.is_cancelled()
325    }
326
327    /// Shuts down the accept loop cleanly.
328    ///
329    /// When this function returns, all [`ProtocolHandler`]s will be shutdown and
330    /// `Endpoint::close` will have been called.
331    ///
332    /// If already shutdown, it returns `Ok`.
333    ///
334    /// If some [`ProtocolHandler`] panicked in the accept loop, this will propagate
335    /// that panic into the result here.
336    pub async fn shutdown(&self) -> Result<(), n0_future::task::JoinError> {
337        if self.is_shutdown() {
338            return Ok(());
339        }
340
341        // Trigger shutdown of the main run task by activating the cancel token.
342        self.cancel_token.cancel();
343
344        // Wait for the main task to terminate.
345
346        // MutexGuard is not held across await point
347        let task = self.task.lock().expect("poisoned").take();
348        if let Some(task) = task {
349            task.await?;
350        }
351
352        Ok(())
353    }
354}
355
356impl RouterBuilder {
357    /// Creates a new router builder using given [`Endpoint`].
358    pub fn new(endpoint: Endpoint) -> Self {
359        Self {
360            endpoint,
361            protocols: ProtocolMap::default(),
362        }
363    }
364
365    /// Configures the router to accept the [`ProtocolHandler`] when receiving a connection
366    /// with this `alpn`.
367    ///
368    /// `handler` can either be a type that implements [`ProtocolHandler`] or a
369    /// [`Box<dyn DynProtocolHandler>`].
370    ///
371    /// [`Box<dyn DynProtocolHandler>`]: DynProtocolHandler
372    pub fn accept(
373        mut self,
374        alpn: impl AsRef<[u8]>,
375        handler: impl Into<Box<dyn DynProtocolHandler>>,
376    ) -> Self {
377        self.protocols
378            .insert(alpn.as_ref().to_vec(), handler.into());
379        self
380    }
381
382    /// Returns the [`Endpoint`] of the node.
383    pub fn endpoint(&self) -> &Endpoint {
384        &self.endpoint
385    }
386
387    /// Spawns an accept loop and returns a handle to it encapsulated as the [`Router`].
388    pub fn spawn(self) -> Router {
389        // Update the endpoint with our alpns.
390        let alpns = self
391            .protocols
392            .alpns()
393            .map(|alpn| alpn.to_vec())
394            .collect::<Vec<_>>();
395
396        let protocols = Arc::new(self.protocols);
397        self.endpoint.set_alpns(alpns);
398
399        let mut join_set = JoinSet::new();
400        let endpoint = self.endpoint.clone();
401
402        // Our own shutdown works with a cancellation token.
403        let cancel = CancellationToken::new();
404        let cancel_token = cancel.clone();
405
406        let run_loop_fut = async move {
407            // Make sure to cancel the token, if this future ever exits.
408            let _cancel_guard = cancel_token.clone().drop_guard();
409            // We create a separate cancellation token to stop any `ProtocolHandler::accept` futures
410            // that are still running after `ProtocolHandler::shutdown` was called.
411            let handler_cancel_token = CancellationToken::new();
412
413            loop {
414                tokio::select! {
415                    biased;
416                    _ = cancel_token.cancelled() => {
417                        break;
418                    },
419                    // handle task terminations and quit on panics.
420                    Some(res) = join_set.join_next() => {
421                        match res {
422                            Err(outer) => {
423                                if outer.is_panic() {
424                                    error!("Task panicked: {outer:?}");
425                                    break;
426                                } else if outer.is_cancelled() {
427                                    trace!("Task cancelled: {outer:?}");
428                                } else {
429                                    error!("Task failed: {outer:?}");
430                                    break;
431                                }
432                            }
433                            Ok(Some(())) => {
434                                trace!("Task finished");
435                            }
436                            Ok(None) => {
437                                trace!("Task cancelled");
438                            }
439                        }
440                    },
441
442                    // handle incoming p2p connections.
443                    incoming = endpoint.accept() => {
444                        let Some(incoming) = incoming else {
445                            break; // Endpoint is closed.
446                        };
447
448                        let protocols = protocols.clone();
449                        let token = handler_cancel_token.child_token();
450                        let span = info_span!("router.accept", me=%endpoint.node_id().fmt_short(), remote=Empty, alpn=Empty);
451                        join_set.spawn(async move {
452                            token.run_until_cancelled(handle_connection(incoming, protocols)).await
453                        }.instrument(span));
454                    },
455                }
456            }
457
458            // We first shutdown the protocol handlers to give them a chance to close connections gracefully.
459            protocols.shutdown().await;
460            // We now cancel the remaining `ProtocolHandler::accept` futures.
461            handler_cancel_token.cancel();
462            // Now we close the endpoint. This will force-close all connections that are not yet closed.
463            endpoint.close().await;
464            // Finally, we abort the remaining accept tasks. This should be a noop because we already cancelled
465            // the futures above.
466            tracing::debug!("Shutting down remaining tasks");
467            join_set.abort_all();
468            while let Some(res) = join_set.join_next().await {
469                match res {
470                    Err(err) if err.is_panic() => error!("Task panicked: {err:?}"),
471                    _ => {}
472                }
473            }
474        };
475        let task = task::spawn(run_loop_fut.instrument(tracing::Span::current()));
476        let task = AbortOnDropHandle::new(task);
477
478        Router {
479            endpoint: self.endpoint,
480            task: Arc::new(Mutex::new(Some(task))),
481            cancel_token: cancel,
482        }
483    }
484}
485
486async fn handle_connection(incoming: crate::endpoint::Incoming, protocols: Arc<ProtocolMap>) {
487    let mut connecting = match incoming.accept() {
488        Ok(conn) => conn,
489        Err(err) => {
490            warn!("Ignoring connection: accepting failed: {err:#}");
491            return;
492        }
493    };
494    let alpn = match connecting.alpn().await {
495        Ok(alpn) => alpn,
496        Err(err) => {
497            warn!("Ignoring connection: invalid handshake: {err:#}");
498            return;
499        }
500    };
501    tracing::Span::current().record("alpn", String::from_utf8_lossy(&alpn).to_string());
502    let Some(handler) = protocols.get(&alpn) else {
503        warn!("Ignoring connection: unsupported ALPN protocol");
504        return;
505    };
506    match handler.on_connecting(connecting).await {
507        Ok(connection) => {
508            if let Ok(remote) = connection.remote_node_id() {
509                tracing::Span::current()
510                    .record("remote", tracing::field::display(remote.fmt_short()));
511            };
512            if let Err(err) = handler.accept(connection).await {
513                warn!("Handling incoming connection ended with error: {err}");
514            }
515        }
516        Err(err) => {
517            warn!("Handling incoming connecting ended with error: {err}");
518        }
519    }
520}
521
522/// Wraps an existing protocol, limiting its access,
523/// based on the provided function.
524///
525/// Any refused connection will be closed with an error code of `0` and reason `not allowed`.
526#[derive(derive_more::Debug, Clone)]
527pub struct AccessLimit<P: ProtocolHandler + Clone> {
528    proto: P,
529    #[debug("limiter")]
530    limiter: Arc<dyn Fn(NodeId) -> bool + Send + Sync + 'static>,
531}
532
533impl<P: ProtocolHandler + Clone> AccessLimit<P> {
534    /// Create a new `AccessLimit`.
535    ///
536    /// The function should return `true` for nodes that are allowed to
537    /// connect, and `false` otherwise.
538    pub fn new<F>(proto: P, limiter: F) -> Self
539    where
540        F: Fn(NodeId) -> bool + Send + Sync + 'static,
541    {
542        Self {
543            proto,
544            limiter: Arc::new(limiter),
545        }
546    }
547}
548
549impl<P: ProtocolHandler + Clone> ProtocolHandler for AccessLimit<P> {
550    fn on_connecting(
551        &self,
552        conn: Connecting,
553    ) -> impl Future<Output = Result<Connection, AcceptError>> + Send {
554        self.proto.on_connecting(conn)
555    }
556
557    async fn accept(&self, conn: Connection) -> Result<(), AcceptError> {
558        let remote = conn.remote_node_id()?;
559        let is_allowed = (self.limiter)(remote);
560        if !is_allowed {
561            conn.close(0u32.into(), b"not allowed");
562            return Err(NotAllowedSnafu.build());
563        }
564        self.proto.accept(conn).await?;
565        Ok(())
566    }
567
568    fn shutdown(&self) -> impl Future<Output = ()> + Send {
569        self.proto.shutdown()
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use std::{sync::Mutex, time::Duration};
576
577    use n0_snafu::{Result, ResultExt};
578    use quinn::ApplicationClose;
579
580    use super::*;
581    use crate::{RelayMode, endpoint::ConnectionError};
582
583    #[tokio::test]
584    async fn test_shutdown() -> Result {
585        let endpoint = Endpoint::builder().bind().await?;
586        let router = Router::builder(endpoint.clone()).spawn();
587
588        assert!(!router.is_shutdown());
589        assert!(!endpoint.is_closed());
590
591        router.shutdown().await.e()?;
592
593        assert!(router.is_shutdown());
594        assert!(endpoint.is_closed());
595
596        Ok(())
597    }
598
599    // The protocol definition:
600    #[derive(Debug, Clone)]
601    struct Echo;
602
603    const ECHO_ALPN: &[u8] = b"/iroh/echo/1";
604
605    impl ProtocolHandler for Echo {
606        async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
607            println!("accepting echo");
608            let (mut send, mut recv) = connection.accept_bi().await?;
609
610            // Echo any bytes received back directly.
611            let _bytes_sent = tokio::io::copy(&mut recv, &mut send).await?;
612
613            send.finish()?;
614            connection.closed().await;
615
616            Ok(())
617        }
618    }
619
620    #[tokio::test]
621    async fn test_limiter() -> Result {
622        // tracing_subscriber::fmt::try_init().ok();
623        let e1 = Endpoint::builder()
624            .relay_mode(RelayMode::Disabled)
625            .bind()
626            .await?;
627        // deny all access
628        let proto = AccessLimit::new(Echo, |_node_id| false);
629        let r1 = Router::builder(e1.clone()).accept(ECHO_ALPN, proto).spawn();
630
631        let addr1 = r1.endpoint().node_addr();
632        dbg!(&addr1);
633        let e2 = Endpoint::builder()
634            .relay_mode(RelayMode::Disabled)
635            .bind()
636            .await?;
637
638        println!("connecting");
639        let conn = e2.connect(addr1, ECHO_ALPN).await?;
640
641        let (_send, mut recv) = conn.open_bi().await.e()?;
642        let response = recv.read_to_end(1000).await.unwrap_err();
643        assert!(format!("{response:#?}").contains("not allowed"));
644
645        r1.shutdown().await.e()?;
646        e2.close().await;
647
648        Ok(())
649    }
650
651    #[tokio::test]
652    async fn test_graceful_shutdown() -> Result {
653        #[derive(Debug, Clone, Default)]
654        struct TestProtocol {
655            connections: Arc<Mutex<Vec<Connection>>>,
656        }
657
658        const TEST_ALPN: &[u8] = b"/iroh/test/1";
659
660        impl ProtocolHandler for TestProtocol {
661            async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
662                self.connections.lock().expect("poisoned").push(connection);
663                Ok(())
664            }
665
666            async fn shutdown(&self) {
667                tokio::time::sleep(Duration::from_millis(100)).await;
668                let mut connections = self.connections.lock().expect("poisoned");
669                for conn in connections.drain(..) {
670                    conn.close(42u32.into(), b"shutdown");
671                }
672            }
673        }
674
675        eprintln!("creating ep1");
676        let endpoint = Endpoint::builder()
677            .relay_mode(RelayMode::Disabled)
678            .bind()
679            .await?;
680        let router = Router::builder(endpoint)
681            .accept(TEST_ALPN, TestProtocol::default())
682            .spawn();
683        eprintln!("waiting for node addr");
684        let addr = router.endpoint().node_addr();
685
686        eprintln!("creating ep2");
687        let endpoint2 = Endpoint::builder()
688            .relay_mode(RelayMode::Disabled)
689            .bind()
690            .await?;
691        eprintln!("connecting to {addr:?}");
692        let conn = endpoint2.connect(addr, TEST_ALPN).await?;
693
694        eprintln!("starting shutdown");
695        router.shutdown().await.e()?;
696
697        eprintln!("waiting for closed conn");
698        let reason = conn.closed().await;
699        assert_eq!(
700            reason,
701            ConnectionError::ApplicationClosed(ApplicationClose {
702                error_code: 42u32.into(),
703                reason: b"shutdown".to_vec().into()
704            })
705        );
706        Ok(())
707    }
708}