1use super::conn::{ConnectorService, EstablishedClientConnection};
2use crate::stream::Socket;
3use parking_lot::Mutex;
4use rama_core::error::{BoxError, ErrorContext, OpaqueError};
5use rama_core::{Context, Layer, Service};
6use rama_utils::macros::generate_field_setters;
7use std::collections::VecDeque;
8use std::fmt::Debug;
9use std::ops::{Deref, DerefMut};
10use std::pin::Pin;
11use std::sync::OnceLock;
12use std::sync::{Arc, Weak};
13use std::time::Duration;
14use std::{future::Future, net::SocketAddr};
15use tokio::io::{AsyncRead, AsyncWrite};
16use tokio::sync::{OwnedSemaphorePermit, Semaphore};
17use tokio::time::timeout;
18
19pub trait Pool<C, ID>: Send + Sync + 'static {
24    type Connection: Send;
25    type CreatePermit: Send;
26
27    fn get_conn(
33        &self,
34        id: &ID,
35    ) -> impl Future<
36        Output = Result<ConnectionResult<Self::Connection, Self::CreatePermit>, OpaqueError>,
37    > + Send;
38
39    fn create(
44        &self,
45        id: ID,
46        conn: C,
47        create_permit: Self::CreatePermit,
48    ) -> impl Future<Output = Self::Connection> + Send;
49}
50
51pub enum ConnectionResult<C, P> {
53    Connection(C),
55    CreatePermit(P),
58}
59
60impl<C: Debug, P: Debug> Debug for ConnectionResult<C, P> {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        match self {
63            Self::Connection(arg0) => f.debug_tuple("Connection").field(arg0).finish(),
64            Self::CreatePermit(arg0) => f.debug_tuple("CreatePermit").field(arg0).finish(),
65        }
66    }
67}
68
69#[derive(Debug, Clone, Default)]
70#[non_exhaustive]
71pub struct NoPool;
77
78impl<C, ID> Pool<C, ID> for NoPool
79where
80    C: Send + 'static,
81    ID: Clone + Send + Sync + PartialEq + 'static,
82{
83    type Connection = C;
84    type CreatePermit = ();
85
86    async fn get_conn(
87        &self,
88        _id: &ID,
89    ) -> Result<ConnectionResult<Self::Connection, Self::CreatePermit>, OpaqueError> {
90        Ok(ConnectionResult::CreatePermit(()))
91    }
92
93    async fn create(&self, _id: ID, conn: C, _permit: Self::CreatePermit) -> Self::Connection {
94        conn
95    }
96}
97
98pub struct LeasedConnection<C, ID> {
105    pooled_conn: Option<PooledConnection<C, ID>>,
106    active_slot: ActiveSlot,
107    returner: Weak<dyn Fn(PooledConnection<C, ID>) + Send + Sync>,
108}
109
110impl<C, ID> LeasedConnection<C, ID> {
111    pub fn into_connection(mut self) -> C {
112        self.pooled_conn.take().expect("only None after drop").conn
113    }
114}
115
116struct PooledConnection<C, ID> {
121    conn: C,
122    id: ID,
123    pool_slot: PoolSlot,
124}
125
126pub struct FiFoReuseLruDropPool<C, ID> {
128    storage: Arc<Mutex<VecDeque<PooledConnection<C, ID>>>>,
129    total_slots: Arc<Semaphore>,
130    active_slots: Arc<Semaphore>,
131    returner: OnceLock<Arc<dyn Fn(PooledConnection<C, ID>) + Send + Sync>>,
132}
133
134impl<C, ID> Clone for FiFoReuseLruDropPool<C, ID> {
135    fn clone(&self) -> Self {
136        Self {
137            storage: self.storage.clone(),
138            total_slots: self.total_slots.clone(),
139            active_slots: self.active_slots.clone(),
140            returner: self.returner.clone(),
141        }
142    }
143}
144
145impl<C, ID> FiFoReuseLruDropPool<C, ID> {
146    pub fn new(max_active: usize, max_total: usize) -> Result<Self, OpaqueError> {
147        if max_active == 0 || max_total == 0 {
148            return Err(OpaqueError::from_display(
149                "max_active or max_total of 0 will make this pool unusable",
150            ));
151        }
152        if max_active > max_total {
153            return Err(OpaqueError::from_display(
154                "max_active should be smaller or equal to max_total",
155            ));
156        }
157        let storage = Arc::new(Mutex::new(VecDeque::with_capacity(max_total)));
158        Ok(Self {
159            storage,
160            returner: OnceLock::new(),
161            total_slots: Arc::new(Semaphore::const_new(max_total)),
162            active_slots: Arc::new(Semaphore::const_new(max_active)),
163        })
164    }
165}
166
167impl<C, ID> FiFoReuseLruDropPool<C, ID>
168where
169    C: Send + 'static,
170    ID: Send + 'static,
171{
172    fn returner(&self) -> Weak<dyn Fn(PooledConnection<C, ID>) + Send + Sync> {
175        let returner = self.returner.get_or_init(|| {
176            let weak_storage = Arc::downgrade(&self.storage);
177            Arc::new(move |conn| {
178                if let Some(storage) = weak_storage.upgrade() {
179                    storage.lock().push_front(conn)
180                }
181            })
182        });
183
184        Arc::downgrade(returner)
185    }
186}
187
188impl<C, ID> Pool<C, ID> for FiFoReuseLruDropPool<C, ID>
189where
190    C: Send + 'static,
191    ID: Clone + Send + Sync + PartialEq + 'static,
192{
193    type Connection = LeasedConnection<C, ID>;
194    type CreatePermit = (ActiveSlot, PoolSlot);
195
196    async fn get_conn(
197        &self,
198        id: &ID,
199    ) -> Result<ConnectionResult<Self::Connection, Self::CreatePermit>, OpaqueError> {
200        let active_slot = ActiveSlot(
201            self.active_slots
202                .clone()
203                .acquire_owned()
204                .await
205                .context("get active pool slot")?,
206        );
207
208        let mut storage = self.storage.lock();
209        let pooled_conn = {
210            storage
211                .iter()
212                .position(|stored| &stored.id == id)
213                .and_then(|idx| storage.remove(idx))
214        };
215
216        if let Some(pooled_conn) = pooled_conn {
217            return Ok(ConnectionResult::Connection(LeasedConnection {
218                active_slot,
219                pooled_conn: Some(pooled_conn),
220                returner: self.returner(),
221            }));
222        }
223
224        let pool_slot = match self.total_slots.clone().try_acquire_owned() {
225            Ok(permit) => PoolSlot(permit),
226            Err(_) => {
227                storage
229                    .pop_back()
230                    .context("get least recently used connection from storage")?
231                    .pool_slot
232            }
233        };
234
235        Ok(ConnectionResult::CreatePermit((active_slot, pool_slot)))
236    }
237
238    async fn create(&self, id: ID, conn: C, permit: Self::CreatePermit) -> Self::Connection {
239        let (active_slot, pool_slot) = permit;
240        LeasedConnection {
241            active_slot,
242            returner: self.returner(),
243            pooled_conn: Some(PooledConnection {
244                id,
245                conn,
246                pool_slot,
247            }),
248        }
249    }
250}
251
252#[expect(dead_code)]
253#[derive(Debug)]
254pub struct ActiveSlot(OwnedSemaphorePermit);
257
258#[expect(dead_code)]
259#[derive(Debug)]
260pub struct PoolSlot(OwnedSemaphorePermit);
264
265impl<C: Debug, ID: Debug> Debug for PooledConnection<C, ID> {
266    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267        f.debug_struct("PooledConnection")
268            .field("conn", &self.conn)
269            .field("id", &self.id)
270            .field("pool_slot", &self.pool_slot)
271            .finish()
272    }
273}
274
275impl<C, ID> Debug for LeasedConnection<C, ID>
276where
277    C: Debug,
278    ID: Debug,
279{
280    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281        f.debug_struct("LeasedConnection")
282            .field("pooled_conn", &self.pooled_conn)
283            .field("active_slot", &self.active_slot)
284            .finish()
285    }
286}
287
288impl<C, ID> Deref for LeasedConnection<C, ID> {
289    type Target = C;
290
291    fn deref(&self) -> &Self::Target {
292        &self
293            .pooled_conn
294            .as_ref()
295            .expect("only None after drop")
296            .conn
297    }
298}
299
300impl<C, ID> DerefMut for LeasedConnection<C, ID> {
301    fn deref_mut(&mut self) -> &mut Self::Target {
302        &mut self
303            .pooled_conn
304            .as_mut()
305            .expect("only None after drop")
306            .conn
307    }
308}
309
310impl<C, ID> AsRef<C> for LeasedConnection<C, ID> {
311    fn as_ref(&self) -> &C {
312        self
313    }
314}
315
316impl<C, ID> AsMut<C> for LeasedConnection<C, ID> {
317    fn as_mut(&mut self) -> &mut C {
318        self
319    }
320}
321
322impl<C, ID> Drop for LeasedConnection<C, ID> {
323    fn drop(&mut self) {
324        if let Some(returner) = self.returner.upgrade() {
325            if let Some(pooled_conn) = self.pooled_conn.take() {
326                (returner)(pooled_conn);
327            }
328        }
329    }
330}
331
332impl<C, ID> Socket for LeasedConnection<C, ID>
336where
337    ID: Send + Sync + 'static,
338    C: Socket,
339{
340    fn local_addr(&self) -> std::io::Result<SocketAddr> {
341        self.as_ref().local_addr()
342    }
343
344    fn peer_addr(&self) -> std::io::Result<SocketAddr> {
345        self.as_ref().peer_addr()
346    }
347}
348
349impl<C, ID> AsyncWrite for LeasedConnection<C, ID>
350where
351    C: AsyncWrite + Unpin,
352    ID: Unpin,
353{
354    fn poll_write(
355        mut self: std::pin::Pin<&mut Self>,
356        cx: &mut std::task::Context<'_>,
357        buf: &[u8],
358    ) -> std::task::Poll<Result<usize, std::io::Error>> {
359        Pin::new(self.deref_mut().as_mut()).poll_write(cx, buf)
360    }
361
362    fn poll_flush(
363        mut self: std::pin::Pin<&mut Self>,
364        cx: &mut std::task::Context<'_>,
365    ) -> std::task::Poll<Result<(), std::io::Error>> {
366        Pin::new(self.deref_mut().as_mut()).poll_flush(cx)
367    }
368
369    fn poll_shutdown(
370        mut self: std::pin::Pin<&mut Self>,
371        cx: &mut std::task::Context<'_>,
372    ) -> std::task::Poll<Result<(), std::io::Error>> {
373        Pin::new(self.deref_mut().as_mut()).poll_shutdown(cx)
374    }
375
376    fn is_write_vectored(&self) -> bool {
377        self.deref().is_write_vectored()
378    }
379
380    fn poll_write_vectored(
381        mut self: Pin<&mut Self>,
382        cx: &mut std::task::Context<'_>,
383        bufs: &[std::io::IoSlice<'_>],
384    ) -> std::task::Poll<Result<usize, std::io::Error>> {
385        Pin::new(self.deref_mut().as_mut()).poll_write_vectored(cx, bufs)
386    }
387}
388
389impl<C, ID> AsyncRead for LeasedConnection<C, ID>
390where
391    C: AsyncRead + Unpin,
392    ID: Unpin,
393{
394    fn poll_read(
395        mut self: Pin<&mut Self>,
396        cx: &mut std::task::Context<'_>,
397        buf: &mut tokio::io::ReadBuf<'_>,
398    ) -> std::task::Poll<std::io::Result<()>> {
399        Pin::new(self.deref_mut().as_mut()).poll_read(cx, buf)
400    }
401}
402
403impl<State, Request, C, ID> Service<State, Request> for LeasedConnection<C, ID>
404where
405    ID: Send + Sync + 'static,
406    C: Service<State, Request>,
407{
408    type Response = C::Response;
409    type Error = C::Error;
410
411    fn serve(
412        &self,
413        ctx: Context<State>,
414        req: Request,
415    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
416        self.as_ref().serve(ctx, req)
417    }
418}
419
420impl<C, ID: Debug> Debug for FiFoReuseLruDropPool<C, ID> {
421    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
422        f.debug_list()
423            .entries(self.storage.lock().iter().map(|item| &item.id))
424            .finish()
425    }
426}
427
428pub trait ReqToConnID<State, Request>: Sized + Clone + Send + Sync + 'static {
433    type ID: Send + Sync + PartialEq + Clone + 'static;
434
435    fn id(&self, ctx: &Context<State>, request: &Request) -> Result<Self::ID, OpaqueError>;
436}
437
438impl<State, Request, ID, F> ReqToConnID<State, Request> for F
439where
440    F: Fn(&Context<State>, &Request) -> Result<ID, OpaqueError> + Clone + Send + Sync + 'static,
441    ID: Send + Sync + PartialEq + Clone + 'static,
442{
443    type ID = ID;
444
445    fn id(&self, ctx: &Context<State>, request: &Request) -> Result<Self::ID, OpaqueError> {
446        self(ctx, request)
447    }
448}
449
450pub struct PooledConnector<S, P, R> {
451    inner: S,
452    pool: P,
453    req_to_conn_id: R,
454    wait_for_pool_timeout: Option<Duration>,
455}
456
457impl<S, P, R> PooledConnector<S, P, R> {
458    pub fn new(inner: S, pool: P, req_to_conn_id: R) -> PooledConnector<S, P, R> {
459        PooledConnector {
460            inner,
461            pool,
462            req_to_conn_id,
463            wait_for_pool_timeout: None,
464        }
465    }
466
467    generate_field_setters!(wait_for_pool_timeout, Duration);
468}
469
470impl<State, Request, S, P, R> Service<State, Request> for PooledConnector<S, P, R>
471where
472    S: ConnectorService<State, Request, Connection: Send, Error: Send + 'static>,
473    State: Send + Sync + 'static,
474    Request: Send + 'static,
475    P: Pool<S::Connection, R::ID>,
476    R: ReqToConnID<State, Request>,
477{
478    type Response = EstablishedClientConnection<P::Connection, State, Request>;
479    type Error = BoxError;
480
481    async fn serve(
482        &self,
483        ctx: Context<State>,
484        req: Request,
485    ) -> Result<Self::Response, Self::Error> {
486        let conn_id = self.req_to_conn_id.id(&ctx, &req)?;
487
488        let create_permit = {
491            let pool = ctx.get::<P>().unwrap_or(&self.pool);
492
493            let pool_result = if let Some(duration) = self.wait_for_pool_timeout {
494                timeout(duration, pool.get_conn(&conn_id))
495                    .await
496                    .map_err(OpaqueError::from_std)?
497            } else {
498                pool.get_conn(&conn_id).await
499            };
500
501            match pool_result? {
502                ConnectionResult::Connection(c) => {
503                    return Ok(EstablishedClientConnection { ctx, conn: c, req });
504                }
505                ConnectionResult::CreatePermit(permit) => permit,
506            }
507        };
508
509        let EstablishedClientConnection { ctx, req, conn } =
510            self.inner.connect(ctx, req).await.map_err(Into::into)?;
511
512        let pool = ctx.get::<P>().unwrap_or(&self.pool);
513        let conn = pool.create(conn_id, conn, create_permit).await;
514        Ok(EstablishedClientConnection { ctx, req, conn })
515    }
516}
517
518pub struct PooledConnectorLayer<P, R> {
519    pool: P,
520    req_to_conn_id: R,
521    wait_for_pool_timeout: Option<Duration>,
522}
523
524impl<P, R> PooledConnectorLayer<P, R> {
525    pub fn new(pool: P, req_to_conn_id: R) -> Self {
526        Self {
527            pool,
528            req_to_conn_id,
529            wait_for_pool_timeout: None,
530        }
531    }
532
533    generate_field_setters!(wait_for_pool_timeout, Duration);
534}
535
536impl<S, P: Clone, R: Clone> Layer<S> for PooledConnectorLayer<P, R> {
537    type Service = PooledConnector<S, P, R>;
538
539    fn layer(&self, inner: S) -> Self::Service {
540        PooledConnector::new(inner, self.pool.clone(), self.req_to_conn_id.clone())
541            .maybe_with_wait_for_pool_timeout(self.wait_for_pool_timeout)
542    }
543
544    fn into_layer(self, inner: S) -> Self::Service {
545        PooledConnector::new(inner, self.pool, self.req_to_conn_id)
546            .maybe_with_wait_for_pool_timeout(self.wait_for_pool_timeout)
547    }
548}
549
550#[cfg(feature = "http")]
551pub mod http {
552    use std::time::Duration;
553
554    use super::{FiFoReuseLruDropPool, PooledConnector, ReqToConnID};
555    use crate::{Protocol, address::Authority, client::pool::OpaqueError, http::RequestContext};
556    use rama_core::Context;
557    use rama_http_types::Request;
558
559    #[derive(Clone, Debug, Default)]
560    #[non_exhaustive]
561    pub struct BasicHttpConnIdentifier;
563
564    pub type BasicHttpConId = (Protocol, Authority);
565
566    impl<State, Body> ReqToConnID<State, Request<Body>> for BasicHttpConnIdentifier {
567        type ID = BasicHttpConId;
568
569        fn id(&self, ctx: &Context<State>, req: &Request<Body>) -> Result<Self::ID, OpaqueError> {
570            let req_ctx = match ctx.get::<RequestContext>() {
571                Some(ctx) => ctx,
572                None => &RequestContext::try_from((ctx, req))?,
573            };
574
575            Ok((req_ctx.protocol.clone(), req_ctx.authority.clone()))
576        }
577    }
578
579    pub struct HttpPooledConnectorBuilder {
580        max_total: usize,
581        max_active: usize,
582        wait_for_pool_timeout: Option<Duration>,
583    }
584
585    impl Default for HttpPooledConnectorBuilder {
586        fn default() -> Self {
587            Self {
588                max_total: 100,
589                max_active: 20,
590                wait_for_pool_timeout: None,
591            }
592        }
593    }
594
595    impl HttpPooledConnectorBuilder {
596        pub fn new() -> Self {
597            Self::default()
598        }
599
600        pub fn max_total(mut self, max: usize) -> Self {
605            self.max_total = max;
606            self
607        }
608
609        pub fn max_active(mut self, max: usize) -> Self {
614            self.max_active = max;
615            self
616        }
617
618        pub fn with_wait_for_pool_timeout(mut self, duration: Duration) -> Self {
621            self.wait_for_pool_timeout = Some(duration);
622            self
623        }
624
625        pub fn maybe_with_wait_for_pool_timeout(mut self, duration: Option<Duration>) -> Self {
626            self.wait_for_pool_timeout = duration;
627            self
628        }
629
630        pub fn build<C, S>(
631            self,
632            inner: S,
633        ) -> Result<
634            PooledConnector<S, FiFoReuseLruDropPool<C, BasicHttpConId>, BasicHttpConnIdentifier>,
635            OpaqueError,
636        > {
637            let pool = FiFoReuseLruDropPool::new(self.max_active, self.max_total)?;
638            Ok(PooledConnector::new(inner, pool, BasicHttpConnIdentifier)
639                .maybe_with_wait_for_pool_timeout(self.wait_for_pool_timeout))
640        }
641    }
642}
643
644#[cfg(test)]
645mod tests {
646    use super::*;
647    use crate::client::EstablishedClientConnection;
648    use rama_core::{Context, Service};
649    use std::{
650        convert::Infallible,
651        sync::atomic::{AtomicI16, Ordering},
652    };
653    use tokio_test::assert_err;
654
655    struct TestService {
656        pub created_connection: AtomicI16,
657    }
658
659    impl Default for TestService {
660        fn default() -> Self {
661            Self {
662                created_connection: AtomicI16::new(0),
663            }
664        }
665    }
666
667    impl<State, Request> Service<State, Request> for TestService
668    where
669        State: Clone + Send + Sync + 'static,
670        Request: Send + 'static,
671    {
672        type Response = EstablishedClientConnection<Vec<u32>, State, Request>;
673        type Error = Infallible;
674
675        async fn serve(
676            &self,
677            ctx: Context<State>,
678            req: Request,
679        ) -> Result<Self::Response, Self::Error> {
680            let conn = vec![];
681            self.created_connection.fetch_add(1, Ordering::Relaxed);
682            Ok(EstablishedClientConnection { ctx, req, conn })
683        }
684    }
685
686    #[derive(Clone)]
687    struct StringRequestLengthID;
691
692    impl<State> ReqToConnID<State, String> for StringRequestLengthID {
693        type ID = usize;
694
695        fn id(&self, _ctx: &Context<State>, req: &String) -> Result<Self::ID, OpaqueError> {
696            Ok(req.chars().count())
697        }
698    }
699
700    #[tokio::test]
701    async fn test_should_reuse_connections() {
702        let pool = FiFoReuseLruDropPool::new(5, 10).unwrap();
703        let svc = PooledConnector::new(
706            TestService::default(),
707            pool,
708            |_ctx: &Context<()>, _req: &String| Ok(()),
709        );
710
711        let iterations = 10;
712        for _i in 0..iterations {
713            let _conn = svc
714                .connect(Context::default(), String::new())
715                .await
716                .unwrap();
717        }
718
719        let created_connection = svc.inner.created_connection.load(Ordering::Relaxed);
720        assert_eq!(created_connection, 1);
721    }
722
723    #[tokio::test]
724    async fn test_conn_id_to_separate() {
725        let pool = FiFoReuseLruDropPool::new(5, 10).unwrap();
726        let svc = PooledConnector::new(TestService::default(), pool, StringRequestLengthID {});
727
728        {
729            let mut conn = svc
730                .connect(Context::default(), String::from("a"))
731                .await
732                .unwrap()
733                .conn;
734
735            conn.push(1);
736            assert_eq!(conn.as_ref(), &vec![1]);
737            assert_eq!(svc.inner.created_connection.load(Ordering::Relaxed), 1);
738        }
739
740        {
742            let mut conn = svc
743                .connect(Context::default(), String::from("B"))
744                .await
745                .unwrap()
746                .conn;
747
748            conn.push(2);
749            assert_eq!(conn.as_ref(), &vec![1, 2]);
750            assert_eq!(svc.inner.created_connection.load(Ordering::Relaxed), 1);
751        }
752
753        {
755            let mut conn = svc
756                .connect(Context::default(), String::from("aa"))
757                .await
758                .unwrap()
759                .conn;
760
761            conn.push(3);
762            assert_eq!(conn.as_ref(), &vec![3]);
763            assert_eq!(svc.inner.created_connection.load(Ordering::Relaxed), 2);
764        }
765
766        {
768            let mut conn = svc
769                .connect(Context::default(), String::from("bb"))
770                .await
771                .unwrap()
772                .conn;
773
774            conn.push(4);
775            assert_eq!(conn.as_ref(), &vec![3, 4]);
776            assert_eq!(svc.inner.created_connection.load(Ordering::Relaxed), 2);
777        }
778    }
779
780    #[tokio::test]
781    async fn test_pool_max_size() {
782        let pool = FiFoReuseLruDropPool::new(1, 1).unwrap();
783        let svc = PooledConnector::new(TestService::default(), pool, StringRequestLengthID {})
784            .with_wait_for_pool_timeout(Duration::from_millis(50));
785
786        let conn1 = svc
787            .connect(Context::default(), String::from("a"))
788            .await
789            .unwrap();
790
791        let conn2 = svc.connect(Context::default(), String::from("a")).await;
792        assert_err!(conn2);
793
794        drop(conn1);
795        let _conn3 = svc
796            .connect(Context::default(), String::from("aaa"))
797            .await
798            .unwrap();
799    }
800}