1use parking_lot::Mutex;
7use std::collections::VecDeque;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::Semaphore;
12
13#[derive(Debug, Clone)]
15pub struct PoolConfig {
16 pub max_connections: usize,
18 pub min_idle: usize,
20 pub acquire_timeout: Duration,
22 pub idle_timeout: Duration,
24 pub max_lifetime: Duration,
26}
27
28impl Default for PoolConfig {
29 fn default() -> Self {
30 Self {
31 max_connections: 100,
32 min_idle: 10,
33 acquire_timeout: Duration::from_secs(30),
34 idle_timeout: Duration::from_secs(600),
35 max_lifetime: Duration::from_secs(3600),
36 }
37 }
38}
39
40#[derive(Debug, Clone, Default)]
42pub struct PoolStats {
43 pub active: usize,
45 pub idle: usize,
47 pub total_created: u64,
49 pub total_closed: u64,
51 pub timeouts: u64,
53 pub waiters: usize,
55}
56
57pub struct PooledConnection<T> {
59 conn: Option<T>,
60 pool: Arc<ConnectionPoolInner<T>>,
61 created_at: Instant,
62 last_used: Instant,
63}
64
65impl<T> PooledConnection<T> {
66 pub fn get(&self) -> &T {
68 self.conn.as_ref().unwrap()
69 }
70
71 pub fn get_mut(&mut self) -> &mut T {
73 self.conn.as_mut().unwrap()
74 }
75
76 pub fn age(&self) -> Duration {
78 self.created_at.elapsed()
79 }
80
81 pub fn idle_time(&self) -> Duration {
83 self.last_used.elapsed()
84 }
85}
86
87impl<T> Drop for PooledConnection<T> {
88 fn drop(&mut self) {
89 if let Some(conn) = self.conn.take() {
90 let pool = Arc::clone(&self.pool);
91 let created_at = self.created_at;
92
93 if created_at.elapsed() < pool.config.max_lifetime {
95 pool.return_connection(conn, created_at);
96 } else {
97 pool.discard_connection();
98 }
99 }
100 }
101}
102
103impl<T> std::ops::Deref for PooledConnection<T> {
104 type Target = T;
105
106 fn deref(&self) -> &Self::Target {
107 self.get()
108 }
109}
110
111impl<T> std::ops::DerefMut for PooledConnection<T> {
112 fn deref_mut(&mut self) -> &mut Self::Target {
113 self.get_mut()
114 }
115}
116
117struct PoolEntry<T> {
118 conn: T,
119 created_at: Instant,
120}
121
122struct ConnectionPoolInner<T> {
123 config: PoolConfig,
124 available: Mutex<VecDeque<PoolEntry<T>>>,
125 semaphore: Semaphore,
126 active_count: AtomicUsize,
127 total_created: AtomicUsize,
128 total_closed: AtomicUsize,
129 timeouts: AtomicUsize,
130}
131
132impl<T> ConnectionPoolInner<T> {
133 fn return_connection(&self, conn: T, created_at: Instant) {
134 let mut available = self.available.lock();
135
136 if created_at.elapsed() < self.config.max_lifetime {
138 available.push_back(PoolEntry { conn, created_at });
139 } else {
140 self.total_closed.fetch_add(1, Ordering::Relaxed);
141 }
142
143 drop(available);
144 self.active_count.fetch_sub(1, Ordering::Relaxed);
145 self.semaphore.add_permits(1);
146 }
147
148 fn discard_connection(&self) {
149 self.active_count.fetch_sub(1, Ordering::Relaxed);
150 self.total_closed.fetch_add(1, Ordering::Relaxed);
151 self.semaphore.add_permits(1);
152 }
153}
154
155pub struct ConnectionPool<T, F>
157where
158 F: Fn() -> T + Send + Sync,
159{
160 inner: Arc<ConnectionPoolInner<T>>,
161 factory: F,
162}
163
164impl<T, F> ConnectionPool<T, F>
165where
166 T: Send + 'static,
167 F: Fn() -> T + Send + Sync,
168{
169 pub fn new(factory: F, config: PoolConfig) -> Self {
171 let inner = Arc::new(ConnectionPoolInner {
172 semaphore: Semaphore::new(config.max_connections),
173 config,
174 available: Mutex::new(VecDeque::new()),
175 active_count: AtomicUsize::new(0),
176 total_created: AtomicUsize::new(0),
177 total_closed: AtomicUsize::new(0),
178 timeouts: AtomicUsize::new(0),
179 });
180
181 Self { inner, factory }
182 }
183
184 pub fn with_defaults(factory: F) -> Self {
186 Self::new(factory, PoolConfig::default())
187 }
188
189 pub async fn acquire(&self) -> Result<PooledConnection<T>, PoolError> {
191 let permit = tokio::time::timeout(
193 self.inner.config.acquire_timeout,
194 self.inner.semaphore.acquire(),
195 )
196 .await
197 .map_err(|_| {
198 self.inner.timeouts.fetch_add(1, Ordering::Relaxed);
199 PoolError::Timeout
200 })?
201 .map_err(|_| PoolError::Closed)?;
202
203 permit.forget();
205
206 let entry = {
208 let mut available = self.inner.available.lock();
209
210 loop {
212 match available.pop_front() {
213 Some(entry) => {
214 if entry.created_at.elapsed() < self.inner.config.max_lifetime {
215 break Some(entry);
216 } else {
217 self.inner.total_closed.fetch_add(1, Ordering::Relaxed);
219 }
220 }
221 None => break None,
222 }
223 }
224 };
225
226 let (conn, created_at) = match entry {
227 Some(entry) => (entry.conn, entry.created_at),
228 None => {
229 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
231 ((self.factory)(), Instant::now())
232 }
233 };
234
235 self.inner.active_count.fetch_add(1, Ordering::Relaxed);
236
237 Ok(PooledConnection {
238 conn: Some(conn),
239 pool: Arc::clone(&self.inner),
240 created_at,
241 last_used: Instant::now(),
242 })
243 }
244
245 pub fn stats(&self) -> PoolStats {
247 let available = self.inner.available.lock();
248 PoolStats {
249 active: self.inner.active_count.load(Ordering::Relaxed),
250 idle: available.len(),
251 total_created: self.inner.total_created.load(Ordering::Relaxed) as u64,
252 total_closed: self.inner.total_closed.load(Ordering::Relaxed) as u64,
253 timeouts: self.inner.timeouts.load(Ordering::Relaxed) as u64,
254 waiters: self.inner.config.max_connections - self.inner.semaphore.available_permits(),
255 }
256 }
257
258 pub fn clear_idle(&self) {
260 let mut available = self.inner.available.lock();
261 let count = available.len();
262 available.clear();
263 self.inner.total_closed.fetch_add(count, Ordering::Relaxed);
264 }
265}
266
267#[derive(Debug, Clone, thiserror::Error)]
269pub enum PoolError {
270 #[error("connection acquisition timed out")]
271 Timeout,
272 #[error("pool is closed")]
273 Closed,
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use std::sync::atomic::AtomicU32;
280
281 #[tokio::test]
282 async fn test_pool_acquire_release() {
283 let counter = Arc::new(AtomicU32::new(0));
284 let counter_clone = Arc::clone(&counter);
285
286 let pool = ConnectionPool::new(
287 move || counter_clone.fetch_add(1, Ordering::Relaxed),
288 PoolConfig {
289 max_connections: 2,
290 ..Default::default()
291 },
292 );
293
294 let conn1 = pool.acquire().await.unwrap();
296 assert_eq!(*conn1, 0);
297
298 let conn2 = pool.acquire().await.unwrap();
300 assert_eq!(*conn2, 1);
301
302 let stats = pool.stats();
303 assert_eq!(stats.active, 2);
304 assert_eq!(stats.idle, 0);
305
306 drop(conn1);
308
309 let stats = pool.stats();
310 assert_eq!(stats.active, 1);
311 assert_eq!(stats.idle, 1);
312
313 let conn3 = pool.acquire().await.unwrap();
315 assert_eq!(*conn3, 0); assert_eq!(counter.load(Ordering::Relaxed), 2); }
319
320 #[tokio::test]
321 async fn test_pool_timeout() {
322 let pool = ConnectionPool::new(
323 || 42,
324 PoolConfig {
325 max_connections: 1,
326 acquire_timeout: Duration::from_millis(50),
327 ..Default::default()
328 },
329 );
330
331 let _conn = pool.acquire().await.unwrap();
333
334 let result = pool.acquire().await;
336 assert!(matches!(result, Err(PoolError::Timeout)));
337 }
338
339 #[tokio::test]
340 async fn test_pool_stats() {
341 let pool = ConnectionPool::new(|| (), PoolConfig::default());
342
343 let conn = pool.acquire().await.unwrap();
344 let stats = pool.stats();
345
346 assert_eq!(stats.active, 1);
347 assert_eq!(stats.total_created, 1);
348
349 drop(conn);
350
351 let stats = pool.stats();
352 assert_eq!(stats.active, 0);
353 assert_eq!(stats.idle, 1);
354 }
355}