1use super::{PgConnection, PgError, PgResult};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{Mutex, Semaphore};
11
12#[derive(Clone)]
13pub struct PoolConfig {
14 pub host: String,
15 pub port: u16,
16 pub user: String,
17 pub database: String,
18 pub password: Option<String>,
19 pub max_connections: usize,
20 pub min_connections: usize,
21 pub idle_timeout: Duration,
22 pub acquire_timeout: Duration,
23 pub connect_timeout: Duration,
24 pub max_lifetime: Option<Duration>,
25 pub test_on_acquire: bool,
26}
27
28impl PoolConfig {
29 pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
31 Self {
32 host: host.to_string(),
33 port,
34 user: user.to_string(),
35 database: database.to_string(),
36 password: None,
37 max_connections: 10,
38 min_connections: 1,
39 idle_timeout: Duration::from_secs(600), acquire_timeout: Duration::from_secs(30), connect_timeout: Duration::from_secs(10), max_lifetime: None, test_on_acquire: false, }
45 }
46
47 pub fn password(mut self, password: &str) -> Self {
49 self.password = Some(password.to_string());
50 self
51 }
52
53 pub fn max_connections(mut self, max: usize) -> Self {
54 self.max_connections = max;
55 self
56 }
57
58 pub fn min_connections(mut self, min: usize) -> Self {
60 self.min_connections = min;
61 self
62 }
63
64 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
66 self.idle_timeout = timeout;
67 self
68 }
69
70 pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
72 self.acquire_timeout = timeout;
73 self
74 }
75
76 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
78 self.connect_timeout = timeout;
79 self
80 }
81
82 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
84 self.max_lifetime = Some(lifetime);
85 self
86 }
87
88 pub fn test_on_acquire(mut self, enabled: bool) -> Self {
90 self.test_on_acquire = enabled;
91 self
92 }
93
94 pub fn from_qail_config(qail: &qail_core::config::QailConfig) -> PgResult<Self> {
99 let pg = &qail.postgres;
100 let (host, port, user, database, password) = parse_pg_url(&pg.url)?;
101
102 let mut config = PoolConfig::new(&host, port, &user, &database)
103 .max_connections(pg.max_connections)
104 .min_connections(pg.min_connections)
105 .idle_timeout(Duration::from_secs(pg.idle_timeout_secs))
106 .acquire_timeout(Duration::from_secs(pg.acquire_timeout_secs))
107 .connect_timeout(Duration::from_secs(pg.connect_timeout_secs))
108 .test_on_acquire(pg.test_on_acquire);
109
110 if let Some(ref pw) = password {
111 config = config.password(pw);
112 }
113
114 Ok(config)
115 }
116}
117
118fn parse_pg_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
120 let url = url.trim_start_matches("postgres://").trim_start_matches("postgresql://");
121
122 let (credentials, host_part) = if url.contains('@') {
123 let mut parts = url.splitn(2, '@');
124 let creds = parts.next().unwrap_or("");
125 let host = parts.next().unwrap_or("localhost/postgres");
126 (Some(creds), host)
127 } else {
128 (None, url)
129 };
130
131 let (host_port, database) = if host_part.contains('/') {
132 let mut parts = host_part.splitn(2, '/');
133 (parts.next().unwrap_or("localhost"), parts.next().unwrap_or("postgres").to_string())
134 } else {
135 (host_part, "postgres".to_string())
136 };
137
138 let (host, port) = if host_port.contains(':') {
139 let mut parts = host_port.split(':');
140 let h = parts.next().unwrap_or("localhost").to_string();
141 let p = parts.next().and_then(|s| s.parse().ok()).unwrap_or(5432u16);
142 (h, p)
143 } else {
144 (host_port.to_string(), 5432u16)
145 };
146
147 let (user, password) = if let Some(creds) = credentials {
148 if creds.contains(':') {
149 let mut parts = creds.splitn(2, ':');
150 let u = parts.next().unwrap_or("postgres").to_string();
151 let p = parts.next().map(|s| s.to_string());
152 (u, p)
153 } else {
154 (creds.to_string(), None)
155 }
156 } else {
157 ("postgres".to_string(), None)
158 };
159
160 Ok((host, port, user, database, password))
161}
162
163#[derive(Debug, Clone, Default)]
165pub struct PoolStats {
166 pub active: usize,
167 pub idle: usize,
168 pub pending: usize,
169 pub max_size: usize,
171 pub total_created: usize,
172}
173
174struct PooledConn {
176 conn: PgConnection,
177 created_at: Instant,
178 last_used: Instant,
179}
180
181pub struct PooledConnection {
187 conn: Option<PgConnection>,
188 pool: Arc<PgPoolInner>,
189 rls_dirty: bool,
190}
191
192impl PooledConnection {
193 pub fn get_mut(&mut self) -> &mut PgConnection {
195 self.conn
196 .as_mut()
197 .expect("Connection should always be present")
198 }
199
200 pub fn cancel_token(&self) -> crate::driver::CancelToken {
202 let (process_id, secret_key) = self.conn.as_ref().expect("Connection missing").get_cancel_key();
203 crate::driver::CancelToken {
204 host: self.pool.config.host.clone(),
205 port: self.pool.config.port,
206 process_id,
207 secret_key,
208 }
209 }
210
211 pub async fn fetch_all_uncached(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
214 use crate::protocol::AstEncoder;
215 use super::ColumnInfo;
216
217 let conn = self.conn.as_mut().expect("Connection should always be present");
218
219 let wire_bytes = AstEncoder::encode_cmd_reuse(
220 cmd,
221 &mut conn.sql_buf,
222 &mut conn.params_buf,
223 );
224
225 conn.send_bytes(&wire_bytes).await?;
226
227 let mut rows: Vec<super::PgRow> = Vec::new();
228 let mut column_info: Option<Arc<ColumnInfo>> = None;
229 let mut error: Option<PgError> = None;
230
231 loop {
232 let msg = conn.recv().await?;
233 match msg {
234 crate::protocol::BackendMessage::ParseComplete
235 | crate::protocol::BackendMessage::BindComplete => {}
236 crate::protocol::BackendMessage::RowDescription(fields) => {
237 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
238 }
239 crate::protocol::BackendMessage::DataRow(data) => {
240 if error.is_none() {
241 rows.push(super::PgRow {
242 columns: data,
243 column_info: column_info.clone(),
244 });
245 }
246 }
247 crate::protocol::BackendMessage::CommandComplete(_) => {}
248 crate::protocol::BackendMessage::ReadyForQuery(_) => {
249 if let Some(err) = error {
250 return Err(err);
251 }
252 return Ok(rows);
253 }
254 crate::protocol::BackendMessage::ErrorResponse(err) => {
255 if error.is_none() {
256 error = Some(PgError::Query(err.message));
257 }
258 }
259 _ => {}
260 }
261 }
262 }
263}
264
265impl Drop for PooledConnection {
266 fn drop(&mut self) {
267 if let Some(conn) = self.conn.take() {
268 let pool = self.pool.clone();
269 let rls_dirty = self.rls_dirty;
270 tokio::spawn(async move {
271 if rls_dirty {
272 let mut conn = conn;
276 let _ = conn.execute_simple(super::rls::reset_sql()).await;
277 pool.return_connection(conn).await;
278 } else {
279 pool.return_connection(conn).await;
280 }
281 });
282 }
283 }
284}
285
286impl std::ops::Deref for PooledConnection {
287 type Target = PgConnection;
288
289 fn deref(&self) -> &Self::Target {
290 self.conn
291 .as_ref()
292 .expect("Connection should always be present")
293 }
294}
295
296impl std::ops::DerefMut for PooledConnection {
297 fn deref_mut(&mut self) -> &mut Self::Target {
298 self.conn
299 .as_mut()
300 .expect("Connection should always be present")
301 }
302}
303
304struct PgPoolInner {
306 config: PoolConfig,
307 connections: Mutex<Vec<PooledConn>>,
308 semaphore: Semaphore,
309 closed: AtomicBool,
310 active_count: AtomicUsize,
311 total_created: AtomicUsize,
312}
313
314impl PgPoolInner {
315 async fn return_connection(&self, conn: PgConnection) {
316
317 self.active_count.fetch_sub(1, Ordering::Relaxed);
318
319
320 if self.closed.load(Ordering::Relaxed) {
321 return;
322 }
323
324 let mut connections = self.connections.lock().await;
325 if connections.len() < self.config.max_connections {
326 connections.push(PooledConn {
327 conn,
328 created_at: Instant::now(),
329 last_used: Instant::now(),
330 });
331 }
332
333 self.semaphore.add_permits(1);
334 }
335
336 async fn get_healthy_connection(&self) -> Option<PgConnection> {
338 let mut connections = self.connections.lock().await;
339
340 while let Some(pooled) = connections.pop() {
341 if pooled.last_used.elapsed() > self.config.idle_timeout {
342 continue;
344 }
345
346 if let Some(max_life) = self.config.max_lifetime
347 && pooled.created_at.elapsed() > max_life
348 {
349 continue;
351 }
352
353 return Some(pooled.conn);
354 }
355
356 None
357 }
358}
359
360#[derive(Clone)]
371pub struct PgPool {
372 inner: Arc<PgPoolInner>,
373}
374
375impl PgPool {
376 pub async fn from_config() -> PgResult<Self> {
383 let qail = qail_core::config::QailConfig::load()
384 .map_err(|e| PgError::Connection(format!("Config error: {}", e)))?;
385 let config = PoolConfig::from_qail_config(&qail)?;
386 Self::connect(config).await
387 }
388
389 pub async fn connect(config: PoolConfig) -> PgResult<Self> {
391 let semaphore = Semaphore::new(config.max_connections);
393
394 let mut initial_connections = Vec::new();
395 for _ in 0..config.min_connections {
396 let conn = Self::create_connection(&config).await?;
397 initial_connections.push(PooledConn {
398 conn,
399 created_at: Instant::now(),
400 last_used: Instant::now(),
401 });
402 }
403
404 let initial_count = initial_connections.len();
405
406 let inner = Arc::new(PgPoolInner {
407 config,
408 connections: Mutex::new(initial_connections),
409 semaphore,
410 closed: AtomicBool::new(false),
411 active_count: AtomicUsize::new(0),
412 total_created: AtomicUsize::new(initial_count),
413 });
414
415 Ok(Self { inner })
416 }
417
418 pub async fn acquire(&self) -> PgResult<PooledConnection> {
420 if self.inner.closed.load(Ordering::Relaxed) {
421 return Err(PgError::Connection("Pool is closed".to_string()));
422 }
423
424 let acquire_timeout = self.inner.config.acquire_timeout;
426 let permit = tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire())
427 .await
428 .map_err(|_| {
429 PgError::Connection(format!(
430 "Timed out waiting for connection ({}s)",
431 acquire_timeout.as_secs()
432 ))
433 })?
434 .map_err(|_| PgError::Connection("Pool closed".to_string()))?;
435 permit.forget();
436
437 let conn = if let Some(conn) = self.inner.get_healthy_connection().await {
439 conn
440 } else {
441 let conn = Self::create_connection(&self.inner.config).await?;
442 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
443 conn
444 };
445
446
447 self.inner.active_count.fetch_add(1, Ordering::Relaxed);
448
449 Ok(PooledConnection {
450 conn: Some(conn),
451 pool: self.inner.clone(),
452 rls_dirty: false,
453 })
454 }
455
456 pub async fn acquire_with_rls(
472 &self,
473 ctx: qail_core::rls::RlsContext,
474 ) -> PgResult<PooledConnection> {
475 let mut conn = self.acquire().await?;
476
477 let sql = super::rls::context_to_sql(&ctx);
479 let pg_conn = conn.get_mut();
480 pg_conn.execute_simple(&sql).await?;
481
482 conn.rls_dirty = true;
484
485 Ok(conn)
486 }
487
488 pub async fn idle_count(&self) -> usize {
490 self.inner.connections.lock().await.len()
491 }
492
493 pub fn active_count(&self) -> usize {
495 self.inner.active_count.load(Ordering::Relaxed)
496 }
497
498 pub fn max_connections(&self) -> usize {
500 self.inner.config.max_connections
501 }
502
503 pub async fn stats(&self) -> PoolStats {
505 let idle = self.inner.connections.lock().await.len();
506 PoolStats {
507 active: self.inner.active_count.load(Ordering::Relaxed),
508 idle,
509 pending: self.inner.config.max_connections
510 - self.inner.semaphore.available_permits()
511 - self.active_count(),
512 max_size: self.inner.config.max_connections,
513 total_created: self.inner.total_created.load(Ordering::Relaxed),
514 }
515 }
516
517 pub fn is_closed(&self) -> bool {
519 self.inner.closed.load(Ordering::Relaxed)
520 }
521
522 pub async fn close(&self) {
524 self.inner.closed.store(true, Ordering::Relaxed);
525
526 let mut connections = self.inner.connections.lock().await;
527 connections.clear();
528 }
529
530 async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
532 match &config.password {
533 Some(password) => {
534 PgConnection::connect_with_password(
535 &config.host,
536 config.port,
537 &config.user,
538 &config.database,
539 Some(password),
540 )
541 .await
542 }
543 None => {
544 PgConnection::connect(&config.host, config.port, &config.user, &config.database)
545 .await
546 }
547 }
548 }
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554
555 #[test]
556 fn test_pool_config() {
557 let config = PoolConfig::new("localhost", 5432, "user", "testdb")
558 .password("secret123")
559 .max_connections(20)
560 .min_connections(5);
561
562 assert_eq!(config.host, "localhost");
563 assert_eq!(config.port, 5432);
564 assert_eq!(config.user, "user");
565 assert_eq!(config.database, "testdb");
566 assert_eq!(config.password, Some("secret123".to_string()));
567 assert_eq!(config.max_connections, 20);
568 assert_eq!(config.min_connections, 5);
569 }
570}