1use std::{collections::HashMap, pin::Pin, task::Poll};
4
5use anyhow::anyhow;
6use futures_lite::Stream;
7use tokio::task::JoinSet;
8use tokio_util::sync::CancellationToken;
9use tracing::error;
10
11use crate::{Endpoint, NodeId};
12
13#[derive(Debug)]
20pub struct Dialer {
21 endpoint: Endpoint,
22 pending: JoinSet<(NodeId, anyhow::Result<quinn::Connection>)>,
23 pending_dials: HashMap<NodeId, CancellationToken>,
24}
25
26impl Dialer {
27 pub fn new(endpoint: Endpoint) -> Self {
29 Self {
30 endpoint,
31 pending: Default::default(),
32 pending_dials: Default::default(),
33 }
34 }
35
36 pub fn queue_dial(&mut self, node_id: NodeId, alpn: &'static [u8]) {
44 if self.is_pending(node_id) {
45 return;
46 }
47 let cancel = CancellationToken::new();
48 self.pending_dials.insert(node_id, cancel.clone());
49 let endpoint = self.endpoint.clone();
50 self.pending.spawn(async move {
51 let res = tokio::select! {
52 biased;
53 _ = cancel.cancelled() => Err(anyhow!("Cancelled")),
54 res = endpoint.connect(node_id, alpn) => res
55 };
56 (node_id, res)
57 });
58 }
59
60 pub fn abort_dial(&mut self, node_id: NodeId) {
62 if let Some(cancel) = self.pending_dials.remove(&node_id) {
63 cancel.cancel();
64 }
65 }
66
67 pub fn is_pending(&self, node: NodeId) -> bool {
69 self.pending_dials.contains_key(&node)
70 }
71
72 pub async fn next_conn(&mut self) -> (NodeId, anyhow::Result<quinn::Connection>) {
74 match self.pending_dials.is_empty() {
75 false => {
76 let (node_id, res) = loop {
77 match self.pending.join_next().await {
78 Some(Ok((node_id, res))) => {
79 self.pending_dials.remove(&node_id);
80 break (node_id, res);
81 }
82 Some(Err(e)) => {
83 error!("next conn error: {:?}", e);
84 }
85 None => {
86 error!("no more pending conns available");
87 std::future::pending().await
88 }
89 }
90 };
91
92 (node_id, res)
93 }
94 true => std::future::pending().await,
95 }
96 }
97
98 pub fn pending_count(&self) -> usize {
100 self.pending_dials.len()
101 }
102
103 pub fn endpoint(&self) -> &Endpoint {
105 &self.endpoint
106 }
107}
108
109impl Stream for Dialer {
110 type Item = (NodeId, anyhow::Result<quinn::Connection>);
111
112 fn poll_next(
113 mut self: Pin<&mut Self>,
114 cx: &mut std::task::Context<'_>,
115 ) -> Poll<Option<Self::Item>> {
116 match self.pending.poll_join_next(cx) {
117 Poll::Ready(Some(Ok((node_id, result)))) => {
118 self.pending_dials.remove(&node_id);
119 Poll::Ready(Some((node_id, result)))
120 }
121 Poll::Ready(Some(Err(e))) => {
122 error!("dialer error: {:?}", e);
123 Poll::Pending
124 }
125 _ => Poll::Pending,
126 }
127 }
128}