1use std::mem::ManuallyDrop;
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use crossbeam_queue::ArrayQueue;
6use tokio::sync::{OwnedSemaphorePermit, Semaphore};
7
8use crate::error::Result;
9use crate::opts::Opts;
10
11use super::Conn;
12
13pub struct Pool {
14 opts: Opts,
15 conns: ArrayQueue<Conn>,
16 semaphore: Option<Arc<Semaphore>>,
17}
18
19impl Pool {
20 pub fn new(opts: Opts) -> Self {
21 let semaphore = opts
22 .pool_max_concurrency
23 .map(|n| Arc::new(Semaphore::new(n)));
24 Self {
25 conns: ArrayQueue::new(opts.pool_max_idle_conn),
26 opts,
27 semaphore,
28 }
29 }
30
31 pub async fn get(self: &Arc<Self>) -> Result<PooledConn> {
32 let permit =
33 match &self.semaphore {
34 Some(sem) => Some(Arc::clone(sem).acquire_owned().await.map_err(
35 |_acquire_err| {
36 crate::error::Error::LibraryBug(color_eyre::eyre::eyre!("semaphore closed"))
37 },
38 )?),
39 None => None,
40 };
41 let mut conn = match self.conns.pop() {
42 Some(c) => c,
43 None => Conn::new(self.opts.clone()).await?,
44 };
45 conn.ping().await?;
46 Ok(PooledConn {
47 conn: ManuallyDrop::new(conn),
48 pool: Arc::clone(self),
49 _permit: permit,
50 })
51 }
52
53 fn check_in(self: &Arc<Self>, mut conn: Conn) {
54 if conn.is_broken() {
55 return;
56 }
57 if self.opts.pool_reset_conn {
58 let Ok(handle) = tokio::runtime::Handle::try_current() else {
59 return;
60 };
61 let pool = Arc::clone(self);
62 handle.spawn(async move {
63 if conn.reset().await.is_ok() {
64 let _ = pool.conns.push(conn);
65 }
66 });
67 } else {
68 let _ = self.conns.push(conn);
69 }
70 }
71}
72
73pub struct PooledConn {
74 pool: Arc<Pool>,
75 conn: ManuallyDrop<Conn>,
76 _permit: Option<OwnedSemaphorePermit>,
77}
78
79impl Deref for PooledConn {
80 type Target = Conn;
81 fn deref(&self) -> &Self::Target {
82 &self.conn
83 }
84}
85
86impl DerefMut for PooledConn {
87 fn deref_mut(&mut self) -> &mut Self::Target {
88 &mut self.conn
89 }
90}
91
92impl Drop for PooledConn {
93 fn drop(&mut self) {
94 let conn = unsafe { ManuallyDrop::take(&mut self.conn) };
96 self.pool.check_in(conn);
97 }
98}