1use std::fmt::Debug;
6use std::ops::{Deref, DerefMut};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::{Arc, Condvar, Mutex};
9use std::time::{Duration, Instant};
10
11#[cfg(feature = "prometheus")]
12use cdk_prometheus::metrics::METRICS;
13
14use crate::database::DatabaseConnector;
15
16#[derive(Debug, thiserror::Error)]
18pub enum Error<E>
19where
20 E: std::error::Error + Send + Sync + 'static,
21{
22 #[error("Internal: PoisonError")]
24 Poison,
25
26 #[error("Timed out waiting for a resource")]
28 Timeout,
29
30 #[error(transparent)]
32 Resource(#[from] E),
33}
34
35pub trait DatabaseConfig: Clone + Debug + Send + Sync {
37 fn max_size(&self) -> usize;
39
40 fn default_timeout(&self) -> Duration;
42}
43
44pub trait DatabasePool: Debug {
46 type Connection: DatabaseConnector;
48
49 type Config: DatabaseConfig;
51
52 type Error: Debug + std::error::Error + Send + Sync + 'static;
54
55 fn new_resource(
60 config: &Self::Config,
61 stale: Arc<AtomicBool>,
62 timeout: Duration,
63 ) -> Result<Self::Connection, Error<Self::Error>>;
64
65 fn drop(_resource: Self::Connection) {}
67}
68
69#[derive(Debug)]
71pub struct Pool<RM>
72where
73 RM: DatabasePool,
74{
75 config: RM::Config,
76 queue: Mutex<Vec<(Arc<AtomicBool>, RM::Connection)>>,
77 in_use: AtomicUsize,
78 max_size: usize,
79 default_timeout: Duration,
80 waiter: Condvar,
81}
82
83pub struct PooledResource<RM>
85where
86 RM: DatabasePool,
87{
88 resource: Option<(Arc<AtomicBool>, RM::Connection)>,
89 pool: Arc<Pool<RM>>,
90 #[cfg(feature = "prometheus")]
91 start_time: Instant,
92}
93
94impl<RM> Debug for PooledResource<RM>
95where
96 RM: DatabasePool,
97{
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 write!(f, "Resource: {:?}", self.resource)
100 }
101}
102
103impl<RM> Drop for PooledResource<RM>
104where
105 RM: DatabasePool,
106{
107 fn drop(&mut self) {
108 if let Some(resource) = self.resource.take() {
109 let mut active_resource = self.pool.queue.lock().expect("active_resource");
110 active_resource.push(resource);
111 let _in_use = self.pool.in_use.fetch_sub(1, Ordering::AcqRel);
112
113 #[cfg(feature = "prometheus")]
114 {
115 METRICS.set_db_connections_active(_in_use as i64);
116
117 let duration = self.start_time.elapsed().as_secs_f64();
118
119 METRICS.record_db_operation(duration, "drop");
120 }
121
122 self.pool.waiter.notify_one();
124 }
125 }
126}
127
128impl<RM> Deref for PooledResource<RM>
129where
130 RM: DatabasePool,
131{
132 type Target = RM::Connection;
133
134 fn deref(&self) -> &Self::Target {
135 &self.resource.as_ref().expect("resource already dropped").1
136 }
137}
138
139impl<RM> DerefMut for PooledResource<RM>
140where
141 RM: DatabasePool,
142{
143 fn deref_mut(&mut self) -> &mut Self::Target {
144 &mut self.resource.as_mut().expect("resource already dropped").1
145 }
146}
147
148impl<RM> Pool<RM>
149where
150 RM: DatabasePool,
151{
152 pub fn new(config: RM::Config) -> Arc<Self> {
154 Arc::new(Self {
155 default_timeout: config.default_timeout(),
156 max_size: config.max_size(),
157 config,
158 queue: Default::default(),
159 in_use: Default::default(),
160 waiter: Default::default(),
161 })
162 }
163
164 #[inline(always)]
166 pub fn get(self: &Arc<Self>) -> Result<PooledResource<RM>, Error<RM::Error>> {
167 self.get_timeout(self.default_timeout)
168 }
169
170 fn increment_connection_counter(&self) -> usize {
172 let in_use = self.in_use.fetch_add(1, Ordering::AcqRel);
173
174 #[cfg(feature = "prometheus")]
175 {
176 METRICS.set_db_connections_active(in_use as i64);
177 }
178
179 in_use
180 }
181
182 #[inline(always)]
187 pub fn get_timeout(
188 self: &Arc<Self>,
189 timeout: Duration,
190 ) -> Result<PooledResource<RM>, Error<RM::Error>> {
191 let mut resources = self.queue.lock().map_err(|_| Error::Poison)?;
192 let time = Instant::now();
193
194 loop {
195 if let Some((stale, resource)) = resources.pop() {
196 if !stale.load(Ordering::SeqCst) {
197 drop(resources);
198 self.increment_connection_counter();
199
200 return Ok(PooledResource {
201 resource: Some((stale, resource)),
202 pool: self.clone(),
203 #[cfg(feature = "prometheus")]
204 start_time: Instant::now(),
205 });
206 }
207 }
208
209 if self.in_use.load(Ordering::Relaxed) < self.max_size {
210 drop(resources);
211 self.increment_connection_counter();
212 let stale: Arc<AtomicBool> = Arc::new(false.into());
213
214 return Ok(PooledResource {
215 resource: Some((
216 stale.clone(),
217 RM::new_resource(&self.config, stale, timeout)?,
218 )),
219 pool: self.clone(),
220 #[cfg(feature = "prometheus")]
221 start_time: Instant::now(),
222 });
223 }
224
225 resources = self
226 .waiter
227 .wait_timeout(resources, timeout)
228 .map_err(|_| Error::Poison)
229 .and_then(|(lock, timeout_result)| {
230 if timeout_result.timed_out() {
231 tracing::warn!(
232 "Timeout waiting for the resource (pool size: {}). Waited {} ms",
233 self.max_size,
234 time.elapsed().as_millis()
235 );
236 Err(Error::Timeout)
237 } else {
238 Ok(lock)
239 }
240 })?;
241 }
242 }
243}
244
245impl<RM> Drop for Pool<RM>
246where
247 RM: DatabasePool,
248{
249 fn drop(&mut self) {
250 if let Ok(mut resources) = self.queue.lock() {
251 loop {
252 while let Some(resource) = resources.pop() {
253 RM::drop(resource.1);
254 }
255
256 if self.in_use.load(Ordering::Relaxed) == 0 {
257 break;
258 }
259
260 resources = if let Ok(resources) = self.waiter.wait(resources) {
261 resources
262 } else {
263 break;
264 };
265 }
266 }
267 }
268}