1use super::{PgConnection, PgError, PgResult};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{Mutex, Semaphore};
11
12#[derive(Clone)]
14pub struct PoolConfig {
15 pub host: String,
17 pub port: u16,
19 pub user: String,
21 pub database: String,
23 pub password: Option<String>,
25 pub max_connections: usize,
27 pub min_connections: usize,
29 pub idle_timeout: Duration,
31 pub acquire_timeout: Duration,
33 pub connect_timeout: Duration,
35 pub max_lifetime: Option<Duration>,
37 pub test_on_acquire: bool,
39}
40
41impl PoolConfig {
42 pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
44 Self {
45 host: host.to_string(),
46 port,
47 user: user.to_string(),
48 database: database.to_string(),
49 password: None,
50 max_connections: 10,
51 min_connections: 1,
52 idle_timeout: Duration::from_secs(600), acquire_timeout: Duration::from_secs(30), connect_timeout: Duration::from_secs(10), max_lifetime: None, test_on_acquire: false, }
58 }
59
60 pub fn password(mut self, password: &str) -> Self {
62 self.password = Some(password.to_string());
63 self
64 }
65
66 pub fn max_connections(mut self, max: usize) -> Self {
68 self.max_connections = max;
69 self
70 }
71
72 pub fn min_connections(mut self, min: usize) -> Self {
74 self.min_connections = min;
75 self
76 }
77
78 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
80 self.idle_timeout = timeout;
81 self
82 }
83
84 pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
86 self.acquire_timeout = timeout;
87 self
88 }
89
90 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
92 self.connect_timeout = timeout;
93 self
94 }
95
96 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
99 self.max_lifetime = Some(lifetime);
100 self
101 }
102
103 pub fn test_on_acquire(mut self, enabled: bool) -> Self {
106 self.test_on_acquire = enabled;
107 self
108 }
109}
110
111#[derive(Debug, Clone, Default)]
113pub struct PoolStats {
114 pub active: usize,
116 pub idle: usize,
118 pub pending: usize,
120 pub max_size: usize,
122 pub total_created: usize,
124}
125
126struct PooledConn {
128 conn: PgConnection,
129 created_at: Instant,
130 last_used: Instant,
131}
132
133pub struct PooledConnection {
135 conn: Option<PgConnection>,
136 pool: Arc<PgPoolInner>,
137}
138
139impl PooledConnection {
140 pub fn get_mut(&mut self) -> &mut PgConnection {
142 self.conn
143 .as_mut()
144 .expect("Connection should always be present")
145 }
146}
147
148impl Drop for PooledConnection {
149 fn drop(&mut self) {
150 if let Some(conn) = self.conn.take() {
151 let pool = self.pool.clone();
153 tokio::spawn(async move {
154 pool.return_connection(conn).await;
155 });
156 }
157 }
158}
159
160impl std::ops::Deref for PooledConnection {
161 type Target = PgConnection;
162
163 fn deref(&self) -> &Self::Target {
164 self.conn
165 .as_ref()
166 .expect("Connection should always be present")
167 }
168}
169
170impl std::ops::DerefMut for PooledConnection {
171 fn deref_mut(&mut self) -> &mut Self::Target {
172 self.conn
173 .as_mut()
174 .expect("Connection should always be present")
175 }
176}
177
178struct PgPoolInner {
180 config: PoolConfig,
181 connections: Mutex<Vec<PooledConn>>,
182 semaphore: Semaphore,
183 closed: AtomicBool,
185 active_count: AtomicUsize,
187 total_created: AtomicUsize,
189}
190
191impl PgPoolInner {
192 async fn return_connection(&self, conn: PgConnection) {
193 self.active_count.fetch_sub(1, Ordering::Relaxed);
195
196 if self.closed.load(Ordering::Relaxed) {
198 return;
199 }
200
201 let mut connections = self.connections.lock().await;
202 if connections.len() < self.config.max_connections {
203 connections.push(PooledConn {
204 conn,
205 created_at: Instant::now(),
206 last_used: Instant::now(),
207 });
208 }
209 self.semaphore.add_permits(1);
211 }
212
213 async fn get_healthy_connection(&self) -> Option<PgConnection> {
215 let mut connections = self.connections.lock().await;
216
217 while let Some(pooled) = connections.pop() {
218 if pooled.last_used.elapsed() > self.config.idle_timeout {
220 continue;
222 }
223
224 if let Some(max_life) = self.config.max_lifetime
226 && pooled.created_at.elapsed() > max_life
227 {
228 continue;
230 }
231
232 return Some(pooled.conn);
234 }
235
236 None
237 }
238}
239
240#[derive(Clone)]
256pub struct PgPool {
257 inner: Arc<PgPoolInner>,
258}
259
260impl PgPool {
261 pub async fn connect(config: PoolConfig) -> PgResult<Self> {
263 let semaphore = Semaphore::new(config.max_connections);
265
266 let mut initial_connections = Vec::new();
268 for _ in 0..config.min_connections {
269 let conn = Self::create_connection(&config).await?;
270 initial_connections.push(PooledConn {
271 conn,
272 created_at: Instant::now(),
273 last_used: Instant::now(),
274 });
275 }
276
277 let initial_count = initial_connections.len();
278
279 let inner = Arc::new(PgPoolInner {
280 config,
281 connections: Mutex::new(initial_connections),
282 semaphore,
283 closed: AtomicBool::new(false),
284 active_count: AtomicUsize::new(0),
285 total_created: AtomicUsize::new(initial_count),
286 });
287
288 Ok(Self { inner })
289 }
290
291 pub async fn acquire(&self) -> PgResult<PooledConnection> {
297 if self.inner.closed.load(Ordering::Relaxed) {
299 return Err(PgError::Connection("Pool is closed".to_string()));
300 }
301
302 let acquire_timeout = self.inner.config.acquire_timeout;
304 let permit = tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire())
305 .await
306 .map_err(|_| {
307 PgError::Connection(format!(
308 "Timed out waiting for connection ({}s)",
309 acquire_timeout.as_secs()
310 ))
311 })?
312 .map_err(|_| PgError::Connection("Pool closed".to_string()))?;
313 permit.forget();
314
315 let conn = if let Some(conn) = self.inner.get_healthy_connection().await {
317 conn
318 } else {
319 let conn = Self::create_connection(&self.inner.config).await?;
321 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
322 conn
323 };
324
325 self.inner.active_count.fetch_add(1, Ordering::Relaxed);
327
328 Ok(PooledConnection {
329 conn: Some(conn),
330 pool: self.inner.clone(),
331 })
332 }
333
334 pub async fn idle_count(&self) -> usize {
336 self.inner.connections.lock().await.len()
337 }
338
339 pub fn active_count(&self) -> usize {
341 self.inner.active_count.load(Ordering::Relaxed)
342 }
343
344 pub fn max_connections(&self) -> usize {
346 self.inner.config.max_connections
347 }
348
349 pub async fn stats(&self) -> PoolStats {
351 let idle = self.inner.connections.lock().await.len();
352 PoolStats {
353 active: self.inner.active_count.load(Ordering::Relaxed),
354 idle,
355 pending: self.inner.config.max_connections
356 - self.inner.semaphore.available_permits()
357 - self.active_count(),
358 max_size: self.inner.config.max_connections,
359 total_created: self.inner.total_created.load(Ordering::Relaxed),
360 }
361 }
362
363 pub fn is_closed(&self) -> bool {
365 self.inner.closed.load(Ordering::Relaxed)
366 }
367
368 pub async fn close(&self) {
373 self.inner.closed.store(true, Ordering::Relaxed);
374 let mut connections = self.inner.connections.lock().await;
376 connections.clear();
377 }
378
379 async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
381 match &config.password {
382 Some(password) => {
383 PgConnection::connect_with_password(
384 &config.host,
385 config.port,
386 &config.user,
387 &config.database,
388 Some(password),
389 )
390 .await
391 }
392 None => {
393 PgConnection::connect(&config.host, config.port, &config.user, &config.database)
394 .await
395 }
396 }
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn test_pool_config() {
406 let config = PoolConfig::new("localhost", 5432, "user", "testdb")
407 .password("secret123")
408 .max_connections(20)
409 .min_connections(5);
410
411 assert_eq!(config.host, "localhost");
412 assert_eq!(config.port, 5432);
413 assert_eq!(config.user, "user");
414 assert_eq!(config.database, "testdb");
415 assert_eq!(config.password, Some("secret123".to_string()));
416 assert_eq!(config.max_connections, 20);
417 assert_eq!(config.min_connections, 5);
418 }
419}