1use std::fmt::Debug;
6use std::ops::{Deref, DerefMut};
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::{Arc, Mutex};
9use std::time::Duration;
10
11#[cfg(feature = "prometheus")]
12use cdk_prometheus::metrics::METRICS;
13use tokio::sync::{OwnedSemaphorePermit, Semaphore};
14
15use crate::database::DatabaseConnector;
16
17#[derive(Debug, thiserror::Error)]
19pub enum Error<E>
20where
21 E: std::error::Error + Send + Sync + 'static,
22{
23 #[error("Internal: PoisonError")]
25 Poison,
26
27 #[error("Timed out waiting for a resource")]
29 Timeout,
30
31 #[error(transparent)]
33 Resource(#[from] E),
34}
35
36pub trait DatabaseConfig: Clone + Debug + Send + Sync {
38 fn max_size(&self) -> usize;
40
41 fn default_timeout(&self) -> Duration;
43}
44
45pub trait DatabasePool: Debug {
47 type Connection: DatabaseConnector;
49
50 type Config: DatabaseConfig;
52
53 type Error: Debug + std::error::Error + Send + Sync + 'static;
55
56 fn new_resource(
61 config: &Self::Config,
62 stale: Arc<AtomicBool>,
63 timeout: Duration,
64 ) -> Result<Self::Connection, Error<Self::Error>>;
65
66 fn drop(_resource: Self::Connection) {}
68}
69
70pub struct Pool<RM>
72where
73 RM: DatabasePool,
74{
75 config: RM::Config,
76 queue: Mutex<Vec<(Arc<AtomicBool>, RM::Connection)>>,
77 max_size: usize,
78 default_timeout: Duration,
79 semaphore: Arc<Semaphore>,
80}
81
82impl<RM> Debug for Pool<RM>
83where
84 RM: DatabasePool,
85{
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct("Pool")
88 .field("config", &self.config)
89 .field("max_size", &self.max_size)
90 .field("default_timeout", &self.default_timeout)
91 .field("available_permits", &self.semaphore.available_permits())
92 .finish()
93 }
94}
95
96pub struct PooledResource<RM>
98where
99 RM: DatabasePool,
100{
101 resource: Option<(Arc<AtomicBool>, RM::Connection)>,
102 pool: Arc<Pool<RM>>,
103 _permit: OwnedSemaphorePermit,
104 #[cfg(feature = "prometheus")]
105 start_time: std::time::Instant,
106}
107
108impl<RM> Debug for PooledResource<RM>
109where
110 RM: DatabasePool,
111{
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 write!(f, "Resource: {:?}", self.resource)
114 }
115}
116
117impl<RM> Drop for PooledResource<RM>
118where
119 RM: DatabasePool,
120{
121 fn drop(&mut self) {
122 if let Some(resource) = self.resource.take() {
123 let mut active_resource = self.pool.queue.lock().expect("active_resource");
124 active_resource.push(resource);
125
126 #[cfg(feature = "prometheus")]
127 {
128 let in_use = self.pool.max_size - self.pool.semaphore.available_permits();
129 METRICS.set_db_connections_active(in_use as i64);
130
131 let duration = self.start_time.elapsed().as_secs_f64();
132
133 METRICS.record_db_operation(duration, "drop");
134 }
135
136 }
139 }
140}
141
142impl<RM> Deref for PooledResource<RM>
143where
144 RM: DatabasePool,
145{
146 type Target = RM::Connection;
147
148 fn deref(&self) -> &Self::Target {
149 &self.resource.as_ref().expect("resource already dropped").1
150 }
151}
152
153impl<RM> DerefMut for PooledResource<RM>
154where
155 RM: DatabasePool,
156{
157 fn deref_mut(&mut self) -> &mut Self::Target {
158 &mut self.resource.as_mut().expect("resource already dropped").1
159 }
160}
161
162impl<RM> Pool<RM>
163where
164 RM: DatabasePool,
165{
166 pub fn new(config: RM::Config) -> Arc<Self> {
168 let max_size = config.max_size();
169 Arc::new(Self {
170 default_timeout: config.default_timeout(),
171 max_size,
172 config,
173 queue: Default::default(),
174 semaphore: Arc::new(Semaphore::new(max_size)),
175 })
176 }
177
178 #[inline(always)]
180 pub async fn get(self: &Arc<Self>) -> Result<PooledResource<RM>, Error<RM::Error>> {
181 self.get_timeout(self.default_timeout).await
182 }
183
184 #[inline(always)]
189 pub async fn get_timeout(
190 self: &Arc<Self>,
191 timeout: Duration,
192 ) -> Result<PooledResource<RM>, Error<RM::Error>> {
193 let permit = match self.semaphore.clone().try_acquire_owned() {
195 Ok(permit) => permit,
196 Err(tokio::sync::TryAcquireError::Closed) => return Err(Error::Poison),
197 Err(tokio::sync::TryAcquireError::NoPermits) => {
198 tracing::debug!(
202 "Pool exhausted (size: {}), waiting for a connection",
203 self.max_size,
204 );
205 tokio::time::timeout(timeout, self.semaphore.clone().acquire_owned())
206 .await
207 .map_err(|_| Error::Timeout)?
208 .map_err(|_| Error::Poison)?
209 }
210 };
211
212 #[cfg(feature = "prometheus")]
213 {
214 let in_use = self.max_size - self.semaphore.available_permits();
215 METRICS.set_db_connections_active(in_use as i64);
216 }
217
218 {
221 let mut resources = self.queue.lock().map_err(|_| Error::Poison)?;
222 while let Some((stale, resource)) = resources.pop() {
223 if !stale.load(Ordering::SeqCst) {
224 return Ok(PooledResource {
225 resource: Some((stale, resource)),
226 pool: self.clone(),
227 _permit: permit,
228 #[cfg(feature = "prometheus")]
229 start_time: std::time::Instant::now(),
230 });
231 }
232 }
234 }
235
236 let stale: Arc<AtomicBool> = Arc::new(false.into());
239 match RM::new_resource(&self.config, stale.clone(), timeout) {
240 Ok(new_resource) => Ok(PooledResource {
241 resource: Some((stale, new_resource)),
242 pool: self.clone(),
243 _permit: permit,
244 #[cfg(feature = "prometheus")]
245 start_time: std::time::Instant::now(),
246 }),
247 Err(e) => {
248 Err(e)
250 }
251 }
252 }
253}
254
255impl<RM> Drop for Pool<RM>
256where
257 RM: DatabasePool,
258{
259 fn drop(&mut self) {
260 self.semaphore.close();
262
263 if let Ok(mut resources) = self.queue.lock() {
265 while let Some(resource) = resources.pop() {
266 RM::drop(resource.1);
267 }
268 }
269 }
270}