1use super::{PgConnection, PgError, PgResult};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{Mutex, Semaphore};
11
12#[derive(Clone)]
13pub struct PoolConfig {
14 pub host: String,
15 pub port: u16,
16 pub user: String,
17 pub database: String,
18 pub password: Option<String>,
19 pub max_connections: usize,
20 pub min_connections: usize,
21 pub idle_timeout: Duration,
22 pub acquire_timeout: Duration,
23 pub connect_timeout: Duration,
24 pub max_lifetime: Option<Duration>,
25 pub test_on_acquire: bool,
26}
27
28impl PoolConfig {
29 pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
31 Self {
32 host: host.to_string(),
33 port,
34 user: user.to_string(),
35 database: database.to_string(),
36 password: None,
37 max_connections: 10,
38 min_connections: 1,
39 idle_timeout: Duration::from_secs(600), acquire_timeout: Duration::from_secs(30), connect_timeout: Duration::from_secs(10), max_lifetime: None, test_on_acquire: false, }
45 }
46
47 pub fn password(mut self, password: &str) -> Self {
49 self.password = Some(password.to_string());
50 self
51 }
52
53 pub fn max_connections(mut self, max: usize) -> Self {
54 self.max_connections = max;
55 self
56 }
57
58 pub fn min_connections(mut self, min: usize) -> Self {
60 self.min_connections = min;
61 self
62 }
63
64 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
66 self.idle_timeout = timeout;
67 self
68 }
69
70 pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
72 self.acquire_timeout = timeout;
73 self
74 }
75
76 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
78 self.connect_timeout = timeout;
79 self
80 }
81
82 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
84 self.max_lifetime = Some(lifetime);
85 self
86 }
87
88 pub fn test_on_acquire(mut self, enabled: bool) -> Self {
90 self.test_on_acquire = enabled;
91 self
92 }
93}
94
95#[derive(Debug, Clone, Default)]
97pub struct PoolStats {
98 pub active: usize,
99 pub idle: usize,
100 pub pending: usize,
101 pub max_size: usize,
103 pub total_created: usize,
104}
105
106struct PooledConn {
108 conn: PgConnection,
109 created_at: Instant,
110 last_used: Instant,
111}
112
113pub struct PooledConnection {
115 conn: Option<PgConnection>,
116 pool: Arc<PgPoolInner>,
117}
118
119impl PooledConnection {
120 pub fn get_mut(&mut self) -> &mut PgConnection {
122 self.conn
123 .as_mut()
124 .expect("Connection should always be present")
125 }
126
127 pub fn cancel_token(&self) -> crate::driver::CancelToken {
129 let (process_id, secret_key) = self.conn.as_ref().expect("Connection missing").get_cancel_key();
130 crate::driver::CancelToken {
131 host: self.pool.config.host.clone(),
132 port: self.pool.config.port,
133 process_id,
134 secret_key,
135 }
136 }
137
138 pub async fn fetch_all_uncached(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
141 use crate::protocol::AstEncoder;
142 use super::ColumnInfo;
143
144 let conn = self.conn.as_mut().expect("Connection should always be present");
145
146 let wire_bytes = AstEncoder::encode_cmd_reuse(
147 cmd,
148 &mut conn.sql_buf,
149 &mut conn.params_buf,
150 );
151
152 conn.send_bytes(&wire_bytes).await?;
153
154 let mut rows: Vec<super::PgRow> = Vec::new();
155 let mut column_info: Option<Arc<ColumnInfo>> = None;
156 let mut error: Option<PgError> = None;
157
158 loop {
159 let msg = conn.recv().await?;
160 match msg {
161 crate::protocol::BackendMessage::ParseComplete
162 | crate::protocol::BackendMessage::BindComplete => {}
163 crate::protocol::BackendMessage::RowDescription(fields) => {
164 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
165 }
166 crate::protocol::BackendMessage::DataRow(data) => {
167 if error.is_none() {
168 rows.push(super::PgRow {
169 columns: data,
170 column_info: column_info.clone(),
171 });
172 }
173 }
174 crate::protocol::BackendMessage::CommandComplete(_) => {}
175 crate::protocol::BackendMessage::ReadyForQuery(_) => {
176 if let Some(err) = error {
177 return Err(err);
178 }
179 return Ok(rows);
180 }
181 crate::protocol::BackendMessage::ErrorResponse(err) => {
182 if error.is_none() {
183 error = Some(PgError::Query(err.message));
184 }
185 }
186 _ => {}
187 }
188 }
189 }
190}
191
192impl Drop for PooledConnection {
193 fn drop(&mut self) {
194 if let Some(conn) = self.conn.take() {
195 let pool = self.pool.clone();
196 tokio::spawn(async move {
197 pool.return_connection(conn).await;
198 });
199 }
200 }
201}
202
203impl std::ops::Deref for PooledConnection {
204 type Target = PgConnection;
205
206 fn deref(&self) -> &Self::Target {
207 self.conn
208 .as_ref()
209 .expect("Connection should always be present")
210 }
211}
212
213impl std::ops::DerefMut for PooledConnection {
214 fn deref_mut(&mut self) -> &mut Self::Target {
215 self.conn
216 .as_mut()
217 .expect("Connection should always be present")
218 }
219}
220
221struct PgPoolInner {
223 config: PoolConfig,
224 connections: Mutex<Vec<PooledConn>>,
225 semaphore: Semaphore,
226 closed: AtomicBool,
227 active_count: AtomicUsize,
228 total_created: AtomicUsize,
229}
230
231impl PgPoolInner {
232 async fn return_connection(&self, conn: PgConnection) {
233
234 self.active_count.fetch_sub(1, Ordering::Relaxed);
235
236
237 if self.closed.load(Ordering::Relaxed) {
238 return;
239 }
240
241 let mut connections = self.connections.lock().await;
242 if connections.len() < self.config.max_connections {
243 connections.push(PooledConn {
244 conn,
245 created_at: Instant::now(),
246 last_used: Instant::now(),
247 });
248 }
249
250 self.semaphore.add_permits(1);
251 }
252
253 async fn get_healthy_connection(&self) -> Option<PgConnection> {
255 let mut connections = self.connections.lock().await;
256
257 while let Some(pooled) = connections.pop() {
258 if pooled.last_used.elapsed() > self.config.idle_timeout {
259 continue;
261 }
262
263 if let Some(max_life) = self.config.max_lifetime
264 && pooled.created_at.elapsed() > max_life
265 {
266 continue;
268 }
269
270 return Some(pooled.conn);
271 }
272
273 None
274 }
275}
276
277#[derive(Clone)]
288pub struct PgPool {
289 inner: Arc<PgPoolInner>,
290}
291
292impl PgPool {
293 pub async fn connect(config: PoolConfig) -> PgResult<Self> {
295 let semaphore = Semaphore::new(config.max_connections);
297
298 let mut initial_connections = Vec::new();
299 for _ in 0..config.min_connections {
300 let conn = Self::create_connection(&config).await?;
301 initial_connections.push(PooledConn {
302 conn,
303 created_at: Instant::now(),
304 last_used: Instant::now(),
305 });
306 }
307
308 let initial_count = initial_connections.len();
309
310 let inner = Arc::new(PgPoolInner {
311 config,
312 connections: Mutex::new(initial_connections),
313 semaphore,
314 closed: AtomicBool::new(false),
315 active_count: AtomicUsize::new(0),
316 total_created: AtomicUsize::new(initial_count),
317 });
318
319 Ok(Self { inner })
320 }
321
322 pub async fn acquire(&self) -> PgResult<PooledConnection> {
324 if self.inner.closed.load(Ordering::Relaxed) {
325 return Err(PgError::Connection("Pool is closed".to_string()));
326 }
327
328 let acquire_timeout = self.inner.config.acquire_timeout;
330 let permit = tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire())
331 .await
332 .map_err(|_| {
333 PgError::Connection(format!(
334 "Timed out waiting for connection ({}s)",
335 acquire_timeout.as_secs()
336 ))
337 })?
338 .map_err(|_| PgError::Connection("Pool closed".to_string()))?;
339 permit.forget();
340
341 let conn = if let Some(conn) = self.inner.get_healthy_connection().await {
343 conn
344 } else {
345 let conn = Self::create_connection(&self.inner.config).await?;
346 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
347 conn
348 };
349
350
351 self.inner.active_count.fetch_add(1, Ordering::Relaxed);
352
353 Ok(PooledConnection {
354 conn: Some(conn),
355 pool: self.inner.clone(),
356 })
357 }
358
359 pub async fn idle_count(&self) -> usize {
361 self.inner.connections.lock().await.len()
362 }
363
364 pub fn active_count(&self) -> usize {
366 self.inner.active_count.load(Ordering::Relaxed)
367 }
368
369 pub fn max_connections(&self) -> usize {
371 self.inner.config.max_connections
372 }
373
374 pub async fn stats(&self) -> PoolStats {
376 let idle = self.inner.connections.lock().await.len();
377 PoolStats {
378 active: self.inner.active_count.load(Ordering::Relaxed),
379 idle,
380 pending: self.inner.config.max_connections
381 - self.inner.semaphore.available_permits()
382 - self.active_count(),
383 max_size: self.inner.config.max_connections,
384 total_created: self.inner.total_created.load(Ordering::Relaxed),
385 }
386 }
387
388 pub fn is_closed(&self) -> bool {
390 self.inner.closed.load(Ordering::Relaxed)
391 }
392
393 pub async fn close(&self) {
395 self.inner.closed.store(true, Ordering::Relaxed);
396
397 let mut connections = self.inner.connections.lock().await;
398 connections.clear();
399 }
400
401 async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
403 match &config.password {
404 Some(password) => {
405 PgConnection::connect_with_password(
406 &config.host,
407 config.port,
408 &config.user,
409 &config.database,
410 Some(password),
411 )
412 .await
413 }
414 None => {
415 PgConnection::connect(&config.host, config.port, &config.user, &config.database)
416 .await
417 }
418 }
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn test_pool_config() {
428 let config = PoolConfig::new("localhost", 5432, "user", "testdb")
429 .password("secret123")
430 .max_connections(20)
431 .min_connections(5);
432
433 assert_eq!(config.host, "localhost");
434 assert_eq!(config.port, 5432);
435 assert_eq!(config.user, "user");
436 assert_eq!(config.database, "testdb");
437 assert_eq!(config.password, Some("secret123".to_string()));
438 assert_eq!(config.max_connections, 20);
439 assert_eq!(config.min_connections, 5);
440 }
441}