1use super::conn::{ConnectorService, EstablishedClientConnection};
2use crate::stream::Socket;
3use parking_lot::Mutex;
4use rama_core::error::{BoxError, ErrorContext, OpaqueError};
5use rama_core::{Context, Layer, Service};
6use rama_utils::macros::{generate_field_setters, nz};
7use std::collections::VecDeque;
8use std::fmt::Debug;
9use std::num::NonZeroU16;
10use std::ops::{Deref, DerefMut};
11use std::pin::Pin;
12use std::sync::Arc;
13use std::time::Duration;
14use std::{future::Future, net::SocketAddr};
15use tokio::io::{AsyncRead, AsyncWrite};
16use tokio::sync::{OwnedSemaphorePermit, Semaphore};
17use tokio::time::timeout;
18use tracing::trace;
19
20pub trait PoolStorage: Sized + Send + Sync + 'static {
25 type ConnID: PartialEq + Clone + Debug + Send + Sync + 'static;
26 type Connection: Send;
27
28 fn new(capacity: NonZeroU16) -> Self;
32
33 fn add_connection(&self, conn: PooledConnection<Self::Connection, Self::ConnID>);
35
36 fn get_connection(
40 &self,
41 id: &Self::ConnID,
42 ) -> Option<PooledConnection<Self::Connection, Self::ConnID>>;
43
44 fn get_connection_to_drop(
49 &self,
50 ) -> Result<PooledConnection<Self::Connection, Self::ConnID>, OpaqueError>;
51}
52
53#[expect(dead_code)]
54#[derive(Debug)]
55struct ActiveSlot(OwnedSemaphorePermit);
58
59#[expect(dead_code)]
60#[derive(Debug)]
61struct PoolSlot(OwnedSemaphorePermit);
65
66pub struct PooledConnection<C, ConnID> {
70 conn: C,
72 id: ConnID,
74 slot: PoolSlot,
76}
77
78impl<C: Debug, ConnID: Debug> Debug for PooledConnection<C, ConnID> {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 f.debug_struct("PooledConnection")
81 .field("conn", &self.conn)
82 .field("id", &self.id)
83 .field("slot", &self.slot)
84 .finish()
85 }
86}
87
88pub struct LeasedConnection<C, ConnID> {
94 pooled_conn: Option<PooledConnection<C, ConnID>>,
97 returner: Arc<dyn Fn(PooledConnection<C, ConnID>) + Send + Sync>,
99 _slot: ActiveSlot,
101}
102
103impl<C, ConnID> Debug for LeasedConnection<C, ConnID>
104where
105 C: Debug,
106 ConnID: Debug,
107{
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 f.debug_struct("LeasedConnection")
110 .field("pooled_conn", &self.pooled_conn)
111 .field("_slot", &self._slot)
112 .finish()
113 }
114}
115
116impl<C, ConnID> LeasedConnection<C, ConnID> {
117 pub fn take(mut self) -> C {
119 self.pooled_conn.take().expect("only None after drop").conn
120 }
121}
122
123impl<C, ConnID> Deref for LeasedConnection<C, ConnID> {
124 type Target = C;
125
126 fn deref(&self) -> &Self::Target {
127 &self
128 .pooled_conn
129 .as_ref()
130 .expect("only None after drop")
131 .conn
132 }
133}
134
135impl<C, ConnID> DerefMut for LeasedConnection<C, ConnID> {
136 fn deref_mut(&mut self) -> &mut Self::Target {
137 &mut self
138 .pooled_conn
139 .as_mut()
140 .expect("only None after drop")
141 .conn
142 }
143}
144
145impl<C, ConnID> AsRef<C> for LeasedConnection<C, ConnID> {
146 fn as_ref(&self) -> &C {
147 self
148 }
149}
150
151impl<C, ConnID> AsMut<C> for LeasedConnection<C, ConnID> {
152 fn as_mut(&mut self) -> &mut C {
153 self
154 }
155}
156
157impl<C, ConnID> Drop for LeasedConnection<C, ConnID> {
158 fn drop(&mut self) {
159 if let Some(pooled_conn) = self.pooled_conn.take() {
160 (self.returner)(pooled_conn);
161 }
163 }
164}
165
166impl<C, ConnID> Socket for LeasedConnection<C, ConnID>
170where
171 ConnID: Send + Sync + 'static,
172 C: Socket,
173{
174 fn local_addr(&self) -> std::io::Result<SocketAddr> {
175 self.as_ref().local_addr()
176 }
177
178 fn peer_addr(&self) -> std::io::Result<SocketAddr> {
179 self.as_ref().peer_addr()
180 }
181}
182
183impl<C, ConnID> AsyncWrite for LeasedConnection<C, ConnID>
184where
185 C: AsyncWrite + Unpin,
186 ConnID: Unpin,
187{
188 fn poll_write(
189 mut self: std::pin::Pin<&mut Self>,
190 cx: &mut std::task::Context<'_>,
191 buf: &[u8],
192 ) -> std::task::Poll<Result<usize, std::io::Error>> {
193 Pin::new(self.deref_mut().as_mut()).poll_write(cx, buf)
194 }
195
196 fn poll_flush(
197 mut self: std::pin::Pin<&mut Self>,
198 cx: &mut std::task::Context<'_>,
199 ) -> std::task::Poll<Result<(), std::io::Error>> {
200 Pin::new(self.deref_mut().as_mut()).poll_flush(cx)
201 }
202
203 fn poll_shutdown(
204 mut self: std::pin::Pin<&mut Self>,
205 cx: &mut std::task::Context<'_>,
206 ) -> std::task::Poll<Result<(), std::io::Error>> {
207 Pin::new(self.deref_mut().as_mut()).poll_shutdown(cx)
208 }
209
210 fn is_write_vectored(&self) -> bool {
211 self.deref().is_write_vectored()
212 }
213
214 fn poll_write_vectored(
215 mut self: Pin<&mut Self>,
216 cx: &mut std::task::Context<'_>,
217 bufs: &[std::io::IoSlice<'_>],
218 ) -> std::task::Poll<Result<usize, std::io::Error>> {
219 Pin::new(self.deref_mut().as_mut()).poll_write_vectored(cx, bufs)
220 }
221}
222
223impl<C, ConnID> AsyncRead for LeasedConnection<C, ConnID>
224where
225 C: AsyncRead + Unpin,
226 ConnID: Unpin,
227{
228 fn poll_read(
229 mut self: Pin<&mut Self>,
230 cx: &mut std::task::Context<'_>,
231 buf: &mut tokio::io::ReadBuf<'_>,
232 ) -> std::task::Poll<std::io::Result<()>> {
233 Pin::new(self.deref_mut().as_mut()).poll_read(cx, buf)
234 }
235}
236
237impl<State, Request, C, ConnID> Service<State, Request> for LeasedConnection<C, ConnID>
238where
239 ConnID: Send + Sync + 'static,
240 C: Service<State, Request>,
241{
242 type Response = C::Response;
243 type Error = C::Error;
244
245 fn serve(
246 &self,
247 ctx: Context<State>,
248 req: Request,
249 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
250 self.as_ref().serve(ctx, req)
251 }
252}
253
254struct PoolInner<S> {
255 total_slots: Arc<Semaphore>,
256 active_slots: Arc<Semaphore>,
257 storage: S,
258}
259
260impl<S: Debug> Debug for PoolInner<S> {
261 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 f.debug_struct("PoolInner")
263 .field("total_slots", &self.total_slots)
264 .field("active_slots", &self.active_slots)
265 .field("storage", &self.storage)
266 .finish()
267 }
268}
269
270pub struct ConnStoreFiFoReuseLruDrop<C, ConnID>(Arc<Mutex<VecDeque<PooledConnection<C, ConnID>>>>);
273
274impl<C, ConnID> PoolStorage for ConnStoreFiFoReuseLruDrop<C, ConnID>
275where
276 C: Send + 'static,
277 ConnID: PartialEq + Clone + Debug + Send + Sync + 'static,
278{
279 type ConnID = ConnID;
280
281 type Connection = C;
282
283 fn new(capacity: NonZeroU16) -> Self {
284 Self(Arc::new(Mutex::new(VecDeque::with_capacity(
285 Into::<u16>::into(capacity).into(),
286 ))))
287 }
288
289 fn add_connection(&self, conn: PooledConnection<Self::Connection, Self::ConnID>) {
290 trace!(conn_id = ?conn.id, "adding connection back to pool");
291 self.0.lock().push_front(conn);
292 }
293
294 fn get_connection(
295 &self,
296 id: &Self::ConnID,
297 ) -> Option<PooledConnection<Self::Connection, Self::ConnID>> {
298 trace!(conn_id = ?id, "getting connection from pool");
299 let mut connections = self.0.lock();
300 connections
301 .iter()
302 .position(|stored| &stored.id == id)
303 .and_then(|idx| connections.remove(idx))
304 }
305
306 fn get_connection_to_drop(
307 &self,
308 ) -> Result<PooledConnection<Self::Connection, Self::ConnID>, OpaqueError> {
309 trace!("getting connection to drop from pool");
310 self.0.lock().pop_back().context("None, this function should only be called when pool is full, in which case this should always return a connection")
311 }
312}
313
314impl<C, ConnID: Debug> Debug for ConnStoreFiFoReuseLruDrop<C, ConnID> {
315 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 f.debug_list()
317 .entries(self.0.lock().iter().map(|item| &item.id))
318 .finish()
319 }
320}
321
322pub struct Pool<S> {
325 inner: Arc<PoolInner<S>>,
326}
327
328impl<S> Clone for Pool<S> {
329 fn clone(&self) -> Self {
330 Self {
331 inner: self.inner.clone(),
332 }
333 }
334}
335
336impl<S: Debug> Debug for Pool<S> {
337 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338 f.debug_struct("Pool").field("inner", &self.inner).finish()
339 }
340}
341
342impl<C, ConnID> Default for Pool<ConnStoreFiFoReuseLruDrop<C, ConnID>>
343where
344 C: Send + 'static,
345 ConnID: PartialEq + Clone + Debug + Send + Sync + 'static,
346{
347 fn default() -> Self {
348 Self::new(nz!(10), nz!(20)).unwrap()
349 }
350}
351
352pub enum GetConnectionOrCreate<F, C, ConnID>
354where
355 F: FnOnce(C) -> LeasedConnection<C, ConnID>,
356{
357 LeasedConnection(LeasedConnection<C, ConnID>),
359 AddConnection(F),
364}
365
366impl<S: PoolStorage> Pool<S> {
367 pub fn new(max_active: NonZeroU16, max_total: NonZeroU16) -> Result<Pool<S>, OpaqueError> {
368 if max_active > max_total {
369 return Err(OpaqueError::from_display(
370 "max_active should be <= then max_total connection",
371 ));
372 }
373
374 let storage = S::new(max_total);
375 let max_total: usize = Into::<u16>::into(max_total).into();
376
377 Ok(Pool {
378 inner: Arc::new(PoolInner {
379 total_slots: Arc::new(Semaphore::new(max_total)),
380 active_slots: Arc::new(Semaphore::new(Into::<u16>::into(max_active).into())),
381 storage,
382 }),
383 })
384 }
385}
386
387impl<S: PoolStorage> Pool<S> {
388 pub async fn get_connection_or_create<F, Fut>(
390 &self,
391 id: &S::ConnID,
392 create_conn: F,
393 ) -> Result<LeasedConnection<S::Connection, S::ConnID>, OpaqueError>
394 where
395 F: FnOnce() -> Fut,
396 Fut: Future<Output = Result<S::Connection, OpaqueError>>,
397 {
398 match self.get_connection_or_create_cb(id).await? {
399 GetConnectionOrCreate::LeasedConnection(leased_connection) => Ok(leased_connection),
400 GetConnectionOrCreate::AddConnection(add) => {
401 let conn = create_conn().await?;
402 Ok(add(conn))
403 }
404 }
405 }
406
407 pub async fn get_connection_or_create_cb(
409 &self,
410 id: &S::ConnID,
411 ) -> Result<
412 GetConnectionOrCreate<
413 impl FnOnce(S::Connection) -> LeasedConnection<S::Connection, S::ConnID>,
414 S::Connection,
415 S::ConnID,
416 >,
417 OpaqueError,
418 > {
419 let active_permit = self
420 .inner
421 .active_slots
422 .clone()
423 .acquire_owned()
424 .await
425 .context("failed to acquire active slot permit")?;
426
427 let active_slot = ActiveSlot(active_permit);
428
429 let pooled_conn = self.inner.storage.get_connection(id);
430
431 let pool = Arc::downgrade(&self.inner);
432 let returner = Arc::new(move |conn| {
433 if let Some(pool) = pool.upgrade() {
434 pool.storage.add_connection(conn);
435 }
436 });
437
438 if pooled_conn.is_some() {
439 trace!(conn_id = ?id, "creating leased connection from stored pooled connection");
440 let leased_conn = LeasedConnection {
441 _slot: active_slot,
442 pooled_conn,
443 returner,
444 };
445 return Ok(GetConnectionOrCreate::LeasedConnection(leased_conn));
446 };
447
448 let pool_slot = match self.inner.total_slots.clone().try_acquire_owned() {
451 Ok(pool_permit) => PoolSlot(pool_permit),
452 Err(_) => {
453 let pooled_conn = self.inner.storage.get_connection_to_drop()?;
454 pooled_conn.slot
455 }
456 };
457
458 trace!(conn_id = ?id, "no pooled connection found, returning callback to create leased connection");
459 Ok(GetConnectionOrCreate::AddConnection(
460 move |conn: S::Connection| LeasedConnection {
461 _slot: active_slot,
462 returner,
463 pooled_conn: Some(PooledConnection {
464 conn,
465 id: id.clone(),
466 slot: pool_slot,
467 }),
468 },
469 ))
470 }
471}
472
473pub trait ReqToConnID<State, Request>: Sized + Send + Sync + 'static {
478 type ConnID: Send + Sync + PartialEq + Clone + 'static;
479
480 fn id(&self, ctx: &Context<State>, request: &Request) -> Result<Self::ConnID, OpaqueError>;
481}
482
483impl<State, Request, ConnID, F> ReqToConnID<State, Request> for F
484where
485 F: Fn(&Context<State>, &Request) -> Result<ConnID, OpaqueError> + Send + Sync + 'static,
486 ConnID: Send + Sync + PartialEq + Clone + 'static,
487{
488 type ConnID = ConnID;
489
490 fn id(&self, ctx: &Context<State>, request: &Request) -> Result<Self::ConnID, OpaqueError> {
491 self(ctx, request)
492 }
493}
494
495pub struct PooledConnector<S, Storage, R> {
499 inner: S,
500 pool: Pool<Storage>,
501 req_to_conn_id: R,
502 wait_for_pool_timeout: Option<Duration>,
503}
504
505impl<S, Storage, R> PooledConnector<S, Storage, R> {
506 pub fn new(inner: S, pool: Pool<Storage>, req_to_conn_id: R) -> PooledConnector<S, Storage, R> {
507 PooledConnector {
508 inner,
509 pool,
510 req_to_conn_id,
511 wait_for_pool_timeout: None,
512 }
513 }
514
515 generate_field_setters!(wait_for_pool_timeout, Duration);
516}
517
518impl<State, Request, S, Storage, R> Service<State, Request> for PooledConnector<S, Storage, R>
519where
520 S: ConnectorService<State, Request, Connection: Send, Error: Send + 'static>,
521 State: Send + Sync + 'static,
522 Request: Send + 'static,
523 Storage: PoolStorage<ConnID = R::ConnID, Connection = S::Connection>,
524 R: ReqToConnID<State, Request>,
525{
526 type Response = EstablishedClientConnection<
527 LeasedConnection<Storage::Connection, Storage::ConnID>,
528 State,
529 Request,
530 >;
531 type Error = BoxError;
532
533 async fn serve(
534 &self,
535 ctx: Context<State>,
536 req: Request,
537 ) -> Result<Self::Response, Self::Error> {
538 let conn_id = self.req_to_conn_id.id(&ctx, &req)?;
539
540 let pool = match ctx.get::<Pool<Storage>>() {
541 Some(pool) => &pool.clone(),
542 None => &self.pool,
543 };
544 let pool_result = if let Some(duration) = self.wait_for_pool_timeout {
545 timeout(duration, pool.get_connection_or_create_cb(&conn_id))
546 .await
547 .map_err(OpaqueError::from_std)?
548 } else {
549 pool.get_connection_or_create_cb(&conn_id).await
550 }?;
551
552 let (ctx, req, leased_conn) = match pool_result {
553 GetConnectionOrCreate::LeasedConnection(leased_conn) => (ctx, req, leased_conn),
554 GetConnectionOrCreate::AddConnection(cb) => {
555 let EstablishedClientConnection { ctx, req, conn } =
556 self.inner.connect(ctx, req).await.map_err(Into::into)?;
557 let leased_conn = cb(conn);
558 (ctx, req, leased_conn)
559 }
560 };
561
562 Ok(EstablishedClientConnection {
563 ctx,
564 req,
565 conn: leased_conn,
566 })
567 }
568}
569
570pub struct PooledConnectorLayer<Storage, R> {
571 pool: Pool<Storage>,
572 req_to_conn_id: R,
573 wait_for_pool_timeout: Option<Duration>,
574}
575
576impl<Storage, R> PooledConnectorLayer<Storage, R> {
577 pub fn new(pool: Pool<Storage>, req_to_conn_id: R) -> Self {
578 Self {
579 pool,
580 req_to_conn_id,
581 wait_for_pool_timeout: None,
582 }
583 }
584
585 generate_field_setters!(wait_for_pool_timeout, Duration);
586}
587
588impl<S, Storage, R: Clone> Layer<S> for PooledConnectorLayer<Storage, R> {
589 type Service = PooledConnector<S, Storage, R>;
590
591 fn layer(&self, inner: S) -> Self::Service {
592 PooledConnector::new(inner, self.pool.clone(), self.req_to_conn_id.clone())
593 .maybe_with_wait_for_pool_timeout(self.wait_for_pool_timeout)
594 }
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600 use crate::client::EstablishedClientConnection;
601 use rama_core::{Context, Service};
602 use std::{
603 convert::Infallible,
604 sync::atomic::{AtomicI16, Ordering},
605 };
606 use tokio_test::assert_err;
607
608 struct TestService {
609 pub created_connection: AtomicI16,
610 }
611
612 impl Default for TestService {
613 fn default() -> Self {
614 Self {
615 created_connection: AtomicI16::new(0),
616 }
617 }
618 }
619
620 impl<State, Request> Service<State, Request> for TestService
621 where
622 State: Clone + Send + Sync + 'static,
623 Request: Send + 'static,
624 {
625 type Response = EstablishedClientConnection<Vec<u32>, State, Request>;
626 type Error = Infallible;
627
628 async fn serve(
629 &self,
630 ctx: Context<State>,
631 req: Request,
632 ) -> Result<Self::Response, Self::Error> {
633 let conn = vec![];
634 self.created_connection.fetch_add(1, Ordering::Relaxed);
635 Ok(EstablishedClientConnection { ctx, req, conn })
636 }
637 }
638
639 struct StringRequestLengthID;
643
644 impl<State> ReqToConnID<State, String> for StringRequestLengthID {
645 type ConnID = usize;
646
647 fn id(&self, _ctx: &Context<State>, req: &String) -> Result<Self::ConnID, OpaqueError> {
648 Ok(req.chars().count())
649 }
650 }
651
652 #[tokio::test]
653 async fn test_should_reuse_connections() {
654 let pool = Pool::<ConnStoreFiFoReuseLruDrop<_, _>>::default();
655 let svc = PooledConnector::new(
658 TestService::default(),
659 pool,
660 |_ctx: &Context<()>, _req: &String| Ok(()),
661 );
662
663 let iterations = 10;
664 for _i in 0..iterations {
665 let _conn = svc
666 .connect(Context::default(), String::new())
667 .await
668 .unwrap();
669 }
670
671 let created_connection = svc.inner.created_connection.load(Ordering::Relaxed);
672 assert_eq!(created_connection, 1);
673 }
674
675 #[tokio::test]
676 async fn test_conn_id_to_separate() {
677 let pool = Pool::<ConnStoreFiFoReuseLruDrop<_, _>>::default();
678 let svc = PooledConnector::new(TestService::default(), pool, StringRequestLengthID {});
679
680 {
681 let mut conn = svc
682 .connect(Context::default(), String::from("a"))
683 .await
684 .unwrap()
685 .conn;
686
687 conn.push(1);
688 assert_eq!(conn.as_ref(), &vec![1]);
689 assert_eq!(svc.inner.created_connection.load(Ordering::Relaxed), 1);
690 }
691
692 {
694 let mut conn = svc
695 .connect(Context::default(), String::from("B"))
696 .await
697 .unwrap()
698 .conn;
699
700 conn.push(2);
701 assert_eq!(conn.as_ref(), &vec![1, 2]);
702 assert_eq!(svc.inner.created_connection.load(Ordering::Relaxed), 1);
703 }
704
705 {
707 let mut conn = svc
708 .connect(Context::default(), String::from("aa"))
709 .await
710 .unwrap()
711 .conn;
712
713 conn.push(3);
714 assert_eq!(conn.as_ref(), &vec![3]);
715 assert_eq!(svc.inner.created_connection.load(Ordering::Relaxed), 2);
716 }
717
718 {
720 let mut conn = svc
721 .connect(Context::default(), String::from("bb"))
722 .await
723 .unwrap()
724 .conn;
725
726 conn.push(4);
727 assert_eq!(conn.as_ref(), &vec![3, 4]);
728 assert_eq!(svc.inner.created_connection.load(Ordering::Relaxed), 2);
729 }
730 }
731
732 #[tokio::test]
733 async fn test_pool_max_size() {
734 let pool = Pool::<ConnStoreFiFoReuseLruDrop<_, _>>::new(nz!(1), nz!(1)).unwrap();
735 let svc = PooledConnector::new(TestService::default(), pool, StringRequestLengthID {})
736 .with_wait_for_pool_timeout(Duration::from_millis(50));
737
738 let conn1 = svc
739 .connect(Context::default(), String::from("a"))
740 .await
741 .unwrap();
742
743 let conn2 = svc.connect(Context::default(), String::from("a")).await;
744 assert_err!(conn2);
745
746 drop(conn1);
747 let _conn3 = svc
748 .connect(Context::default(), String::from("aaa"))
749 .await
750 .unwrap();
751 }
752}