Skip to main content

cdk_sql_common/
pool.rs

1//! Very simple connection pool, to avoid an external dependency on r2d2 and other crates. If this
2//! endup work it can be re-used in other parts of the project and may be promoted to its own
3//! generic crate
4
5use std::fmt::Debug;
6use std::ops::{Deref, DerefMut};
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::{Arc, Mutex};
9use std::time::Duration;
10
11#[cfg(feature = "prometheus")]
12use cdk_prometheus::metrics::METRICS;
13use tokio::sync::{OwnedSemaphorePermit, Semaphore};
14
15use crate::database::DatabaseConnector;
16
17/// Pool error
18#[derive(Debug, thiserror::Error)]
19pub enum Error<E>
20where
21    E: std::error::Error + Send + Sync + 'static,
22{
23    /// Mutex Poison Error
24    #[error("Internal: PoisonError")]
25    Poison,
26
27    /// Timeout error
28    #[error("Timed out waiting for a resource")]
29    Timeout,
30
31    /// Internal database error
32    #[error(transparent)]
33    Resource(#[from] E),
34}
35
36/// Configuration
37pub trait DatabaseConfig: Clone + Debug + Send + Sync {
38    /// Max resource sizes
39    fn max_size(&self) -> usize;
40
41    /// Default timeout
42    fn default_timeout(&self) -> Duration;
43}
44
45/// Trait to manage resources
46pub trait DatabasePool: Debug {
47    /// The resource to be pooled
48    type Connection: DatabaseConnector;
49
50    /// The configuration that is needed in order to create the resource
51    type Config: DatabaseConfig;
52
53    /// The error the resource may return when creating a new instance
54    type Error: Debug + std::error::Error + Send + Sync + 'static;
55
56    /// Creates a new resource with a given config.
57    ///
58    /// If `stale` is ever set to TRUE it is assumed the resource is no longer valid and it will be
59    /// dropped.
60    fn new_resource(
61        config: &Self::Config,
62        stale: Arc<AtomicBool>,
63        timeout: Duration,
64    ) -> Result<Self::Connection, Error<Self::Error>>;
65
66    /// The object is dropped
67    fn drop(_resource: Self::Connection) {}
68}
69
70/// Generic connection pool of resources R
71pub struct Pool<RM>
72where
73    RM: DatabasePool,
74{
75    config: RM::Config,
76    queue: Mutex<Vec<(Arc<AtomicBool>, RM::Connection)>>,
77    max_size: usize,
78    default_timeout: Duration,
79    semaphore: Arc<Semaphore>,
80}
81
82impl<RM> Debug for Pool<RM>
83where
84    RM: DatabasePool,
85{
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        f.debug_struct("Pool")
88            .field("config", &self.config)
89            .field("max_size", &self.max_size)
90            .field("default_timeout", &self.default_timeout)
91            .field("available_permits", &self.semaphore.available_permits())
92            .finish()
93    }
94}
95
96/// The pooled resource
97pub struct PooledResource<RM>
98where
99    RM: DatabasePool,
100{
101    resource: Option<(Arc<AtomicBool>, RM::Connection)>,
102    pool: Arc<Pool<RM>>,
103    _permit: OwnedSemaphorePermit,
104    #[cfg(feature = "prometheus")]
105    start_time: std::time::Instant,
106}
107
108impl<RM> Debug for PooledResource<RM>
109where
110    RM: DatabasePool,
111{
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        write!(f, "Resource: {:?}", self.resource)
114    }
115}
116
117impl<RM> Drop for PooledResource<RM>
118where
119    RM: DatabasePool,
120{
121    fn drop(&mut self) {
122        if let Some(resource) = self.resource.take() {
123            let mut active_resource = self.pool.queue.lock().expect("active_resource");
124            active_resource.push(resource);
125
126            #[cfg(feature = "prometheus")]
127            {
128                let in_use = self.pool.max_size - self.pool.semaphore.available_permits();
129                METRICS.set_db_connections_active(in_use as i64);
130
131                let duration = self.start_time.elapsed().as_secs_f64();
132
133                METRICS.record_db_operation(duration, "drop");
134            }
135
136            // The semaphore permit is dropped automatically after this,
137            // which wakes any async task waiting in `get()`.
138        }
139    }
140}
141
142impl<RM> Deref for PooledResource<RM>
143where
144    RM: DatabasePool,
145{
146    type Target = RM::Connection;
147
148    fn deref(&self) -> &Self::Target {
149        &self.resource.as_ref().expect("resource already dropped").1
150    }
151}
152
153impl<RM> DerefMut for PooledResource<RM>
154where
155    RM: DatabasePool,
156{
157    fn deref_mut(&mut self) -> &mut Self::Target {
158        &mut self.resource.as_mut().expect("resource already dropped").1
159    }
160}
161
162impl<RM> Pool<RM>
163where
164    RM: DatabasePool,
165{
166    /// Creates a new pool
167    pub fn new(config: RM::Config) -> Arc<Self> {
168        let max_size = config.max_size();
169        Arc::new(Self {
170            default_timeout: config.default_timeout(),
171            max_size,
172            config,
173            queue: Default::default(),
174            semaphore: Arc::new(Semaphore::new(max_size)),
175        })
176    }
177
178    /// Similar to get_timeout but uses the default timeout value.
179    #[inline(always)]
180    pub async fn get(self: &Arc<Self>) -> Result<PooledResource<RM>, Error<RM::Error>> {
181        self.get_timeout(self.default_timeout).await
182    }
183
184    /// Get a new resource or fail after timeout is reached.
185    ///
186    /// This function will return a free resource or create a new one if there is still room for it;
187    /// otherwise, it will asynchronously wait for a resource to be released for reuse.
188    #[inline(always)]
189    pub async fn get_timeout(
190        self: &Arc<Self>,
191        timeout: Duration,
192    ) -> Result<PooledResource<RM>, Error<RM::Error>> {
193        // Fast path: try to grab a permit without waiting.
194        let permit = match self.semaphore.clone().try_acquire_owned() {
195            Ok(permit) => permit,
196            Err(tokio::sync::TryAcquireError::Closed) => return Err(Error::Poison),
197            Err(tokio::sync::TryAcquireError::NoPermits) => {
198                // All permits are in use — wait asynchronously.  This yields
199                // the task instead of blocking the OS thread, preventing Tokio
200                // worker thread starvation.
201                tracing::debug!(
202                    "Pool exhausted (size: {}), waiting for a connection",
203                    self.max_size,
204                );
205                tokio::time::timeout(timeout, self.semaphore.clone().acquire_owned())
206                    .await
207                    .map_err(|_| Error::Timeout)?
208                    .map_err(|_| Error::Poison)?
209            }
210        };
211
212        #[cfg(feature = "prometheus")]
213        {
214            let in_use = self.max_size - self.semaphore.available_permits();
215            METRICS.set_db_connections_active(in_use as i64);
216        }
217
218        // Briefly lock the idle queue to try to pop a non-stale connection.
219        // This mutex is held for nanoseconds (just a Vec::pop).
220        {
221            let mut resources = self.queue.lock().map_err(|_| Error::Poison)?;
222            while let Some((stale, resource)) = resources.pop() {
223                if !stale.load(Ordering::SeqCst) {
224                    return Ok(PooledResource {
225                        resource: Some((stale, resource)),
226                        pool: self.clone(),
227                        _permit: permit,
228                        #[cfg(feature = "prometheus")]
229                        start_time: std::time::Instant::now(),
230                    });
231                }
232                // Stale connection — drop it and keep looking.
233            }
234        }
235
236        // No idle connection available — create a new one.
237        // The semaphore already guarantees we won't exceed max_size.
238        let stale: Arc<AtomicBool> = Arc::new(false.into());
239        match RM::new_resource(&self.config, stale.clone(), timeout) {
240            Ok(new_resource) => Ok(PooledResource {
241                resource: Some((stale, new_resource)),
242                pool: self.clone(),
243                _permit: permit,
244                #[cfg(feature = "prometheus")]
245                start_time: std::time::Instant::now(),
246            }),
247            Err(e) => {
248                // Permit is dropped here, releasing the slot back to the semaphore.
249                Err(e)
250            }
251        }
252    }
253}
254
255impl<RM> Drop for Pool<RM>
256where
257    RM: DatabasePool,
258{
259    fn drop(&mut self) {
260        // Close the semaphore so no new acquisitions can succeed.
261        self.semaphore.close();
262
263        // Drain all idle connections.
264        if let Ok(mut resources) = self.queue.lock() {
265            while let Some(resource) = resources.pop() {
266                RM::drop(resource.1);
267            }
268        }
269    }
270}