1use super::conn::{ConnectorService, EstablishedClientConnection};
2use crate::stream::Socket;
3use parking_lot::Mutex;
4use rama_core::error::{BoxError, ErrorContext, OpaqueError};
5use rama_core::{Context, Layer, Service};
6use rama_utils::macros::generate_field_setters;
7use std::collections::VecDeque;
8use std::fmt::Debug;
9use std::ops::{Deref, DerefMut};
10use std::pin::Pin;
11use std::sync::OnceLock;
12use std::sync::{Arc, Weak};
13use std::time::Duration;
14use std::{future::Future, net::SocketAddr};
15use tokio::io::{AsyncRead, AsyncWrite};
16use tokio::sync::{OwnedSemaphorePermit, Semaphore};
17use tokio::time::timeout;
18
19pub trait Pool<C, ID>: Send + Sync + 'static {
24 type Connection: Send;
25 type CreatePermit: Send;
26
27 fn get_conn(
33 &self,
34 id: &ID,
35 ) -> impl Future<
36 Output = Result<ConnectionResult<Self::Connection, Self::CreatePermit>, OpaqueError>,
37 > + Send;
38
39 fn create(
44 &self,
45 id: ID,
46 conn: C,
47 create_permit: Self::CreatePermit,
48 ) -> impl Future<Output = Self::Connection> + Send;
49}
50
51pub enum ConnectionResult<C, P> {
53 Connection(C),
55 CreatePermit(P),
58}
59
60impl<C: Debug, P: Debug> Debug for ConnectionResult<C, P> {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 Self::Connection(arg0) => f.debug_tuple("Connection").field(arg0).finish(),
64 Self::CreatePermit(arg0) => f.debug_tuple("CreatePermit").field(arg0).finish(),
65 }
66 }
67}
68
69#[derive(Debug, Clone, Default)]
70#[non_exhaustive]
71pub struct NoPool;
77
78impl<C, ID> Pool<C, ID> for NoPool
79where
80 C: Send + 'static,
81 ID: Clone + Send + Sync + PartialEq + 'static,
82{
83 type Connection = C;
84 type CreatePermit = ();
85
86 async fn get_conn(
87 &self,
88 _id: &ID,
89 ) -> Result<ConnectionResult<Self::Connection, Self::CreatePermit>, OpaqueError> {
90 Ok(ConnectionResult::CreatePermit(()))
91 }
92
93 async fn create(&self, _id: ID, conn: C, _permit: Self::CreatePermit) -> Self::Connection {
94 conn
95 }
96}
97
98pub struct LeasedConnection<C, ID> {
105 pooled_conn: Option<PooledConnection<C, ID>>,
106 active_slot: ActiveSlot,
107 returner: Weak<dyn Fn(PooledConnection<C, ID>) + Send + Sync>,
108}
109
110impl<C, ID> LeasedConnection<C, ID> {
111 pub fn into_connection(mut self) -> C {
112 self.pooled_conn.take().expect("only None after drop").conn
113 }
114}
115
116struct PooledConnection<C, ID> {
121 conn: C,
122 id: ID,
123 pool_slot: PoolSlot,
124}
125
126pub struct FiFoReuseLruDropPool<C, ID> {
128 storage: Arc<Mutex<VecDeque<PooledConnection<C, ID>>>>,
129 total_slots: Arc<Semaphore>,
130 active_slots: Arc<Semaphore>,
131 returner: OnceLock<Arc<dyn Fn(PooledConnection<C, ID>) + Send + Sync>>,
132}
133
134impl<C, ID> Clone for FiFoReuseLruDropPool<C, ID> {
135 fn clone(&self) -> Self {
136 Self {
137 storage: self.storage.clone(),
138 total_slots: self.total_slots.clone(),
139 active_slots: self.active_slots.clone(),
140 returner: self.returner.clone(),
141 }
142 }
143}
144
145impl<C, ID> FiFoReuseLruDropPool<C, ID> {
146 pub fn new(max_active: usize, max_total: usize) -> Result<Self, OpaqueError> {
147 if max_active == 0 || max_total == 0 {
148 return Err(OpaqueError::from_display(
149 "max_active or max_total of 0 will make this pool unusable",
150 ));
151 }
152 if max_active > max_total {
153 return Err(OpaqueError::from_display(
154 "max_active should be smaller or equal to max_total",
155 ));
156 }
157 let storage = Arc::new(Mutex::new(VecDeque::with_capacity(max_total)));
158 Ok(Self {
159 storage,
160 returner: OnceLock::new(),
161 total_slots: Arc::new(Semaphore::const_new(max_total)),
162 active_slots: Arc::new(Semaphore::const_new(max_active)),
163 })
164 }
165}
166
167impl<C, ID> FiFoReuseLruDropPool<C, ID>
168where
169 C: Send + 'static,
170 ID: Send + 'static,
171{
172 fn returner(&self) -> Weak<dyn Fn(PooledConnection<C, ID>) + Send + Sync> {
175 let returner = self.returner.get_or_init(|| {
176 let weak_storage = Arc::downgrade(&self.storage);
177 Arc::new(move |conn| {
178 if let Some(storage) = weak_storage.upgrade() {
179 storage.lock().push_front(conn)
180 }
181 })
182 });
183
184 Arc::downgrade(returner)
185 }
186}
187
188impl<C, ID> Pool<C, ID> for FiFoReuseLruDropPool<C, ID>
189where
190 C: Send + 'static,
191 ID: Clone + Send + Sync + PartialEq + 'static,
192{
193 type Connection = LeasedConnection<C, ID>;
194 type CreatePermit = (ActiveSlot, PoolSlot);
195
196 async fn get_conn(
197 &self,
198 id: &ID,
199 ) -> Result<ConnectionResult<Self::Connection, Self::CreatePermit>, OpaqueError> {
200 let active_slot = ActiveSlot(
201 self.active_slots
202 .clone()
203 .acquire_owned()
204 .await
205 .context("get active pool slot")?,
206 );
207
208 let mut storage = self.storage.lock();
209 let pooled_conn = {
210 storage
211 .iter()
212 .position(|stored| &stored.id == id)
213 .and_then(|idx| storage.remove(idx))
214 };
215
216 if let Some(pooled_conn) = pooled_conn {
217 return Ok(ConnectionResult::Connection(LeasedConnection {
218 active_slot,
219 pooled_conn: Some(pooled_conn),
220 returner: self.returner(),
221 }));
222 }
223
224 let pool_slot = match self.total_slots.clone().try_acquire_owned() {
225 Ok(permit) => PoolSlot(permit),
226 Err(_) => {
227 storage
229 .pop_back()
230 .context("get least recently used connection from storage")?
231 .pool_slot
232 }
233 };
234
235 Ok(ConnectionResult::CreatePermit((active_slot, pool_slot)))
236 }
237
238 async fn create(&self, id: ID, conn: C, permit: Self::CreatePermit) -> Self::Connection {
239 let (active_slot, pool_slot) = permit;
240 LeasedConnection {
241 active_slot,
242 returner: self.returner(),
243 pooled_conn: Some(PooledConnection {
244 id,
245 conn,
246 pool_slot,
247 }),
248 }
249 }
250}
251
252#[expect(dead_code)]
253#[derive(Debug)]
254pub struct ActiveSlot(OwnedSemaphorePermit);
257
258#[expect(dead_code)]
259#[derive(Debug)]
260pub struct PoolSlot(OwnedSemaphorePermit);
264
265impl<C: Debug, ID: Debug> Debug for PooledConnection<C, ID> {
266 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 f.debug_struct("PooledConnection")
268 .field("conn", &self.conn)
269 .field("id", &self.id)
270 .field("pool_slot", &self.pool_slot)
271 .finish()
272 }
273}
274
275impl<C, ID> Debug for LeasedConnection<C, ID>
276where
277 C: Debug,
278 ID: Debug,
279{
280 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281 f.debug_struct("LeasedConnection")
282 .field("pooled_conn", &self.pooled_conn)
283 .field("active_slot", &self.active_slot)
284 .finish()
285 }
286}
287
288impl<C, ID> Deref for LeasedConnection<C, ID> {
289 type Target = C;
290
291 fn deref(&self) -> &Self::Target {
292 &self
293 .pooled_conn
294 .as_ref()
295 .expect("only None after drop")
296 .conn
297 }
298}
299
300impl<C, ID> DerefMut for LeasedConnection<C, ID> {
301 fn deref_mut(&mut self) -> &mut Self::Target {
302 &mut self
303 .pooled_conn
304 .as_mut()
305 .expect("only None after drop")
306 .conn
307 }
308}
309
310impl<C, ID> AsRef<C> for LeasedConnection<C, ID> {
311 fn as_ref(&self) -> &C {
312 self
313 }
314}
315
316impl<C, ID> AsMut<C> for LeasedConnection<C, ID> {
317 fn as_mut(&mut self) -> &mut C {
318 self
319 }
320}
321
322impl<C, ID> Drop for LeasedConnection<C, ID> {
323 fn drop(&mut self) {
324 if let Some(returner) = self.returner.upgrade() {
325 if let Some(pooled_conn) = self.pooled_conn.take() {
326 (returner)(pooled_conn);
327 }
328 }
329 }
330}
331
332impl<C, ID> Socket for LeasedConnection<C, ID>
336where
337 ID: Send + Sync + 'static,
338 C: Socket,
339{
340 fn local_addr(&self) -> std::io::Result<SocketAddr> {
341 self.as_ref().local_addr()
342 }
343
344 fn peer_addr(&self) -> std::io::Result<SocketAddr> {
345 self.as_ref().peer_addr()
346 }
347}
348
349impl<C, ID> AsyncWrite for LeasedConnection<C, ID>
350where
351 C: AsyncWrite + Unpin,
352 ID: Unpin,
353{
354 fn poll_write(
355 mut self: std::pin::Pin<&mut Self>,
356 cx: &mut std::task::Context<'_>,
357 buf: &[u8],
358 ) -> std::task::Poll<Result<usize, std::io::Error>> {
359 Pin::new(self.deref_mut().as_mut()).poll_write(cx, buf)
360 }
361
362 fn poll_flush(
363 mut self: std::pin::Pin<&mut Self>,
364 cx: &mut std::task::Context<'_>,
365 ) -> std::task::Poll<Result<(), std::io::Error>> {
366 Pin::new(self.deref_mut().as_mut()).poll_flush(cx)
367 }
368
369 fn poll_shutdown(
370 mut self: std::pin::Pin<&mut Self>,
371 cx: &mut std::task::Context<'_>,
372 ) -> std::task::Poll<Result<(), std::io::Error>> {
373 Pin::new(self.deref_mut().as_mut()).poll_shutdown(cx)
374 }
375
376 fn is_write_vectored(&self) -> bool {
377 self.deref().is_write_vectored()
378 }
379
380 fn poll_write_vectored(
381 mut self: Pin<&mut Self>,
382 cx: &mut std::task::Context<'_>,
383 bufs: &[std::io::IoSlice<'_>],
384 ) -> std::task::Poll<Result<usize, std::io::Error>> {
385 Pin::new(self.deref_mut().as_mut()).poll_write_vectored(cx, bufs)
386 }
387}
388
389impl<C, ID> AsyncRead for LeasedConnection<C, ID>
390where
391 C: AsyncRead + Unpin,
392 ID: Unpin,
393{
394 fn poll_read(
395 mut self: Pin<&mut Self>,
396 cx: &mut std::task::Context<'_>,
397 buf: &mut tokio::io::ReadBuf<'_>,
398 ) -> std::task::Poll<std::io::Result<()>> {
399 Pin::new(self.deref_mut().as_mut()).poll_read(cx, buf)
400 }
401}
402
403impl<State, Request, C, ID> Service<State, Request> for LeasedConnection<C, ID>
404where
405 ID: Send + Sync + 'static,
406 C: Service<State, Request>,
407{
408 type Response = C::Response;
409 type Error = C::Error;
410
411 fn serve(
412 &self,
413 ctx: Context<State>,
414 req: Request,
415 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
416 self.as_ref().serve(ctx, req)
417 }
418}
419
420impl<C, ID: Debug> Debug for FiFoReuseLruDropPool<C, ID> {
421 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
422 f.debug_list()
423 .entries(self.storage.lock().iter().map(|item| &item.id))
424 .finish()
425 }
426}
427
428pub trait ReqToConnID<State, Request>: Sized + Clone + Send + Sync + 'static {
433 type ID: Send + Sync + PartialEq + Clone + 'static;
434
435 fn id(&self, ctx: &Context<State>, request: &Request) -> Result<Self::ID, OpaqueError>;
436}
437
438impl<State, Request, ID, F> ReqToConnID<State, Request> for F
439where
440 F: Fn(&Context<State>, &Request) -> Result<ID, OpaqueError> + Clone + Send + Sync + 'static,
441 ID: Send + Sync + PartialEq + Clone + 'static,
442{
443 type ID = ID;
444
445 fn id(&self, ctx: &Context<State>, request: &Request) -> Result<Self::ID, OpaqueError> {
446 self(ctx, request)
447 }
448}
449
450pub struct PooledConnector<S, P, R> {
451 inner: S,
452 pool: P,
453 req_to_conn_id: R,
454 wait_for_pool_timeout: Option<Duration>,
455}
456
457impl<S, P, R> PooledConnector<S, P, R> {
458 pub fn new(inner: S, pool: P, req_to_conn_id: R) -> PooledConnector<S, P, R> {
459 PooledConnector {
460 inner,
461 pool,
462 req_to_conn_id,
463 wait_for_pool_timeout: None,
464 }
465 }
466
467 generate_field_setters!(wait_for_pool_timeout, Duration);
468}
469
470impl<State, Request, S, P, R> Service<State, Request> for PooledConnector<S, P, R>
471where
472 S: ConnectorService<State, Request, Connection: Send, Error: Send + 'static>,
473 State: Send + Sync + 'static,
474 Request: Send + 'static,
475 P: Pool<S::Connection, R::ID>,
476 R: ReqToConnID<State, Request>,
477{
478 type Response = EstablishedClientConnection<P::Connection, State, Request>;
479 type Error = BoxError;
480
481 async fn serve(
482 &self,
483 ctx: Context<State>,
484 req: Request,
485 ) -> Result<Self::Response, Self::Error> {
486 let conn_id = self.req_to_conn_id.id(&ctx, &req)?;
487
488 let create_permit = {
491 let pool = ctx.get::<P>().unwrap_or(&self.pool);
492
493 let pool_result = if let Some(duration) = self.wait_for_pool_timeout {
494 timeout(duration, pool.get_conn(&conn_id))
495 .await
496 .map_err(OpaqueError::from_std)?
497 } else {
498 pool.get_conn(&conn_id).await
499 };
500
501 match pool_result? {
502 ConnectionResult::Connection(c) => {
503 return Ok(EstablishedClientConnection { ctx, conn: c, req });
504 }
505 ConnectionResult::CreatePermit(permit) => permit,
506 }
507 };
508
509 let EstablishedClientConnection { ctx, req, conn } =
510 self.inner.connect(ctx, req).await.map_err(Into::into)?;
511
512 let pool = ctx.get::<P>().unwrap_or(&self.pool);
513 let conn = pool.create(conn_id, conn, create_permit).await;
514 Ok(EstablishedClientConnection { ctx, req, conn })
515 }
516}
517
518pub struct PooledConnectorLayer<P, R> {
519 pool: P,
520 req_to_conn_id: R,
521 wait_for_pool_timeout: Option<Duration>,
522}
523
524impl<P, R> PooledConnectorLayer<P, R> {
525 pub fn new(pool: P, req_to_conn_id: R) -> Self {
526 Self {
527 pool,
528 req_to_conn_id,
529 wait_for_pool_timeout: None,
530 }
531 }
532
533 generate_field_setters!(wait_for_pool_timeout, Duration);
534}
535
536impl<S, P: Clone, R: Clone> Layer<S> for PooledConnectorLayer<P, R> {
537 type Service = PooledConnector<S, P, R>;
538
539 fn layer(&self, inner: S) -> Self::Service {
540 PooledConnector::new(inner, self.pool.clone(), self.req_to_conn_id.clone())
541 .maybe_with_wait_for_pool_timeout(self.wait_for_pool_timeout)
542 }
543
544 fn into_layer(self, inner: S) -> Self::Service {
545 PooledConnector::new(inner, self.pool, self.req_to_conn_id)
546 .maybe_with_wait_for_pool_timeout(self.wait_for_pool_timeout)
547 }
548}
549
550#[cfg(feature = "http")]
551pub mod http {
552 use std::time::Duration;
553
554 use super::{FiFoReuseLruDropPool, PooledConnector, ReqToConnID};
555 use crate::{Protocol, address::Authority, client::pool::OpaqueError, http::RequestContext};
556 use rama_core::Context;
557 use rama_http_types::Request;
558
559 #[derive(Clone, Debug, Default)]
560 #[non_exhaustive]
561 pub struct BasicHttpConnIdentifier;
563
564 pub type BasicHttpConId = (Protocol, Authority);
565
566 impl<State, Body> ReqToConnID<State, Request<Body>> for BasicHttpConnIdentifier {
567 type ID = BasicHttpConId;
568
569 fn id(&self, ctx: &Context<State>, req: &Request<Body>) -> Result<Self::ID, OpaqueError> {
570 let req_ctx = match ctx.get::<RequestContext>() {
571 Some(ctx) => ctx,
572 None => &RequestContext::try_from((ctx, req))?,
573 };
574
575 Ok((req_ctx.protocol.clone(), req_ctx.authority.clone()))
576 }
577 }
578
579 pub struct HttpPooledConnectorBuilder {
580 max_total: usize,
581 max_active: usize,
582 wait_for_pool_timeout: Option<Duration>,
583 }
584
585 impl Default for HttpPooledConnectorBuilder {
586 fn default() -> Self {
587 Self {
588 max_total: 100,
589 max_active: 20,
590 wait_for_pool_timeout: None,
591 }
592 }
593 }
594
595 impl HttpPooledConnectorBuilder {
596 pub fn new() -> Self {
597 Self::default()
598 }
599
600 pub fn max_total(mut self, max: usize) -> Self {
605 self.max_total = max;
606 self
607 }
608
609 pub fn max_active(mut self, max: usize) -> Self {
614 self.max_active = max;
615 self
616 }
617
618 pub fn with_wait_for_pool_timeout(mut self, duration: Duration) -> Self {
621 self.wait_for_pool_timeout = Some(duration);
622 self
623 }
624
625 pub fn maybe_with_wait_for_pool_timeout(mut self, duration: Option<Duration>) -> Self {
626 self.wait_for_pool_timeout = duration;
627 self
628 }
629
630 pub fn build<C, S>(
631 self,
632 inner: S,
633 ) -> Result<
634 PooledConnector<S, FiFoReuseLruDropPool<C, BasicHttpConId>, BasicHttpConnIdentifier>,
635 OpaqueError,
636 > {
637 let pool = FiFoReuseLruDropPool::new(self.max_active, self.max_total)?;
638 Ok(PooledConnector::new(inner, pool, BasicHttpConnIdentifier)
639 .maybe_with_wait_for_pool_timeout(self.wait_for_pool_timeout))
640 }
641 }
642}
643
644#[cfg(test)]
645mod tests {
646 use super::*;
647 use crate::client::EstablishedClientConnection;
648 use rama_core::{Context, Service};
649 use std::{
650 convert::Infallible,
651 sync::atomic::{AtomicI16, Ordering},
652 };
653 use tokio_test::assert_err;
654
655 struct TestService {
656 pub created_connection: AtomicI16,
657 }
658
659 impl Default for TestService {
660 fn default() -> Self {
661 Self {
662 created_connection: AtomicI16::new(0),
663 }
664 }
665 }
666
667 impl<State, Request> Service<State, Request> for TestService
668 where
669 State: Clone + Send + Sync + 'static,
670 Request: Send + 'static,
671 {
672 type Response = EstablishedClientConnection<Vec<u32>, State, Request>;
673 type Error = Infallible;
674
675 async fn serve(
676 &self,
677 ctx: Context<State>,
678 req: Request,
679 ) -> Result<Self::Response, Self::Error> {
680 let conn = vec![];
681 self.created_connection.fetch_add(1, Ordering::Relaxed);
682 Ok(EstablishedClientConnection { ctx, req, conn })
683 }
684 }
685
686 #[derive(Clone)]
687 struct StringRequestLengthID;
691
692 impl<State> ReqToConnID<State, String> for StringRequestLengthID {
693 type ID = usize;
694
695 fn id(&self, _ctx: &Context<State>, req: &String) -> Result<Self::ID, OpaqueError> {
696 Ok(req.chars().count())
697 }
698 }
699
700 #[tokio::test]
701 async fn test_should_reuse_connections() {
702 let pool = FiFoReuseLruDropPool::new(5, 10).unwrap();
703 let svc = PooledConnector::new(
706 TestService::default(),
707 pool,
708 |_ctx: &Context<()>, _req: &String| Ok(()),
709 );
710
711 let iterations = 10;
712 for _i in 0..iterations {
713 let _conn = svc
714 .connect(Context::default(), String::new())
715 .await
716 .unwrap();
717 }
718
719 let created_connection = svc.inner.created_connection.load(Ordering::Relaxed);
720 assert_eq!(created_connection, 1);
721 }
722
723 #[tokio::test]
724 async fn test_conn_id_to_separate() {
725 let pool = FiFoReuseLruDropPool::new(5, 10).unwrap();
726 let svc = PooledConnector::new(TestService::default(), pool, StringRequestLengthID {});
727
728 {
729 let mut conn = svc
730 .connect(Context::default(), String::from("a"))
731 .await
732 .unwrap()
733 .conn;
734
735 conn.push(1);
736 assert_eq!(conn.as_ref(), &vec![1]);
737 assert_eq!(svc.inner.created_connection.load(Ordering::Relaxed), 1);
738 }
739
740 {
742 let mut conn = svc
743 .connect(Context::default(), String::from("B"))
744 .await
745 .unwrap()
746 .conn;
747
748 conn.push(2);
749 assert_eq!(conn.as_ref(), &vec![1, 2]);
750 assert_eq!(svc.inner.created_connection.load(Ordering::Relaxed), 1);
751 }
752
753 {
755 let mut conn = svc
756 .connect(Context::default(), String::from("aa"))
757 .await
758 .unwrap()
759 .conn;
760
761 conn.push(3);
762 assert_eq!(conn.as_ref(), &vec![3]);
763 assert_eq!(svc.inner.created_connection.load(Ordering::Relaxed), 2);
764 }
765
766 {
768 let mut conn = svc
769 .connect(Context::default(), String::from("bb"))
770 .await
771 .unwrap()
772 .conn;
773
774 conn.push(4);
775 assert_eq!(conn.as_ref(), &vec![3, 4]);
776 assert_eq!(svc.inner.created_connection.load(Ordering::Relaxed), 2);
777 }
778 }
779
780 #[tokio::test]
781 async fn test_pool_max_size() {
782 let pool = FiFoReuseLruDropPool::new(1, 1).unwrap();
783 let svc = PooledConnector::new(TestService::default(), pool, StringRequestLengthID {})
784 .with_wait_for_pool_timeout(Duration::from_millis(50));
785
786 let conn1 = svc
787 .connect(Context::default(), String::from("a"))
788 .await
789 .unwrap();
790
791 let conn2 = svc.connect(Context::default(), String::from("a")).await;
792 assert_err!(conn2);
793
794 drop(conn1);
795 let _conn3 = svc
796 .connect(Context::default(), String::from("aaa"))
797 .await
798 .unwrap();
799 }
800}