1use super::{PgConnection, PgError, PgResult};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{Mutex, Semaphore};
11
12#[derive(Clone)]
25pub struct PoolConfig {
26 pub host: String,
28 pub port: u16,
30 pub user: String,
32 pub database: String,
34 pub password: Option<String>,
36 pub max_connections: usize,
38 pub min_connections: usize,
40 pub idle_timeout: Duration,
42 pub acquire_timeout: Duration,
44 pub connect_timeout: Duration,
46 pub max_lifetime: Option<Duration>,
48 pub test_on_acquire: bool,
50}
51
52impl PoolConfig {
53 pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
62 Self {
63 host: host.to_string(),
64 port,
65 user: user.to_string(),
66 database: database.to_string(),
67 password: None,
68 max_connections: 10,
69 min_connections: 1,
70 idle_timeout: Duration::from_secs(600), acquire_timeout: Duration::from_secs(30), connect_timeout: Duration::from_secs(10), max_lifetime: None, test_on_acquire: false, }
76 }
77
78 pub fn password(mut self, password: &str) -> Self {
80 self.password = Some(password.to_string());
81 self
82 }
83
84 pub fn max_connections(mut self, max: usize) -> Self {
86 self.max_connections = max;
87 self
88 }
89
90 pub fn min_connections(mut self, min: usize) -> Self {
92 self.min_connections = min;
93 self
94 }
95
96 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
98 self.idle_timeout = timeout;
99 self
100 }
101
102 pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
104 self.acquire_timeout = timeout;
105 self
106 }
107
108 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
110 self.connect_timeout = timeout;
111 self
112 }
113
114 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
116 self.max_lifetime = Some(lifetime);
117 self
118 }
119
120 pub fn test_on_acquire(mut self, enabled: bool) -> Self {
122 self.test_on_acquire = enabled;
123 self
124 }
125
126 pub fn from_qail_config(qail: &qail_core::config::QailConfig) -> PgResult<Self> {
131 let pg = &qail.postgres;
132 let (host, port, user, database, password) = parse_pg_url(&pg.url)?;
133
134 let mut config = PoolConfig::new(&host, port, &user, &database)
135 .max_connections(pg.max_connections)
136 .min_connections(pg.min_connections)
137 .idle_timeout(Duration::from_secs(pg.idle_timeout_secs))
138 .acquire_timeout(Duration::from_secs(pg.acquire_timeout_secs))
139 .connect_timeout(Duration::from_secs(pg.connect_timeout_secs))
140 .test_on_acquire(pg.test_on_acquire);
141
142 if let Some(ref pw) = password {
143 config = config.password(pw);
144 }
145
146 Ok(config)
147 }
148}
149
150fn parse_pg_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
152 let url = url.trim_start_matches("postgres://").trim_start_matches("postgresql://");
153
154 let (credentials, host_part) = if url.contains('@') {
155 let mut parts = url.splitn(2, '@');
156 let creds = parts.next().unwrap_or("");
157 let host = parts.next().unwrap_or("localhost/postgres");
158 (Some(creds), host)
159 } else {
160 (None, url)
161 };
162
163 let (host_port, database) = if host_part.contains('/') {
164 let mut parts = host_part.splitn(2, '/');
165 (parts.next().unwrap_or("localhost"), parts.next().unwrap_or("postgres").to_string())
166 } else {
167 (host_part, "postgres".to_string())
168 };
169
170 let (host, port) = if host_port.contains(':') {
171 let mut parts = host_port.split(':');
172 let h = parts.next().unwrap_or("localhost").to_string();
173 let p = parts.next().and_then(|s| s.parse().ok()).unwrap_or(5432u16);
174 (h, p)
175 } else {
176 (host_port.to_string(), 5432u16)
177 };
178
179 let (user, password) = if let Some(creds) = credentials {
180 if creds.contains(':') {
181 let mut parts = creds.splitn(2, ':');
182 let u = parts.next().unwrap_or("postgres").to_string();
183 let p = parts.next().map(|s| s.to_string());
184 (u, p)
185 } else {
186 (creds.to_string(), None)
187 }
188 } else {
189 ("postgres".to_string(), None)
190 };
191
192 Ok((host, port, user, database, password))
193}
194
195#[derive(Debug, Clone, Default)]
197pub struct PoolStats {
198 pub active: usize,
200 pub idle: usize,
202 pub pending: usize,
204 pub max_size: usize,
206 pub total_created: usize,
208}
209
210struct PooledConn {
212 conn: PgConnection,
213 created_at: Instant,
214 last_used: Instant,
215}
216
217pub struct PooledConnection {
223 conn: Option<PgConnection>,
224 pool: Arc<PgPoolInner>,
225 rls_dirty: bool,
226}
227
228impl PooledConnection {
229 fn conn_ref(&self) -> PgResult<&PgConnection> {
232 self.conn.as_ref().ok_or_else(|| PgError::Connection(
233 "Connection already released back to pool".into()
234 ))
235 }
236
237 fn conn_mut(&mut self) -> PgResult<&mut PgConnection> {
240 self.conn.as_mut().ok_or_else(|| PgError::Connection(
241 "Connection already released back to pool".into()
242 ))
243 }
244
245 pub fn get_mut(&mut self) -> &mut PgConnection {
248 self.conn
251 .as_mut()
252 .expect("Connection should always be present")
253 }
254
255 pub fn cancel_token(&self) -> PgResult<crate::driver::CancelToken> {
257 let conn = self.conn_ref()?;
258 let (process_id, secret_key) = conn.get_cancel_key();
259 Ok(crate::driver::CancelToken {
260 host: self.pool.config.host.clone(),
261 port: self.pool.config.port,
262 process_id,
263 secret_key,
264 })
265 }
266
267 pub async fn release(mut self) {
284 if let Some(mut conn) = self.conn.take() {
285 if let Err(e) = conn.execute_simple(super::rls::reset_sql()).await {
290 eprintln!(
291 "[CRITICAL] pool_release_failed: COMMIT failed — \
292 dropping connection to prevent state leak: {}",
293 e
294 );
295 return; }
297
298 self.pool.return_connection(conn).await;
299 }
300 }
301
302 pub async fn fetch_all_uncached(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
305 use crate::protocol::AstEncoder;
306 use super::ColumnInfo;
307
308 let conn = self.conn_mut()?;
309
310 let wire_bytes = AstEncoder::encode_cmd_reuse(
311 cmd,
312 &mut conn.sql_buf,
313 &mut conn.params_buf,
314 )
315 .map_err(|e| PgError::Encode(e.to_string()))?;
316
317 conn.send_bytes(&wire_bytes).await?;
318
319 let mut rows: Vec<super::PgRow> = Vec::new();
320 let mut column_info: Option<Arc<ColumnInfo>> = None;
321 let mut error: Option<PgError> = None;
322
323 loop {
324 let msg = conn.recv().await?;
325 match msg {
326 crate::protocol::BackendMessage::ParseComplete
327 | crate::protocol::BackendMessage::BindComplete => {}
328 crate::protocol::BackendMessage::RowDescription(fields) => {
329 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
330 }
331 crate::protocol::BackendMessage::DataRow(data) => {
332 if error.is_none() {
333 rows.push(super::PgRow {
334 columns: data,
335 column_info: column_info.clone(),
336 });
337 }
338 }
339 crate::protocol::BackendMessage::CommandComplete(_) => {}
340 crate::protocol::BackendMessage::ReadyForQuery(_) => {
341 if let Some(err) = error {
342 return Err(err);
343 }
344 return Ok(rows);
345 }
346 crate::protocol::BackendMessage::ErrorResponse(err) => {
347 if error.is_none() {
348 error = Some(PgError::Query(err.message));
349 }
350 }
351 _ => {}
352 }
353 }
354 }
355
356 pub async fn fetch_all_fast(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
360 use crate::protocol::AstEncoder;
361
362 let conn = self.conn_mut()?;
363
364 AstEncoder::encode_cmd_reuse_into(
365 cmd,
366 &mut conn.sql_buf,
367 &mut conn.params_buf,
368 &mut conn.write_buf,
369 )
370 .map_err(|e| PgError::Encode(e.to_string()))?;
371
372 conn.flush_write_buf().await?;
373
374 let mut rows: Vec<super::PgRow> = Vec::with_capacity(32);
375 let mut error: Option<PgError> = None;
376
377 loop {
378 let res = conn.recv_with_data_fast().await;
379 match res {
380 Ok((msg_type, data)) => {
381 match msg_type {
382 b'D' => {
383 if error.is_none() && let Some(columns) = data {
384 rows.push(super::PgRow {
385 columns,
386 column_info: None,
387 });
388 }
389 }
390 b'Z' => {
391 if let Some(err) = error {
392 return Err(err);
393 }
394 return Ok(rows);
395 }
396 _ => {}
397 }
398 }
399 Err(e) => {
400 if error.is_none() {
401 error = Some(e);
402 }
403 }
404 }
405 }
406 }
407
408 pub async fn fetch_all_cached(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
413 use super::ColumnInfo;
414 use std::collections::hash_map::DefaultHasher;
415 use std::hash::{Hash, Hasher};
416
417 let conn = self.conn.as_mut().ok_or_else(|| PgError::Connection(
418 "Connection already released back to pool".into()
419 ))?;
420
421 conn.sql_buf.clear();
422 conn.params_buf.clear();
423
424 match cmd.action {
426 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
427 crate::protocol::ast_encoder::dml::encode_select(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
428 }
429 qail_core::ast::Action::Add => {
430 crate::protocol::ast_encoder::dml::encode_insert(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
431 }
432 qail_core::ast::Action::Set => {
433 crate::protocol::ast_encoder::dml::encode_update(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
434 }
435 qail_core::ast::Action::Del => {
436 crate::protocol::ast_encoder::dml::encode_delete(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
437 }
438 _ => {
439 return self.fetch_all_uncached(cmd).await;
441 }
442 }
443
444 let mut hasher = DefaultHasher::new();
445 conn.sql_buf.hash(&mut hasher);
446 let sql_hash = hasher.finish();
447
448 let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
449
450 conn.write_buf.clear();
451
452 let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
453 name.clone()
454 } else {
455 let name = format!("qail_{:x}", sql_hash);
456
457 conn.evict_prepared_if_full();
458
459 let sql_str = std::str::from_utf8(&conn.sql_buf).unwrap_or("");
460
461 use crate::protocol::PgEncoder;
462 let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
463 let describe_msg = PgEncoder::encode_describe(false, &name);
464 conn.write_buf.extend_from_slice(&parse_msg);
465 conn.write_buf.extend_from_slice(&describe_msg);
466
467 conn.stmt_cache.put(sql_hash, name.clone());
468 conn.prepared_statements.insert(name.clone(), sql_str.to_string());
469
470 if let Ok(mut hot) = self.pool.hot_statements.write()
472 && hot.len() < MAX_HOT_STATEMENTS
473 {
474 hot.insert(sql_hash, (name.clone(), sql_str.to_string()));
475 }
476
477 name
478 };
479
480 use crate::protocol::PgEncoder;
481 PgEncoder::encode_bind_to(&mut conn.write_buf, &stmt_name, &conn.params_buf)
482 .map_err(|e| PgError::Encode(e.to_string()))?;
483 PgEncoder::encode_execute_to(&mut conn.write_buf);
484 PgEncoder::encode_sync_to(&mut conn.write_buf);
485
486 conn.flush_write_buf().await?;
487
488 let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
489
490 let mut rows: Vec<super::PgRow> = Vec::with_capacity(32);
491 let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
492 let mut error: Option<PgError> = None;
493
494 loop {
495 let msg = conn.recv().await?;
496 match msg {
497 crate::protocol::BackendMessage::ParseComplete
498 | crate::protocol::BackendMessage::BindComplete => {}
499 crate::protocol::BackendMessage::ParameterDescription(_) => {}
500 crate::protocol::BackendMessage::RowDescription(fields) => {
501 let info = Arc::new(ColumnInfo::from_fields(&fields));
502 if is_cache_miss {
503 conn.column_info_cache.insert(sql_hash, info.clone());
504 }
505 column_info = Some(info);
506 }
507 crate::protocol::BackendMessage::DataRow(data) => {
508 if error.is_none() {
509 rows.push(super::PgRow {
510 columns: data,
511 column_info: column_info.clone(),
512 });
513 }
514 }
515 crate::protocol::BackendMessage::CommandComplete(_) => {}
516 crate::protocol::BackendMessage::ReadyForQuery(_) => {
517 if let Some(err) = error {
518 return Err(err);
519 }
520 return Ok(rows);
521 }
522 crate::protocol::BackendMessage::ErrorResponse(err) => {
523 if error.is_none() {
524 error = Some(PgError::Query(err.message));
525 }
526 }
527 _ => {}
528 }
529 }
530 }
531
532 pub async fn fetch_all_with_rls(
551 &mut self,
552 cmd: &qail_core::ast::Qail,
553 rls_sql: &str,
554 ) -> PgResult<Vec<super::PgRow>> {
555 use super::ColumnInfo;
556 use std::collections::hash_map::DefaultHasher;
557 use std::hash::{Hash, Hasher};
558
559 let conn = self.conn.as_mut().ok_or_else(|| PgError::Connection(
560 "Connection already released back to pool".into()
561 ))?;
562
563 conn.sql_buf.clear();
564 conn.params_buf.clear();
565
566 match cmd.action {
568 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
569 crate::protocol::ast_encoder::dml::encode_select(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
570 }
571 qail_core::ast::Action::Add => {
572 crate::protocol::ast_encoder::dml::encode_insert(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
573 }
574 qail_core::ast::Action::Set => {
575 crate::protocol::ast_encoder::dml::encode_update(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
576 }
577 qail_core::ast::Action::Del => {
578 crate::protocol::ast_encoder::dml::encode_delete(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
579 }
580 _ => {
581 conn.execute_simple(rls_sql).await?;
583 self.rls_dirty = true;
584 return self.fetch_all_uncached(cmd).await;
585 }
586 }
587
588 let mut hasher = DefaultHasher::new();
589 conn.sql_buf.hash(&mut hasher);
590 let sql_hash = hasher.finish();
591
592 let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
593
594 conn.write_buf.clear();
595
596 let rls_msg = crate::protocol::PgEncoder::encode_query_string(rls_sql);
600 conn.write_buf.extend_from_slice(&rls_msg);
601
602 let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
604 name.clone()
605 } else {
606 let name = format!("qail_{:x}", sql_hash);
607
608 conn.evict_prepared_if_full();
609
610 let sql_str = std::str::from_utf8(&conn.sql_buf).unwrap_or("");
611
612 use crate::protocol::PgEncoder;
613 let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
614 let describe_msg = PgEncoder::encode_describe(false, &name);
615 conn.write_buf.extend_from_slice(&parse_msg);
616 conn.write_buf.extend_from_slice(&describe_msg);
617
618 conn.stmt_cache.put(sql_hash, name.clone());
619 conn.prepared_statements.insert(name.clone(), sql_str.to_string());
620
621 if let Ok(mut hot) = self.pool.hot_statements.write()
622 && hot.len() < MAX_HOT_STATEMENTS
623 {
624 hot.insert(sql_hash, (name.clone(), sql_str.to_string()));
625 }
626
627 name
628 };
629
630 use crate::protocol::PgEncoder;
631 PgEncoder::encode_bind_to(&mut conn.write_buf, &stmt_name, &conn.params_buf)
632 .map_err(|e| PgError::Encode(e.to_string()))?;
633 PgEncoder::encode_execute_to(&mut conn.write_buf);
634 PgEncoder::encode_sync_to(&mut conn.write_buf);
635
636 conn.flush_write_buf().await?;
638
639 self.rls_dirty = true;
641
642 let mut rls_error: Option<PgError> = None;
646 loop {
647 let msg = conn.recv().await?;
648 match msg {
649 crate::protocol::BackendMessage::ReadyForQuery(_) => {
650 if let Some(err) = rls_error {
652 return Err(err);
653 }
654 break;
655 }
656 crate::protocol::BackendMessage::ErrorResponse(err) => {
657 if rls_error.is_none() {
658 rls_error = Some(PgError::Query(err.message));
659 }
660 }
661 _ => {}
663 }
664 }
665
666 let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
668
669 let mut rows: Vec<super::PgRow> = Vec::with_capacity(32);
670 let mut column_info: Option<std::sync::Arc<ColumnInfo>> = cached_column_info;
671 let mut error: Option<PgError> = None;
672
673 loop {
674 let msg = conn.recv().await?;
675 match msg {
676 crate::protocol::BackendMessage::ParseComplete
677 | crate::protocol::BackendMessage::BindComplete => {}
678 crate::protocol::BackendMessage::ParameterDescription(_) => {}
679 crate::protocol::BackendMessage::RowDescription(fields) => {
680 let info = std::sync::Arc::new(ColumnInfo::from_fields(&fields));
681 if is_cache_miss {
682 conn.column_info_cache.insert(sql_hash, info.clone());
683 }
684 column_info = Some(info);
685 }
686 crate::protocol::BackendMessage::DataRow(data) => {
687 if error.is_none() {
688 rows.push(super::PgRow {
689 columns: data,
690 column_info: column_info.clone(),
691 });
692 }
693 }
694 crate::protocol::BackendMessage::CommandComplete(_) => {}
695 crate::protocol::BackendMessage::ReadyForQuery(_) => {
696 if let Some(err) = error {
697 return Err(err);
698 }
699 return Ok(rows);
700 }
701 crate::protocol::BackendMessage::ErrorResponse(err) => {
702 if error.is_none() {
703 error = Some(PgError::Query(err.message));
704 }
705 }
706 _ => {}
707 }
708 }
709 }
710
711 pub async fn pipeline_ast(
719 &mut self,
720 cmds: &[qail_core::ast::Qail],
721 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
722 let conn = self.conn_mut()?;
723 conn.pipeline_ast(cmds).await
724 }
725
726 pub async fn explain_estimate(
732 &mut self,
733 cmd: &qail_core::ast::Qail,
734 ) -> PgResult<Option<super::explain::ExplainEstimate>> {
735 use qail_core::transpiler::ToSql;
736
737 let sql = cmd.to_sql();
738 let explain_sql = format!("EXPLAIN (FORMAT JSON) {}", sql);
739
740 let rows = self.simple_query(&explain_sql).await?;
741
742 let mut json_output = String::new();
744 for row in &rows {
745 if let Some(Some(val)) = row.columns.first()
746 && let Ok(text) = std::str::from_utf8(val)
747 {
748 json_output.push_str(text);
749 }
750 }
751
752 Ok(super::explain::parse_explain_json(&json_output))
753 }
754}
755
756impl Drop for PooledConnection {
757 fn drop(&mut self) {
758 if self.conn.is_some() {
759 eprintln!(
771 "[WARN] pool_connection_leaked: PooledConnection dropped without release() — \
772 connection destroyed to prevent state leak (rls_dirty={}). \
773 Use conn.release().await for deterministic cleanup.",
774 self.rls_dirty
775 );
776 self.pool.active_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
778 }
779 }
780}
781
782impl std::ops::Deref for PooledConnection {
783 type Target = PgConnection;
784
785 fn deref(&self) -> &Self::Target {
786 self.conn
789 .as_ref()
790 .expect("PooledConnection::deref called after release — this is a bug")
791 }
792}
793
794impl std::ops::DerefMut for PooledConnection {
795 fn deref_mut(&mut self) -> &mut Self::Target {
796 self.conn
799 .as_mut()
800 .expect("PooledConnection::deref_mut called after release — this is a bug")
801 }
802}
803
804const MAX_HOT_STATEMENTS: usize = 32;
806
807struct PgPoolInner {
809 config: PoolConfig,
810 connections: Mutex<Vec<PooledConn>>,
811 semaphore: Semaphore,
812 closed: AtomicBool,
813 active_count: AtomicUsize,
814 total_created: AtomicUsize,
815 hot_statements: std::sync::RwLock<std::collections::HashMap<u64, (String, String)>>,
819}
820
821impl PgPoolInner {
822 async fn return_connection(&self, conn: PgConnection) {
823
824 self.active_count.fetch_sub(1, Ordering::Relaxed);
825
826
827 if self.closed.load(Ordering::Relaxed) {
828 return;
829 }
830
831 let mut connections = self.connections.lock().await;
832 if connections.len() < self.config.max_connections {
833 connections.push(PooledConn {
834 conn,
835 created_at: Instant::now(),
836 last_used: Instant::now(),
837 });
838 }
839
840 self.semaphore.add_permits(1);
841 }
842
843 async fn get_healthy_connection(&self) -> Option<PgConnection> {
845 let mut connections = self.connections.lock().await;
846
847 while let Some(pooled) = connections.pop() {
848 if pooled.last_used.elapsed() > self.config.idle_timeout {
849 continue;
851 }
852
853 if let Some(max_life) = self.config.max_lifetime
854 && pooled.created_at.elapsed() > max_life
855 {
856 continue;
858 }
859
860 return Some(pooled.conn);
861 }
862
863 None
864 }
865}
866
867#[derive(Clone)]
878pub struct PgPool {
879 inner: Arc<PgPoolInner>,
880}
881
882impl PgPool {
883 pub async fn from_config() -> PgResult<Self> {
890 let qail = qail_core::config::QailConfig::load()
891 .map_err(|e| PgError::Connection(format!("Config error: {}", e)))?;
892 let config = PoolConfig::from_qail_config(&qail)?;
893 Self::connect(config).await
894 }
895
896 pub async fn connect(config: PoolConfig) -> PgResult<Self> {
898 let semaphore = Semaphore::new(config.max_connections);
900
901 let mut initial_connections = Vec::new();
902 for _ in 0..config.min_connections {
903 let conn = Self::create_connection(&config).await?;
904 initial_connections.push(PooledConn {
905 conn,
906 created_at: Instant::now(),
907 last_used: Instant::now(),
908 });
909 }
910
911 let initial_count = initial_connections.len();
912
913 let inner = Arc::new(PgPoolInner {
914 config,
915 connections: Mutex::new(initial_connections),
916 semaphore,
917 closed: AtomicBool::new(false),
918 active_count: AtomicUsize::new(0),
919 total_created: AtomicUsize::new(initial_count),
920 hot_statements: std::sync::RwLock::new(std::collections::HashMap::new()),
921 });
922
923 Ok(Self { inner })
924 }
925
926 pub async fn acquire_raw(&self) -> PgResult<PooledConnection> {
941 if self.inner.closed.load(Ordering::Relaxed) {
942 return Err(PgError::PoolClosed);
943 }
944
945 let acquire_timeout = self.inner.config.acquire_timeout;
947 let permit = tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire())
948 .await
949 .map_err(|_| {
950 PgError::Timeout(format!(
951 "pool acquire after {}s ({} max connections)",
952 acquire_timeout.as_secs(),
953 self.inner.config.max_connections
954 ))
955 })?
956 .map_err(|_| PgError::PoolClosed)?;
957 permit.forget();
958
959 let mut conn = if let Some(conn) = self.inner.get_healthy_connection().await {
961 conn
962 } else {
963 let conn = Self::create_connection(&self.inner.config).await?;
964 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
965 conn
966 };
967
968 let missing: Vec<(u64, String, String)> = {
971 if let Ok(hot) = self.inner.hot_statements.read() {
972 hot.iter()
973 .filter(|(hash, _)| !conn.stmt_cache.contains(hash))
974 .map(|(hash, (name, sql))| (*hash, name.clone(), sql.clone()))
975 .collect()
976 } else {
977 Vec::new()
978 }
979 }; if !missing.is_empty() {
982 use crate::protocol::PgEncoder;
983 let mut buf = bytes::BytesMut::new();
984 for (_, name, sql) in &missing {
985 let parse_msg = PgEncoder::encode_parse(name, sql, &[]);
986 buf.extend_from_slice(&parse_msg);
987 }
988 PgEncoder::encode_sync_to(&mut buf);
989 if conn.send_bytes(&buf).await.is_ok() {
990 loop {
992 match conn.recv().await {
993 Ok(crate::protocol::BackendMessage::ReadyForQuery(_)) => break,
994 Ok(_) => continue,
995 Err(_) => break,
996 }
997 }
998 for (hash, name, sql) in &missing {
1000 conn.stmt_cache.put(*hash, name.clone());
1001 conn.prepared_statements.insert(name.clone(), sql.clone());
1002 }
1003 }
1004 }
1005
1006 self.inner.active_count.fetch_add(1, Ordering::Relaxed);
1007
1008 Ok(PooledConnection {
1009 conn: Some(conn),
1010 pool: self.inner.clone(),
1011 rls_dirty: false,
1012 })
1013 }
1014
1015 pub async fn acquire_with_rls(
1031 &self,
1032 ctx: qail_core::rls::RlsContext,
1033 ) -> PgResult<PooledConnection> {
1034 let mut conn = self.acquire_raw().await?;
1036
1037 let sql = super::rls::context_to_sql(&ctx);
1039 let pg_conn = conn.get_mut();
1040 pg_conn.execute_simple(&sql).await?;
1041
1042 conn.rls_dirty = true;
1044
1045 Ok(conn)
1046 }
1047
1048 pub async fn acquire_with_rls_timeout(
1053 &self,
1054 ctx: qail_core::rls::RlsContext,
1055 timeout_ms: u32,
1056 ) -> PgResult<PooledConnection> {
1057 let mut conn = self.acquire_raw().await?;
1059
1060 let sql = super::rls::context_to_sql_with_timeout(&ctx, timeout_ms);
1062 let pg_conn = conn.get_mut();
1063 pg_conn.execute_simple(&sql).await?;
1064
1065 conn.rls_dirty = true;
1067
1068 Ok(conn)
1069 }
1070
1071 pub async fn acquire_system(&self) -> PgResult<PooledConnection> {
1081 let ctx = qail_core::rls::RlsContext::empty();
1082 self.acquire_with_rls(ctx).await
1083 }
1084
1085 pub async fn acquire_with_branch(
1099 &self,
1100 ctx: &qail_core::branch::BranchContext,
1101 ) -> PgResult<PooledConnection> {
1102 let mut conn = self.acquire_raw().await?;
1104
1105 if let Some(branch_name) = ctx.branch_name() {
1106 let sql = super::branch_sql::branch_context_sql(branch_name);
1107 let pg_conn = conn.get_mut();
1108 pg_conn.execute_simple(&sql).await?;
1109 conn.rls_dirty = true; }
1111
1112 Ok(conn)
1113 }
1114
1115 pub async fn idle_count(&self) -> usize {
1117 self.inner.connections.lock().await.len()
1118 }
1119
1120 pub fn active_count(&self) -> usize {
1122 self.inner.active_count.load(Ordering::Relaxed)
1123 }
1124
1125 pub fn max_connections(&self) -> usize {
1127 self.inner.config.max_connections
1128 }
1129
1130 pub async fn stats(&self) -> PoolStats {
1132 let idle = self.inner.connections.lock().await.len();
1133 PoolStats {
1134 active: self.inner.active_count.load(Ordering::Relaxed),
1135 idle,
1136 pending: self.inner.config.max_connections
1137 - self.inner.semaphore.available_permits()
1138 - self.active_count(),
1139 max_size: self.inner.config.max_connections,
1140 total_created: self.inner.total_created.load(Ordering::Relaxed),
1141 }
1142 }
1143
1144 pub fn is_closed(&self) -> bool {
1146 self.inner.closed.load(Ordering::Relaxed)
1147 }
1148
1149 pub async fn close(&self) {
1151 self.inner.closed.store(true, Ordering::Relaxed);
1152
1153 let mut connections = self.inner.connections.lock().await;
1154 connections.clear();
1155 }
1156
1157 async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
1159 match &config.password {
1160 Some(password) => {
1161 PgConnection::connect_with_password(
1162 &config.host,
1163 config.port,
1164 &config.user,
1165 &config.database,
1166 Some(password),
1167 )
1168 .await
1169 }
1170 None => {
1171 PgConnection::connect(&config.host, config.port, &config.user, &config.database)
1172 .await
1173 }
1174 }
1175 }
1176}
1177
1178#[cfg(test)]
1179mod tests {
1180 use super::*;
1181
1182 #[test]
1183 fn test_pool_config() {
1184 let config = PoolConfig::new("localhost", 5432, "user", "testdb")
1185 .password("secret123")
1186 .max_connections(20)
1187 .min_connections(5);
1188
1189 assert_eq!(config.host, "localhost");
1190 assert_eq!(config.port, 5432);
1191 assert_eq!(config.user, "user");
1192 assert_eq!(config.database, "testdb");
1193 assert_eq!(config.password, Some("secret123".to_string()));
1194 assert_eq!(config.max_connections, 20);
1195 assert_eq!(config.min_connections, 5);
1196 }
1197
1198 #[test]
1199 fn test_pool_config_defaults() {
1200 let config = PoolConfig::new("localhost", 5432, "user", "testdb");
1201 assert_eq!(config.max_connections, 10);
1202 assert_eq!(config.min_connections, 1);
1203 assert_eq!(config.idle_timeout, Duration::from_secs(600));
1204 assert_eq!(config.acquire_timeout, Duration::from_secs(30));
1205 assert_eq!(config.connect_timeout, Duration::from_secs(10));
1206 assert!(config.password.is_none());
1207 }
1208
1209 #[test]
1210 fn test_pool_config_builder_chaining() {
1211 let config = PoolConfig::new("db.example.com", 5433, "admin", "prod")
1212 .password("p@ss")
1213 .max_connections(50)
1214 .min_connections(10)
1215 .idle_timeout(Duration::from_secs(300))
1216 .acquire_timeout(Duration::from_secs(5))
1217 .connect_timeout(Duration::from_secs(3))
1218 .max_lifetime(Duration::from_secs(3600))
1219 .test_on_acquire(false);
1220
1221 assert_eq!(config.host, "db.example.com");
1222 assert_eq!(config.port, 5433);
1223 assert_eq!(config.max_connections, 50);
1224 assert_eq!(config.min_connections, 10);
1225 assert_eq!(config.idle_timeout, Duration::from_secs(300));
1226 assert_eq!(config.acquire_timeout, Duration::from_secs(5));
1227 assert_eq!(config.connect_timeout, Duration::from_secs(3));
1228 assert_eq!(config.max_lifetime, Some(Duration::from_secs(3600)));
1229 assert!(!config.test_on_acquire);
1230 }
1231
1232 #[test]
1233 fn test_timeout_error_display() {
1234 let err = PgError::Timeout("pool acquire after 30s (10 max connections)".to_string());
1235 let msg = err.to_string();
1236 assert!(msg.contains("Timeout"));
1237 assert!(msg.contains("30s"));
1238 assert!(msg.contains("10 max connections"));
1239 }
1240
1241 #[test]
1242 fn test_pool_closed_error_display() {
1243 let err = PgError::PoolClosed;
1244 assert_eq!(err.to_string(), "Connection pool is closed");
1245 }
1246
1247 #[test]
1248 fn test_pool_exhausted_error_display() {
1249 let err = PgError::PoolExhausted { max: 20 };
1250 let msg = err.to_string();
1251 assert!(msg.contains("exhausted"));
1252 assert!(msg.contains("20"));
1253 }
1254
1255 #[test]
1256 fn test_io_error_source_chaining() {
1257 use std::error::Error;
1258 let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "peer reset");
1259 let pg_err = PgError::Io(io_err);
1260 let source = pg_err.source().expect("Io variant should have source");
1262 assert!(source.to_string().contains("peer reset"));
1263 }
1264
1265 #[test]
1266 fn test_non_io_errors_have_no_source() {
1267 use std::error::Error;
1268 assert!(PgError::Connection("test".into()).source().is_none());
1269 assert!(PgError::Query("test".into()).source().is_none());
1270 assert!(PgError::Timeout("test".into()).source().is_none());
1271 assert!(PgError::PoolClosed.source().is_none());
1272 assert!(PgError::NoRows.source().is_none());
1273 }
1274
1275 #[test]
1276 fn test_io_error_from_conversion() {
1277 let io_err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "broken");
1278 let pg_err: PgError = io_err.into();
1279 assert!(matches!(pg_err, PgError::Io(_)));
1280 assert!(pg_err.to_string().contains("broken"));
1281 }
1282
1283 #[test]
1284 fn test_error_variants_are_distinct() {
1285 let errors: Vec<PgError> = vec![
1287 PgError::Connection("conn".into()),
1288 PgError::Protocol("proto".into()),
1289 PgError::Auth("auth".into()),
1290 PgError::Query("query".into()),
1291 PgError::NoRows,
1292 PgError::Io(std::io::Error::new(std::io::ErrorKind::Other, "io")),
1293 PgError::Encode("enc".into()),
1294 PgError::Timeout("timeout".into()),
1295 PgError::PoolExhausted { max: 10 },
1296 PgError::PoolClosed,
1297 ];
1298 for err in &errors {
1300 assert!(!err.to_string().is_empty());
1301 }
1302 assert_eq!(errors.len(), 10);
1303 }
1304}
1305