1use super::{PgConnection, PgError, PgResult};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{Mutex, Semaphore};
11
12#[derive(Clone)]
13pub struct PoolConfig {
14 pub host: String,
15 pub port: u16,
16 pub user: String,
17 pub database: String,
18 pub password: Option<String>,
19 pub max_connections: usize,
20 pub min_connections: usize,
21 pub idle_timeout: Duration,
22 pub acquire_timeout: Duration,
23 pub connect_timeout: Duration,
24 pub max_lifetime: Option<Duration>,
25 pub test_on_acquire: bool,
26}
27
28impl PoolConfig {
29 pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
31 Self {
32 host: host.to_string(),
33 port,
34 user: user.to_string(),
35 database: database.to_string(),
36 password: None,
37 max_connections: 10,
38 min_connections: 1,
39 idle_timeout: Duration::from_secs(600), acquire_timeout: Duration::from_secs(30), connect_timeout: Duration::from_secs(10), max_lifetime: None, test_on_acquire: false, }
45 }
46
47 pub fn password(mut self, password: &str) -> Self {
49 self.password = Some(password.to_string());
50 self
51 }
52
53 pub fn max_connections(mut self, max: usize) -> Self {
54 self.max_connections = max;
55 self
56 }
57
58 pub fn min_connections(mut self, min: usize) -> Self {
60 self.min_connections = min;
61 self
62 }
63
64 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
66 self.idle_timeout = timeout;
67 self
68 }
69
70 pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
72 self.acquire_timeout = timeout;
73 self
74 }
75
76 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
78 self.connect_timeout = timeout;
79 self
80 }
81
82 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
84 self.max_lifetime = Some(lifetime);
85 self
86 }
87
88 pub fn test_on_acquire(mut self, enabled: bool) -> Self {
90 self.test_on_acquire = enabled;
91 self
92 }
93
94 pub fn from_qail_config(qail: &qail_core::config::QailConfig) -> PgResult<Self> {
99 let pg = &qail.postgres;
100 let (host, port, user, database, password) = parse_pg_url(&pg.url)?;
101
102 let mut config = PoolConfig::new(&host, port, &user, &database)
103 .max_connections(pg.max_connections)
104 .min_connections(pg.min_connections)
105 .idle_timeout(Duration::from_secs(pg.idle_timeout_secs))
106 .acquire_timeout(Duration::from_secs(pg.acquire_timeout_secs))
107 .connect_timeout(Duration::from_secs(pg.connect_timeout_secs))
108 .test_on_acquire(pg.test_on_acquire);
109
110 if let Some(ref pw) = password {
111 config = config.password(pw);
112 }
113
114 Ok(config)
115 }
116}
117
118fn parse_pg_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
120 let url = url.trim_start_matches("postgres://").trim_start_matches("postgresql://");
121
122 let (credentials, host_part) = if url.contains('@') {
123 let mut parts = url.splitn(2, '@');
124 let creds = parts.next().unwrap_or("");
125 let host = parts.next().unwrap_or("localhost/postgres");
126 (Some(creds), host)
127 } else {
128 (None, url)
129 };
130
131 let (host_port, database) = if host_part.contains('/') {
132 let mut parts = host_part.splitn(2, '/');
133 (parts.next().unwrap_or("localhost"), parts.next().unwrap_or("postgres").to_string())
134 } else {
135 (host_part, "postgres".to_string())
136 };
137
138 let (host, port) = if host_port.contains(':') {
139 let mut parts = host_port.split(':');
140 let h = parts.next().unwrap_or("localhost").to_string();
141 let p = parts.next().and_then(|s| s.parse().ok()).unwrap_or(5432u16);
142 (h, p)
143 } else {
144 (host_port.to_string(), 5432u16)
145 };
146
147 let (user, password) = if let Some(creds) = credentials {
148 if creds.contains(':') {
149 let mut parts = creds.splitn(2, ':');
150 let u = parts.next().unwrap_or("postgres").to_string();
151 let p = parts.next().map(|s| s.to_string());
152 (u, p)
153 } else {
154 (creds.to_string(), None)
155 }
156 } else {
157 ("postgres".to_string(), None)
158 };
159
160 Ok((host, port, user, database, password))
161}
162
163#[derive(Debug, Clone, Default)]
165pub struct PoolStats {
166 pub active: usize,
167 pub idle: usize,
168 pub pending: usize,
169 pub max_size: usize,
171 pub total_created: usize,
172}
173
174struct PooledConn {
176 conn: PgConnection,
177 created_at: Instant,
178 last_used: Instant,
179}
180
181pub struct PooledConnection {
187 conn: Option<PgConnection>,
188 pool: Arc<PgPoolInner>,
189 rls_dirty: bool,
190}
191
192impl PooledConnection {
193 pub fn get_mut(&mut self) -> &mut PgConnection {
195 self.conn
196 .as_mut()
197 .expect("Connection should always be present")
198 }
199
200 pub fn cancel_token(&self) -> crate::driver::CancelToken {
202 let (process_id, secret_key) = self.conn.as_ref().expect("Connection missing").get_cancel_key();
203 crate::driver::CancelToken {
204 host: self.pool.config.host.clone(),
205 port: self.pool.config.port,
206 process_id,
207 secret_key,
208 }
209 }
210
211 pub async fn release(mut self) {
228 if let Some(mut conn) = self.conn.take() {
229 if let Err(e) = conn.execute_simple(super::rls::reset_sql()).await {
234 eprintln!(
235 "[CRITICAL] pool_release_failed: COMMIT failed — \
236 dropping connection to prevent state leak: {}",
237 e
238 );
239 return; }
241
242 self.pool.return_connection(conn).await;
243 }
244 }
245
246 pub async fn fetch_all_uncached(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
249 use crate::protocol::AstEncoder;
250 use super::ColumnInfo;
251
252 let conn = self.conn.as_mut().expect("Connection should always be present");
253
254 let wire_bytes = AstEncoder::encode_cmd_reuse(
255 cmd,
256 &mut conn.sql_buf,
257 &mut conn.params_buf,
258 )
259 .map_err(|e| PgError::Encode(e.to_string()))?;
260
261 conn.send_bytes(&wire_bytes).await?;
262
263 let mut rows: Vec<super::PgRow> = Vec::new();
264 let mut column_info: Option<Arc<ColumnInfo>> = None;
265 let mut error: Option<PgError> = None;
266
267 loop {
268 let msg = conn.recv().await?;
269 match msg {
270 crate::protocol::BackendMessage::ParseComplete
271 | crate::protocol::BackendMessage::BindComplete => {}
272 crate::protocol::BackendMessage::RowDescription(fields) => {
273 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
274 }
275 crate::protocol::BackendMessage::DataRow(data) => {
276 if error.is_none() {
277 rows.push(super::PgRow {
278 columns: data,
279 column_info: column_info.clone(),
280 });
281 }
282 }
283 crate::protocol::BackendMessage::CommandComplete(_) => {}
284 crate::protocol::BackendMessage::ReadyForQuery(_) => {
285 if let Some(err) = error {
286 return Err(err);
287 }
288 return Ok(rows);
289 }
290 crate::protocol::BackendMessage::ErrorResponse(err) => {
291 if error.is_none() {
292 error = Some(PgError::Query(err.message));
293 }
294 }
295 _ => {}
296 }
297 }
298 }
299
300 pub async fn fetch_all_fast(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
304 use crate::protocol::AstEncoder;
305
306 let conn = self.conn.as_mut().expect("Connection should always be present");
307
308 AstEncoder::encode_cmd_reuse_into(
309 cmd,
310 &mut conn.sql_buf,
311 &mut conn.params_buf,
312 &mut conn.write_buf,
313 )
314 .map_err(|e| PgError::Encode(e.to_string()))?;
315
316 conn.flush_write_buf().await?;
317
318 let mut rows: Vec<super::PgRow> = Vec::with_capacity(32);
319 let mut error: Option<PgError> = None;
320
321 loop {
322 let res = conn.recv_with_data_fast().await;
323 match res {
324 Ok((msg_type, data)) => {
325 match msg_type {
326 b'D' => {
327 if error.is_none() && let Some(columns) = data {
328 rows.push(super::PgRow {
329 columns,
330 column_info: None,
331 });
332 }
333 }
334 b'Z' => {
335 if let Some(err) = error {
336 return Err(err);
337 }
338 return Ok(rows);
339 }
340 _ => {}
341 }
342 }
343 Err(e) => {
344 if error.is_none() {
345 error = Some(e);
346 }
347 }
348 }
349 }
350 }
351
352 pub async fn fetch_all_cached(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
357 use super::ColumnInfo;
358 use std::collections::hash_map::DefaultHasher;
359 use std::hash::{Hash, Hasher};
360
361 let conn = self.conn.as_mut().expect("Connection should always be present");
362
363 conn.sql_buf.clear();
364 conn.params_buf.clear();
365
366 match cmd.action {
368 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
369 crate::protocol::ast_encoder::dml::encode_select(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
370 }
371 qail_core::ast::Action::Add => {
372 crate::protocol::ast_encoder::dml::encode_insert(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
373 }
374 qail_core::ast::Action::Set => {
375 crate::protocol::ast_encoder::dml::encode_update(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
376 }
377 qail_core::ast::Action::Del => {
378 crate::protocol::ast_encoder::dml::encode_delete(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
379 }
380 _ => {
381 return self.fetch_all_uncached(cmd).await;
383 }
384 }
385
386 let mut hasher = DefaultHasher::new();
387 conn.sql_buf.hash(&mut hasher);
388 let sql_hash = hasher.finish();
389
390 let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
391
392 conn.write_buf.clear();
393
394 let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
395 name.clone()
396 } else {
397 let name = format!("qail_{:x}", sql_hash);
398
399 conn.evict_prepared_if_full();
400
401 let sql_str = std::str::from_utf8(&conn.sql_buf).unwrap_or("");
402
403 use crate::protocol::PgEncoder;
404 let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
405 let describe_msg = PgEncoder::encode_describe(false, &name);
406 conn.write_buf.extend_from_slice(&parse_msg);
407 conn.write_buf.extend_from_slice(&describe_msg);
408
409 conn.stmt_cache.put(sql_hash, name.clone());
410 conn.prepared_statements.insert(name.clone(), sql_str.to_string());
411
412 if let Ok(mut hot) = self.pool.hot_statements.write()
414 && hot.len() < MAX_HOT_STATEMENTS
415 {
416 hot.insert(sql_hash, (name.clone(), sql_str.to_string()));
417 }
418
419 name
420 };
421
422 use crate::protocol::PgEncoder;
423 PgEncoder::encode_bind_to(&mut conn.write_buf, &stmt_name, &conn.params_buf)
424 .map_err(|e| PgError::Encode(e.to_string()))?;
425 PgEncoder::encode_execute_to(&mut conn.write_buf);
426 PgEncoder::encode_sync_to(&mut conn.write_buf);
427
428 conn.flush_write_buf().await?;
429
430 let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
431
432 let mut rows: Vec<super::PgRow> = Vec::with_capacity(32);
433 let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
434 let mut error: Option<PgError> = None;
435
436 loop {
437 let msg = conn.recv().await?;
438 match msg {
439 crate::protocol::BackendMessage::ParseComplete
440 | crate::protocol::BackendMessage::BindComplete => {}
441 crate::protocol::BackendMessage::ParameterDescription(_) => {}
442 crate::protocol::BackendMessage::RowDescription(fields) => {
443 let info = Arc::new(ColumnInfo::from_fields(&fields));
444 if is_cache_miss {
445 conn.column_info_cache.insert(sql_hash, info.clone());
446 }
447 column_info = Some(info);
448 }
449 crate::protocol::BackendMessage::DataRow(data) => {
450 if error.is_none() {
451 rows.push(super::PgRow {
452 columns: data,
453 column_info: column_info.clone(),
454 });
455 }
456 }
457 crate::protocol::BackendMessage::CommandComplete(_) => {}
458 crate::protocol::BackendMessage::ReadyForQuery(_) => {
459 if let Some(err) = error {
460 return Err(err);
461 }
462 return Ok(rows);
463 }
464 crate::protocol::BackendMessage::ErrorResponse(err) => {
465 if error.is_none() {
466 error = Some(PgError::Query(err.message));
467 }
468 }
469 _ => {}
470 }
471 }
472 }
473
474 pub async fn fetch_all_with_rls(
493 &mut self,
494 cmd: &qail_core::ast::Qail,
495 rls_sql: &str,
496 ) -> PgResult<Vec<super::PgRow>> {
497 use super::ColumnInfo;
498 use std::collections::hash_map::DefaultHasher;
499 use std::hash::{Hash, Hasher};
500
501 let conn = self.conn.as_mut().expect("Connection should always be present");
502
503 conn.sql_buf.clear();
504 conn.params_buf.clear();
505
506 match cmd.action {
508 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
509 crate::protocol::ast_encoder::dml::encode_select(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
510 }
511 qail_core::ast::Action::Add => {
512 crate::protocol::ast_encoder::dml::encode_insert(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
513 }
514 qail_core::ast::Action::Set => {
515 crate::protocol::ast_encoder::dml::encode_update(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
516 }
517 qail_core::ast::Action::Del => {
518 crate::protocol::ast_encoder::dml::encode_delete(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
519 }
520 _ => {
521 conn.execute_simple(rls_sql).await?;
523 self.rls_dirty = true;
524 return self.fetch_all_uncached(cmd).await;
525 }
526 }
527
528 let mut hasher = DefaultHasher::new();
529 conn.sql_buf.hash(&mut hasher);
530 let sql_hash = hasher.finish();
531
532 let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
533
534 conn.write_buf.clear();
535
536 let rls_msg = crate::protocol::PgEncoder::encode_query_string(rls_sql);
540 conn.write_buf.extend_from_slice(&rls_msg);
541
542 let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
544 name.clone()
545 } else {
546 let name = format!("qail_{:x}", sql_hash);
547
548 conn.evict_prepared_if_full();
549
550 let sql_str = std::str::from_utf8(&conn.sql_buf).unwrap_or("");
551
552 use crate::protocol::PgEncoder;
553 let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
554 let describe_msg = PgEncoder::encode_describe(false, &name);
555 conn.write_buf.extend_from_slice(&parse_msg);
556 conn.write_buf.extend_from_slice(&describe_msg);
557
558 conn.stmt_cache.put(sql_hash, name.clone());
559 conn.prepared_statements.insert(name.clone(), sql_str.to_string());
560
561 if let Ok(mut hot) = self.pool.hot_statements.write()
562 && hot.len() < MAX_HOT_STATEMENTS
563 {
564 hot.insert(sql_hash, (name.clone(), sql_str.to_string()));
565 }
566
567 name
568 };
569
570 use crate::protocol::PgEncoder;
571 PgEncoder::encode_bind_to(&mut conn.write_buf, &stmt_name, &conn.params_buf)
572 .map_err(|e| PgError::Encode(e.to_string()))?;
573 PgEncoder::encode_execute_to(&mut conn.write_buf);
574 PgEncoder::encode_sync_to(&mut conn.write_buf);
575
576 conn.flush_write_buf().await?;
578
579 self.rls_dirty = true;
581
582 let mut rls_error: Option<PgError> = None;
586 loop {
587 let msg = conn.recv().await?;
588 match msg {
589 crate::protocol::BackendMessage::ReadyForQuery(_) => {
590 if let Some(err) = rls_error {
592 return Err(err);
593 }
594 break;
595 }
596 crate::protocol::BackendMessage::ErrorResponse(err) => {
597 if rls_error.is_none() {
598 rls_error = Some(PgError::Query(err.message));
599 }
600 }
601 _ => {}
603 }
604 }
605
606 let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
608
609 let mut rows: Vec<super::PgRow> = Vec::with_capacity(32);
610 let mut column_info: Option<std::sync::Arc<ColumnInfo>> = cached_column_info;
611 let mut error: Option<PgError> = None;
612
613 loop {
614 let msg = conn.recv().await?;
615 match msg {
616 crate::protocol::BackendMessage::ParseComplete
617 | crate::protocol::BackendMessage::BindComplete => {}
618 crate::protocol::BackendMessage::ParameterDescription(_) => {}
619 crate::protocol::BackendMessage::RowDescription(fields) => {
620 let info = std::sync::Arc::new(ColumnInfo::from_fields(&fields));
621 if is_cache_miss {
622 conn.column_info_cache.insert(sql_hash, info.clone());
623 }
624 column_info = Some(info);
625 }
626 crate::protocol::BackendMessage::DataRow(data) => {
627 if error.is_none() {
628 rows.push(super::PgRow {
629 columns: data,
630 column_info: column_info.clone(),
631 });
632 }
633 }
634 crate::protocol::BackendMessage::CommandComplete(_) => {}
635 crate::protocol::BackendMessage::ReadyForQuery(_) => {
636 if let Some(err) = error {
637 return Err(err);
638 }
639 return Ok(rows);
640 }
641 crate::protocol::BackendMessage::ErrorResponse(err) => {
642 if error.is_none() {
643 error = Some(PgError::Query(err.message));
644 }
645 }
646 _ => {}
647 }
648 }
649 }
650
651 pub async fn pipeline_ast(
659 &mut self,
660 cmds: &[qail_core::ast::Qail],
661 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
662 let conn = self.conn.as_mut().expect("Connection should always be present");
663 conn.pipeline_ast(cmds).await
664 }
665
666 pub async fn explain_estimate(
672 &mut self,
673 cmd: &qail_core::ast::Qail,
674 ) -> PgResult<Option<super::explain::ExplainEstimate>> {
675 use qail_core::transpiler::ToSql;
676
677 let sql = cmd.to_sql();
678 let explain_sql = format!("EXPLAIN (FORMAT JSON) {}", sql);
679
680 let rows = self.simple_query(&explain_sql).await?;
681
682 let mut json_output = String::new();
684 for row in &rows {
685 if let Some(Some(val)) = row.columns.first()
686 && let Ok(text) = std::str::from_utf8(val)
687 {
688 json_output.push_str(text);
689 }
690 }
691
692 Ok(super::explain::parse_explain_json(&json_output))
693 }
694}
695
696impl Drop for PooledConnection {
697 fn drop(&mut self) {
698 if self.conn.is_some() {
699 eprintln!(
711 "[WARN] pool_connection_leaked: PooledConnection dropped without release() — \
712 connection destroyed to prevent state leak (rls_dirty={}). \
713 Use conn.release().await for deterministic cleanup.",
714 self.rls_dirty
715 );
716 self.pool.active_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
718 }
719 }
720}
721
722impl std::ops::Deref for PooledConnection {
723 type Target = PgConnection;
724
725 fn deref(&self) -> &Self::Target {
726 self.conn
727 .as_ref()
728 .expect("Connection should always be present")
729 }
730}
731
732impl std::ops::DerefMut for PooledConnection {
733 fn deref_mut(&mut self) -> &mut Self::Target {
734 self.conn
735 .as_mut()
736 .expect("Connection should always be present")
737 }
738}
739
740const MAX_HOT_STATEMENTS: usize = 32;
742
743struct PgPoolInner {
745 config: PoolConfig,
746 connections: Mutex<Vec<PooledConn>>,
747 semaphore: Semaphore,
748 closed: AtomicBool,
749 active_count: AtomicUsize,
750 total_created: AtomicUsize,
751 hot_statements: std::sync::RwLock<std::collections::HashMap<u64, (String, String)>>,
755}
756
757impl PgPoolInner {
758 async fn return_connection(&self, conn: PgConnection) {
759
760 self.active_count.fetch_sub(1, Ordering::Relaxed);
761
762
763 if self.closed.load(Ordering::Relaxed) {
764 return;
765 }
766
767 let mut connections = self.connections.lock().await;
768 if connections.len() < self.config.max_connections {
769 connections.push(PooledConn {
770 conn,
771 created_at: Instant::now(),
772 last_used: Instant::now(),
773 });
774 }
775
776 self.semaphore.add_permits(1);
777 }
778
779 async fn get_healthy_connection(&self) -> Option<PgConnection> {
781 let mut connections = self.connections.lock().await;
782
783 while let Some(pooled) = connections.pop() {
784 if pooled.last_used.elapsed() > self.config.idle_timeout {
785 continue;
787 }
788
789 if let Some(max_life) = self.config.max_lifetime
790 && pooled.created_at.elapsed() > max_life
791 {
792 continue;
794 }
795
796 return Some(pooled.conn);
797 }
798
799 None
800 }
801}
802
803#[derive(Clone)]
814pub struct PgPool {
815 inner: Arc<PgPoolInner>,
816}
817
818impl PgPool {
819 pub async fn from_config() -> PgResult<Self> {
826 let qail = qail_core::config::QailConfig::load()
827 .map_err(|e| PgError::Connection(format!("Config error: {}", e)))?;
828 let config = PoolConfig::from_qail_config(&qail)?;
829 Self::connect(config).await
830 }
831
832 pub async fn connect(config: PoolConfig) -> PgResult<Self> {
834 let semaphore = Semaphore::new(config.max_connections);
836
837 let mut initial_connections = Vec::new();
838 for _ in 0..config.min_connections {
839 let conn = Self::create_connection(&config).await?;
840 initial_connections.push(PooledConn {
841 conn,
842 created_at: Instant::now(),
843 last_used: Instant::now(),
844 });
845 }
846
847 let initial_count = initial_connections.len();
848
849 let inner = Arc::new(PgPoolInner {
850 config,
851 connections: Mutex::new(initial_connections),
852 semaphore,
853 closed: AtomicBool::new(false),
854 active_count: AtomicUsize::new(0),
855 total_created: AtomicUsize::new(initial_count),
856 hot_statements: std::sync::RwLock::new(std::collections::HashMap::new()),
857 });
858
859 Ok(Self { inner })
860 }
861
862 pub async fn acquire_raw(&self) -> PgResult<PooledConnection> {
877 if self.inner.closed.load(Ordering::Relaxed) {
878 return Err(PgError::Connection("Pool is closed".to_string()));
879 }
880
881 let acquire_timeout = self.inner.config.acquire_timeout;
883 let permit = tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire())
884 .await
885 .map_err(|_| {
886 PgError::Connection(format!(
887 "Timed out waiting for connection ({}s)",
888 acquire_timeout.as_secs()
889 ))
890 })?
891 .map_err(|_| PgError::Connection("Pool closed".to_string()))?;
892 permit.forget();
893
894 let mut conn = if let Some(conn) = self.inner.get_healthy_connection().await {
896 conn
897 } else {
898 let conn = Self::create_connection(&self.inner.config).await?;
899 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
900 conn
901 };
902
903 let missing: Vec<(u64, String, String)> = {
906 if let Ok(hot) = self.inner.hot_statements.read() {
907 hot.iter()
908 .filter(|(hash, _)| !conn.stmt_cache.contains(hash))
909 .map(|(hash, (name, sql))| (*hash, name.clone(), sql.clone()))
910 .collect()
911 } else {
912 Vec::new()
913 }
914 }; if !missing.is_empty() {
917 use crate::protocol::PgEncoder;
918 let mut buf = bytes::BytesMut::new();
919 for (_, name, sql) in &missing {
920 let parse_msg = PgEncoder::encode_parse(name, sql, &[]);
921 buf.extend_from_slice(&parse_msg);
922 }
923 PgEncoder::encode_sync_to(&mut buf);
924 if conn.send_bytes(&buf).await.is_ok() {
925 loop {
927 match conn.recv().await {
928 Ok(crate::protocol::BackendMessage::ReadyForQuery(_)) => break,
929 Ok(_) => continue,
930 Err(_) => break,
931 }
932 }
933 for (hash, name, sql) in &missing {
935 conn.stmt_cache.put(*hash, name.clone());
936 conn.prepared_statements.insert(name.clone(), sql.clone());
937 }
938 }
939 }
940
941 self.inner.active_count.fetch_add(1, Ordering::Relaxed);
942
943 Ok(PooledConnection {
944 conn: Some(conn),
945 pool: self.inner.clone(),
946 rls_dirty: false,
947 })
948 }
949
950 pub async fn acquire_with_rls(
966 &self,
967 ctx: qail_core::rls::RlsContext,
968 ) -> PgResult<PooledConnection> {
969 let mut conn = self.acquire_raw().await?;
971
972 let sql = super::rls::context_to_sql(&ctx);
974 let pg_conn = conn.get_mut();
975 pg_conn.execute_simple(&sql).await?;
976
977 conn.rls_dirty = true;
979
980 Ok(conn)
981 }
982
983 pub async fn acquire_with_rls_timeout(
988 &self,
989 ctx: qail_core::rls::RlsContext,
990 timeout_ms: u32,
991 ) -> PgResult<PooledConnection> {
992 let mut conn = self.acquire_raw().await?;
994
995 let sql = super::rls::context_to_sql_with_timeout(&ctx, timeout_ms);
997 let pg_conn = conn.get_mut();
998 pg_conn.execute_simple(&sql).await?;
999
1000 conn.rls_dirty = true;
1002
1003 Ok(conn)
1004 }
1005
1006 pub async fn acquire_system(&self) -> PgResult<PooledConnection> {
1016 let ctx = qail_core::rls::RlsContext::empty();
1017 self.acquire_with_rls(ctx).await
1018 }
1019
1020 pub async fn acquire_with_branch(
1034 &self,
1035 ctx: &qail_core::branch::BranchContext,
1036 ) -> PgResult<PooledConnection> {
1037 let mut conn = self.acquire_raw().await?;
1039
1040 if let Some(branch_name) = ctx.branch_name() {
1041 let sql = super::branch_sql::branch_context_sql(branch_name);
1042 let pg_conn = conn.get_mut();
1043 pg_conn.execute_simple(&sql).await?;
1044 conn.rls_dirty = true; }
1046
1047 Ok(conn)
1048 }
1049
1050 pub async fn idle_count(&self) -> usize {
1052 self.inner.connections.lock().await.len()
1053 }
1054
1055 pub fn active_count(&self) -> usize {
1057 self.inner.active_count.load(Ordering::Relaxed)
1058 }
1059
1060 pub fn max_connections(&self) -> usize {
1062 self.inner.config.max_connections
1063 }
1064
1065 pub async fn stats(&self) -> PoolStats {
1067 let idle = self.inner.connections.lock().await.len();
1068 PoolStats {
1069 active: self.inner.active_count.load(Ordering::Relaxed),
1070 idle,
1071 pending: self.inner.config.max_connections
1072 - self.inner.semaphore.available_permits()
1073 - self.active_count(),
1074 max_size: self.inner.config.max_connections,
1075 total_created: self.inner.total_created.load(Ordering::Relaxed),
1076 }
1077 }
1078
1079 pub fn is_closed(&self) -> bool {
1081 self.inner.closed.load(Ordering::Relaxed)
1082 }
1083
1084 pub async fn close(&self) {
1086 self.inner.closed.store(true, Ordering::Relaxed);
1087
1088 let mut connections = self.inner.connections.lock().await;
1089 connections.clear();
1090 }
1091
1092 async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
1094 match &config.password {
1095 Some(password) => {
1096 PgConnection::connect_with_password(
1097 &config.host,
1098 config.port,
1099 &config.user,
1100 &config.database,
1101 Some(password),
1102 )
1103 .await
1104 }
1105 None => {
1106 PgConnection::connect(&config.host, config.port, &config.user, &config.database)
1107 .await
1108 }
1109 }
1110 }
1111}
1112
1113#[cfg(test)]
1114mod tests {
1115 use super::*;
1116
1117 #[test]
1118 fn test_pool_config() {
1119 let config = PoolConfig::new("localhost", 5432, "user", "testdb")
1120 .password("secret123")
1121 .max_connections(20)
1122 .min_connections(5);
1123
1124 assert_eq!(config.host, "localhost");
1125 assert_eq!(config.port, 5432);
1126 assert_eq!(config.user, "user");
1127 assert_eq!(config.database, "testdb");
1128 assert_eq!(config.password, Some("secret123".to_string()));
1129 assert_eq!(config.max_connections, 20);
1130 assert_eq!(config.min_connections, 5);
1131 }
1132}