1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use crate::arena::Arena;
14use crate::codec::Encode;
15use crate::conn::Connection;
16use crate::types::{Config, PgDataRow, QueryResult, SimpleRow, StatementCacheMode};
17use crate::DriverError;
18
19#[cfg(feature = "async")]
20use crate::async_conn::AsyncConnection;
21
22pub(crate) enum PoolSlot {
30 Sync(Connection),
32 #[cfg(feature = "async")]
34 Async(AsyncConnection),
35}
36
37#[cfg(feature = "detect-n-plus-one")]
43pub(crate) struct NPlusOneDetector {
44 last_query_hash: u64,
45 repeat_count: u16,
46 threshold: u16,
47}
48
49#[cfg(feature = "detect-n-plus-one")]
50impl NPlusOneDetector {
51 pub(crate) fn new(threshold: u16) -> Self {
53 Self {
54 last_query_hash: 0,
55 repeat_count: 0,
56 threshold,
57 }
58 }
59
60 #[inline]
62 pub(crate) fn track(&mut self, sql_hash: u64) {
63 if sql_hash == self.last_query_hash {
64 self.repeat_count = self.repeat_count.saturating_add(1);
65 } else {
66 self.emit_warning();
68 self.last_query_hash = sql_hash;
69 self.repeat_count = 1;
70 }
71 }
72
73 pub(crate) fn check_final(&self) -> Option<(u64, u16)> {
76 if self.repeat_count > self.threshold && self.last_query_hash != 0 {
77 Some((self.last_query_hash, self.repeat_count))
78 } else {
79 None
80 }
81 }
82
83 #[cold]
85 #[inline(never)]
86 fn emit_warning(&self) {
87 if let Some((hash, count)) = self.check_final() {
88 log::warn!(
89 "[bsql] potential N+1 detected: sql_hash={:#018x} repeated {} times (threshold: {})",
90 hash,
91 count,
92 self.threshold,
93 );
94 }
95 }
96
97 #[cold]
99 #[inline(never)]
100 pub(crate) fn emit_final_warning(&self) {
101 self.emit_warning();
102 }
103}
104
105pub struct Pool {
121 inner: Arc<PoolInner>,
122}
123
124struct PoolInner {
125 stack: std::sync::Mutex<Vec<PoolSlot>>,
129 max_size: usize,
130 open_count: AtomicUsize,
131 config: Arc<Config>,
132 closed: AtomicBool,
134 release_pair: (std::sync::Mutex<()>, std::sync::Condvar),
137 max_lifetime: Option<Duration>,
140 acquire_timeout: Option<Duration>,
142 min_idle: usize,
144 warmup_sqls: std::sync::Mutex<Arc<Vec<Box<str>>>>,
146 max_stmt_cache_size: usize,
148 stale_timeout: Duration,
151 #[cfg(feature = "detect-n-plus-one")]
154 n_plus_one_threshold: u16,
155}
156
157impl Pool {
158 pub fn connect(url: &str) -> Result<Self, DriverError> {
162 PoolBuilder::new().url(url).build()
163 }
164
165 pub fn builder() -> PoolBuilder {
167 PoolBuilder::new()
168 }
169
170 #[inline]
178 pub fn acquire(&self) -> Result<PoolGuard, DriverError> {
179 if self.inner.closed.load(Ordering::Acquire) {
180 return Err(DriverError::Pool("pool is closed".into()));
181 }
182
183 if let Some(guard) = self.try_pop_idle()? {
185 return Ok(guard);
186 }
187
188 loop {
190 let current = self.inner.open_count.load(Ordering::Acquire);
191 if current >= self.inner.max_size {
192 if let Some(timeout) = self.inner.acquire_timeout {
193 let (lock, cvar) = &self.inner.release_pair;
194 let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
195 let (_guard, result) = cvar
196 .wait_timeout(guard, timeout)
197 .unwrap_or_else(|e| e.into_inner());
198 if result.timed_out() {
199 return Err(DriverError::Pool(
200 "pool exhausted: acquire timeout expired".into(),
201 ));
202 }
203 if let Some(guard) = self.try_pop_idle()? {
205 return Ok(guard);
206 }
207 continue;
209 }
210 return Err(DriverError::Pool(
211 "pool exhausted: all connections in use".into(),
212 ));
213 }
214 if self
215 .inner
216 .open_count
217 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
218 .is_ok()
219 {
220 break;
221 }
222 }
224
225 let conn_result = Connection::connect_arc(self.inner.config.clone());
227 match conn_result {
228 Ok(mut conn) => {
229 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
231 self.warmup_conn(&mut conn);
233
234 Ok(PoolGuard {
235 conn: Some(PoolSlot::Sync(conn)),
236 pool: self.inner.clone(),
237 discard: false,
238 #[cfg(feature = "detect-n-plus-one")]
239 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
240 })
241 }
242 Err(e) => {
243 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
245 Err(e)
246 }
247 }
248 }
249
250 #[inline]
256 fn try_pop_idle(&self) -> Result<Option<PoolGuard>, DriverError> {
257 loop {
261 let (mut slot, needs_health_check) = {
262 let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
263 loop {
264 let Some(slot) = stack.pop() else {
265 return Ok(None);
266 };
267 let (created_at, idle_dur) = match &slot {
268 PoolSlot::Sync(conn) => (conn.created_at(), conn.idle_duration()),
269 #[cfg(feature = "async")]
270 PoolSlot::Async(conn) => (conn.created_at(), conn.idle_duration()),
271 };
272 if let Some(max_lifetime) = self.inner.max_lifetime {
273 if created_at.elapsed() >= max_lifetime {
274 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
275 continue;
276 }
277 }
278 if idle_dur >= self.inner.stale_timeout {
279 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
281 continue;
282 }
283 break (slot, idle_dur > Duration::from_secs(5));
284 }
285 };
286 if needs_health_check {
290 let alive = match &mut slot {
291 PoolSlot::Sync(conn) => conn.simple_query("").is_ok(),
292 #[cfg(feature = "async")]
293 PoolSlot::Async(_) => true, };
295 if !alive {
296 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
297 continue; }
299 }
300 return Ok(Some(PoolGuard {
301 conn: Some(slot),
302 pool: self.inner.clone(),
303 discard: false,
304 #[cfg(feature = "detect-n-plus-one")]
305 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
306 }));
307 }
308 }
309
310 pub fn is_uds(&self) -> bool {
315 #[cfg(unix)]
316 {
317 self.inner.config.host_is_uds()
318 }
319 #[cfg(not(unix))]
320 {
321 false
322 }
323 }
324
325 pub fn begin(&self) -> Result<Transaction, DriverError> {
327 let mut guard = self.acquire()?;
328 guard.simple_query("BEGIN")?;
329 Ok(Transaction {
330 guard,
331 committed: false,
332 deferred_buf: Vec::new(),
333 deferred_count: 0,
334 })
335 }
336
337 pub fn open_count(&self) -> usize {
339 self.inner.open_count.load(Ordering::Relaxed)
340 }
341
342 pub fn max_size(&self) -> usize {
344 self.inner.max_size
345 }
346
347 pub fn status(&self) -> PoolStatus {
349 let idle = self
350 .inner
351 .stack
352 .lock()
353 .unwrap_or_else(|e| e.into_inner())
354 .len();
355 let open = self.inner.open_count.load(Ordering::Relaxed);
356 let active = open.saturating_sub(idle);
357 PoolStatus {
358 idle,
359 active,
360 open,
361 max_size: self.inner.max_size,
362 }
363 }
364
365 fn warmup_conn(&self, conn: &mut Connection) {
373 let sqls = self
374 .inner
375 .warmup_sqls
376 .lock()
377 .unwrap_or_else(|e| e.into_inner())
378 .clone();
379
380 if sqls.is_empty() {
381 return;
382 }
383
384 let batch: Vec<(&str, u64)> = sqls
385 .iter()
386 .map(|sql| (sql.as_ref(), crate::types::hash_sql(sql)))
387 .collect();
388
389 let _ = conn.prepare_batch(&batch);
390 }
391
392 pub fn set_warmup_sqls<S: Into<Box<str>>>(&self, sqls: impl IntoIterator<Item = S>) {
419 let boxed: Arc<Vec<Box<str>>> = Arc::new(sqls.into_iter().map(Into::into).collect());
420 *self
421 .inner
422 .warmup_sqls
423 .lock()
424 .unwrap_or_else(|e| e.into_inner()) = boxed;
425 }
426
427 pub fn close(&self) {
430 self.inner.closed.store(true, Ordering::Release);
431 let slots: Vec<PoolSlot> = {
433 let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
434 std::mem::take(&mut *stack)
435 };
436 for slot in slots {
437 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
438 match slot {
439 PoolSlot::Sync(conn) => {
440 let _ = conn.close();
441 }
442 #[cfg(feature = "async")]
443 PoolSlot::Async(_conn) => {
444 }
447 }
448 }
449 let (_, cvar) = &self.inner.release_pair;
451 cvar.notify_all();
452 }
453
454 pub fn is_closed(&self) -> bool {
456 self.inner.closed.load(Ordering::Acquire)
457 }
458
459 #[cfg(feature = "async")]
469 pub async fn acquire_async(&self) -> Result<PoolGuard, DriverError> {
470 if self.inner.closed.load(Ordering::Acquire) {
471 return Err(DriverError::Pool("pool is closed".into()));
472 }
473
474 if let Some(guard) = self.try_pop_idle()? {
476 return Ok(guard);
477 }
478
479 loop {
481 let current = self.inner.open_count.load(Ordering::Acquire);
482 if current >= self.inner.max_size {
483 if let Some(timeout) = self.inner.acquire_timeout {
484 let (lock, cvar) = &self.inner.release_pair;
485 let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
486 let (_guard, result) = cvar
487 .wait_timeout(guard, timeout)
488 .unwrap_or_else(|e| e.into_inner());
489 if result.timed_out() {
490 return Err(DriverError::Pool(
491 "pool exhausted: acquire timeout expired".into(),
492 ));
493 }
494 if let Some(guard) = self.try_pop_idle()? {
495 return Ok(guard);
496 }
497 continue;
498 }
499 return Err(DriverError::Pool(
500 "pool exhausted: all connections in use".into(),
501 ));
502 }
503 if self
504 .inner
505 .open_count
506 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
507 .is_ok()
508 {
509 break;
510 }
511 }
512
513 if self.inner.config.host_is_uds() {
515 let conn_result = Connection::connect_arc(self.inner.config.clone());
517 match conn_result {
518 Ok(mut conn) => {
519 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
520 self.warmup_conn(&mut conn);
521 Ok(PoolGuard {
522 conn: Some(PoolSlot::Sync(conn)),
523 pool: self.inner.clone(),
524 discard: false,
525 #[cfg(feature = "detect-n-plus-one")]
526 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
527 })
528 }
529 Err(e) => {
530 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
531 Err(e)
532 }
533 }
534 } else {
535 let conn_result = AsyncConnection::connect_arc(self.inner.config.clone()).await;
537 match conn_result {
538 Ok(mut conn) => {
539 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
540 Ok(PoolGuard {
541 conn: Some(PoolSlot::Async(conn)),
542 pool: self.inner.clone(),
543 discard: false,
544 #[cfg(feature = "detect-n-plus-one")]
545 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
546 })
547 }
548 Err(e) => {
549 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
550 Err(e)
551 }
552 }
553 }
554 }
555}
556
557impl Clone for Pool {
558 fn clone(&self) -> Self {
559 Pool {
560 inner: self.inner.clone(),
561 }
562 }
563}
564
565#[derive(Debug, Clone, Copy)]
569pub struct PoolStatus {
570 pub idle: usize,
572 pub active: usize,
574 pub open: usize,
576 pub max_size: usize,
578}
579
580pub struct PoolBuilder {
584 url: Option<String>,
585 max_size: usize,
586 max_lifetime: Option<Duration>,
588 acquire_timeout: Option<Duration>,
590 min_idle: usize,
592 max_stmt_cache_size: usize,
594 stale_timeout: Duration,
596 statement_cache_mode: Option<StatementCacheMode>,
598 #[cfg(feature = "detect-n-plus-one")]
600 n_plus_one_threshold: Option<u16>,
601}
602
603impl PoolBuilder {
604 fn new() -> Self {
605 Self {
606 url: None,
607 max_size: 10,
608 max_lifetime: Some(Duration::from_secs(30 * 60)), acquire_timeout: Some(Duration::from_secs(5)), min_idle: 0, max_stmt_cache_size: 256, stale_timeout: Duration::from_secs(30), statement_cache_mode: None,
614 #[cfg(feature = "detect-n-plus-one")]
615 n_plus_one_threshold: None,
616 }
617 }
618
619 pub fn url(mut self, url: &str) -> Self {
621 self.url = Some(url.to_owned());
622 self
623 }
624
625 pub fn max_size(mut self, size: usize) -> Self {
629 self.max_size = size;
630 self
631 }
632
633 pub fn max_lifetime(mut self, lifetime: Option<Duration>) -> Self {
636 self.max_lifetime = lifetime;
637 self
638 }
639
640 pub fn acquire_timeout(mut self, timeout: Option<Duration>) -> Self {
643 self.acquire_timeout = timeout;
644 self
645 }
646
647 pub fn min_idle(mut self, count: usize) -> Self {
650 self.min_idle = count;
651 self
652 }
653
654 pub fn max_stmt_cache_size(mut self, size: usize) -> Self {
658 self.max_stmt_cache_size = size;
659 self
660 }
661
662 pub fn stale_timeout(mut self, timeout: Duration) -> Self {
666 self.stale_timeout = timeout;
667 self
668 }
669
670 pub fn statement_cache_mode(mut self, mode: StatementCacheMode) -> Self {
679 self.statement_cache_mode = Some(mode);
680 self
681 }
682
683 #[cfg(feature = "detect-n-plus-one")]
688 pub fn n_plus_one_threshold(mut self, n: u16) -> Self {
689 self.n_plus_one_threshold = Some(n);
690 self
691 }
692
693 pub fn build(self) -> Result<Pool, DriverError> {
695 let url = self
696 .url
697 .ok_or_else(|| DriverError::Pool("pool builder requires a URL".into()))?;
698
699 let mut config = Config::from_url(&url)?;
700 if let Some(mode) = self.statement_cache_mode {
701 config.statement_cache_mode = mode;
702 }
703 let config = Arc::new(config);
704
705 let pool = Pool {
706 inner: Arc::new(PoolInner {
707 stack: std::sync::Mutex::new(Vec::with_capacity(self.max_size)),
708 max_size: self.max_size,
709 open_count: AtomicUsize::new(0),
710 config,
711 closed: AtomicBool::new(false),
712 release_pair: (std::sync::Mutex::new(()), std::sync::Condvar::new()),
713 max_lifetime: self.max_lifetime,
714 acquire_timeout: self.acquire_timeout,
715 min_idle: self.min_idle,
716 warmup_sqls: std::sync::Mutex::new(Arc::new(Vec::new())),
717 max_stmt_cache_size: self.max_stmt_cache_size,
718 stale_timeout: self.stale_timeout,
719 #[cfg(feature = "detect-n-plus-one")]
720 n_plus_one_threshold: self.n_plus_one_threshold.unwrap_or(10),
721 }),
722 };
723
724 if self.min_idle > 0 {
725 let inner = pool.inner.clone();
726 std::thread::spawn(move || {
727 maintain_min_idle(inner);
728 });
729 }
730
731 Ok(pool)
732 }
733}
734
735fn maintain_min_idle(inner: Arc<PoolInner>) {
737 loop {
738 if inner.closed.load(Ordering::Acquire) {
739 return;
740 }
741
742 let idle_count = inner.stack.lock().unwrap_or_else(|e| e.into_inner()).len();
743 let needed = inner.min_idle.saturating_sub(idle_count);
744
745 for _ in 0..needed {
746 if inner.closed.load(Ordering::Acquire) {
747 return;
748 }
749 let current = inner.open_count.load(Ordering::Acquire);
750 if current >= inner.max_size {
751 break;
752 }
753 if inner
754 .open_count
755 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
756 .is_err()
757 {
758 continue;
759 }
760
761 match Connection::connect_arc(inner.config.clone()) {
762 Ok(conn) => {
763 let mut stack = inner.stack.lock().unwrap_or_else(|e| e.into_inner());
764 stack.push(PoolSlot::Sync(conn));
765 let (_, cvar) = &inner.release_pair;
766 cvar.notify_one();
767 }
768 Err(_) => {
769 inner.open_count.fetch_sub(1, Ordering::AcqRel);
770 }
771 }
772 }
773
774 std::thread::sleep(Duration::from_secs(1));
777 }
778}
779
780pub struct PoolGuard {
787 conn: Option<PoolSlot>,
788 pool: Arc<PoolInner>,
789 discard: bool,
791 #[cfg(feature = "detect-n-plus-one")]
793 detector: NPlusOneDetector,
794}
795
796impl PoolGuard {
797 #[inline]
800 fn sync_conn(&self) -> Result<&Connection, DriverError> {
801 match self.conn.as_ref() {
802 Some(PoolSlot::Sync(conn)) => Ok(conn),
803 #[cfg(feature = "async")]
804 Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
805 "expected sync connection, got async; use async methods".into(),
806 )),
807 None => Err(DriverError::Pool("connection already taken".into())),
808 }
809 }
810
811 #[inline]
813 fn sync_conn_mut(&mut self) -> Result<&mut Connection, DriverError> {
814 match self.conn.as_mut() {
815 Some(PoolSlot::Sync(conn)) => Ok(conn),
816 #[cfg(feature = "async")]
817 Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
818 "expected sync connection, got async; use async methods".into(),
819 )),
820 None => Err(DriverError::Pool("connection already taken".into())),
821 }
822 }
823
824 pub fn mark_discard(&mut self) {
827 self.discard = true;
828 }
829
830 pub fn cancel(&self) -> Result<(), DriverError> {
835 self.sync_conn()?.cancel()
836 }
837
838 pub fn pid(&self) -> i32 {
847 match self.conn.as_ref().expect("connection returned to pool") {
848 PoolSlot::Sync(conn) => conn.pid(),
849 #[cfg(feature = "async")]
850 PoolSlot::Async(conn) => conn.pid(),
851 }
852 }
853
854 pub fn is_idle(&self) -> bool {
861 match self.conn.as_ref().expect("connection returned to pool") {
862 PoolSlot::Sync(conn) => conn.is_idle(),
863 #[cfg(feature = "async")]
864 PoolSlot::Async(conn) => conn.is_idle(),
865 }
866 }
867
868 pub fn is_in_transaction(&self) -> bool {
875 match self.conn.as_ref().expect("connection returned to pool") {
876 PoolSlot::Sync(conn) => conn.is_in_transaction(),
877 #[cfg(feature = "async")]
878 PoolSlot::Async(conn) => conn.is_in_transaction(),
879 }
880 }
881
882 #[inline]
886 pub fn query(
887 &mut self,
888 sql: &str,
889 sql_hash: u64,
890 params: &[&(dyn Encode + Sync)],
891 ) -> Result<QueryResult, DriverError> {
892 #[cfg(feature = "detect-n-plus-one")]
893 self.detector.track(sql_hash);
894 self.sync_conn_mut()?.query(sql, sql_hash, params)
895 }
896
897 #[inline]
899 pub fn execute(
900 &mut self,
901 sql: &str,
902 sql_hash: u64,
903 params: &[&(dyn Encode + Sync)],
904 ) -> Result<u64, DriverError> {
905 #[cfg(feature = "detect-n-plus-one")]
906 self.detector.track(sql_hash);
907 self.sync_conn_mut()?.execute(sql, sql_hash, params)
908 }
909
910 pub fn execute_pipeline(
915 &mut self,
916 sql: &str,
917 sql_hash: u64,
918 param_sets: &[&[&(dyn Encode + Sync)]],
919 ) -> Result<Vec<u64>, DriverError> {
920 #[cfg(feature = "detect-n-plus-one")]
921 self.detector.track(sql_hash);
922 self.sync_conn_mut()?
923 .execute_pipeline(sql, sql_hash, param_sets)
924 }
925
926 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
928 self.sync_conn_mut()?.simple_query(sql)
929 }
930
931 pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
935 self.sync_conn_mut()?.simple_query_rows(sql)
936 }
937
938 pub fn for_each<F>(
940 &mut self,
941 sql: &str,
942 sql_hash: u64,
943 params: &[&(dyn Encode + Sync)],
944 f: F,
945 ) -> Result<(), DriverError>
946 where
947 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
948 {
949 #[cfg(feature = "detect-n-plus-one")]
950 self.detector.track(sql_hash);
951 self.sync_conn_mut()?.for_each(sql, sql_hash, params, f)
952 }
953
954 pub fn for_each_raw<F>(
956 &mut self,
957 sql: &str,
958 sql_hash: u64,
959 params: &[&(dyn Encode + Sync)],
960 f: F,
961 ) -> Result<(), DriverError>
962 where
963 F: FnMut(&[u8]) -> Result<(), DriverError>,
964 {
965 #[cfg(feature = "detect-n-plus-one")]
966 self.detector.track(sql_hash);
967 self.sync_conn_mut()?.for_each_raw(sql, sql_hash, params, f)
968 }
969
970 pub fn query_streaming_start(
974 &mut self,
975 sql: &str,
976 sql_hash: u64,
977 params: &[&(dyn Encode + Sync)],
978 chunk_size: i32,
979 ) -> Result<(std::sync::Arc<[crate::types::ColumnDesc]>, bool), DriverError> {
980 #[cfg(feature = "detect-n-plus-one")]
981 self.detector.track(sql_hash);
982 self.sync_conn_mut()?
983 .query_streaming_start(sql, sql_hash, params, chunk_size)
984 }
985
986 pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
988 self.sync_conn_mut()?.streaming_send_execute(chunk_size)
989 }
990
991 pub fn streaming_next_chunk(
993 &mut self,
994 arena: &mut Arena,
995 all_col_offsets: &mut Vec<(usize, i32)>,
996 ) -> Result<bool, DriverError> {
997 self.sync_conn_mut()?
998 .streaming_next_chunk(arena, all_col_offsets)
999 }
1000
1001 pub fn copy_in<'a, I>(
1007 &mut self,
1008 table: &str,
1009 columns: &[&str],
1010 rows: I,
1011 ) -> Result<u64, DriverError>
1012 where
1013 I: IntoIterator<Item = &'a str>,
1014 {
1015 self.sync_conn_mut()?.copy_in(table, columns, rows)
1016 }
1017
1018 pub fn copy_out<W: std::io::Write>(
1022 &mut self,
1023 query: &str,
1024 writer: &mut W,
1025 ) -> Result<u64, DriverError> {
1026 self.sync_conn_mut()?.copy_out(query, writer)
1027 }
1028
1029 pub fn is_sync(&self) -> bool {
1031 matches!(self.conn.as_ref(), Some(PoolSlot::Sync(_)))
1032 }
1033
1034 #[cfg(feature = "async")]
1036 pub fn is_async(&self) -> bool {
1037 matches!(self.conn.as_ref(), Some(PoolSlot::Async(_)))
1038 }
1039
1040 #[cfg(feature = "async")]
1048 pub async fn query_async(
1049 &mut self,
1050 sql: &str,
1051 sql_hash: u64,
1052 params: &[&(dyn Encode + Sync)],
1053 ) -> Result<QueryResult, DriverError> {
1054 #[cfg(feature = "detect-n-plus-one")]
1055 self.detector.track(sql_hash);
1056 match self.conn.as_mut() {
1057 Some(PoolSlot::Sync(conn)) => conn.query(sql, sql_hash, params),
1058 Some(PoolSlot::Async(conn)) => conn.query(sql, sql_hash, params).await,
1059 None => Err(DriverError::Pool("connection already taken".into())),
1060 }
1061 }
1062
1063 #[cfg(feature = "async")]
1065 pub async fn execute_async(
1066 &mut self,
1067 sql: &str,
1068 sql_hash: u64,
1069 params: &[&(dyn Encode + Sync)],
1070 ) -> Result<u64, DriverError> {
1071 #[cfg(feature = "detect-n-plus-one")]
1072 self.detector.track(sql_hash);
1073 match self.conn.as_mut() {
1074 Some(PoolSlot::Sync(conn)) => conn.execute(sql, sql_hash, params),
1075 Some(PoolSlot::Async(conn)) => conn.execute(sql, sql_hash, params).await,
1076 None => Err(DriverError::Pool("connection already taken".into())),
1077 }
1078 }
1079
1080 #[cfg(feature = "async")]
1082 pub async fn simple_query_async(&mut self, sql: &str) -> Result<(), DriverError> {
1083 match self.conn.as_mut() {
1084 Some(PoolSlot::Sync(conn)) => conn.simple_query(sql),
1085 Some(PoolSlot::Async(conn)) => conn.simple_query(sql).await,
1086 None => Err(DriverError::Pool("connection already taken".into())),
1087 }
1088 }
1089
1090 pub(crate) fn ensure_stmt_prepared(
1094 &mut self,
1095 sql: &str,
1096 sql_hash: u64,
1097 params: &[&(dyn Encode + Sync)],
1098 ) -> Result<[u8; 18], DriverError> {
1099 self.sync_conn_mut()?
1100 .ensure_stmt_prepared(sql, sql_hash, params)
1101 }
1102
1103 pub(crate) fn write_deferred_bind_execute(
1105 &self,
1106 sql: &str,
1107 sql_hash: u64,
1108 params: &[&(dyn Encode + Sync)],
1109 buf: &mut Vec<u8>,
1110 ) -> Result<(), DriverError> {
1111 let conn = self.sync_conn()?;
1112 conn.write_deferred_bind_execute(sql, sql_hash, params, buf)
1113 }
1114
1115 pub(crate) fn flush_deferred_pipeline(
1117 &mut self,
1118 buf: &mut Vec<u8>,
1119 count: usize,
1120 ) -> Result<Vec<u64>, DriverError> {
1121 self.sync_conn_mut()?.flush_deferred_pipeline(buf, count)
1122 }
1123}
1124
1125impl Drop for PoolGuard {
1126 fn drop(&mut self) {
1127 #[cfg(feature = "detect-n-plus-one")]
1128 self.detector.emit_final_warning();
1129
1130 if let Some(slot) = self.conn.take() {
1131 let should_discard = self.discard
1133 || self.pool.closed.load(Ordering::Acquire)
1134 || match &slot {
1135 PoolSlot::Sync(conn) => {
1136 conn.is_in_failed_transaction()
1137 || conn.is_in_transaction()
1138 || conn.is_streaming()
1139 }
1140 #[cfg(feature = "async")]
1141 PoolSlot::Async(conn) => {
1142 conn.is_in_failed_transaction() || conn.is_in_transaction()
1143 }
1144 };
1145
1146 if should_discard {
1147 self.pool.open_count.fetch_sub(1, Ordering::AcqRel);
1148 return;
1149 }
1150
1151 let mut slot = slot;
1154 match &mut slot {
1155 PoolSlot::Sync(conn) => {
1156 if conn.query_counter() & 63 == 0 {
1157 conn.touch();
1158 }
1159 }
1160 #[cfg(feature = "async")]
1161 PoolSlot::Async(conn) => {
1162 if conn.query_counter() & 63 == 0 {
1163 conn.touch();
1164 }
1165 }
1166 }
1167
1168 {
1170 let mut stack = self.pool.stack.lock().unwrap_or_else(|e| e.into_inner());
1171 stack.push(slot);
1172 }
1173
1174 if self.pool.open_count.load(Ordering::Relaxed) >= self.pool.max_size {
1176 let (_, cvar) = &self.pool.release_pair;
1177 cvar.notify_one();
1178 }
1179 }
1180 }
1181}
1182
1183pub struct Transaction {
1199 guard: PoolGuard,
1200 committed: bool,
1201 deferred_buf: Vec<u8>,
1203 deferred_count: usize,
1205}
1206
1207impl Transaction {
1208 pub fn commit(mut self) -> Result<(), DriverError> {
1212 if self.deferred_count > 0 {
1213 self.flush_deferred()?;
1214 }
1215 self.guard.simple_query("COMMIT")?;
1216 self.committed = true;
1217 Ok(())
1218 }
1219
1220 pub fn rollback(mut self) -> Result<(), DriverError> {
1224 self.deferred_buf.clear();
1225 self.deferred_count = 0;
1226 self.guard.simple_query("ROLLBACK")?;
1227 self.committed = true; Ok(())
1229 }
1230
1231 pub fn query(
1236 &mut self,
1237 sql: &str,
1238 sql_hash: u64,
1239 params: &[&(dyn Encode + Sync)],
1240 ) -> Result<QueryResult, DriverError> {
1241 if self.deferred_count > 0 {
1242 self.flush_deferred()?;
1243 }
1244 self.guard.query(sql, sql_hash, params)
1245 }
1246
1247 pub fn execute(
1249 &mut self,
1250 sql: &str,
1251 sql_hash: u64,
1252 params: &[&(dyn Encode + Sync)],
1253 ) -> Result<u64, DriverError> {
1254 self.guard.execute(sql, sql_hash, params)
1255 }
1256
1257 pub fn execute_pipeline(
1259 &mut self,
1260 sql: &str,
1261 sql_hash: u64,
1262 param_sets: &[&[&(dyn Encode + Sync)]],
1263 ) -> Result<Vec<u64>, DriverError> {
1264 self.guard.execute_pipeline(sql, sql_hash, param_sets)
1265 }
1266
1267 pub fn for_each<F>(
1271 &mut self,
1272 sql: &str,
1273 sql_hash: u64,
1274 params: &[&(dyn Encode + Sync)],
1275 f: F,
1276 ) -> Result<(), DriverError>
1277 where
1278 F: FnMut(crate::types::PgDataRow<'_>) -> Result<(), DriverError>,
1279 {
1280 if self.deferred_count > 0 {
1281 self.flush_deferred()?;
1282 }
1283 self.guard.for_each(sql, sql_hash, params, f)
1284 }
1285
1286 pub fn for_each_raw<F>(
1290 &mut self,
1291 sql: &str,
1292 sql_hash: u64,
1293 params: &[&(dyn Encode + Sync)],
1294 f: F,
1295 ) -> Result<(), DriverError>
1296 where
1297 F: FnMut(&[u8]) -> Result<(), DriverError>,
1298 {
1299 if self.deferred_count > 0 {
1300 self.flush_deferred()?;
1301 }
1302 self.guard.for_each_raw(sql, sql_hash, params, f)
1303 }
1304
1305 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
1309 if self.deferred_count > 0 {
1310 self.flush_deferred()?;
1311 }
1312 self.guard.simple_query(sql)
1313 }
1314
1315 pub fn defer_execute(
1344 &mut self,
1345 sql: &str,
1346 sql_hash: u64,
1347 params: &[&(dyn Encode + Sync)],
1348 ) -> Result<(), DriverError> {
1349 if params.len() > i16::MAX as usize {
1350 return Err(DriverError::Protocol(format!(
1351 "parameter count {} exceeds maximum {}",
1352 params.len(),
1353 i16::MAX
1354 )));
1355 }
1356
1357 self.guard.ensure_stmt_prepared(sql, sql_hash, params)?;
1359
1360 self.guard
1362 .write_deferred_bind_execute(sql, sql_hash, params, &mut self.deferred_buf)?;
1363 self.deferred_count += 1;
1364 Ok(())
1365 }
1366
1367 pub fn flush_deferred(&mut self) -> Result<Vec<u64>, DriverError> {
1372 let count = self.deferred_count;
1373 self.deferred_count = 0;
1374 self.guard
1375 .flush_deferred_pipeline(&mut self.deferred_buf, count)
1376 }
1377
1378 pub fn deferred_count(&self) -> usize {
1380 self.deferred_count
1381 }
1382}
1383
1384impl Drop for Transaction {
1385 fn drop(&mut self) {
1386 if !self.committed {
1387 if let Some(_slot) = self.guard.conn.take() {
1390 self.guard.pool.open_count.fetch_sub(1, Ordering::AcqRel);
1391 }
1393 }
1394 }
1395}
1396
1397#[cfg(test)]
1398mod tests {
1399 use super::*;
1400
1401 #[test]
1402 fn pool_builder_requires_url() {
1403 let result = PoolBuilder::new().build();
1404 assert!(result.is_err());
1405 }
1406
1407 #[test]
1408 fn pool_builder_validates_url() {
1409 let result = PoolBuilder::new().url("not_a_url").build();
1410 assert!(result.is_err());
1411 }
1412
1413 #[test]
1414 fn pool_builder_accepts_valid_url() {
1415 let pool = PoolBuilder::new()
1416 .url("postgres://user:pass@localhost/db")
1417 .max_size(5)
1418 .build()
1419 .unwrap();
1420 assert_eq!(pool.max_size(), 5);
1421 assert_eq!(pool.open_count(), 0);
1422 }
1423
1424 #[test]
1425 fn pool_connect_validates_url() {
1426 let result = Pool::connect("not_a_url");
1427 assert!(result.is_err());
1428 }
1429
1430 #[test]
1431 fn pool_max_size_zero() {
1432 let pool = PoolBuilder::new()
1433 .url("postgres://user:pass@localhost/db")
1434 .max_size(0)
1435 .build()
1436 .unwrap();
1437
1438 let result = pool.acquire();
1439 assert!(result.is_err());
1440 match result {
1441 Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
1442 Err(e) => panic!("expected Pool error, got: {e:?}"),
1443 Ok(_) => panic!("expected error, got Ok"),
1444 }
1445 }
1446
1447 #[test]
1448 fn pool_clone_shares_state() {
1449 let pool = PoolBuilder::new()
1450 .url("postgres://user:pass@localhost/db")
1451 .max_size(5)
1452 .build()
1453 .unwrap();
1454
1455 let pool2 = pool.clone();
1456 assert_eq!(pool.max_size(), pool2.max_size());
1457 }
1458
1459 #[test]
1463 fn pool_builder_max_lifetime() {
1464 let pool = PoolBuilder::new()
1465 .url("postgres://user:pass@localhost/db")
1466 .max_lifetime(Some(Duration::from_secs(60)))
1467 .build()
1468 .unwrap();
1469 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(60)));
1470 }
1471
1472 #[test]
1474 fn pool_builder_max_lifetime_none() {
1475 let pool = PoolBuilder::new()
1476 .url("postgres://user:pass@localhost/db")
1477 .max_lifetime(None)
1478 .build()
1479 .unwrap();
1480 assert_eq!(pool.inner.max_lifetime, None);
1481 }
1482
1483 #[test]
1485 fn pool_builder_acquire_timeout_none() {
1486 let pool = PoolBuilder::new()
1487 .url("postgres://user:pass@localhost/db")
1488 .acquire_timeout(None)
1489 .build()
1490 .unwrap();
1491 assert_eq!(pool.inner.acquire_timeout, None);
1492 }
1493
1494 #[test]
1496 fn pool_builder_acquire_timeout_custom() {
1497 let pool = PoolBuilder::new()
1498 .url("postgres://user:pass@localhost/db")
1499 .acquire_timeout(Some(Duration::from_secs(10)))
1500 .build()
1501 .unwrap();
1502 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(10)));
1503 }
1504
1505 #[test]
1507 fn pool_builder_min_idle() {
1508 let pool = PoolBuilder::new()
1509 .url("postgres://user:pass@localhost/db")
1510 .min_idle(2)
1511 .build()
1512 .unwrap();
1513 assert_eq!(pool.inner.min_idle, 2);
1514 }
1515
1516 #[test]
1518 fn pool_close_marks_closed() {
1519 let pool = PoolBuilder::new()
1520 .url("postgres://user:pass@localhost/db")
1521 .max_size(5)
1522 .build()
1523 .unwrap();
1524
1525 assert!(!pool.is_closed());
1526 pool.close();
1527 assert!(pool.is_closed());
1528
1529 let result = pool.acquire();
1531 assert!(result.is_err());
1532 match result {
1533 Err(DriverError::Pool(msg)) => assert!(msg.contains("closed")),
1534 Err(e) => panic!("expected Pool(closed) error, got: {e:?}"),
1535 Ok(_) => panic!("expected error, got Ok"),
1536 }
1537 }
1538
1539 #[test]
1541 fn pool_status_initial() {
1542 let pool = PoolBuilder::new()
1543 .url("postgres://user:pass@localhost/db")
1544 .max_size(10)
1545 .build()
1546 .unwrap();
1547
1548 let status = pool.status();
1549 assert_eq!(status.idle, 0);
1550 assert_eq!(status.active, 0);
1551 assert_eq!(status.open, 0);
1552 assert_eq!(status.max_size, 10);
1553 }
1554
1555 #[test]
1557 fn pool_builder_defaults() {
1558 let pool = PoolBuilder::new()
1559 .url("postgres://user:pass@localhost/db")
1560 .build()
1561 .unwrap();
1562
1563 assert_eq!(pool.max_size(), 10);
1564 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(30 * 60)));
1565 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
1566 assert_eq!(pool.inner.min_idle, 0);
1567 }
1568
1569 #[test]
1571 fn pool_open_count_initial() {
1572 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1573 assert_eq!(pool.open_count(), 0);
1574 }
1575
1576 #[test]
1579 fn pool_builder_max_stmt_cache_size_default() {
1580 let pool = PoolBuilder::new()
1581 .url("postgres://user:pass@localhost/db")
1582 .build()
1583 .unwrap();
1584 assert_eq!(pool.inner.max_stmt_cache_size, 256);
1585 }
1586
1587 #[test]
1588 fn pool_builder_max_stmt_cache_size_custom() {
1589 let pool = PoolBuilder::new()
1590 .url("postgres://user:pass@localhost/db")
1591 .max_stmt_cache_size(512)
1592 .build()
1593 .unwrap();
1594 assert_eq!(pool.inner.max_stmt_cache_size, 512);
1595 }
1596
1597 #[test]
1600 fn pool_is_uds_false_for_tcp() {
1601 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1602 assert!(!pool.is_uds());
1603 }
1604
1605 #[cfg(unix)]
1606 #[test]
1607 fn pool_is_uds_true_for_unix_socket() {
1608 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
1609 assert!(pool.is_uds());
1610 }
1611
1612 #[cfg(unix)]
1613 #[test]
1614 fn pool_is_uds_true_for_var_run_socket() {
1615 let pool = Pool::connect("postgres://user@localhost/db?host=/var/run/postgresql").unwrap();
1616 assert!(pool.is_uds());
1617 }
1618
1619 #[test]
1620 fn pool_is_uds_false_for_ip_address() {
1621 let pool = Pool::connect("postgres://user:pass@127.0.0.1/db").unwrap();
1622 assert!(!pool.is_uds());
1623 }
1624
1625 #[cfg(unix)]
1626 #[test]
1627 fn pool_slot_sync_created_for_uds_config() {
1628 let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
1629 assert!(config.host_is_uds());
1630 }
1631
1632 #[test]
1633 fn pool_slot_tcp_config() {
1634 let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
1635 assert!(!config.host_is_uds());
1636 }
1637
1638 #[test]
1643 fn pool_is_uds_false_for_hostname() {
1644 let pool = Pool::connect("postgres://user:pass@db.example.com/db").unwrap();
1645 assert!(!pool.is_uds());
1646 }
1647
1648 #[cfg(unix)]
1649 #[test]
1650 fn pool_is_uds_true_for_tmp() {
1651 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
1652 assert!(pool.is_uds());
1653 }
1654
1655 #[test]
1660 fn pool_close_then_acquire_fails() {
1661 let pool = PoolBuilder::new()
1662 .url("postgres://user:pass@localhost/db")
1663 .max_size(5)
1664 .build()
1665 .unwrap();
1666 pool.close();
1667 let result = pool.acquire();
1668 assert!(result.is_err());
1669 match result {
1670 Err(DriverError::Pool(msg)) => {
1671 assert!(msg.contains("closed"), "should say closed: {msg}")
1672 }
1673 Err(e) => panic!("expected Pool error, got: {e:?}"),
1674 Ok(_) => panic!("expected error"),
1675 }
1676 }
1677
1678 #[test]
1679 fn pool_is_closed_before_and_after() {
1680 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1681 assert!(!pool.is_closed());
1682 pool.close();
1683 assert!(pool.is_closed());
1684 }
1685
1686 #[test]
1691 fn pool_exhausted_no_timeout() {
1692 let pool = PoolBuilder::new()
1693 .url("postgres://user:pass@localhost/db")
1694 .max_size(0)
1695 .acquire_timeout(None) .build()
1697 .unwrap();
1698 let result = pool.acquire();
1699 assert!(result.is_err());
1700 match result {
1701 Err(DriverError::Pool(msg)) => {
1702 assert!(msg.contains("exhausted"), "should say exhausted: {msg}")
1703 }
1704 Err(e) => panic!("expected Pool error, got: {e:?}"),
1705 Ok(_) => panic!("expected error"),
1706 }
1707 }
1708
1709 #[test]
1714 fn pool_builder_no_url_error() {
1715 let result = PoolBuilder::new().max_size(5).build();
1716 assert!(result.is_err());
1717 match result {
1718 Err(DriverError::Pool(msg)) => {
1719 assert!(msg.contains("URL"), "should mention URL: {msg}")
1720 }
1721 Err(e) => panic!("expected Pool error, got: {e:?}"),
1722 Ok(_) => panic!("expected error"),
1723 }
1724 }
1725
1726 #[test]
1727 fn pool_builder_invalid_url_error() {
1728 let result = PoolBuilder::new().url("ftp://something").build();
1729 assert!(result.is_err());
1730 }
1731
1732 #[test]
1733 fn pool_builder_stmt_cache_size_zero() {
1734 let pool = PoolBuilder::new()
1735 .url("postgres://user:pass@localhost/db")
1736 .max_stmt_cache_size(0)
1737 .build()
1738 .unwrap();
1739 assert_eq!(pool.inner.max_stmt_cache_size, 0);
1740 }
1741
1742 #[test]
1745 fn pool_builder_stale_timeout_default() {
1746 let pool = PoolBuilder::new()
1747 .url("postgres://user:pass@localhost/db")
1748 .build()
1749 .unwrap();
1750 assert_eq!(pool.inner.stale_timeout, Duration::from_secs(30));
1751 }
1752
1753 #[test]
1754 fn pool_builder_stale_timeout_custom() {
1755 let pool = PoolBuilder::new()
1756 .url("postgres://user:pass@localhost/db")
1757 .stale_timeout(Duration::from_secs(60))
1758 .build()
1759 .unwrap();
1760 assert_eq!(pool.inner.stale_timeout, Duration::from_secs(60));
1761 }
1762
1763 #[test]
1764 fn pool_builder_stale_timeout_zero() {
1765 let pool = PoolBuilder::new()
1766 .url("postgres://user:pass@localhost/db")
1767 .stale_timeout(Duration::from_secs(0))
1768 .build()
1769 .unwrap();
1770 assert_eq!(pool.inner.stale_timeout, Duration::from_secs(0));
1771 }
1772
1773 #[test]
1778 fn pool_status_reflects_max_size() {
1779 let pool = PoolBuilder::new()
1780 .url("postgres://user:pass@localhost/db")
1781 .max_size(20)
1782 .build()
1783 .unwrap();
1784 let status = pool.status();
1785 assert_eq!(status.max_size, 20);
1786 assert_eq!(status.idle, 0);
1787 assert_eq!(status.active, 0);
1788 assert_eq!(status.open, 0);
1789 }
1790
1791 #[test]
1796 fn pool_clone_shares_config() {
1797 let pool = PoolBuilder::new()
1798 .url("postgres://user:pass@localhost/db")
1799 .max_size(7)
1800 .build()
1801 .unwrap();
1802 let p2 = pool.clone();
1803 assert_eq!(pool.max_size(), 7);
1804 assert_eq!(p2.max_size(), 7);
1805 assert_eq!(pool.open_count(), p2.open_count());
1806 }
1807
1808 #[test]
1813 fn pool_set_warmup_sqls_empty() {
1814 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1815 pool.set_warmup_sqls([] as [&str; 0]);
1816 let sqls = pool
1817 .inner
1818 .warmup_sqls
1819 .lock()
1820 .unwrap_or_else(|e| e.into_inner())
1821 .clone();
1822 assert!(sqls.is_empty());
1823 }
1824
1825 #[test]
1826 fn pool_set_warmup_sqls_multiple() {
1827 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1828 pool.set_warmup_sqls(["SELECT 1", "SELECT 2", "SELECT 3"]);
1829 let sqls = pool
1830 .inner
1831 .warmup_sqls
1832 .lock()
1833 .unwrap_or_else(|e| e.into_inner())
1834 .clone();
1835 assert_eq!(sqls.len(), 3);
1836 assert_eq!(&*sqls[0], "SELECT 1");
1837 assert_eq!(&*sqls[1], "SELECT 2");
1838 assert_eq!(&*sqls[2], "SELECT 3");
1839 }
1840
1841 #[test]
1842 fn pool_set_warmup_sqls_overwrite() {
1843 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1844 pool.set_warmup_sqls(["SELECT 1"]);
1845 pool.set_warmup_sqls(["SELECT 99"]);
1846 let sqls = pool
1847 .inner
1848 .warmup_sqls
1849 .lock()
1850 .unwrap_or_else(|e| e.into_inner())
1851 .clone();
1852 assert_eq!(sqls.len(), 1);
1853 assert_eq!(&*sqls[0], "SELECT 99");
1854 }
1855
1856 #[test]
1857 fn pool_set_warmup_sqls_with_iter_empty() {
1858 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1859 pool.set_warmup_sqls(std::iter::empty::<&str>());
1860 let sqls = pool
1861 .inner
1862 .warmup_sqls
1863 .lock()
1864 .unwrap_or_else(|e| e.into_inner())
1865 .clone();
1866 assert!(sqls.is_empty());
1867 }
1868
1869 #[test]
1870 fn pool_set_warmup_sqls_with_owned_string() {
1871 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1872 let dynamic = format!("SET search_path TO test_{}", 42);
1873 pool.set_warmup_sqls([dynamic]);
1874 let sqls = pool
1875 .inner
1876 .warmup_sqls
1877 .lock()
1878 .unwrap_or_else(|e| e.into_inner())
1879 .clone();
1880 assert_eq!(sqls.len(), 1);
1881 assert_eq!(&*sqls[0], "SET search_path TO test_42");
1882 }
1883
1884 #[test]
1885 fn pool_set_warmup_sqls_with_vec_of_strings() {
1886 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1887 let sqls_owned: Vec<String> = vec!["SELECT 1".to_owned(), "SELECT 2".to_owned()];
1888 pool.set_warmup_sqls(sqls_owned);
1889 let sqls = pool
1890 .inner
1891 .warmup_sqls
1892 .lock()
1893 .unwrap_or_else(|e| e.into_inner())
1894 .clone();
1895 assert_eq!(sqls.len(), 2);
1896 assert_eq!(&*sqls[0], "SELECT 1");
1897 }
1898
1899 #[test]
1900 fn pool_set_warmup_sqls_with_boxed_str() {
1901 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1902 let b: Box<str> = "SELECT 1".into();
1903 pool.set_warmup_sqls([b]);
1904 let sqls = pool
1905 .inner
1906 .warmup_sqls
1907 .lock()
1908 .unwrap_or_else(|e| e.into_inner())
1909 .clone();
1910 assert_eq!(&*sqls[0], "SELECT 1");
1911 }
1912
1913 #[test]
1914 fn pool_set_warmup_sqls_single_static_str() {
1915 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1916 pool.set_warmup_sqls(["SET statement_timeout = '30s'"]);
1917 let sqls = pool
1918 .inner
1919 .warmup_sqls
1920 .lock()
1921 .unwrap_or_else(|e| e.into_inner())
1922 .clone();
1923 assert_eq!(sqls.len(), 1);
1924 }
1925
1926 #[test]
1927 fn pool_set_warmup_sqls_preserves_order() {
1928 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1929 pool.set_warmup_sqls(["first", "second", "third"]);
1930 let sqls = pool
1931 .inner
1932 .warmup_sqls
1933 .lock()
1934 .unwrap_or_else(|e| e.into_inner())
1935 .clone();
1936 assert_eq!(&*sqls[0], "first");
1937 assert_eq!(&*sqls[1], "second");
1938 assert_eq!(&*sqls[2], "third");
1939 }
1940
1941 #[test]
1942 fn pool_set_warmup_sqls_unicode() {
1943 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1944 pool.set_warmup_sqls(["SET client_encoding TO 'UTF8'", "SELECT '日本語'"]);
1945 let sqls = pool
1946 .inner
1947 .warmup_sqls
1948 .lock()
1949 .unwrap_or_else(|e| e.into_inner())
1950 .clone();
1951 assert_eq!(&*sqls[1], "SELECT '日本語'");
1952 }
1953
1954 #[test]
1955 fn pool_set_warmup_sqls_empty_string() {
1956 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1957 pool.set_warmup_sqls([""]);
1958 let sqls = pool
1959 .inner
1960 .warmup_sqls
1961 .lock()
1962 .unwrap_or_else(|e| e.into_inner())
1963 .clone();
1964 assert_eq!(sqls.len(), 1);
1965 assert_eq!(&*sqls[0], "");
1966 }
1967
1968 #[test]
1969 fn pool_set_warmup_sqls_long_sql() {
1970 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1971 let long = "SELECT ".to_owned() + &"x, ".repeat(1000) + "1";
1972 pool.set_warmup_sqls([long]);
1973 let sqls = pool
1974 .inner
1975 .warmup_sqls
1976 .lock()
1977 .unwrap_or_else(|e| e.into_inner())
1978 .clone();
1979 assert!(sqls[0].len() > 3000);
1980 }
1981
1982 #[test]
1987 fn pool_status_debug() {
1988 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1989 let status = pool.status();
1990 let dbg = format!("{status:?}");
1991 assert!(dbg.contains("PoolStatus"));
1992 assert!(dbg.contains("idle"));
1993 assert!(dbg.contains("active"));
1994 assert!(dbg.contains("open"));
1995 assert!(dbg.contains("max_size"));
1996 }
1997
1998 #[test]
2003 fn config_host_is_uds_returns_true_for_slash() {
2004 let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
2005 assert!(config.host_is_uds());
2006 }
2007
2008 #[test]
2009 fn config_host_is_uds_returns_false_for_tcp() {
2010 let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
2011 assert!(!config.host_is_uds());
2012 }
2013
2014 #[test]
2015 fn config_host_is_uds_returns_false_for_ip() {
2016 let config = Config::from_url("postgres://user:pass@192.168.1.1/db").unwrap();
2017 assert!(!config.host_is_uds());
2018 }
2019
2020 #[test]
2025 fn pool_builder_full_chain() {
2026 let pool = PoolBuilder::new()
2027 .url("postgres://user:pass@localhost/db")
2028 .max_size(3)
2029 .max_lifetime(Some(Duration::from_secs(600)))
2030 .acquire_timeout(Some(Duration::from_secs(5)))
2031 .min_idle(1)
2032 .max_stmt_cache_size(128)
2033 .build()
2034 .unwrap();
2035 assert_eq!(pool.max_size(), 3);
2036 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(600)));
2037 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
2038 assert_eq!(pool.inner.min_idle, 1);
2039 assert_eq!(pool.inner.max_stmt_cache_size, 128);
2040 }
2041
2042 #[test]
2045 fn pool_max_size_zero_rejects_all_acquires() {
2046 let pool = PoolBuilder::new()
2047 .url("postgres://user:pass@localhost/db")
2048 .max_size(0)
2049 .build()
2050 .unwrap();
2051 let result = pool.acquire();
2052 assert!(result.is_err());
2053 match &result {
2054 Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
2055 _ => panic!("expected pool exhausted error"),
2056 }
2057 }
2058
2059 #[test]
2062 fn url_parse_unknown_sslmode_returns_error() {
2063 let result = Config::from_url("postgres://u:p@h/d?sslmode=bogus");
2064 assert!(result.is_err());
2065 let msg = format!("{}", result.unwrap_err());
2066 assert!(msg.contains("unknown sslmode"));
2067 }
2068
2069 #[test]
2070 fn url_parse_invalid_port_returns_error() {
2071 let result = Config::from_url("postgres://u:p@h:abc/d");
2072 assert!(result.is_err());
2073 let msg = format!("{}", result.unwrap_err());
2074 assert!(msg.contains("invalid port"));
2075 }
2076
2077 #[test]
2078 fn url_parse_missing_at_sign_returns_error() {
2079 let result = Config::from_url("postgres://u:plocalhost/d");
2080 assert!(result.is_err());
2081 let msg = format!("{}", result.unwrap_err());
2082 assert!(msg.contains("missing @"));
2083 }
2084
2085 #[test]
2086 fn url_parse_empty_host_returns_error() {
2087 let result = Config::from_url("postgres://u:p@/d");
2088 assert!(result.is_err());
2089 }
2090
2091 #[test]
2092 fn url_parse_empty_user_returns_error() {
2093 let result = Config::from_url("postgres://:p@h/d");
2094 assert!(result.is_err());
2095 }
2096
2097 #[test]
2098 fn url_parse_statement_timeout_invalid_uses_default() {
2099 let config = Config::from_url("postgres://u:p@h/d?statement_timeout=notnum").unwrap();
2100 assert_eq!(config.statement_timeout_secs, 30);
2101 }
2102
2103 #[test]
2104 fn url_parse_malformed_percent_encoding() {
2105 let result = Config::from_url("postgres://u%:p@h/d");
2106 assert!(result.is_err());
2107 }
2108
2109 #[test]
2110 fn url_parse_invalid_hex_in_percent_encoding() {
2111 let result = Config::from_url("postgres://u%ZZ:p@h/d");
2112 assert!(result.is_err());
2113 }
2114
2115 #[test]
2120 fn pool_acquire_timeout_no_connections_available() {
2121 let pool = PoolBuilder::new()
2125 .url("postgres://user:pass@localhost/db")
2126 .max_size(0)
2127 .acquire_timeout(Some(Duration::from_millis(50)))
2128 .build()
2129 .unwrap();
2130
2131 let start = std::time::Instant::now();
2132 let result = pool.acquire();
2133 let elapsed = start.elapsed();
2134
2135 assert!(result.is_err());
2136 match result {
2137 Err(DriverError::Pool(msg)) => {
2138 assert!(msg.contains("exhausted"), "should say exhausted: {msg}");
2139 }
2140 Err(e) => panic!("expected Pool error, got: {e:?}"),
2141 Ok(_) => panic!("expected error"),
2142 }
2143 assert!(elapsed < Duration::from_secs(5));
2145 }
2146
2147 #[test]
2152 fn pool_max_lifetime_very_short() {
2153 let pool = PoolBuilder::new()
2154 .url("postgres://user:pass@localhost/db")
2155 .max_lifetime(Some(Duration::from_millis(1)))
2156 .build()
2157 .unwrap();
2158 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_millis(1)));
2159 }
2160
2161 #[test]
2162 fn pool_max_lifetime_zero_duration() {
2163 let pool = PoolBuilder::new()
2165 .url("postgres://user:pass@localhost/db")
2166 .max_lifetime(Some(Duration::from_secs(0)))
2167 .build()
2168 .unwrap();
2169 assert_eq!(pool.inner.max_lifetime, Some(Duration::ZERO));
2170 }
2171
2172 #[test]
2177 fn pool_status_open_equals_idle_plus_active() {
2178 let pool = PoolBuilder::new()
2180 .url("postgres://user:pass@localhost/db")
2181 .max_size(10)
2182 .build()
2183 .unwrap();
2184
2185 let status = pool.status();
2186 assert_eq!(status.open, status.idle + status.active);
2187 assert_eq!(status.open, 0);
2188 }
2189
2190 #[test]
2195 fn pool_close_idempotent() {
2196 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
2197 pool.close();
2198 assert!(pool.is_closed());
2199 pool.close(); assert!(pool.is_closed());
2201 }
2202
2203 #[test]
2204 fn pool_close_then_status_all_zero() {
2205 let pool = PoolBuilder::new()
2206 .url("postgres://user:pass@localhost/db")
2207 .max_size(5)
2208 .build()
2209 .unwrap();
2210 pool.close();
2211 let status = pool.status();
2212 assert_eq!(status.idle, 0);
2213 assert_eq!(status.active, 0);
2214 assert_eq!(status.open, 0);
2215 }
2216
2217 #[test]
2222 fn pool_builder_all_options_maximal() {
2223 let pool = PoolBuilder::new()
2224 .url("postgres://user:pass@localhost/db")
2225 .max_size(100)
2226 .max_lifetime(Some(Duration::from_secs(3600)))
2227 .acquire_timeout(Some(Duration::from_secs(30)))
2228 .min_idle(10)
2229 .max_stmt_cache_size(1024)
2230 .stale_timeout(Duration::from_secs(120))
2231 .build()
2232 .unwrap();
2233 assert_eq!(pool.max_size(), 100);
2234 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(3600)));
2235 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(30)));
2236 assert_eq!(pool.inner.min_idle, 10);
2237 assert_eq!(pool.inner.max_stmt_cache_size, 1024);
2238 assert_eq!(pool.inner.stale_timeout, Duration::from_secs(120));
2239 }
2240
2241 #[test]
2242 fn pool_builder_all_options_minimal() {
2243 let pool = PoolBuilder::new()
2244 .url("postgres://user:pass@localhost/db")
2245 .max_size(1)
2246 .max_lifetime(None)
2247 .acquire_timeout(None)
2248 .min_idle(0)
2249 .max_stmt_cache_size(0)
2250 .stale_timeout(Duration::ZERO)
2251 .build()
2252 .unwrap();
2253 assert_eq!(pool.max_size(), 1);
2254 assert_eq!(pool.inner.max_lifetime, None);
2255 assert_eq!(pool.inner.acquire_timeout, None);
2256 assert_eq!(pool.inner.min_idle, 0);
2257 assert_eq!(pool.inner.max_stmt_cache_size, 0);
2258 assert_eq!(pool.inner.stale_timeout, Duration::ZERO);
2259 }
2260
2261 #[test]
2266 fn pool_close_concurrent_with_failed_acquire() {
2267 let pool = std::sync::Arc::new(
2268 PoolBuilder::new()
2269 .url("postgres://user:pass@localhost/db")
2270 .max_size(0)
2271 .build()
2272 .unwrap(),
2273 );
2274
2275 let pool2 = pool.clone();
2276 let handle = std::thread::spawn(move || {
2277 let result = pool2.acquire();
2279 assert!(result.is_err());
2280 });
2281
2282 pool.close();
2283 handle.join().unwrap();
2284 assert!(pool.is_closed());
2285 }
2286}
2287
2288#[cfg(all(test, feature = "detect-n-plus-one"))]
2291mod n_plus_one_tests {
2292 use super::NPlusOneDetector;
2293
2294 #[test]
2295 fn below_threshold_no_warning() {
2296 let mut d = NPlusOneDetector::new(10);
2297 for _ in 0..10 {
2298 d.track(42);
2299 }
2300 assert!(d.check_final().is_none());
2301 }
2302
2303 #[test]
2304 fn above_threshold_warns() {
2305 let mut d = NPlusOneDetector::new(10);
2306 for _ in 0..11 {
2307 d.track(42);
2308 }
2309 let w = d.check_final().unwrap();
2310 assert_eq!(w, (42, 11));
2311 }
2312
2313 #[test]
2314 fn exact_threshold_no_warning() {
2315 let mut d = NPlusOneDetector::new(5);
2316 for _ in 0..5 {
2317 d.track(99);
2318 }
2319 assert!(d.check_final().is_none(), "> not >=");
2320 }
2321
2322 #[test]
2323 fn threshold_plus_one_warns() {
2324 let mut d = NPlusOneDetector::new(5);
2325 for _ in 0..6 {
2326 d.track(99);
2327 }
2328 assert_eq!(d.check_final(), Some((99, 6)));
2329 }
2330
2331 #[test]
2332 fn alternating_hashes_no_warning() {
2333 let mut d = NPlusOneDetector::new(2);
2334 for i in 0..100 {
2335 d.track(if i % 2 == 0 { 1 } else { 2 });
2336 }
2337 assert!(d.check_final().is_none());
2338 }
2339
2340 #[test]
2341 fn single_query_no_warning() {
2342 let mut d = NPlusOneDetector::new(10);
2343 d.track(42);
2344 assert!(d.check_final().is_none());
2345 }
2346
2347 #[test]
2348 fn no_queries_no_warning() {
2349 let d = NPlusOneDetector::new(10);
2350 assert!(d.check_final().is_none());
2351 }
2352
2353 #[test]
2354 fn threshold_zero_warns_on_second() {
2355 let mut d = NPlusOneDetector::new(0);
2356 d.track(42);
2357 assert_eq!(d.check_final(), Some((42, 1)));
2359 }
2360
2361 #[test]
2362 fn threshold_max_never_warns() {
2363 let mut d = NPlusOneDetector::new(u16::MAX);
2364 for _ in 0..1000 {
2365 d.track(42);
2366 }
2367 assert!(d.check_final().is_none());
2368 }
2369
2370 #[test]
2371 fn saturating_add_no_overflow() {
2372 let mut d = NPlusOneDetector::new(10);
2373 d.last_query_hash = 42;
2374 d.repeat_count = u16::MAX - 1;
2375 d.track(42); d.track(42); assert_eq!(d.repeat_count, u16::MAX);
2378 }
2379
2380 #[test]
2381 fn different_hash_resets() {
2382 let mut d = NPlusOneDetector::new(100);
2383 for _ in 0..50 {
2384 d.track(1);
2385 }
2386 d.track(2); assert_eq!(d.repeat_count, 1);
2388 assert_eq!(d.last_query_hash, 2);
2389 }
2390
2391 #[test]
2392 fn multiple_n_plus_one_sequences() {
2393 let mut d = NPlusOneDetector::new(3);
2394 for _ in 0..5 {
2396 d.track(1);
2397 }
2398 for _ in 0..4 {
2401 d.track(2);
2402 }
2403 assert_eq!(d.check_final(), Some((2, 4)));
2405 }
2406
2407 #[test]
2408 fn warning_emitted_on_hash_switch() {
2409 let mut d = NPlusOneDetector::new(2);
2410 d.track(10);
2411 d.track(10);
2412 d.track(10); d.track(20);
2415 assert_eq!(d.last_query_hash, 20);
2417 assert_eq!(d.repeat_count, 1);
2418 }
2419
2420 #[test]
2421 fn hash_zero_treated_normally() {
2422 let mut d = NPlusOneDetector::new(2);
2423 d.track(0);
2424 d.track(0);
2425 d.track(0);
2426 assert!(d.check_final().is_none());
2428 }
2429
2430 #[test]
2431 fn long_sequence_correct_count() {
2432 let mut d = NPlusOneDetector::new(10);
2433 for _ in 0..500 {
2434 d.track(42);
2435 }
2436 assert_eq!(d.check_final(), Some((42, 500)));
2437 }
2438
2439 #[test]
2440 fn two_queries_below_threshold() {
2441 let mut d = NPlusOneDetector::new(10);
2442 d.track(1);
2443 d.track(1);
2444 assert!(d.check_final().is_none());
2445 }
2446
2447 #[test]
2448 fn interleaved_then_burst() {
2449 let mut d = NPlusOneDetector::new(3);
2450 d.track(1);
2452 d.track(2);
2453 d.track(1);
2454 d.track(2);
2455 for _ in 0..5 {
2457 d.track(5);
2458 }
2459 assert_eq!(d.check_final(), Some((5, 5)));
2460 }
2461
2462 #[test]
2465 fn pool_builder_n_plus_one_threshold_default() {
2466 let pool = super::PoolBuilder::new()
2467 .url("postgres://user:pass@localhost/db")
2468 .build()
2469 .unwrap();
2470 assert_eq!(pool.inner.n_plus_one_threshold, 10);
2471 }
2472
2473 #[test]
2474 fn pool_builder_n_plus_one_threshold_custom() {
2475 let pool = super::PoolBuilder::new()
2476 .url("postgres://user:pass@localhost/db")
2477 .n_plus_one_threshold(5)
2478 .build()
2479 .unwrap();
2480 assert_eq!(pool.inner.n_plus_one_threshold, 5);
2481 }
2482
2483 #[test]
2484 fn pool_builder_n_plus_one_threshold_zero() {
2485 let pool = super::PoolBuilder::new()
2486 .url("postgres://user:pass@localhost/db")
2487 .n_plus_one_threshold(0)
2488 .build()
2489 .unwrap();
2490 assert_eq!(pool.inner.n_plus_one_threshold, 0);
2491 }
2492
2493 #[test]
2494 fn pool_builder_n_plus_one_threshold_max() {
2495 let pool = super::PoolBuilder::new()
2496 .url("postgres://user:pass@localhost/db")
2497 .n_plus_one_threshold(u16::MAX)
2498 .build()
2499 .unwrap();
2500 assert_eq!(pool.inner.n_plus_one_threshold, u16::MAX);
2501 }
2502
2503 #[test]
2504 fn one_then_different_no_warning() {
2505 let mut d = NPlusOneDetector::new(10);
2506 d.track(1);
2507 d.track(2);
2508 assert!(d.check_final().is_none());
2510 }
2511
2512 #[test]
2513 fn nonzero_hash_after_zero_init() {
2514 let mut d = NPlusOneDetector::new(0);
2518 d.track(42);
2519 let w = d.check_final().unwrap();
2520 assert_eq!(w, (42, 1));
2521 }
2522
2523 #[test]
2524 fn independent_detectors_dont_interfere() {
2525 let mut d1 = NPlusOneDetector::new(5);
2527 let mut d2 = NPlusOneDetector::new(5);
2528
2529 for _ in 0..10 {
2531 d1.track(42);
2532 }
2533 d2.track(1);
2535 d2.track(2);
2536 d2.track(3);
2537
2538 assert!(d1.check_final().is_some());
2540 assert!(d2.check_final().is_none());
2541 }
2542
2543 #[test]
2544 fn rapid_hash_changes_dont_false_positive() {
2545 let mut d = NPlusOneDetector::new(2);
2547 for i in 0u64..1000 {
2548 d.track(i);
2549 }
2550 assert!(d.check_final().is_none());
2552 }
2553
2554 #[test]
2555 fn detector_reset_state_after_warning() {
2556 let mut d = NPlusOneDetector::new(2);
2558 d.track(1);
2559 d.track(1);
2560 d.track(1); d.track(2); d.track(2); assert!(d.check_final().is_none()); }
2565
2566 #[test]
2567 fn detector_with_realistic_orm_pattern() {
2568 let mut d = NPlusOneDetector::new(5);
2570 d.track(100); for _ in 0..20 {
2573 d.track(200); }
2575 assert_eq!(d.check_final(), Some((200, 20)));
2577 }
2578
2579 #[test]
2580 fn detector_with_legitimate_batch_pattern() {
2581 let mut d = NPlusOneDetector::new(10);
2584 for _ in 0..15 {
2585 d.track(300); }
2587 assert!(d.check_final().is_some());
2588 }
2589
2590 #[test]
2591 fn detector_exactly_at_boundaries() {
2592 for threshold in [0u16, 1, 2, 5, 10, 100] {
2593 let mut d = NPlusOneDetector::new(threshold);
2594 for _ in 0..=threshold {
2595 d.track(42);
2596 }
2597 assert!(
2599 d.check_final().is_some(),
2600 "threshold={threshold} should warn at count={}",
2601 threshold + 1
2602 );
2603 }
2604 }
2605
2606 #[test]
2607 fn detector_with_deterministic_random_sequences() {
2608 let mut d = NPlusOneDetector::new(5);
2610 let hashes: Vec<u64> = (0..100).map(|i| ((i * 7 + 3) % 4) as u64).collect();
2611 for &h in &hashes {
2612 d.track(h);
2613 }
2614 let _ = d.check_final();
2616 }
2617
2618 mod proptest_fuzz {
2619 use super::*;
2620 use proptest::prelude::*;
2621
2622 proptest! {
2623 #[test]
2624 fn detector_never_panics(
2625 hashes in proptest::collection::vec(0u64..100, 0..500),
2626 threshold in 0u16..100,
2627 ) {
2628 let mut d = NPlusOneDetector::new(threshold);
2629 for h in &hashes {
2630 d.track(*h);
2631 }
2632 let _ = d.check_final();
2633 }
2634
2635 #[test]
2636 fn sequential_repeats_always_detected(
2637 hash in 1u64..u64::MAX,
2638 count in 2u16..1000,
2639 threshold in 0u16..100,
2640 ) {
2641 let mut d = NPlusOneDetector::new(threshold);
2642 for _ in 0..count {
2643 d.track(hash);
2644 }
2645 if count > threshold {
2646 assert!(d.check_final().is_some(),
2647 "count={count} > threshold={threshold} should trigger");
2648 }
2649 }
2650 }
2651 }
2652}