1use std::{
12 collections::{HashMap, VecDeque},
13 io,
14 ops::Deref,
15 sync::{
16 atomic::{AtomicUsize, Ordering},
17 Arc,
18 },
19 time::Duration,
20};
21
22use iroh::{
23 endpoint::{ConnectError, Connection},
24 Endpoint, NodeId,
25};
26use n0_future::{
27 future::{self},
28 FuturesUnordered, MaybeFuture, Stream, StreamExt,
29};
30use snafu::Snafu;
31use tokio::sync::{
32 mpsc::{self, error::SendError as TokioSendError},
33 oneshot, Notify,
34};
35use tokio_util::time::FutureExt as TimeFutureExt;
36use tracing::{debug, error, info, trace};
37
38pub type OnConnected =
39 Arc<dyn Fn(&Endpoint, &Connection) -> n0_future::future::Boxed<io::Result<()>> + Send + Sync>;
40
41#[derive(derive_more::Debug, Clone)]
43pub struct Options {
44 pub idle_timeout: Duration,
46 pub connect_timeout: Duration,
48 pub max_connections: usize,
50 #[debug(skip)]
54 pub on_connected: Option<OnConnected>,
55}
56
57impl Default for Options {
58 fn default() -> Self {
59 Self {
60 idle_timeout: Duration::from_secs(5),
61 connect_timeout: Duration::from_secs(1),
62 max_connections: 1024,
63 on_connected: None,
64 }
65 }
66}
67
68impl Options {
69 pub fn with_on_connected<F, Fut>(mut self, f: F) -> Self
71 where
72 F: Fn(Endpoint, Connection) -> Fut + Send + Sync + 'static,
73 Fut: std::future::Future<Output = io::Result<()>> + Send + 'static,
74 {
75 self.on_connected = Some(Arc::new(move |ep, conn| {
76 let ep = ep.clone();
77 let conn = conn.clone();
78 Box::pin(f(ep, conn))
79 }));
80 self
81 }
82}
83
84#[derive(Debug)]
86pub struct ConnectionRef {
87 connection: iroh::endpoint::Connection,
88 _permit: OneConnection,
89}
90
91impl Deref for ConnectionRef {
92 type Target = iroh::endpoint::Connection;
93
94 fn deref(&self) -> &Self::Target {
95 &self.connection
96 }
97}
98
99impl ConnectionRef {
100 fn new(connection: iroh::endpoint::Connection, counter: OneConnection) -> Self {
101 Self {
102 connection,
103 _permit: counter,
104 }
105 }
106}
107
108#[derive(Debug, Clone, Snafu)]
113#[snafu(module)]
114pub enum PoolConnectError {
115 Shutdown,
117 Timeout,
119 TooManyConnections,
121 ConnectError { source: Arc<ConnectError> },
123 OnConnectError { source: Arc<io::Error> },
125}
126
127impl From<ConnectError> for PoolConnectError {
128 fn from(e: ConnectError) -> Self {
129 PoolConnectError::ConnectError {
130 source: Arc::new(e),
131 }
132 }
133}
134
135impl From<io::Error> for PoolConnectError {
136 fn from(e: io::Error) -> Self {
137 PoolConnectError::OnConnectError {
138 source: Arc::new(e),
139 }
140 }
141}
142
143#[derive(Debug, Snafu)]
147#[snafu(module)]
148pub enum ConnectionPoolError {
149 Shutdown,
151}
152
153enum ActorMessage {
154 RequestRef(RequestRef),
155 ConnectionIdle { id: NodeId },
156 ConnectionShutdown { id: NodeId },
157}
158
159struct RequestRef {
160 id: NodeId,
161 tx: oneshot::Sender<Result<ConnectionRef, PoolConnectError>>,
162}
163
164struct Context {
165 options: Options,
166 endpoint: Endpoint,
167 owner: ConnectionPool,
168 alpn: Vec<u8>,
169}
170
171impl Context {
172 async fn run_connection_actor(
173 self: Arc<Self>,
174 node_id: NodeId,
175 mut rx: mpsc::Receiver<RequestRef>,
176 ) {
177 let context = self;
178
179 let conn_fut = {
180 let context = context.clone();
181 async move {
182 let conn = context
183 .endpoint
184 .connect(node_id, &context.alpn)
185 .await
186 .map_err(PoolConnectError::from)?;
187 if let Some(on_connect) = &context.options.on_connected {
188 on_connect(&context.endpoint, &conn)
189 .await
190 .map_err(PoolConnectError::from)?;
191 }
192 Result::<Connection, PoolConnectError>::Ok(conn)
193 }
194 };
195
196 let state = conn_fut
198 .timeout(context.options.connect_timeout)
199 .await
200 .map_err(|_| PoolConnectError::Timeout)
201 .and_then(|r| r);
202 let conn_close = match &state {
203 Ok(conn) => {
204 let conn = conn.clone();
205 MaybeFuture::Some(async move { conn.closed().await })
206 }
207 Err(e) => {
208 debug!(%node_id, "Failed to connect {e:?}, requesting shutdown");
209 if context.owner.close(node_id).await.is_err() {
210 return;
211 }
212 MaybeFuture::None
213 }
214 };
215
216 let counter = ConnectionCounter::new();
217 let idle_timer = MaybeFuture::default();
218 let idle_stream = counter.clone().idle_stream();
219
220 tokio::pin!(idle_timer, idle_stream, conn_close);
221
222 loop {
223 tokio::select! {
224 biased;
225
226 handler = rx.recv() => {
228 match handler {
229 Some(RequestRef { id, tx }) => {
230 assert!(id == node_id, "Not for me!");
231 match &state {
232 Ok(state) => {
233 let res = ConnectionRef::new(state.clone(), counter.get_one());
234 info!(%node_id, "Handing out ConnectionRef {}", counter.current());
235
236 idle_timer.as_mut().set_none();
238 tx.send(Ok(res)).ok();
239 }
240 Err(cause) => {
241 tx.send(Err(cause.clone())).ok();
242 }
243 }
244 }
245 None => {
246 break;
248 }
249 }
250 }
251
252 _ = &mut conn_close => {
253 context.owner.close(node_id).await.ok();
255 }
256
257 _ = idle_stream.next() => {
258 if !counter.is_idle() {
259 continue;
260 };
261 trace!(%node_id, "Idle");
263 if context.owner.idle(node_id).await.is_err() {
264 break;
266 }
267 idle_timer.as_mut().set_future(tokio::time::sleep(context.options.idle_timeout));
269 }
270
271 _ = &mut idle_timer => {
273 trace!(%node_id, "Idle timer expired, requesting shutdown");
274 context.owner.close(node_id).await.ok();
275 }
277 }
278 }
279
280 if let Ok(connection) = state {
281 let reason = if counter.is_idle() { b"idle" } else { b"drop" };
282 connection.close(0u32.into(), reason);
283 }
284
285 trace!(%node_id, "Connection actor shutting down");
286 }
287}
288
289struct Actor {
290 rx: mpsc::Receiver<ActorMessage>,
291 connections: HashMap<NodeId, mpsc::Sender<RequestRef>>,
292 context: Arc<Context>,
293 idle: VecDeque<NodeId>,
296 tasks: FuturesUnordered<future::Boxed<()>>,
298}
299
300impl Actor {
301 pub fn new(
302 endpoint: Endpoint,
303 alpn: &[u8],
304 options: Options,
305 ) -> (Self, mpsc::Sender<ActorMessage>) {
306 let (tx, rx) = mpsc::channel(100);
307 (
308 Self {
309 rx,
310 connections: HashMap::new(),
311 idle: VecDeque::new(),
312 context: Arc::new(Context {
313 options,
314 alpn: alpn.to_vec(),
315 endpoint,
316 owner: ConnectionPool { tx: tx.clone() },
317 }),
318 tasks: FuturesUnordered::new(),
319 },
320 tx,
321 )
322 }
323
324 fn add_idle(&mut self, id: NodeId) {
325 self.remove_idle(id);
326 self.idle.push_back(id);
327 }
328
329 fn remove_idle(&mut self, id: NodeId) {
330 self.idle.retain(|&x| x != id);
331 }
332
333 fn pop_oldest_idle(&mut self) -> Option<NodeId> {
334 self.idle.pop_front()
335 }
336
337 fn remove_connection(&mut self, id: NodeId) {
338 self.connections.remove(&id);
339 self.remove_idle(id);
340 }
341
342 async fn handle_msg(&mut self, msg: ActorMessage) {
343 match msg {
344 ActorMessage::RequestRef(mut msg) => {
345 let id = msg.id;
346 self.remove_idle(id);
347 if let Some(conn_tx) = self.connections.get(&id) {
349 if let Err(TokioSendError(e)) = conn_tx.send(msg).await {
350 msg = e;
351 } else {
352 return;
353 }
354 self.remove_connection(id);
356 }
357
358 if self.connections.len() >= self.context.options.max_connections {
360 if let Some(idle) = self.pop_oldest_idle() {
361 trace!("removing oldest idle connection {}", idle);
363 self.connections.remove(&idle);
364 } else {
365 msg.tx.send(Err(PoolConnectError::TooManyConnections)).ok();
366 return;
367 }
368 }
369 let (conn_tx, conn_rx) = mpsc::channel(100);
370 self.connections.insert(id, conn_tx.clone());
371
372 let context = self.context.clone();
373
374 self.tasks
375 .push(Box::pin(context.run_connection_actor(id, conn_rx)));
376
377 if conn_tx.send(msg).await.is_err() {
379 error!(%id, "Failed to send handler to new connection actor");
380 self.connections.remove(&id);
381 }
382 }
383 ActorMessage::ConnectionIdle { id } => {
384 self.add_idle(id);
385 trace!(%id, "connection idle");
386 }
387 ActorMessage::ConnectionShutdown { id } => {
388 self.remove_connection(id);
390 trace!(%id, "removed connection");
391 }
392 }
393 }
394
395 pub async fn run(mut self) {
396 loop {
397 tokio::select! {
398 biased;
399
400 msg = self.rx.recv() => {
401 if let Some(msg) = msg {
402 self.handle_msg(msg).await;
403 } else {
404 break;
405 }
406 }
407
408 _ = self.tasks.next(), if !self.tasks.is_empty() => {}
409 }
410 }
411 }
412}
413
414#[derive(Debug, Clone)]
416pub struct ConnectionPool {
417 tx: mpsc::Sender<ActorMessage>,
418}
419
420impl ConnectionPool {
421 pub fn new(endpoint: Endpoint, alpn: &[u8], options: Options) -> Self {
422 let (actor, tx) = Actor::new(endpoint, alpn, options);
423
424 tokio::spawn(actor.run());
426
427 Self { tx }
428 }
429
430 pub async fn get_or_connect(
435 &self,
436 id: NodeId,
437 ) -> std::result::Result<ConnectionRef, PoolConnectError> {
438 let (tx, rx) = oneshot::channel();
439 self.tx
440 .send(ActorMessage::RequestRef(RequestRef { id, tx }))
441 .await
442 .map_err(|_| PoolConnectError::Shutdown)?;
443 rx.await.map_err(|_| PoolConnectError::Shutdown)?
444 }
445
446 pub async fn close(&self, id: NodeId) -> std::result::Result<(), ConnectionPoolError> {
451 self.tx
452 .send(ActorMessage::ConnectionShutdown { id })
453 .await
454 .map_err(|_| ConnectionPoolError::Shutdown)?;
455 Ok(())
456 }
457
458 pub(crate) async fn idle(&self, id: NodeId) -> std::result::Result<(), ConnectionPoolError> {
462 self.tx
463 .send(ActorMessage::ConnectionIdle { id })
464 .await
465 .map_err(|_| ConnectionPoolError::Shutdown)?;
466 Ok(())
467 }
468}
469
470#[derive(Debug)]
471struct ConnectionCounterInner {
472 count: AtomicUsize,
473 notify: Notify,
474}
475
476#[derive(Debug, Clone)]
477struct ConnectionCounter {
478 inner: Arc<ConnectionCounterInner>,
479}
480
481impl ConnectionCounter {
482 fn new() -> Self {
483 Self {
484 inner: Arc::new(ConnectionCounterInner {
485 count: Default::default(),
486 notify: Notify::new(),
487 }),
488 }
489 }
490
491 fn current(&self) -> usize {
492 self.inner.count.load(Ordering::SeqCst)
493 }
494
495 fn get_one(&self) -> OneConnection {
497 self.inner.count.fetch_add(1, Ordering::SeqCst);
498 OneConnection {
499 inner: self.inner.clone(),
500 }
501 }
502
503 fn is_idle(&self) -> bool {
504 self.inner.count.load(Ordering::SeqCst) == 0
505 }
506
507 fn idle_stream(self) -> impl Stream<Item = ()> {
516 n0_future::stream::unfold(self, |c| async move {
517 c.inner.notify.notified().await;
518 Some(((), c))
519 })
520 }
521}
522
523#[derive(Debug)]
525struct OneConnection {
526 inner: Arc<ConnectionCounterInner>,
527}
528
529impl Drop for OneConnection {
530 fn drop(&mut self) {
531 if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 {
532 self.inner.notify.notify_waiters();
533 }
534 }
535}
536
537#[cfg(test)]
538mod tests {
539 use std::{collections::BTreeMap, sync::Arc, time::Duration};
540
541 use iroh::{
542 discovery::static_provider::StaticProvider,
543 endpoint::{Connection, ConnectionType},
544 protocol::{AcceptError, ProtocolHandler, Router},
545 Endpoint, NodeAddr, NodeId, SecretKey, Watcher,
546 };
547 use n0_future::{io, stream, BufferedStreamExt, StreamExt};
548 use n0_snafu::ResultExt;
549 use testresult::TestResult;
550 use tracing::trace;
551
552 use super::{ConnectionPool, Options, PoolConnectError};
553 use crate::util::connection_pool::OnConnected;
554
555 const ECHO_ALPN: &[u8] = b"echo";
556
557 #[derive(Debug, Clone)]
558 struct Echo;
559
560 impl ProtocolHandler for Echo {
561 async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
562 let conn_id = connection.stable_id();
563 let id = connection.remote_node_id().map_err(AcceptError::from_err)?;
564 trace!(%id, %conn_id, "Accepting echo connection");
565 loop {
566 match connection.accept_bi().await {
567 Ok((mut send, mut recv)) => {
568 trace!(%id, %conn_id, "Accepted echo request");
569 tokio::io::copy(&mut recv, &mut send).await?;
570 send.finish().map_err(AcceptError::from_err)?;
571 }
572 Err(e) => {
573 trace!(%id, %conn_id, "Failed to accept echo request {e}");
574 break;
575 }
576 }
577 }
578 Ok(())
579 }
580 }
581
582 async fn echo_client(conn: &Connection, text: &[u8]) -> n0_snafu::Result<Vec<u8>> {
583 let conn_id = conn.stable_id();
584 let id = conn.remote_node_id().e()?;
585 trace!(%id, %conn_id, "Sending echo request");
586 let (mut send, mut recv) = conn.open_bi().await.e()?;
587 send.write_all(text).await.e()?;
588 send.finish().e()?;
589 let response = recv.read_to_end(1000).await.e()?;
590 trace!(%id, %conn_id, "Received echo response");
591 Ok(response)
592 }
593
594 async fn echo_server() -> TestResult<(NodeAddr, Router)> {
595 let endpoint = iroh::Endpoint::builder()
596 .alpns(vec![ECHO_ALPN.to_vec()])
597 .bind()
598 .await?;
599 endpoint.online().await;
600 let addr = endpoint.node_addr();
601 let router = iroh::protocol::Router::builder(endpoint)
602 .accept(ECHO_ALPN, Echo)
603 .spawn();
604
605 Ok((addr, router))
606 }
607
608 async fn echo_servers(n: usize) -> TestResult<(Vec<NodeId>, Vec<Router>, StaticProvider)> {
609 let res = stream::iter(0..n)
610 .map(|_| echo_server())
611 .buffered_unordered(16)
612 .collect::<Vec<_>>()
613 .await;
614 let res: Vec<(NodeAddr, Router)> = res.into_iter().collect::<TestResult<Vec<_>>>()?;
615 let (addrs, routers): (Vec<_>, Vec<_>) = res.into_iter().unzip();
616 let ids = addrs.iter().map(|a| a.node_id).collect::<Vec<_>>();
617 let discovery = StaticProvider::from_node_info(addrs);
618 Ok((ids, routers, discovery))
619 }
620
621 async fn shutdown_routers(routers: Vec<Router>) {
622 stream::iter(routers)
623 .for_each_concurrent(16, |router| async move {
624 let _ = router.shutdown().await;
625 })
626 .await;
627 }
628
629 fn test_options() -> Options {
630 Options {
631 idle_timeout: Duration::from_millis(100),
632 connect_timeout: Duration::from_secs(5),
633 max_connections: 32,
634 on_connected: None,
635 }
636 }
637
638 struct EchoClient {
639 pool: ConnectionPool,
640 }
641
642 impl EchoClient {
643 async fn echo(
644 &self,
645 id: NodeId,
646 text: Vec<u8>,
647 ) -> Result<Result<(usize, Vec<u8>), n0_snafu::Error>, PoolConnectError> {
648 let conn = self.pool.get_or_connect(id).await?;
649 let id = conn.stable_id();
650 match echo_client(&conn, &text).await {
651 Ok(res) => Ok(Ok((id, res))),
652 Err(e) => Ok(Err(e)),
653 }
654 }
655 }
656
657 #[tokio::test]
658 async fn connection_pool_errors() -> TestResult<()> {
660 let discovery = StaticProvider::new();
662 let endpoint = iroh::Endpoint::builder()
663 .discovery(discovery.clone())
664 .bind()
665 .await?;
666 let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
667 let client = EchoClient { pool };
668 {
669 let non_existing = SecretKey::from_bytes(&[0; 32]).public();
670 let res = client.echo(non_existing, b"Hello, world!".to_vec()).await;
671 assert!(matches!(res, Err(PoolConnectError::ConnectError { .. })));
674 }
675 {
676 let non_listening = SecretKey::from_bytes(&[0; 32]).public();
677 discovery.add_node_info(NodeAddr {
679 node_id: non_listening,
680 relay_url: None,
681 direct_addresses: vec!["127.0.0.1:12121".parse().unwrap()]
682 .into_iter()
683 .collect(),
684 });
685 let res = client.echo(non_listening, b"Hello, world!".to_vec()).await;
688 assert!(matches!(res, Err(PoolConnectError::Timeout)));
689 }
690 Ok(())
691 }
692
693 #[tokio::test]
694 async fn connection_pool_smoke() -> TestResult<()> {
696 let n = 32;
697 let (ids, routers, discovery) = echo_servers(n).await?;
698 let endpoint = iroh::Endpoint::builder()
700 .discovery(discovery.clone())
701 .bind()
702 .await?;
703 let pool = ConnectionPool::new(endpoint.clone(), ECHO_ALPN, test_options());
704 let client = EchoClient { pool };
705 let mut connection_ids = BTreeMap::new();
706 let msg = b"Hello, pool!".to_vec();
707 for id in &ids {
708 let (cid1, res) = client.echo(*id, msg.clone()).await??;
709 assert_eq!(res, msg);
710 let (cid2, res) = client.echo(*id, msg.clone()).await??;
711 assert_eq!(res, msg);
712 assert_eq!(cid1, cid2);
713 connection_ids.insert(id, cid1);
714 }
715 tokio::time::sleep(Duration::from_millis(1000)).await;
716 for id in &ids {
717 let cid1 = *connection_ids.get(id).expect("Connection ID not found");
718 let (cid2, res) = client.echo(*id, msg.clone()).await??;
719 assert_eq!(res, msg);
720 assert_ne!(cid1, cid2);
721 }
722 shutdown_routers(routers).await;
723 Ok(())
724 }
725
726 #[tokio::test]
729 async fn connection_pool_idle() -> TestResult<()> {
731 let n = 32;
732 let (ids, routers, discovery) = echo_servers(n).await?;
733 let endpoint = iroh::Endpoint::builder()
735 .discovery(discovery.clone())
736 .bind()
737 .await?;
738 let pool = ConnectionPool::new(
739 endpoint.clone(),
740 ECHO_ALPN,
741 Options {
742 idle_timeout: Duration::from_secs(100),
743 max_connections: 8,
744 ..test_options()
745 },
746 );
747 let client = EchoClient { pool };
748 let msg = b"Hello, pool!".to_vec();
749 for id in &ids {
750 let (_, res) = client.echo(*id, msg.clone()).await??;
751 assert_eq!(res, msg);
752 }
753 shutdown_routers(routers).await;
754 Ok(())
755 }
756
757 #[tokio::test]
761 async fn on_connected_error() -> TestResult<()> {
763 let n = 1;
764 let (ids, routers, discovery) = echo_servers(n).await?;
765 let endpoint = iroh::Endpoint::builder()
766 .discovery(discovery)
767 .bind()
768 .await?;
769 let on_connected: OnConnected =
770 Arc::new(|_, _| Box::pin(async { Err(io::Error::other("on_connect failed")) }));
771 let pool = ConnectionPool::new(
772 endpoint,
773 ECHO_ALPN,
774 Options {
775 on_connected: Some(on_connected),
776 ..test_options()
777 },
778 );
779 let client = EchoClient { pool };
780 let msg = b"Hello, pool!".to_vec();
781 for id in &ids {
782 let res = client.echo(*id, msg.clone()).await;
783 assert!(matches!(res, Err(PoolConnectError::OnConnectError { .. })));
784 }
785 shutdown_routers(routers).await;
786 Ok(())
787 }
788
789 #[tokio::test]
791 async fn on_connected_direct() -> TestResult<()> {
793 let n = 1;
794 let (ids, routers, discovery) = echo_servers(n).await?;
795 let endpoint = iroh::Endpoint::builder()
796 .discovery(discovery)
797 .bind()
798 .await?;
799 let on_connected = |ep: Endpoint, conn: Connection| async move {
800 let Ok(id) = conn.remote_node_id() else {
801 return Err(io::Error::other("unable to get node id"));
802 };
803 let Some(watcher) = ep.conn_type(id) else {
804 return Err(io::Error::other("unable to get conn_type watcher"));
805 };
806 let mut stream = watcher.stream();
807 while let Some(status) = stream.next().await {
808 if let ConnectionType::Direct { .. } = status {
809 return Ok(());
810 }
811 }
812 Err(io::Error::other("connection closed before becoming direct"))
813 };
814 let pool = ConnectionPool::new(
815 endpoint,
816 ECHO_ALPN,
817 test_options().with_on_connected(on_connected),
818 );
819 let client = EchoClient { pool };
820 let msg = b"Hello, pool!".to_vec();
821 for id in &ids {
822 let res = client.echo(*id, msg.clone()).await;
823 assert!(res.is_ok());
824 }
825 shutdown_routers(routers).await;
826 Ok(())
827 }
828
829 #[tokio::test]
834 async fn watch_close() -> TestResult<()> {
836 let n = 1;
837 let (ids, routers, discovery) = echo_servers(n).await?;
838 let endpoint = iroh::Endpoint::builder()
839 .discovery(discovery)
840 .bind()
841 .await?;
842
843 let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
844 let conn = pool.get_or_connect(ids[0]).await?;
845 let cid1 = conn.stable_id();
846 conn.close(0u32.into(), b"test");
847 tokio::time::sleep(Duration::from_millis(500)).await;
848 let conn = pool.get_or_connect(ids[0]).await?;
849 let cid2 = conn.stable_id();
850 assert_ne!(cid1, cid2);
851 shutdown_routers(routers).await;
852 Ok(())
853 }
854}