1use super::{PgConnection, PgError, PgResult};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{Mutex, Semaphore};
11
12#[derive(Clone)]
13pub struct PoolConfig {
14 pub host: String,
15 pub port: u16,
16 pub user: String,
17 pub database: String,
18 pub password: Option<String>,
19 pub max_connections: usize,
20 pub min_connections: usize,
21 pub idle_timeout: Duration,
22 pub acquire_timeout: Duration,
23 pub connect_timeout: Duration,
24 pub max_lifetime: Option<Duration>,
25 pub test_on_acquire: bool,
26}
27
28impl PoolConfig {
29 pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
31 Self {
32 host: host.to_string(),
33 port,
34 user: user.to_string(),
35 database: database.to_string(),
36 password: None,
37 max_connections: 10,
38 min_connections: 1,
39 idle_timeout: Duration::from_secs(600), acquire_timeout: Duration::from_secs(30), connect_timeout: Duration::from_secs(10), max_lifetime: None, test_on_acquire: false, }
45 }
46
47 pub fn password(mut self, password: &str) -> Self {
49 self.password = Some(password.to_string());
50 self
51 }
52
53 pub fn max_connections(mut self, max: usize) -> Self {
54 self.max_connections = max;
55 self
56 }
57
58 pub fn min_connections(mut self, min: usize) -> Self {
60 self.min_connections = min;
61 self
62 }
63
64 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
66 self.idle_timeout = timeout;
67 self
68 }
69
70 pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
72 self.acquire_timeout = timeout;
73 self
74 }
75
76 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
78 self.connect_timeout = timeout;
79 self
80 }
81
82 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
84 self.max_lifetime = Some(lifetime);
85 self
86 }
87
88 pub fn test_on_acquire(mut self, enabled: bool) -> Self {
90 self.test_on_acquire = enabled;
91 self
92 }
93}
94
95#[derive(Debug, Clone, Default)]
97pub struct PoolStats {
98 pub active: usize,
99 pub idle: usize,
100 pub pending: usize,
101 pub max_size: usize,
103 pub total_created: usize,
104}
105
106struct PooledConn {
108 conn: PgConnection,
109 created_at: Instant,
110 last_used: Instant,
111}
112
113pub struct PooledConnection {
115 conn: Option<PgConnection>,
116 pool: Arc<PgPoolInner>,
117}
118
119impl PooledConnection {
120 pub fn get_mut(&mut self) -> &mut PgConnection {
122 self.conn
123 .as_mut()
124 .expect("Connection should always be present")
125 }
126
127 pub fn cancel_token(&self) -> crate::driver::CancelToken {
129 let (process_id, secret_key) = self.conn.as_ref().expect("Connection missing").get_cancel_key();
130 crate::driver::CancelToken {
131 host: self.pool.config.host.clone(),
132 port: self.pool.config.port,
133 process_id,
134 secret_key,
135 }
136 }
137}
138
139impl Drop for PooledConnection {
140 fn drop(&mut self) {
141 if let Some(conn) = self.conn.take() {
142 let pool = self.pool.clone();
143 tokio::spawn(async move {
144 pool.return_connection(conn).await;
145 });
146 }
147 }
148}
149
150impl std::ops::Deref for PooledConnection {
151 type Target = PgConnection;
152
153 fn deref(&self) -> &Self::Target {
154 self.conn
155 .as_ref()
156 .expect("Connection should always be present")
157 }
158}
159
160impl std::ops::DerefMut for PooledConnection {
161 fn deref_mut(&mut self) -> &mut Self::Target {
162 self.conn
163 .as_mut()
164 .expect("Connection should always be present")
165 }
166}
167
168struct PgPoolInner {
170 config: PoolConfig,
171 connections: Mutex<Vec<PooledConn>>,
172 semaphore: Semaphore,
173 closed: AtomicBool,
174 active_count: AtomicUsize,
175 total_created: AtomicUsize,
176}
177
178impl PgPoolInner {
179 async fn return_connection(&self, conn: PgConnection) {
180
181 self.active_count.fetch_sub(1, Ordering::Relaxed);
182
183
184 if self.closed.load(Ordering::Relaxed) {
185 return;
186 }
187
188 let mut connections = self.connections.lock().await;
189 if connections.len() < self.config.max_connections {
190 connections.push(PooledConn {
191 conn,
192 created_at: Instant::now(),
193 last_used: Instant::now(),
194 });
195 }
196
197 self.semaphore.add_permits(1);
198 }
199
200 async fn get_healthy_connection(&self) -> Option<PgConnection> {
202 let mut connections = self.connections.lock().await;
203
204 while let Some(pooled) = connections.pop() {
205 if pooled.last_used.elapsed() > self.config.idle_timeout {
206 continue;
208 }
209
210 if let Some(max_life) = self.config.max_lifetime
211 && pooled.created_at.elapsed() > max_life
212 {
213 continue;
215 }
216
217 return Some(pooled.conn);
218 }
219
220 None
221 }
222}
223
224#[derive(Clone)]
235pub struct PgPool {
236 inner: Arc<PgPoolInner>,
237}
238
239impl PgPool {
240 pub async fn connect(config: PoolConfig) -> PgResult<Self> {
242 let semaphore = Semaphore::new(config.max_connections);
244
245 let mut initial_connections = Vec::new();
246 for _ in 0..config.min_connections {
247 let conn = Self::create_connection(&config).await?;
248 initial_connections.push(PooledConn {
249 conn,
250 created_at: Instant::now(),
251 last_used: Instant::now(),
252 });
253 }
254
255 let initial_count = initial_connections.len();
256
257 let inner = Arc::new(PgPoolInner {
258 config,
259 connections: Mutex::new(initial_connections),
260 semaphore,
261 closed: AtomicBool::new(false),
262 active_count: AtomicUsize::new(0),
263 total_created: AtomicUsize::new(initial_count),
264 });
265
266 Ok(Self { inner })
267 }
268
269 pub async fn acquire(&self) -> PgResult<PooledConnection> {
271 if self.inner.closed.load(Ordering::Relaxed) {
272 return Err(PgError::Connection("Pool is closed".to_string()));
273 }
274
275 let acquire_timeout = self.inner.config.acquire_timeout;
277 let permit = tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire())
278 .await
279 .map_err(|_| {
280 PgError::Connection(format!(
281 "Timed out waiting for connection ({}s)",
282 acquire_timeout.as_secs()
283 ))
284 })?
285 .map_err(|_| PgError::Connection("Pool closed".to_string()))?;
286 permit.forget();
287
288 let conn = if let Some(conn) = self.inner.get_healthy_connection().await {
290 conn
291 } else {
292 let conn = Self::create_connection(&self.inner.config).await?;
293 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
294 conn
295 };
296
297
298 self.inner.active_count.fetch_add(1, Ordering::Relaxed);
299
300 Ok(PooledConnection {
301 conn: Some(conn),
302 pool: self.inner.clone(),
303 })
304 }
305
306 pub async fn idle_count(&self) -> usize {
308 self.inner.connections.lock().await.len()
309 }
310
311 pub fn active_count(&self) -> usize {
313 self.inner.active_count.load(Ordering::Relaxed)
314 }
315
316 pub fn max_connections(&self) -> usize {
318 self.inner.config.max_connections
319 }
320
321 pub async fn stats(&self) -> PoolStats {
323 let idle = self.inner.connections.lock().await.len();
324 PoolStats {
325 active: self.inner.active_count.load(Ordering::Relaxed),
326 idle,
327 pending: self.inner.config.max_connections
328 - self.inner.semaphore.available_permits()
329 - self.active_count(),
330 max_size: self.inner.config.max_connections,
331 total_created: self.inner.total_created.load(Ordering::Relaxed),
332 }
333 }
334
335 pub fn is_closed(&self) -> bool {
337 self.inner.closed.load(Ordering::Relaxed)
338 }
339
340 pub async fn close(&self) {
342 self.inner.closed.store(true, Ordering::Relaxed);
343
344 let mut connections = self.inner.connections.lock().await;
345 connections.clear();
346 }
347
348 async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
350 match &config.password {
351 Some(password) => {
352 PgConnection::connect_with_password(
353 &config.host,
354 config.port,
355 &config.user,
356 &config.database,
357 Some(password),
358 )
359 .await
360 }
361 None => {
362 PgConnection::connect(&config.host, config.port, &config.user, &config.database)
363 .await
364 }
365 }
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_pool_config() {
375 let config = PoolConfig::new("localhost", 5432, "user", "testdb")
376 .password("secret123")
377 .max_connections(20)
378 .min_connections(5);
379
380 assert_eq!(config.host, "localhost");
381 assert_eq!(config.port, 5432);
382 assert_eq!(config.user, "user");
383 assert_eq!(config.database, "testdb");
384 assert_eq!(config.password, Some("secret123".to_string()));
385 assert_eq!(config.max_connections, 20);
386 assert_eq!(config.min_connections, 5);
387 }
388}