cdk_sql_common/
pool.rs

1//! Very simple connection pool, to avoid an external dependency on r2d2 and other crates. If this
2//! endup work it can be re-used in other parts of the project and may be promoted to its own
3//! generic crate
4
5use std::fmt::Debug;
6use std::ops::{Deref, DerefMut};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::{Arc, Condvar, Mutex};
9use std::time::{Duration, Instant};
10
11#[cfg(feature = "prometheus")]
12use cdk_prometheus::metrics::METRICS;
13
14use crate::database::DatabaseConnector;
15
16/// Pool error
17#[derive(Debug, thiserror::Error)]
18pub enum Error<E>
19where
20    E: std::error::Error + Send + Sync + 'static,
21{
22    /// Mutex Poison Error
23    #[error("Internal: PoisonError")]
24    Poison,
25
26    /// Timeout error
27    #[error("Timed out waiting for a resource")]
28    Timeout,
29
30    /// Internal database error
31    #[error(transparent)]
32    Resource(#[from] E),
33}
34
35/// Configuration
36pub trait DatabaseConfig: Clone + Debug + Send + Sync {
37    /// Max resource sizes
38    fn max_size(&self) -> usize;
39
40    /// Default timeout
41    fn default_timeout(&self) -> Duration;
42}
43
44/// Trait to manage resources
45pub trait DatabasePool: Debug {
46    /// The resource to be pooled
47    type Connection: DatabaseConnector;
48
49    /// The configuration that is needed in order to create the resource
50    type Config: DatabaseConfig;
51
52    /// The error the resource may return when creating a new instance
53    type Error: Debug + std::error::Error + Send + Sync + 'static;
54
55    /// Creates a new resource with a given config.
56    ///
57    /// If `stale` is ever set to TRUE it is assumed the resource is no longer valid and it will be
58    /// dropped.
59    fn new_resource(
60        config: &Self::Config,
61        stale: Arc<AtomicBool>,
62        timeout: Duration,
63    ) -> Result<Self::Connection, Error<Self::Error>>;
64
65    /// The object is dropped
66    fn drop(_resource: Self::Connection) {}
67}
68
69/// Generic connection pool of resources R
70#[derive(Debug)]
71pub struct Pool<RM>
72where
73    RM: DatabasePool,
74{
75    config: RM::Config,
76    queue: Mutex<Vec<(Arc<AtomicBool>, RM::Connection)>>,
77    in_use: AtomicUsize,
78    max_size: usize,
79    default_timeout: Duration,
80    waiter: Condvar,
81}
82
83/// The pooled resource
84pub struct PooledResource<RM>
85where
86    RM: DatabasePool,
87{
88    resource: Option<(Arc<AtomicBool>, RM::Connection)>,
89    pool: Arc<Pool<RM>>,
90    #[cfg(feature = "prometheus")]
91    start_time: Instant,
92}
93
94impl<RM> Debug for PooledResource<RM>
95where
96    RM: DatabasePool,
97{
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        write!(f, "Resource: {:?}", self.resource)
100    }
101}
102
103impl<RM> Drop for PooledResource<RM>
104where
105    RM: DatabasePool,
106{
107    fn drop(&mut self) {
108        if let Some(resource) = self.resource.take() {
109            let mut active_resource = self.pool.queue.lock().expect("active_resource");
110            active_resource.push(resource);
111            let _in_use = self.pool.in_use.fetch_sub(1, Ordering::AcqRel);
112
113            #[cfg(feature = "prometheus")]
114            {
115                METRICS.set_db_connections_active(_in_use as i64);
116
117                let duration = self.start_time.elapsed().as_secs_f64();
118
119                METRICS.record_db_operation(duration, "drop");
120            }
121
122            // Notify a waiting thread
123            self.pool.waiter.notify_one();
124        }
125    }
126}
127
128impl<RM> Deref for PooledResource<RM>
129where
130    RM: DatabasePool,
131{
132    type Target = RM::Connection;
133
134    fn deref(&self) -> &Self::Target {
135        &self.resource.as_ref().expect("resource already dropped").1
136    }
137}
138
139impl<RM> DerefMut for PooledResource<RM>
140where
141    RM: DatabasePool,
142{
143    fn deref_mut(&mut self) -> &mut Self::Target {
144        &mut self.resource.as_mut().expect("resource already dropped").1
145    }
146}
147
148impl<RM> Pool<RM>
149where
150    RM: DatabasePool,
151{
152    /// Creates a new pool
153    pub fn new(config: RM::Config) -> Arc<Self> {
154        Arc::new(Self {
155            default_timeout: config.default_timeout(),
156            max_size: config.max_size(),
157            config,
158            queue: Default::default(),
159            in_use: Default::default(),
160            waiter: Default::default(),
161        })
162    }
163
164    /// Similar to get_timeout but uses the default timeout value.
165    #[inline(always)]
166    pub fn get(self: &Arc<Self>) -> Result<PooledResource<RM>, Error<RM::Error>> {
167        self.get_timeout(self.default_timeout)
168    }
169
170    /// Increments the in_use connection counter and updates the metric
171    fn increment_connection_counter(&self) -> usize {
172        let in_use = self.in_use.fetch_add(1, Ordering::AcqRel);
173
174        #[cfg(feature = "prometheus")]
175        {
176            METRICS.set_db_connections_active(in_use as i64);
177        }
178
179        in_use
180    }
181
182    /// Get a new resource or fail after timeout is reached.
183    ///
184    /// This function will return a free resource or create a new one if there is still room for it;
185    /// otherwise, it will wait for a resource to be released for reuse.
186    #[inline(always)]
187    pub fn get_timeout(
188        self: &Arc<Self>,
189        timeout: Duration,
190    ) -> Result<PooledResource<RM>, Error<RM::Error>> {
191        let mut resources = self.queue.lock().map_err(|_| Error::Poison)?;
192        let time = Instant::now();
193
194        loop {
195            if let Some((stale, resource)) = resources.pop() {
196                if !stale.load(Ordering::SeqCst) {
197                    drop(resources);
198                    self.increment_connection_counter();
199
200                    return Ok(PooledResource {
201                        resource: Some((stale, resource)),
202                        pool: self.clone(),
203                        #[cfg(feature = "prometheus")]
204                        start_time: Instant::now(),
205                    });
206                }
207            }
208
209            if self.in_use.load(Ordering::Relaxed) < self.max_size {
210                drop(resources);
211                self.increment_connection_counter();
212                let stale: Arc<AtomicBool> = Arc::new(false.into());
213
214                return Ok(PooledResource {
215                    resource: Some((
216                        stale.clone(),
217                        RM::new_resource(&self.config, stale, timeout)?,
218                    )),
219                    pool: self.clone(),
220                    #[cfg(feature = "prometheus")]
221                    start_time: Instant::now(),
222                });
223            }
224
225            resources = self
226                .waiter
227                .wait_timeout(resources, timeout)
228                .map_err(|_| Error::Poison)
229                .and_then(|(lock, timeout_result)| {
230                    if timeout_result.timed_out() {
231                        tracing::warn!(
232                            "Timeout waiting for the resource (pool size: {}). Waited {} ms",
233                            self.max_size,
234                            time.elapsed().as_millis()
235                        );
236                        Err(Error::Timeout)
237                    } else {
238                        Ok(lock)
239                    }
240                })?;
241        }
242    }
243}
244
245impl<RM> Drop for Pool<RM>
246where
247    RM: DatabasePool,
248{
249    fn drop(&mut self) {
250        if let Ok(mut resources) = self.queue.lock() {
251            loop {
252                while let Some(resource) = resources.pop() {
253                    RM::drop(resource.1);
254                }
255
256                if self.in_use.load(Ordering::Relaxed) == 0 {
257                    break;
258                }
259
260                resources = if let Ok(resources) = self.waiter.wait(resources) {
261                    resources
262                } else {
263                    break;
264                };
265            }
266        }
267    }
268}