1use std::{
12 collections::{HashMap, VecDeque},
13 io,
14 ops::Deref,
15 sync::{
16 atomic::{AtomicUsize, Ordering},
17 Arc,
18 },
19 time::Duration,
20};
21
22use iroh::{
23 endpoint::{ConnectError, Connection},
24 Endpoint, NodeId,
25};
26use n0_future::{
27 future::{self},
28 FuturesUnordered, MaybeFuture, Stream, StreamExt,
29};
30use snafu::Snafu;
31use tokio::sync::{
32 mpsc::{self, error::SendError as TokioSendError},
33 oneshot, Notify,
34};
35use tokio_util::time::FutureExt as TimeFutureExt;
36use tracing::{debug, error, info, trace};
37
38pub type OnConnected =
39 Arc<dyn Fn(&Endpoint, &Connection) -> n0_future::future::Boxed<io::Result<()>> + Send + Sync>;
40
41#[derive(derive_more::Debug, Clone)]
43pub struct Options {
44 pub idle_timeout: Duration,
46 pub connect_timeout: Duration,
48 pub max_connections: usize,
50 #[debug(skip)]
54 pub on_connected: Option<OnConnected>,
55}
56
57impl Default for Options {
58 fn default() -> Self {
59 Self {
60 idle_timeout: Duration::from_secs(5),
61 connect_timeout: Duration::from_secs(1),
62 max_connections: 1024,
63 on_connected: None,
64 }
65 }
66}
67
68#[derive(Debug)]
70pub struct ConnectionRef {
71 connection: iroh::endpoint::Connection,
72 _permit: OneConnection,
73}
74
75impl Deref for ConnectionRef {
76 type Target = iroh::endpoint::Connection;
77
78 fn deref(&self) -> &Self::Target {
79 &self.connection
80 }
81}
82
83impl ConnectionRef {
84 fn new(connection: iroh::endpoint::Connection, counter: OneConnection) -> Self {
85 Self {
86 connection,
87 _permit: counter,
88 }
89 }
90}
91
92#[derive(Debug, Clone, Snafu)]
97#[snafu(module)]
98pub enum PoolConnectError {
99 Shutdown,
101 Timeout,
103 TooManyConnections,
105 ConnectError { source: Arc<ConnectError> },
107 OnConnectError { source: Arc<io::Error> },
109}
110
111impl From<ConnectError> for PoolConnectError {
112 fn from(e: ConnectError) -> Self {
113 PoolConnectError::ConnectError {
114 source: Arc::new(e),
115 }
116 }
117}
118
119impl From<io::Error> for PoolConnectError {
120 fn from(e: io::Error) -> Self {
121 PoolConnectError::OnConnectError {
122 source: Arc::new(e),
123 }
124 }
125}
126
127#[derive(Debug, Snafu)]
131#[snafu(module)]
132pub enum ConnectionPoolError {
133 Shutdown,
135}
136
137enum ActorMessage {
138 RequestRef(RequestRef),
139 ConnectionIdle { id: NodeId },
140 ConnectionShutdown { id: NodeId },
141}
142
143struct RequestRef {
144 id: NodeId,
145 tx: oneshot::Sender<Result<ConnectionRef, PoolConnectError>>,
146}
147
148struct Context {
149 options: Options,
150 endpoint: Endpoint,
151 owner: ConnectionPool,
152 alpn: Vec<u8>,
153}
154
155impl Context {
156 async fn run_connection_actor(
157 self: Arc<Self>,
158 node_id: NodeId,
159 mut rx: mpsc::Receiver<RequestRef>,
160 ) {
161 let context = self;
162
163 let conn_fut = {
164 let context = context.clone();
165 async move {
166 let conn = context
167 .endpoint
168 .connect(node_id, &context.alpn)
169 .await
170 .map_err(PoolConnectError::from)?;
171 if let Some(on_connect) = &context.options.on_connected {
172 on_connect(&context.endpoint, &conn)
173 .await
174 .map_err(PoolConnectError::from)?;
175 }
176 Result::<Connection, PoolConnectError>::Ok(conn)
177 }
178 };
179
180 let state = conn_fut
182 .timeout(context.options.connect_timeout)
183 .await
184 .map_err(|_| PoolConnectError::Timeout)
185 .and_then(|r| r);
186 let conn_close = match &state {
187 Ok(conn) => {
188 let conn = conn.clone();
189 MaybeFuture::Some(async move { conn.closed().await })
190 }
191 Err(e) => {
192 debug!(%node_id, "Failed to connect {e:?}, requesting shutdown");
193 if context.owner.close(node_id).await.is_err() {
194 return;
195 }
196 MaybeFuture::None
197 }
198 };
199
200 let counter = ConnectionCounter::new();
201 let idle_timer = MaybeFuture::default();
202 let idle_stream = counter.clone().idle_stream();
203
204 tokio::pin!(idle_timer, idle_stream, conn_close);
205
206 loop {
207 tokio::select! {
208 biased;
209
210 handler = rx.recv() => {
212 match handler {
213 Some(RequestRef { id, tx }) => {
214 assert!(id == node_id, "Not for me!");
215 match &state {
216 Ok(state) => {
217 let res = ConnectionRef::new(state.clone(), counter.get_one());
218 info!(%node_id, "Handing out ConnectionRef {}", counter.current());
219
220 idle_timer.as_mut().set_none();
222 tx.send(Ok(res)).ok();
223 }
224 Err(cause) => {
225 tx.send(Err(cause.clone())).ok();
226 }
227 }
228 }
229 None => {
230 break;
232 }
233 }
234 }
235
236 _ = &mut conn_close => {
237 context.owner.close(node_id).await.ok();
239 }
240
241 _ = idle_stream.next() => {
242 if !counter.is_idle() {
243 continue;
244 };
245 trace!(%node_id, "Idle");
247 if context.owner.idle(node_id).await.is_err() {
248 break;
250 }
251 idle_timer.as_mut().set_future(tokio::time::sleep(context.options.idle_timeout));
253 }
254
255 _ = &mut idle_timer => {
257 trace!(%node_id, "Idle timer expired, requesting shutdown");
258 context.owner.close(node_id).await.ok();
259 }
261 }
262 }
263
264 if let Ok(connection) = state {
265 let reason = if counter.is_idle() { b"idle" } else { b"drop" };
266 connection.close(0u32.into(), reason);
267 }
268
269 trace!(%node_id, "Connection actor shutting down");
270 }
271}
272
273struct Actor {
274 rx: mpsc::Receiver<ActorMessage>,
275 connections: HashMap<NodeId, mpsc::Sender<RequestRef>>,
276 context: Arc<Context>,
277 idle: VecDeque<NodeId>,
280 tasks: FuturesUnordered<future::Boxed<()>>,
282}
283
284impl Actor {
285 pub fn new(
286 endpoint: Endpoint,
287 alpn: &[u8],
288 options: Options,
289 ) -> (Self, mpsc::Sender<ActorMessage>) {
290 let (tx, rx) = mpsc::channel(100);
291 (
292 Self {
293 rx,
294 connections: HashMap::new(),
295 idle: VecDeque::new(),
296 context: Arc::new(Context {
297 options,
298 alpn: alpn.to_vec(),
299 endpoint,
300 owner: ConnectionPool { tx: tx.clone() },
301 }),
302 tasks: FuturesUnordered::new(),
303 },
304 tx,
305 )
306 }
307
308 fn add_idle(&mut self, id: NodeId) {
309 self.remove_idle(id);
310 self.idle.push_back(id);
311 }
312
313 fn remove_idle(&mut self, id: NodeId) {
314 self.idle.retain(|&x| x != id);
315 }
316
317 fn pop_oldest_idle(&mut self) -> Option<NodeId> {
318 self.idle.pop_front()
319 }
320
321 fn remove_connection(&mut self, id: NodeId) {
322 self.connections.remove(&id);
323 self.remove_idle(id);
324 }
325
326 async fn handle_msg(&mut self, msg: ActorMessage) {
327 match msg {
328 ActorMessage::RequestRef(mut msg) => {
329 let id = msg.id;
330 self.remove_idle(id);
331 if let Some(conn_tx) = self.connections.get(&id) {
333 if let Err(TokioSendError(e)) = conn_tx.send(msg).await {
334 msg = e;
335 } else {
336 return;
337 }
338 self.remove_connection(id);
340 }
341
342 if self.connections.len() >= self.context.options.max_connections {
344 if let Some(idle) = self.pop_oldest_idle() {
345 trace!("removing oldest idle connection {}", idle);
347 self.connections.remove(&idle);
348 } else {
349 msg.tx.send(Err(PoolConnectError::TooManyConnections)).ok();
350 return;
351 }
352 }
353 let (conn_tx, conn_rx) = mpsc::channel(100);
354 self.connections.insert(id, conn_tx.clone());
355
356 let context = self.context.clone();
357
358 self.tasks
359 .push(Box::pin(context.run_connection_actor(id, conn_rx)));
360
361 if conn_tx.send(msg).await.is_err() {
363 error!(%id, "Failed to send handler to new connection actor");
364 self.connections.remove(&id);
365 }
366 }
367 ActorMessage::ConnectionIdle { id } => {
368 self.add_idle(id);
369 trace!(%id, "connection idle");
370 }
371 ActorMessage::ConnectionShutdown { id } => {
372 self.remove_connection(id);
374 trace!(%id, "removed connection");
375 }
376 }
377 }
378
379 pub async fn run(mut self) {
380 loop {
381 tokio::select! {
382 biased;
383
384 msg = self.rx.recv() => {
385 if let Some(msg) = msg {
386 self.handle_msg(msg).await;
387 } else {
388 break;
389 }
390 }
391
392 _ = self.tasks.next(), if !self.tasks.is_empty() => {}
393 }
394 }
395 }
396}
397
398#[derive(Debug, Clone)]
400pub struct ConnectionPool {
401 tx: mpsc::Sender<ActorMessage>,
402}
403
404impl ConnectionPool {
405 pub fn new(endpoint: Endpoint, alpn: &[u8], options: Options) -> Self {
406 let (actor, tx) = Actor::new(endpoint, alpn, options);
407
408 tokio::spawn(actor.run());
410
411 Self { tx }
412 }
413
414 pub async fn get_or_connect(
419 &self,
420 id: NodeId,
421 ) -> std::result::Result<ConnectionRef, PoolConnectError> {
422 let (tx, rx) = oneshot::channel();
423 self.tx
424 .send(ActorMessage::RequestRef(RequestRef { id, tx }))
425 .await
426 .map_err(|_| PoolConnectError::Shutdown)?;
427 rx.await.map_err(|_| PoolConnectError::Shutdown)?
428 }
429
430 pub async fn close(&self, id: NodeId) -> std::result::Result<(), ConnectionPoolError> {
435 self.tx
436 .send(ActorMessage::ConnectionShutdown { id })
437 .await
438 .map_err(|_| ConnectionPoolError::Shutdown)?;
439 Ok(())
440 }
441
442 pub(crate) async fn idle(&self, id: NodeId) -> std::result::Result<(), ConnectionPoolError> {
446 self.tx
447 .send(ActorMessage::ConnectionIdle { id })
448 .await
449 .map_err(|_| ConnectionPoolError::Shutdown)?;
450 Ok(())
451 }
452}
453
454#[derive(Debug)]
455struct ConnectionCounterInner {
456 count: AtomicUsize,
457 notify: Notify,
458}
459
460#[derive(Debug, Clone)]
461struct ConnectionCounter {
462 inner: Arc<ConnectionCounterInner>,
463}
464
465impl ConnectionCounter {
466 fn new() -> Self {
467 Self {
468 inner: Arc::new(ConnectionCounterInner {
469 count: Default::default(),
470 notify: Notify::new(),
471 }),
472 }
473 }
474
475 fn current(&self) -> usize {
476 self.inner.count.load(Ordering::SeqCst)
477 }
478
479 fn get_one(&self) -> OneConnection {
481 self.inner.count.fetch_add(1, Ordering::SeqCst);
482 OneConnection {
483 inner: self.inner.clone(),
484 }
485 }
486
487 fn is_idle(&self) -> bool {
488 self.inner.count.load(Ordering::SeqCst) == 0
489 }
490
491 fn idle_stream(self) -> impl Stream<Item = ()> {
500 n0_future::stream::unfold(self, |c| async move {
501 c.inner.notify.notified().await;
502 Some(((), c))
503 })
504 }
505}
506
507#[derive(Debug)]
509struct OneConnection {
510 inner: Arc<ConnectionCounterInner>,
511}
512
513impl Drop for OneConnection {
514 fn drop(&mut self) {
515 if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 {
516 self.inner.notify.notify_waiters();
517 }
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use std::{collections::BTreeMap, sync::Arc, time::Duration};
524
525 use iroh::{
526 discovery::static_provider::StaticProvider,
527 endpoint::Connection,
528 protocol::{AcceptError, ProtocolHandler, Router},
529 NodeAddr, NodeId, SecretKey, Watcher,
530 };
531 use n0_future::{io, stream, BufferedStreamExt, StreamExt};
532 use n0_snafu::ResultExt;
533 use testresult::TestResult;
534 use tracing::trace;
535
536 use super::{ConnectionPool, Options, PoolConnectError};
537 use crate::util::connection_pool::OnConnected;
538
539 const ECHO_ALPN: &[u8] = b"echo";
540
541 #[derive(Debug, Clone)]
542 struct Echo;
543
544 impl ProtocolHandler for Echo {
545 async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
546 let conn_id = connection.stable_id();
547 let id = connection.remote_node_id().map_err(AcceptError::from_err)?;
548 trace!(%id, %conn_id, "Accepting echo connection");
549 loop {
550 match connection.accept_bi().await {
551 Ok((mut send, mut recv)) => {
552 trace!(%id, %conn_id, "Accepted echo request");
553 tokio::io::copy(&mut recv, &mut send).await?;
554 send.finish().map_err(AcceptError::from_err)?;
555 }
556 Err(e) => {
557 trace!(%id, %conn_id, "Failed to accept echo request {e}");
558 break;
559 }
560 }
561 }
562 Ok(())
563 }
564 }
565
566 async fn echo_client(conn: &Connection, text: &[u8]) -> n0_snafu::Result<Vec<u8>> {
567 let conn_id = conn.stable_id();
568 let id = conn.remote_node_id().e()?;
569 trace!(%id, %conn_id, "Sending echo request");
570 let (mut send, mut recv) = conn.open_bi().await.e()?;
571 send.write_all(text).await.e()?;
572 send.finish().e()?;
573 let response = recv.read_to_end(1000).await.e()?;
574 trace!(%id, %conn_id, "Received echo response");
575 Ok(response)
576 }
577
578 async fn echo_server() -> TestResult<(NodeAddr, Router)> {
579 let endpoint = iroh::Endpoint::builder()
580 .alpns(vec![ECHO_ALPN.to_vec()])
581 .bind()
582 .await?;
583 endpoint.home_relay().initialized().await;
584 let addr = endpoint.node_addr().initialized().await;
585 let router = iroh::protocol::Router::builder(endpoint)
586 .accept(ECHO_ALPN, Echo)
587 .spawn();
588
589 Ok((addr, router))
590 }
591
592 async fn echo_servers(n: usize) -> TestResult<(Vec<NodeId>, Vec<Router>, StaticProvider)> {
593 let res = stream::iter(0..n)
594 .map(|_| echo_server())
595 .buffered_unordered(16)
596 .collect::<Vec<_>>()
597 .await;
598 let res: Vec<(NodeAddr, Router)> = res.into_iter().collect::<TestResult<Vec<_>>>()?;
599 let (addrs, routers): (Vec<_>, Vec<_>) = res.into_iter().unzip();
600 let ids = addrs.iter().map(|a| a.node_id).collect::<Vec<_>>();
601 let discovery = StaticProvider::from_node_info(addrs);
602 Ok((ids, routers, discovery))
603 }
604
605 async fn shutdown_routers(routers: Vec<Router>) {
606 stream::iter(routers)
607 .for_each_concurrent(16, |router| async move {
608 let _ = router.shutdown().await;
609 })
610 .await;
611 }
612
613 fn test_options() -> Options {
614 Options {
615 idle_timeout: Duration::from_millis(100),
616 connect_timeout: Duration::from_secs(5),
617 max_connections: 32,
618 on_connected: None,
619 }
620 }
621
622 struct EchoClient {
623 pool: ConnectionPool,
624 }
625
626 impl EchoClient {
627 async fn echo(
628 &self,
629 id: NodeId,
630 text: Vec<u8>,
631 ) -> Result<Result<(usize, Vec<u8>), n0_snafu::Error>, PoolConnectError> {
632 let conn = self.pool.get_or_connect(id).await?;
633 let id = conn.stable_id();
634 match echo_client(&conn, &text).await {
635 Ok(res) => Ok(Ok((id, res))),
636 Err(e) => Ok(Err(e)),
637 }
638 }
639 }
640
641 #[tokio::test]
642 async fn connection_pool_errors() -> TestResult<()> {
644 let discovery = StaticProvider::new();
646 let endpoint = iroh::Endpoint::builder()
647 .discovery(discovery.clone())
648 .bind()
649 .await?;
650 let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
651 let client = EchoClient { pool };
652 {
653 let non_existing = SecretKey::from_bytes(&[0; 32]).public();
654 let res = client.echo(non_existing, b"Hello, world!".to_vec()).await;
655 assert!(matches!(res, Err(PoolConnectError::ConnectError { .. })));
658 }
659 {
660 let non_listening = SecretKey::from_bytes(&[0; 32]).public();
661 discovery.add_node_info(NodeAddr {
663 node_id: non_listening,
664 relay_url: None,
665 direct_addresses: vec!["127.0.0.1:12121".parse().unwrap()]
666 .into_iter()
667 .collect(),
668 });
669 let res = client.echo(non_listening, b"Hello, world!".to_vec()).await;
672 assert!(matches!(res, Err(PoolConnectError::Timeout)));
673 }
674 Ok(())
675 }
676
677 #[tokio::test]
678 async fn connection_pool_smoke() -> TestResult<()> {
680 let n = 32;
681 let (ids, routers, discovery) = echo_servers(n).await?;
682 let endpoint = iroh::Endpoint::builder()
684 .discovery(discovery.clone())
685 .bind()
686 .await?;
687 let pool = ConnectionPool::new(endpoint.clone(), ECHO_ALPN, test_options());
688 let client = EchoClient { pool };
689 let mut connection_ids = BTreeMap::new();
690 let msg = b"Hello, pool!".to_vec();
691 for id in &ids {
692 let (cid1, res) = client.echo(*id, msg.clone()).await??;
693 assert_eq!(res, msg);
694 let (cid2, res) = client.echo(*id, msg.clone()).await??;
695 assert_eq!(res, msg);
696 assert_eq!(cid1, cid2);
697 connection_ids.insert(id, cid1);
698 }
699 tokio::time::sleep(Duration::from_millis(1000)).await;
700 for id in &ids {
701 let cid1 = *connection_ids.get(id).expect("Connection ID not found");
702 let (cid2, res) = client.echo(*id, msg.clone()).await??;
703 assert_eq!(res, msg);
704 assert_ne!(cid1, cid2);
705 }
706 shutdown_routers(routers).await;
707 Ok(())
708 }
709
710 #[tokio::test]
713 async fn connection_pool_idle() -> TestResult<()> {
715 let n = 32;
716 let (ids, routers, discovery) = echo_servers(n).await?;
717 let endpoint = iroh::Endpoint::builder()
719 .discovery(discovery.clone())
720 .bind()
721 .await?;
722 let pool = ConnectionPool::new(
723 endpoint.clone(),
724 ECHO_ALPN,
725 Options {
726 idle_timeout: Duration::from_secs(100),
727 max_connections: 8,
728 ..test_options()
729 },
730 );
731 let client = EchoClient { pool };
732 let msg = b"Hello, pool!".to_vec();
733 for id in &ids {
734 let (_, res) = client.echo(*id, msg.clone()).await??;
735 assert_eq!(res, msg);
736 }
737 shutdown_routers(routers).await;
738 Ok(())
739 }
740
741 #[tokio::test]
745 async fn on_connected_error() -> TestResult<()> {
747 let n = 1;
748 let (ids, routers, discovery) = echo_servers(n).await?;
749 let endpoint = iroh::Endpoint::builder()
750 .discovery(discovery)
751 .bind()
752 .await?;
753 let on_connected: OnConnected =
754 Arc::new(|_, _| Box::pin(async { Err(io::Error::other("on_connect failed")) }));
755 let pool = ConnectionPool::new(
756 endpoint,
757 ECHO_ALPN,
758 Options {
759 on_connected: Some(on_connected),
760 ..test_options()
761 },
762 );
763 let client = EchoClient { pool };
764 let msg = b"Hello, pool!".to_vec();
765 for id in &ids {
766 let res = client.echo(*id, msg.clone()).await;
767 assert!(matches!(res, Err(PoolConnectError::OnConnectError { .. })));
768 }
769 shutdown_routers(routers).await;
770 Ok(())
771 }
772
773 #[tokio::test]
777 async fn on_connected_timeout() -> TestResult<()> {
779 let n = 1;
780 let (ids, routers, discovery) = echo_servers(n).await?;
781 let endpoint = iroh::Endpoint::builder()
782 .discovery(discovery)
783 .bind()
784 .await?;
785 let on_connected: OnConnected = Arc::new(|_, _| {
786 Box::pin(async {
787 tokio::time::sleep(Duration::from_secs(20)).await;
788 Ok(())
789 })
790 });
791 let pool = ConnectionPool::new(
792 endpoint,
793 ECHO_ALPN,
794 Options {
795 on_connected: Some(on_connected),
796 ..test_options()
797 },
798 );
799 let client = EchoClient { pool };
800 let msg = b"Hello, pool!".to_vec();
801 for id in &ids {
802 let res = client.echo(*id, msg.clone()).await;
803 assert!(matches!(res, Err(PoolConnectError::Timeout)));
804 }
805 shutdown_routers(routers).await;
806 Ok(())
807 }
808
809 #[tokio::test]
814 async fn watch_close() -> TestResult<()> {
816 let n = 1;
817 let (ids, routers, discovery) = echo_servers(n).await?;
818 let endpoint = iroh::Endpoint::builder()
819 .discovery(discovery)
820 .bind()
821 .await?;
822
823 let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
824 let conn = pool.get_or_connect(ids[0]).await?;
825 let cid1 = conn.stable_id();
826 conn.close(0u32.into(), b"test");
827 tokio::time::sleep(Duration::from_millis(500)).await;
828 let conn = pool.get_or_connect(ids[0]).await?;
829 let cid2 = conn.stable_id();
830 assert_ne!(cid1, cid2);
831 shutdown_routers(routers).await;
832 Ok(())
833 }
834}