iroh_net/
dialer.rs

1//! A dialer to conveniently dial many nodes.
2
3use std::{collections::HashMap, pin::Pin, task::Poll};
4
5use anyhow::anyhow;
6use futures_lite::Stream;
7use tokio::task::JoinSet;
8use tokio_util::sync::CancellationToken;
9use tracing::error;
10
11use crate::{Endpoint, NodeId};
12
13/// Dials nodes and maintains a queue of pending dials.
14///
15/// The [`Dialer`] wraps an [`Endpoint`], connects to nodes through the endpoint, stores the
16/// pending connect futures and emits finished connect results.
17///
18/// The [`Dialer`] also implements [`Stream`] to retrieve the dialled connections.
19#[derive(Debug)]
20pub struct Dialer {
21    endpoint: Endpoint,
22    pending: JoinSet<(NodeId, anyhow::Result<quinn::Connection>)>,
23    pending_dials: HashMap<NodeId, CancellationToken>,
24}
25
26impl Dialer {
27    /// Create a new dialer for a [`Endpoint`]
28    pub fn new(endpoint: Endpoint) -> Self {
29        Self {
30            endpoint,
31            pending: Default::default(),
32            pending_dials: Default::default(),
33        }
34    }
35
36    /// Starts to dial a node by [`NodeId`].
37    ///
38    /// Since this dials by [`NodeId`] the [`Endpoint`] must know how to contact the node by
39    /// [`NodeId`] only.  This relies on addressing information being provided by either the
40    /// [discovery service] or manually by calling [`Endpoint::add_node_addr`].
41    ///
42    /// [discovery service]: crate::discovery::Discovery
43    pub fn queue_dial(&mut self, node_id: NodeId, alpn: &'static [u8]) {
44        if self.is_pending(node_id) {
45            return;
46        }
47        let cancel = CancellationToken::new();
48        self.pending_dials.insert(node_id, cancel.clone());
49        let endpoint = self.endpoint.clone();
50        self.pending.spawn(async move {
51            let res = tokio::select! {
52                biased;
53                _ = cancel.cancelled() => Err(anyhow!("Cancelled")),
54                res = endpoint.connect(node_id, alpn) => res
55            };
56            (node_id, res)
57        });
58    }
59
60    /// Aborts a pending dial.
61    pub fn abort_dial(&mut self, node_id: NodeId) {
62        if let Some(cancel) = self.pending_dials.remove(&node_id) {
63            cancel.cancel();
64        }
65    }
66
67    /// Checks if a node is currently being dialed.
68    pub fn is_pending(&self, node: NodeId) -> bool {
69        self.pending_dials.contains_key(&node)
70    }
71
72    /// Waits for the next dial operation to complete.
73    pub async fn next_conn(&mut self) -> (NodeId, anyhow::Result<quinn::Connection>) {
74        match self.pending_dials.is_empty() {
75            false => {
76                let (node_id, res) = loop {
77                    match self.pending.join_next().await {
78                        Some(Ok((node_id, res))) => {
79                            self.pending_dials.remove(&node_id);
80                            break (node_id, res);
81                        }
82                        Some(Err(e)) => {
83                            error!("next conn error: {:?}", e);
84                        }
85                        None => {
86                            error!("no more pending conns available");
87                            std::future::pending().await
88                        }
89                    }
90                };
91
92                (node_id, res)
93            }
94            true => std::future::pending().await,
95        }
96    }
97
98    /// Number of pending connections to be opened.
99    pub fn pending_count(&self) -> usize {
100        self.pending_dials.len()
101    }
102
103    /// Returns a reference to the endpoint used in this dialer.
104    pub fn endpoint(&self) -> &Endpoint {
105        &self.endpoint
106    }
107}
108
109impl Stream for Dialer {
110    type Item = (NodeId, anyhow::Result<quinn::Connection>);
111
112    fn poll_next(
113        mut self: Pin<&mut Self>,
114        cx: &mut std::task::Context<'_>,
115    ) -> Poll<Option<Self::Item>> {
116        match self.pending.poll_join_next(cx) {
117            Poll::Ready(Some(Ok((node_id, result)))) => {
118                self.pending_dials.remove(&node_id);
119                Poll::Ready(Some((node_id, result)))
120            }
121            Poll::Ready(Some(Err(e))) => {
122                error!("dialer error: {:?}", e);
123                Poll::Pending
124            }
125            _ => Poll::Pending,
126        }
127    }
128}