Skip to main content

cdk_sql_common/
pool.rs

1//! Very simple connection pool, to avoid an external dependency on r2d2 and other crates. If this
2//! endup work it can be re-used in other parts of the project and may be promoted to its own
3//! generic crate
4
5use std::fmt::Debug;
6use std::ops::{Deref, DerefMut};
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::{Arc, Mutex};
9use std::time::Duration;
10
11#[cfg(feature = "prometheus")]
12use cdk_prometheus::metrics::METRICS;
13use tokio::sync::{OwnedSemaphorePermit, Semaphore};
14
15use crate::database::DatabaseConnector;
16
17/// Pool error
18#[derive(Debug, thiserror::Error)]
19pub enum Error<E>
20where
21    E: std::error::Error + Send + Sync + 'static,
22{
23    /// Mutex Poison Error
24    #[error("Internal: PoisonError")]
25    Poison,
26
27    /// Timeout error
28    #[error("Timed out waiting for a resource")]
29    Timeout,
30
31    /// Internal database error
32    #[error(transparent)]
33    Resource(#[from] E),
34}
35
36/// Configuration
37pub trait DatabaseConfig: Clone + Debug + Send + Sync {
38    /// Max resource sizes
39    fn max_size(&self) -> usize;
40
41    /// Default timeout
42    fn default_timeout(&self) -> Duration;
43}
44
45/// Trait to manage resources
46pub trait DatabasePool: Debug {
47    /// The resource to be pooled
48    type Connection: DatabaseConnector;
49
50    /// The configuration that is needed in order to create the resource
51    type Config: DatabaseConfig;
52
53    /// The error the resource may return when creating a new instance
54    type Error: Debug + std::error::Error + Send + Sync + 'static;
55
56    /// Creates a new resource with a given config.
57    ///
58    /// If `stale` is ever set to TRUE it is assumed the resource is no longer valid and it will be
59    /// dropped.
60    fn new_resource(
61        config: &Self::Config,
62        stale: Arc<AtomicBool>,
63        timeout: Duration,
64    ) -> Result<Self::Connection, Error<Self::Error>>;
65
66    /// The object is dropped
67    fn drop(_resource: Self::Connection) {}
68}
69
70/// Generic connection pool of resources R
71pub struct Pool<RM>
72where
73    RM: DatabasePool,
74{
75    config: RM::Config,
76    queue: Mutex<Vec<(Arc<AtomicBool>, RM::Connection)>>,
77    max_size: usize,
78    default_timeout: Duration,
79    semaphore: Arc<Semaphore>,
80}
81
82impl<RM> Debug for Pool<RM>
83where
84    RM: DatabasePool,
85{
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        f.debug_struct("Pool")
88            .field("config", &self.config)
89            .field("max_size", &self.max_size)
90            .field("default_timeout", &self.default_timeout)
91            .field("available_permits", &self.semaphore.available_permits())
92            .finish()
93    }
94}
95
96/// The pooled resource
97pub struct PooledResource<RM>
98where
99    RM: DatabasePool,
100{
101    resource: Option<(Arc<AtomicBool>, RM::Connection)>,
102    pool: Arc<Pool<RM>>,
103    _permit: OwnedSemaphorePermit,
104    #[cfg(feature = "prometheus")]
105    start_time: std::time::Instant,
106}
107
108impl<RM> Debug for PooledResource<RM>
109where
110    RM: DatabasePool,
111{
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        write!(f, "Resource: {:?}", self.resource)
114    }
115}
116
117impl<RM> Drop for PooledResource<RM>
118where
119    RM: DatabasePool,
120{
121    fn drop(&mut self) {
122        if let Some(resource) = self.resource.take() {
123            let mut active_resource = self.pool.queue.lock().expect("active_resource");
124            active_resource.push(resource);
125
126            #[cfg(feature = "prometheus")]
127            {
128                let in_use = self
129                    .pool
130                    .max_size
131                    .saturating_sub(self.pool.semaphore.available_permits())
132                    .saturating_sub(1);
133                METRICS.set_db_connections_active(in_use as i64);
134
135                let duration = self.start_time.elapsed().as_secs_f64();
136
137                METRICS.record_db_operation(duration, "drop");
138            }
139
140            // The semaphore permit is dropped automatically after this,
141            // which wakes any async task waiting in `get()`.
142        }
143    }
144}
145
146impl<RM> Deref for PooledResource<RM>
147where
148    RM: DatabasePool,
149{
150    type Target = RM::Connection;
151
152    fn deref(&self) -> &Self::Target {
153        &self.resource.as_ref().expect("resource already dropped").1
154    }
155}
156
157impl<RM> DerefMut for PooledResource<RM>
158where
159    RM: DatabasePool,
160{
161    fn deref_mut(&mut self) -> &mut Self::Target {
162        &mut self.resource.as_mut().expect("resource already dropped").1
163    }
164}
165
166impl<RM> Pool<RM>
167where
168    RM: DatabasePool,
169{
170    /// Creates a new pool
171    pub fn new(config: RM::Config) -> Arc<Self> {
172        let max_size = config.max_size();
173        Arc::new(Self {
174            default_timeout: config.default_timeout(),
175            max_size,
176            config,
177            queue: Default::default(),
178            semaphore: Arc::new(Semaphore::new(max_size)),
179        })
180    }
181
182    /// Similar to get_timeout but uses the default timeout value.
183    #[inline(always)]
184    pub async fn get(self: &Arc<Self>) -> Result<PooledResource<RM>, Error<RM::Error>> {
185        self.get_timeout(self.default_timeout).await
186    }
187
188    /// Get a new resource or fail after timeout is reached.
189    ///
190    /// This function will return a free resource or create a new one if there is still room for it;
191    /// otherwise, it will asynchronously wait for a resource to be released for reuse.
192    #[inline(always)]
193    pub async fn get_timeout(
194        self: &Arc<Self>,
195        timeout: Duration,
196    ) -> Result<PooledResource<RM>, Error<RM::Error>> {
197        // Fast path: try to grab a permit without waiting.
198        let permit = match self.semaphore.clone().try_acquire_owned() {
199            Ok(permit) => permit,
200            Err(tokio::sync::TryAcquireError::Closed) => return Err(Error::Poison),
201            Err(tokio::sync::TryAcquireError::NoPermits) => {
202                // All permits are in use — wait asynchronously.  This yields
203                // the task instead of blocking the OS thread, preventing Tokio
204                // worker thread starvation.
205                tracing::debug!(
206                    "Pool exhausted (size: {}), waiting for a connection",
207                    self.max_size,
208                );
209                tokio::time::timeout(timeout, self.semaphore.clone().acquire_owned())
210                    .await
211                    .map_err(|_| Error::Timeout)?
212                    .map_err(|_| Error::Poison)?
213            }
214        };
215
216        #[cfg(feature = "prometheus")]
217        {
218            let in_use = self.max_size - self.semaphore.available_permits();
219            METRICS.set_db_connections_active(in_use as i64);
220        }
221
222        // Briefly lock the idle queue to try to pop a non-stale connection.
223        // This mutex is held for nanoseconds (just a Vec::pop).
224        {
225            let mut resources = self.queue.lock().map_err(|_| Error::Poison)?;
226            while let Some((stale, resource)) = resources.pop() {
227                if !stale.load(Ordering::SeqCst) {
228                    return Ok(PooledResource {
229                        resource: Some((stale, resource)),
230                        pool: self.clone(),
231                        _permit: permit,
232                        #[cfg(feature = "prometheus")]
233                        start_time: std::time::Instant::now(),
234                    });
235                }
236                // Stale connection — drop it and keep looking.
237            }
238        }
239
240        // No idle connection available — create a new one.
241        // The semaphore already guarantees we won't exceed max_size.
242        let stale: Arc<AtomicBool> = Arc::new(false.into());
243        match RM::new_resource(&self.config, stale.clone(), timeout) {
244            Ok(new_resource) => Ok(PooledResource {
245                resource: Some((stale, new_resource)),
246                pool: self.clone(),
247                _permit: permit,
248                #[cfg(feature = "prometheus")]
249                start_time: std::time::Instant::now(),
250            }),
251            Err(e) => {
252                #[cfg(feature = "prometheus")]
253                {
254                    let in_use = self
255                        .max_size
256                        .saturating_sub(self.semaphore.available_permits())
257                        .saturating_sub(1);
258                    METRICS.set_db_connections_active(in_use as i64);
259                }
260
261                // Permit is dropped here, releasing the slot back to the semaphore.
262                Err(e)
263            }
264        }
265    }
266}
267
268impl<RM> Drop for Pool<RM>
269where
270    RM: DatabasePool,
271{
272    fn drop(&mut self) {
273        // Close the semaphore so no new acquisitions can succeed.
274        self.semaphore.close();
275
276        // Drain all idle connections.
277        if let Ok(mut resources) = self.queue.lock() {
278            while let Some(resource) = resources.pop() {
279                RM::drop(resource.1);
280            }
281        }
282    }
283}
284
285#[cfg(all(test, feature = "prometheus"))]
286mod tests {
287    use std::fmt;
288    use std::sync::atomic::AtomicBool;
289    use std::sync::Arc;
290    use std::time::Duration;
291
292    use cdk_common::database::Error as DatabaseError;
293    use cdk_prometheus::METRICS;
294
295    use super::{DatabaseConfig, DatabasePool, Error, Pool};
296    use crate::database::{DatabaseConnector, DatabaseExecutor, DatabaseTransaction};
297    use crate::stmt::{Column, Statement};
298
299    #[derive(Debug, Clone)]
300    struct TestConfig {
301        max_size: usize,
302        default_timeout: Duration,
303        fail_new_resource: bool,
304    }
305
306    impl DatabaseConfig for TestConfig {
307        fn max_size(&self) -> usize {
308            self.max_size
309        }
310
311        fn default_timeout(&self) -> Duration {
312            self.default_timeout
313        }
314    }
315
316    #[derive(Debug)]
317    struct TestResourceError;
318
319    impl fmt::Display for TestResourceError {
320        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321            f.write_str("test resource error")
322        }
323    }
324
325    impl std::error::Error for TestResourceError {}
326
327    #[derive(Debug)]
328    struct TestConnection;
329
330    #[async_trait::async_trait]
331    impl DatabaseExecutor for TestConnection {
332        fn name() -> &'static str {
333            "test"
334        }
335
336        async fn execute(&self, _statement: Statement) -> Result<usize, DatabaseError> {
337            Ok(0)
338        }
339
340        async fn fetch_one(
341            &self,
342            _statement: Statement,
343        ) -> Result<Option<Vec<Column>>, DatabaseError> {
344            Ok(None)
345        }
346
347        async fn fetch_all(
348            &self,
349            _statement: Statement,
350        ) -> Result<Vec<Vec<Column>>, DatabaseError> {
351            Ok(Vec::new())
352        }
353
354        async fn pluck(&self, _statement: Statement) -> Result<Option<Column>, DatabaseError> {
355            Ok(None)
356        }
357
358        async fn batch(&self, _statement: Statement) -> Result<(), DatabaseError> {
359            Ok(())
360        }
361    }
362
363    #[derive(Debug)]
364    struct TestTransaction;
365
366    #[async_trait::async_trait]
367    impl DatabaseTransaction<TestConnection> for TestTransaction {
368        async fn commit(_conn: &mut TestConnection) -> Result<(), DatabaseError> {
369            Ok(())
370        }
371
372        async fn begin(_conn: &mut TestConnection) -> Result<(), DatabaseError> {
373            Ok(())
374        }
375
376        async fn rollback(_conn: &mut TestConnection) -> Result<(), DatabaseError> {
377            Ok(())
378        }
379    }
380
381    impl DatabaseConnector for TestConnection {
382        type Transaction = TestTransaction;
383    }
384
385    #[derive(Debug)]
386    struct TestPool;
387
388    impl DatabasePool for TestPool {
389        type Connection = TestConnection;
390        type Config = TestConfig;
391        type Error = TestResourceError;
392
393        fn new_resource(
394            config: &Self::Config,
395            _stale: Arc<AtomicBool>,
396            _timeout: Duration,
397        ) -> Result<Self::Connection, Error<Self::Error>> {
398            if config.fail_new_resource {
399                Err(Error::Resource(TestResourceError))
400            } else {
401                Ok(TestConnection)
402            }
403        }
404    }
405
406    fn test_config(max_size: usize, fail_new_resource: bool) -> TestConfig {
407        TestConfig {
408            max_size,
409            default_timeout: Duration::from_millis(10),
410            fail_new_resource,
411        }
412    }
413
414    fn db_connections_active() -> f64 {
415        for family in METRICS.registry().gather() {
416            if family.get_name() != "cdk_db_connections_active" {
417                continue;
418            }
419
420            return family
421                .get_metric()
422                .first()
423                .expect("active connections metric should exist")
424                .get_gauge()
425                .get_value();
426        }
427
428        panic!("active connections metric should be registered");
429    }
430
431    #[tokio::test(flavor = "current_thread")]
432    async fn active_connections_gauge_tracks_current_checkout_and_drop_counts() {
433        let _lock = crate::metrics_test_lock::lock().await;
434        METRICS.set_db_connections_active(0);
435
436        let pool = Pool::<TestPool>::new(test_config(2, false));
437
438        let first = pool
439            .get()
440            .await
441            .expect("first resource should be checked out");
442        assert_eq!(db_connections_active(), 1.0);
443
444        let second = pool
445            .get()
446            .await
447            .expect("second resource should be checked out");
448        assert_eq!(db_connections_active(), 2.0);
449
450        drop(first);
451        assert_eq!(db_connections_active(), 1.0);
452
453        drop(second);
454        assert_eq!(db_connections_active(), 0.0);
455        assert_eq!(pool.semaphore.available_permits(), pool.max_size);
456    }
457
458    #[tokio::test(flavor = "current_thread")]
459    async fn active_connections_gauge_is_restored_when_resource_creation_fails() {
460        let _lock = crate::metrics_test_lock::lock().await;
461        METRICS.set_db_connections_active(0);
462
463        let pool = Pool::<TestPool>::new(test_config(1, true));
464        let result = pool.get().await;
465
466        assert!(matches!(result, Err(Error::Resource(_))));
467        assert_eq!(db_connections_active(), 0.0);
468        assert_eq!(pool.semaphore.available_permits(), pool.max_size);
469    }
470}