iroh_blobs/util/
connection_pool.rs

1//! A simple iroh connection pool
2//!
3//! Entry point is [`ConnectionPool`]. You create a connection pool for a specific
4//! ALPN and [`Options`]. Then the pool will manage connections for you.
5//!
6//! Access to connections is via the [`ConnectionPool::get_or_connect`] method, which
7//! gives you access to a connection via a [`ConnectionRef`] if possible.
8//!
9//! It is important that you keep the [`ConnectionRef`] alive while you are using
10//! the connection.
11use std::{
12    collections::{HashMap, VecDeque},
13    io,
14    ops::Deref,
15    sync::{
16        atomic::{AtomicUsize, Ordering},
17        Arc,
18    },
19    time::Duration,
20};
21
22use iroh::{
23    endpoint::{ConnectError, Connection},
24    Endpoint, NodeId,
25};
26use n0_future::{
27    future::{self},
28    FuturesUnordered, MaybeFuture, Stream, StreamExt,
29};
30use snafu::Snafu;
31use tokio::sync::{
32    mpsc::{self, error::SendError as TokioSendError},
33    oneshot, Notify,
34};
35use tokio_util::time::FutureExt as TimeFutureExt;
36use tracing::{debug, error, info, trace};
37
38pub type OnConnected =
39    Arc<dyn Fn(&Endpoint, &Connection) -> n0_future::future::Boxed<io::Result<()>> + Send + Sync>;
40
41/// Configuration options for the connection pool
42#[derive(derive_more::Debug, Clone)]
43pub struct Options {
44    /// How long to keep idle connections around.
45    pub idle_timeout: Duration,
46    /// Timeout for connect. This includes the time spent in on_connect, if set.
47    pub connect_timeout: Duration,
48    /// Maximum number of connections to hand out.
49    pub max_connections: usize,
50    /// An optional callback that can be used to wait for the connection to enter some state.
51    /// An example usage could be to wait for the connection to become direct before handing
52    /// it out to the user.
53    #[debug(skip)]
54    pub on_connected: Option<OnConnected>,
55}
56
57impl Default for Options {
58    fn default() -> Self {
59        Self {
60            idle_timeout: Duration::from_secs(5),
61            connect_timeout: Duration::from_secs(1),
62            max_connections: 1024,
63            on_connected: None,
64        }
65    }
66}
67
68impl Options {
69    /// Set the on_connected callback
70    pub fn with_on_connected<F, Fut>(mut self, f: F) -> Self
71    where
72        F: Fn(Endpoint, Connection) -> Fut + Send + Sync + 'static,
73        Fut: std::future::Future<Output = io::Result<()>> + Send + 'static,
74    {
75        self.on_connected = Some(Arc::new(move |ep, conn| {
76            let ep = ep.clone();
77            let conn = conn.clone();
78            Box::pin(f(ep, conn))
79        }));
80        self
81    }
82}
83
84/// A reference to a connection that is owned by a connection pool.
85#[derive(Debug)]
86pub struct ConnectionRef {
87    connection: iroh::endpoint::Connection,
88    _permit: OneConnection,
89}
90
91impl Deref for ConnectionRef {
92    type Target = iroh::endpoint::Connection;
93
94    fn deref(&self) -> &Self::Target {
95        &self.connection
96    }
97}
98
99impl ConnectionRef {
100    fn new(connection: iroh::endpoint::Connection, counter: OneConnection) -> Self {
101        Self {
102            connection,
103            _permit: counter,
104        }
105    }
106}
107
108/// Error when a connection can not be acquired
109///
110/// This includes the normal iroh connection errors as well as pool specific
111/// errors such as timeouts and connection limits.
112#[derive(Debug, Clone, Snafu)]
113#[snafu(module)]
114pub enum PoolConnectError {
115    /// Connection pool is shut down
116    Shutdown,
117    /// Timeout during connect
118    Timeout,
119    /// Too many connections
120    TooManyConnections,
121    /// Error during connect
122    ConnectError { source: Arc<ConnectError> },
123    /// Error during on_connect callback
124    OnConnectError { source: Arc<io::Error> },
125}
126
127impl From<ConnectError> for PoolConnectError {
128    fn from(e: ConnectError) -> Self {
129        PoolConnectError::ConnectError {
130            source: Arc::new(e),
131        }
132    }
133}
134
135impl From<io::Error> for PoolConnectError {
136    fn from(e: io::Error) -> Self {
137        PoolConnectError::OnConnectError {
138            source: Arc::new(e),
139        }
140    }
141}
142
143/// Error when calling a fn on the [`ConnectionPool`].
144///
145/// The only thing that can go wrong is that the connection pool is shut down.
146#[derive(Debug, Snafu)]
147#[snafu(module)]
148pub enum ConnectionPoolError {
149    /// The connection pool has been shut down
150    Shutdown,
151}
152
153enum ActorMessage {
154    RequestRef(RequestRef),
155    ConnectionIdle { id: NodeId },
156    ConnectionShutdown { id: NodeId },
157}
158
159struct RequestRef {
160    id: NodeId,
161    tx: oneshot::Sender<Result<ConnectionRef, PoolConnectError>>,
162}
163
164struct Context {
165    options: Options,
166    endpoint: Endpoint,
167    owner: ConnectionPool,
168    alpn: Vec<u8>,
169}
170
171impl Context {
172    async fn run_connection_actor(
173        self: Arc<Self>,
174        node_id: NodeId,
175        mut rx: mpsc::Receiver<RequestRef>,
176    ) {
177        let context = self;
178
179        let conn_fut = {
180            let context = context.clone();
181            async move {
182                let conn = context
183                    .endpoint
184                    .connect(node_id, &context.alpn)
185                    .await
186                    .map_err(PoolConnectError::from)?;
187                if let Some(on_connect) = &context.options.on_connected {
188                    on_connect(&context.endpoint, &conn)
189                        .await
190                        .map_err(PoolConnectError::from)?;
191                }
192                Result::<Connection, PoolConnectError>::Ok(conn)
193            }
194        };
195
196        // Connect to the node
197        let state = conn_fut
198            .timeout(context.options.connect_timeout)
199            .await
200            .map_err(|_| PoolConnectError::Timeout)
201            .and_then(|r| r);
202        let conn_close = match &state {
203            Ok(conn) => {
204                let conn = conn.clone();
205                MaybeFuture::Some(async move { conn.closed().await })
206            }
207            Err(e) => {
208                debug!(%node_id, "Failed to connect {e:?}, requesting shutdown");
209                if context.owner.close(node_id).await.is_err() {
210                    return;
211                }
212                MaybeFuture::None
213            }
214        };
215
216        let counter = ConnectionCounter::new();
217        let idle_timer = MaybeFuture::default();
218        let idle_stream = counter.clone().idle_stream();
219
220        tokio::pin!(idle_timer, idle_stream, conn_close);
221
222        loop {
223            tokio::select! {
224                biased;
225
226                // Handle new work
227                handler = rx.recv() => {
228                    match handler {
229                        Some(RequestRef { id, tx }) => {
230                            assert!(id == node_id, "Not for me!");
231                            match &state {
232                                Ok(state) => {
233                                    let res = ConnectionRef::new(state.clone(), counter.get_one());
234                                    info!(%node_id, "Handing out ConnectionRef {}", counter.current());
235
236                                    // clear the idle timer
237                                    idle_timer.as_mut().set_none();
238                                    tx.send(Ok(res)).ok();
239                                }
240                                Err(cause) => {
241                                    tx.send(Err(cause.clone())).ok();
242                                }
243                            }
244                        }
245                        None => {
246                            // Channel closed - exit
247                            break;
248                        }
249                    }
250                }
251
252                _ = &mut conn_close => {
253                    // connection was closed by somebody, notify owner that we should be removed
254                    context.owner.close(node_id).await.ok();
255                }
256
257                _ = idle_stream.next() => {
258                    if !counter.is_idle() {
259                        continue;
260                    };
261                    // notify the pool that we are idle.
262                    trace!(%node_id, "Idle");
263                    if context.owner.idle(node_id).await.is_err() {
264                        // If we can't notify the pool, we are shutting down
265                        break;
266                    }
267                    // set the idle timer
268                    idle_timer.as_mut().set_future(tokio::time::sleep(context.options.idle_timeout));
269                }
270
271                // Idle timeout - request shutdown
272                _ = &mut idle_timer => {
273                    trace!(%node_id, "Idle timer expired, requesting shutdown");
274                    context.owner.close(node_id).await.ok();
275                    // Don't break here - wait for main actor to close our channel
276                }
277            }
278        }
279
280        if let Ok(connection) = state {
281            let reason = if counter.is_idle() { b"idle" } else { b"drop" };
282            connection.close(0u32.into(), reason);
283        }
284
285        trace!(%node_id, "Connection actor shutting down");
286    }
287}
288
289struct Actor {
290    rx: mpsc::Receiver<ActorMessage>,
291    connections: HashMap<NodeId, mpsc::Sender<RequestRef>>,
292    context: Arc<Context>,
293    // idle set (most recent last)
294    // todo: use a better data structure if this becomes a performance issue
295    idle: VecDeque<NodeId>,
296    // per connection tasks
297    tasks: FuturesUnordered<future::Boxed<()>>,
298}
299
300impl Actor {
301    pub fn new(
302        endpoint: Endpoint,
303        alpn: &[u8],
304        options: Options,
305    ) -> (Self, mpsc::Sender<ActorMessage>) {
306        let (tx, rx) = mpsc::channel(100);
307        (
308            Self {
309                rx,
310                connections: HashMap::new(),
311                idle: VecDeque::new(),
312                context: Arc::new(Context {
313                    options,
314                    alpn: alpn.to_vec(),
315                    endpoint,
316                    owner: ConnectionPool { tx: tx.clone() },
317                }),
318                tasks: FuturesUnordered::new(),
319            },
320            tx,
321        )
322    }
323
324    fn add_idle(&mut self, id: NodeId) {
325        self.remove_idle(id);
326        self.idle.push_back(id);
327    }
328
329    fn remove_idle(&mut self, id: NodeId) {
330        self.idle.retain(|&x| x != id);
331    }
332
333    fn pop_oldest_idle(&mut self) -> Option<NodeId> {
334        self.idle.pop_front()
335    }
336
337    fn remove_connection(&mut self, id: NodeId) {
338        self.connections.remove(&id);
339        self.remove_idle(id);
340    }
341
342    async fn handle_msg(&mut self, msg: ActorMessage) {
343        match msg {
344            ActorMessage::RequestRef(mut msg) => {
345                let id = msg.id;
346                self.remove_idle(id);
347                // Try to send to existing connection actor
348                if let Some(conn_tx) = self.connections.get(&id) {
349                    if let Err(TokioSendError(e)) = conn_tx.send(msg).await {
350                        msg = e;
351                    } else {
352                        return;
353                    }
354                    // Connection actor died, remove it
355                    self.remove_connection(id);
356                }
357
358                // No connection actor or it died - check limits
359                if self.connections.len() >= self.context.options.max_connections {
360                    if let Some(idle) = self.pop_oldest_idle() {
361                        // remove the oldest idle connection to make room for one more
362                        trace!("removing oldest idle connection {}", idle);
363                        self.connections.remove(&idle);
364                    } else {
365                        msg.tx.send(Err(PoolConnectError::TooManyConnections)).ok();
366                        return;
367                    }
368                }
369                let (conn_tx, conn_rx) = mpsc::channel(100);
370                self.connections.insert(id, conn_tx.clone());
371
372                let context = self.context.clone();
373
374                self.tasks
375                    .push(Box::pin(context.run_connection_actor(id, conn_rx)));
376
377                // Send the handler to the new actor
378                if conn_tx.send(msg).await.is_err() {
379                    error!(%id, "Failed to send handler to new connection actor");
380                    self.connections.remove(&id);
381                }
382            }
383            ActorMessage::ConnectionIdle { id } => {
384                self.add_idle(id);
385                trace!(%id, "connection idle");
386            }
387            ActorMessage::ConnectionShutdown { id } => {
388                // Remove the connection from our map - this closes the channel
389                self.remove_connection(id);
390                trace!(%id, "removed connection");
391            }
392        }
393    }
394
395    pub async fn run(mut self) {
396        loop {
397            tokio::select! {
398                biased;
399
400                msg = self.rx.recv() => {
401                    if let Some(msg) = msg {
402                        self.handle_msg(msg).await;
403                    } else {
404                        break;
405                    }
406                }
407
408                _ = self.tasks.next(), if !self.tasks.is_empty() => {}
409            }
410        }
411    }
412}
413
414/// A connection pool
415#[derive(Debug, Clone)]
416pub struct ConnectionPool {
417    tx: mpsc::Sender<ActorMessage>,
418}
419
420impl ConnectionPool {
421    pub fn new(endpoint: Endpoint, alpn: &[u8], options: Options) -> Self {
422        let (actor, tx) = Actor::new(endpoint, alpn, options);
423
424        // Spawn the main actor
425        tokio::spawn(actor.run());
426
427        Self { tx }
428    }
429
430    /// Returns either a fresh connection or a reference to an existing one.
431    ///
432    /// This is guaranteed to return after approximately [Options::connect_timeout]
433    /// with either an error or a connection.
434    pub async fn get_or_connect(
435        &self,
436        id: NodeId,
437    ) -> std::result::Result<ConnectionRef, PoolConnectError> {
438        let (tx, rx) = oneshot::channel();
439        self.tx
440            .send(ActorMessage::RequestRef(RequestRef { id, tx }))
441            .await
442            .map_err(|_| PoolConnectError::Shutdown)?;
443        rx.await.map_err(|_| PoolConnectError::Shutdown)?
444    }
445
446    /// Close an existing connection, if it exists
447    ///
448    /// This will finish pending tasks and close the connection. New tasks will
449    /// get a new connection if they are submitted after this call
450    pub async fn close(&self, id: NodeId) -> std::result::Result<(), ConnectionPoolError> {
451        self.tx
452            .send(ActorMessage::ConnectionShutdown { id })
453            .await
454            .map_err(|_| ConnectionPoolError::Shutdown)?;
455        Ok(())
456    }
457
458    /// Notify the connection pool that a connection is idle.
459    ///
460    /// Should only be called from connection handlers.
461    pub(crate) async fn idle(&self, id: NodeId) -> std::result::Result<(), ConnectionPoolError> {
462        self.tx
463            .send(ActorMessage::ConnectionIdle { id })
464            .await
465            .map_err(|_| ConnectionPoolError::Shutdown)?;
466        Ok(())
467    }
468}
469
470#[derive(Debug)]
471struct ConnectionCounterInner {
472    count: AtomicUsize,
473    notify: Notify,
474}
475
476#[derive(Debug, Clone)]
477struct ConnectionCounter {
478    inner: Arc<ConnectionCounterInner>,
479}
480
481impl ConnectionCounter {
482    fn new() -> Self {
483        Self {
484            inner: Arc::new(ConnectionCounterInner {
485                count: Default::default(),
486                notify: Notify::new(),
487            }),
488        }
489    }
490
491    fn current(&self) -> usize {
492        self.inner.count.load(Ordering::SeqCst)
493    }
494
495    /// Increase the connection count and return a guard for the new connection
496    fn get_one(&self) -> OneConnection {
497        self.inner.count.fetch_add(1, Ordering::SeqCst);
498        OneConnection {
499            inner: self.inner.clone(),
500        }
501    }
502
503    fn is_idle(&self) -> bool {
504        self.inner.count.load(Ordering::SeqCst) == 0
505    }
506
507    /// Infinite stream that yields when the connection is briefly idle.
508    ///
509    /// Note that you still have to check if the connection is still idle when
510    /// you get the notification.
511    ///
512    /// Also note that this stream is triggered on [OneConnection::drop], so it
513    /// won't trigger initially even though a [ConnectionCounter] starts up as
514    /// idle.
515    fn idle_stream(self) -> impl Stream<Item = ()> {
516        n0_future::stream::unfold(self, |c| async move {
517            c.inner.notify.notified().await;
518            Some(((), c))
519        })
520    }
521}
522
523/// Guard for one connection
524#[derive(Debug)]
525struct OneConnection {
526    inner: Arc<ConnectionCounterInner>,
527}
528
529impl Drop for OneConnection {
530    fn drop(&mut self) {
531        if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 {
532            self.inner.notify.notify_waiters();
533        }
534    }
535}
536
537#[cfg(test)]
538mod tests {
539    use std::{collections::BTreeMap, sync::Arc, time::Duration};
540
541    use iroh::{
542        discovery::static_provider::StaticProvider,
543        endpoint::{Connection, ConnectionType},
544        protocol::{AcceptError, ProtocolHandler, Router},
545        Endpoint, NodeAddr, NodeId, SecretKey, Watcher,
546    };
547    use n0_future::{io, stream, BufferedStreamExt, StreamExt};
548    use n0_snafu::ResultExt;
549    use testresult::TestResult;
550    use tracing::trace;
551
552    use super::{ConnectionPool, Options, PoolConnectError};
553    use crate::util::connection_pool::OnConnected;
554
555    const ECHO_ALPN: &[u8] = b"echo";
556
557    #[derive(Debug, Clone)]
558    struct Echo;
559
560    impl ProtocolHandler for Echo {
561        async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
562            let conn_id = connection.stable_id();
563            let id = connection.remote_node_id().map_err(AcceptError::from_err)?;
564            trace!(%id, %conn_id, "Accepting echo connection");
565            loop {
566                match connection.accept_bi().await {
567                    Ok((mut send, mut recv)) => {
568                        trace!(%id, %conn_id, "Accepted echo request");
569                        tokio::io::copy(&mut recv, &mut send).await?;
570                        send.finish().map_err(AcceptError::from_err)?;
571                    }
572                    Err(e) => {
573                        trace!(%id, %conn_id, "Failed to accept echo request {e}");
574                        break;
575                    }
576                }
577            }
578            Ok(())
579        }
580    }
581
582    async fn echo_client(conn: &Connection, text: &[u8]) -> n0_snafu::Result<Vec<u8>> {
583        let conn_id = conn.stable_id();
584        let id = conn.remote_node_id().e()?;
585        trace!(%id, %conn_id, "Sending echo request");
586        let (mut send, mut recv) = conn.open_bi().await.e()?;
587        send.write_all(text).await.e()?;
588        send.finish().e()?;
589        let response = recv.read_to_end(1000).await.e()?;
590        trace!(%id, %conn_id, "Received echo response");
591        Ok(response)
592    }
593
594    async fn echo_server() -> TestResult<(NodeAddr, Router)> {
595        let endpoint = iroh::Endpoint::builder()
596            .alpns(vec![ECHO_ALPN.to_vec()])
597            .bind()
598            .await?;
599        endpoint.online().await;
600        let addr = endpoint.node_addr();
601        let router = iroh::protocol::Router::builder(endpoint)
602            .accept(ECHO_ALPN, Echo)
603            .spawn();
604
605        Ok((addr, router))
606    }
607
608    async fn echo_servers(n: usize) -> TestResult<(Vec<NodeId>, Vec<Router>, StaticProvider)> {
609        let res = stream::iter(0..n)
610            .map(|_| echo_server())
611            .buffered_unordered(16)
612            .collect::<Vec<_>>()
613            .await;
614        let res: Vec<(NodeAddr, Router)> = res.into_iter().collect::<TestResult<Vec<_>>>()?;
615        let (addrs, routers): (Vec<_>, Vec<_>) = res.into_iter().unzip();
616        let ids = addrs.iter().map(|a| a.node_id).collect::<Vec<_>>();
617        let discovery = StaticProvider::from_node_info(addrs);
618        Ok((ids, routers, discovery))
619    }
620
621    async fn shutdown_routers(routers: Vec<Router>) {
622        stream::iter(routers)
623            .for_each_concurrent(16, |router| async move {
624                let _ = router.shutdown().await;
625            })
626            .await;
627    }
628
629    fn test_options() -> Options {
630        Options {
631            idle_timeout: Duration::from_millis(100),
632            connect_timeout: Duration::from_secs(5),
633            max_connections: 32,
634            on_connected: None,
635        }
636    }
637
638    struct EchoClient {
639        pool: ConnectionPool,
640    }
641
642    impl EchoClient {
643        async fn echo(
644            &self,
645            id: NodeId,
646            text: Vec<u8>,
647        ) -> Result<Result<(usize, Vec<u8>), n0_snafu::Error>, PoolConnectError> {
648            let conn = self.pool.get_or_connect(id).await?;
649            let id = conn.stable_id();
650            match echo_client(&conn, &text).await {
651                Ok(res) => Ok(Ok((id, res))),
652                Err(e) => Ok(Err(e)),
653            }
654        }
655    }
656
657    #[tokio::test]
658    // #[traced_test]
659    async fn connection_pool_errors() -> TestResult<()> {
660        // set up static discovery for all addrs
661        let discovery = StaticProvider::new();
662        let endpoint = iroh::Endpoint::builder()
663            .discovery(discovery.clone())
664            .bind()
665            .await?;
666        let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
667        let client = EchoClient { pool };
668        {
669            let non_existing = SecretKey::from_bytes(&[0; 32]).public();
670            let res = client.echo(non_existing, b"Hello, world!".to_vec()).await;
671            // trying to connect to a non-existing id will fail with ConnectError
672            // because we don't have any information about the node
673            assert!(matches!(res, Err(PoolConnectError::ConnectError { .. })));
674        }
675        {
676            let non_listening = SecretKey::from_bytes(&[0; 32]).public();
677            // make up fake node info
678            discovery.add_node_info(NodeAddr {
679                node_id: non_listening,
680                relay_url: None,
681                direct_addresses: vec!["127.0.0.1:12121".parse().unwrap()]
682                    .into_iter()
683                    .collect(),
684            });
685            // trying to connect to an id for which we have info, but the other
686            // end is not listening, will lead to a timeout.
687            let res = client.echo(non_listening, b"Hello, world!".to_vec()).await;
688            assert!(matches!(res, Err(PoolConnectError::Timeout)));
689        }
690        Ok(())
691    }
692
693    #[tokio::test]
694    // #[traced_test]
695    async fn connection_pool_smoke() -> TestResult<()> {
696        let n = 32;
697        let (ids, routers, discovery) = echo_servers(n).await?;
698        // build a client endpoint that can resolve all the node ids
699        let endpoint = iroh::Endpoint::builder()
700            .discovery(discovery.clone())
701            .bind()
702            .await?;
703        let pool = ConnectionPool::new(endpoint.clone(), ECHO_ALPN, test_options());
704        let client = EchoClient { pool };
705        let mut connection_ids = BTreeMap::new();
706        let msg = b"Hello, pool!".to_vec();
707        for id in &ids {
708            let (cid1, res) = client.echo(*id, msg.clone()).await??;
709            assert_eq!(res, msg);
710            let (cid2, res) = client.echo(*id, msg.clone()).await??;
711            assert_eq!(res, msg);
712            assert_eq!(cid1, cid2);
713            connection_ids.insert(id, cid1);
714        }
715        tokio::time::sleep(Duration::from_millis(1000)).await;
716        for id in &ids {
717            let cid1 = *connection_ids.get(id).expect("Connection ID not found");
718            let (cid2, res) = client.echo(*id, msg.clone()).await??;
719            assert_eq!(res, msg);
720            assert_ne!(cid1, cid2);
721        }
722        shutdown_routers(routers).await;
723        Ok(())
724    }
725
726    /// Tests that idle connections are being reclaimed to make room if we hit the
727    /// maximum connection limit.
728    #[tokio::test]
729    // #[traced_test]
730    async fn connection_pool_idle() -> TestResult<()> {
731        let n = 32;
732        let (ids, routers, discovery) = echo_servers(n).await?;
733        // build a client endpoint that can resolve all the node ids
734        let endpoint = iroh::Endpoint::builder()
735            .discovery(discovery.clone())
736            .bind()
737            .await?;
738        let pool = ConnectionPool::new(
739            endpoint.clone(),
740            ECHO_ALPN,
741            Options {
742                idle_timeout: Duration::from_secs(100),
743                max_connections: 8,
744                ..test_options()
745            },
746        );
747        let client = EchoClient { pool };
748        let msg = b"Hello, pool!".to_vec();
749        for id in &ids {
750            let (_, res) = client.echo(*id, msg.clone()).await??;
751            assert_eq!(res, msg);
752        }
753        shutdown_routers(routers).await;
754        Ok(())
755    }
756
757    /// Uses an on_connected callback that just errors out every time.
758    ///
759    /// This is a basic smoke test that on_connected gets called at all.
760    #[tokio::test]
761    // #[traced_test]
762    async fn on_connected_error() -> TestResult<()> {
763        let n = 1;
764        let (ids, routers, discovery) = echo_servers(n).await?;
765        let endpoint = iroh::Endpoint::builder()
766            .discovery(discovery)
767            .bind()
768            .await?;
769        let on_connected: OnConnected =
770            Arc::new(|_, _| Box::pin(async { Err(io::Error::other("on_connect failed")) }));
771        let pool = ConnectionPool::new(
772            endpoint,
773            ECHO_ALPN,
774            Options {
775                on_connected: Some(on_connected),
776                ..test_options()
777            },
778        );
779        let client = EchoClient { pool };
780        let msg = b"Hello, pool!".to_vec();
781        for id in &ids {
782            let res = client.echo(*id, msg.clone()).await;
783            assert!(matches!(res, Err(PoolConnectError::OnConnectError { .. })));
784        }
785        shutdown_routers(routers).await;
786        Ok(())
787    }
788
789    /// Uses an on_connected callback to ensure that the connection is direct.
790    #[tokio::test]
791    // #[traced_test]
792    async fn on_connected_direct() -> TestResult<()> {
793        let n = 1;
794        let (ids, routers, discovery) = echo_servers(n).await?;
795        let endpoint = iroh::Endpoint::builder()
796            .discovery(discovery)
797            .bind()
798            .await?;
799        let on_connected = |ep: Endpoint, conn: Connection| async move {
800            let Ok(id) = conn.remote_node_id() else {
801                return Err(io::Error::other("unable to get node id"));
802            };
803            let Some(watcher) = ep.conn_type(id) else {
804                return Err(io::Error::other("unable to get conn_type watcher"));
805            };
806            let mut stream = watcher.stream();
807            while let Some(status) = stream.next().await {
808                if let ConnectionType::Direct { .. } = status {
809                    return Ok(());
810                }
811            }
812            Err(io::Error::other("connection closed before becoming direct"))
813        };
814        let pool = ConnectionPool::new(
815            endpoint,
816            ECHO_ALPN,
817            test_options().with_on_connected(on_connected),
818        );
819        let client = EchoClient { pool };
820        let msg = b"Hello, pool!".to_vec();
821        for id in &ids {
822            let res = client.echo(*id, msg.clone()).await;
823            assert!(res.is_ok());
824        }
825        shutdown_routers(routers).await;
826        Ok(())
827    }
828
829    /// Check that when a connection is closed, the pool will give you a new
830    /// connection next time you want one.
831    ///
832    /// This test fails if the connection watch is disabled.
833    #[tokio::test]
834    // #[traced_test]
835    async fn watch_close() -> TestResult<()> {
836        let n = 1;
837        let (ids, routers, discovery) = echo_servers(n).await?;
838        let endpoint = iroh::Endpoint::builder()
839            .discovery(discovery)
840            .bind()
841            .await?;
842
843        let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
844        let conn = pool.get_or_connect(ids[0]).await?;
845        let cid1 = conn.stable_id();
846        conn.close(0u32.into(), b"test");
847        tokio::time::sleep(Duration::from_millis(500)).await;
848        let conn = pool.get_or_connect(ids[0]).await?;
849        let cid2 = conn.stable_id();
850        assert_ne!(cid1, cid2);
851        shutdown_routers(routers).await;
852        Ok(())
853    }
854}