1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use crate::arena::Arena;
14use crate::codec::Encode;
15use crate::conn::Connection;
16use crate::types::{Config, PgDataRow, QueryResult, SimpleRow};
17use crate::DriverError;
18
19#[cfg(feature = "async")]
20use crate::async_conn::AsyncConnection;
21
22pub(crate) enum PoolSlot {
30 Sync(Connection),
32 #[cfg(feature = "async")]
34 Async(AsyncConnection),
35}
36
37#[cfg(feature = "detect-n-plus-one")]
43pub(crate) struct NPlusOneDetector {
44 last_query_hash: u64,
45 repeat_count: u16,
46 threshold: u16,
47}
48
49#[cfg(feature = "detect-n-plus-one")]
50impl NPlusOneDetector {
51 pub(crate) fn new(threshold: u16) -> Self {
53 Self {
54 last_query_hash: 0,
55 repeat_count: 0,
56 threshold,
57 }
58 }
59
60 #[inline]
62 pub(crate) fn track(&mut self, sql_hash: u64) {
63 if sql_hash == self.last_query_hash {
64 self.repeat_count = self.repeat_count.saturating_add(1);
65 } else {
66 self.emit_warning();
68 self.last_query_hash = sql_hash;
69 self.repeat_count = 1;
70 }
71 }
72
73 pub(crate) fn check_final(&self) -> Option<(u64, u16)> {
76 if self.repeat_count > self.threshold && self.last_query_hash != 0 {
77 Some((self.last_query_hash, self.repeat_count))
78 } else {
79 None
80 }
81 }
82
83 #[cold]
85 #[inline(never)]
86 fn emit_warning(&self) {
87 if let Some((hash, count)) = self.check_final() {
88 log::warn!(
89 "[bsql] potential N+1 detected: sql_hash={:#018x} repeated {} times (threshold: {})",
90 hash,
91 count,
92 self.threshold,
93 );
94 }
95 }
96
97 #[cold]
99 #[inline(never)]
100 pub(crate) fn emit_final_warning(&self) {
101 self.emit_warning();
102 }
103}
104
105pub struct Pool {
121 inner: Arc<PoolInner>,
122}
123
124struct PoolInner {
125 stack: std::sync::Mutex<Vec<PoolSlot>>,
129 max_size: usize,
130 open_count: AtomicUsize,
131 config: Arc<Config>,
132 closed: AtomicBool,
134 release_pair: (std::sync::Mutex<()>, std::sync::Condvar),
137 max_lifetime: Option<Duration>,
140 acquire_timeout: Option<Duration>,
142 min_idle: usize,
144 warmup_sqls: std::sync::Mutex<Arc<Vec<Box<str>>>>,
146 max_stmt_cache_size: usize,
148 stale_timeout: Duration,
151 #[cfg(feature = "detect-n-plus-one")]
154 n_plus_one_threshold: u16,
155}
156
157impl Pool {
158 pub fn connect(url: &str) -> Result<Self, DriverError> {
162 PoolBuilder::new().url(url).build()
163 }
164
165 pub fn builder() -> PoolBuilder {
167 PoolBuilder::new()
168 }
169
170 #[inline]
178 pub fn acquire(&self) -> Result<PoolGuard, DriverError> {
179 if self.inner.closed.load(Ordering::Acquire) {
180 return Err(DriverError::Pool("pool is closed".into()));
181 }
182
183 if let Some(guard) = self.try_pop_idle()? {
185 return Ok(guard);
186 }
187
188 loop {
190 let current = self.inner.open_count.load(Ordering::Acquire);
191 if current >= self.inner.max_size {
192 if let Some(timeout) = self.inner.acquire_timeout {
193 let (lock, cvar) = &self.inner.release_pair;
194 let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
195 let (_guard, result) = cvar
196 .wait_timeout(guard, timeout)
197 .unwrap_or_else(|e| e.into_inner());
198 if result.timed_out() {
199 return Err(DriverError::Pool(
200 "pool exhausted: acquire timeout expired".into(),
201 ));
202 }
203 if let Some(guard) = self.try_pop_idle()? {
205 return Ok(guard);
206 }
207 continue;
209 }
210 return Err(DriverError::Pool(
211 "pool exhausted: all connections in use".into(),
212 ));
213 }
214 if self
215 .inner
216 .open_count
217 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
218 .is_ok()
219 {
220 break;
221 }
222 }
224
225 let conn_result = Connection::connect_arc(self.inner.config.clone());
227 match conn_result {
228 Ok(mut conn) => {
229 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
231 self.warmup_conn(&mut conn);
233
234 Ok(PoolGuard {
235 conn: Some(PoolSlot::Sync(conn)),
236 pool: self.inner.clone(),
237 discard: false,
238 #[cfg(feature = "detect-n-plus-one")]
239 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
240 })
241 }
242 Err(e) => {
243 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
245 Err(e)
246 }
247 }
248 }
249
250 #[inline]
256 fn try_pop_idle(&self) -> Result<Option<PoolGuard>, DriverError> {
257 loop {
261 let (mut slot, needs_health_check) = {
262 let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
263 loop {
264 let Some(slot) = stack.pop() else {
265 return Ok(None);
266 };
267 let (created_at, idle_dur) = match &slot {
268 PoolSlot::Sync(conn) => (conn.created_at(), conn.idle_duration()),
269 #[cfg(feature = "async")]
270 PoolSlot::Async(conn) => (conn.created_at(), conn.idle_duration()),
271 };
272 if let Some(max_lifetime) = self.inner.max_lifetime {
273 if created_at.elapsed() >= max_lifetime {
274 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
275 continue;
276 }
277 }
278 if idle_dur >= self.inner.stale_timeout {
279 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
281 continue;
282 }
283 break (slot, idle_dur > Duration::from_secs(5));
284 }
285 };
286 if needs_health_check {
290 let alive = match &mut slot {
291 PoolSlot::Sync(conn) => conn.simple_query("").is_ok(),
292 #[cfg(feature = "async")]
293 PoolSlot::Async(_) => true, };
295 if !alive {
296 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
297 continue; }
299 }
300 return Ok(Some(PoolGuard {
301 conn: Some(slot),
302 pool: self.inner.clone(),
303 discard: false,
304 #[cfg(feature = "detect-n-plus-one")]
305 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
306 }));
307 }
308 }
309
310 pub fn is_uds(&self) -> bool {
315 #[cfg(unix)]
316 {
317 self.inner.config.host_is_uds()
318 }
319 #[cfg(not(unix))]
320 {
321 false
322 }
323 }
324
325 pub fn begin(&self) -> Result<Transaction, DriverError> {
327 let mut guard = self.acquire()?;
328 guard.simple_query("BEGIN")?;
329 Ok(Transaction {
330 guard,
331 committed: false,
332 deferred_buf: Vec::new(),
333 deferred_count: 0,
334 })
335 }
336
337 pub fn open_count(&self) -> usize {
339 self.inner.open_count.load(Ordering::Relaxed)
340 }
341
342 pub fn max_size(&self) -> usize {
344 self.inner.max_size
345 }
346
347 pub fn status(&self) -> PoolStatus {
349 let idle = self
350 .inner
351 .stack
352 .lock()
353 .unwrap_or_else(|e| e.into_inner())
354 .len();
355 let open = self.inner.open_count.load(Ordering::Relaxed);
356 let active = open.saturating_sub(idle);
357 PoolStatus {
358 idle,
359 active,
360 open,
361 max_size: self.inner.max_size,
362 }
363 }
364
365 fn warmup_conn(&self, conn: &mut Connection) {
373 let sqls = self
374 .inner
375 .warmup_sqls
376 .lock()
377 .unwrap_or_else(|e| e.into_inner())
378 .clone();
379
380 if sqls.is_empty() {
381 return;
382 }
383
384 let batch: Vec<(&str, u64)> = sqls
385 .iter()
386 .map(|sql| (sql.as_ref(), crate::types::hash_sql(sql)))
387 .collect();
388
389 let _ = conn.prepare_batch(&batch);
390 }
391
392 pub fn set_warmup_sqls(&self, sqls: &[&str]) {
414 let boxed: Arc<Vec<Box<str>>> =
415 Arc::new(sqls.iter().map(|s| (*s).into()).collect::<Vec<_>>());
416 *self
417 .inner
418 .warmup_sqls
419 .lock()
420 .unwrap_or_else(|e| e.into_inner()) = boxed;
421 }
422
423 pub fn close(&self) {
426 self.inner.closed.store(true, Ordering::Release);
427 let slots: Vec<PoolSlot> = {
429 let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
430 std::mem::take(&mut *stack)
431 };
432 for slot in slots {
433 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
434 match slot {
435 PoolSlot::Sync(conn) => {
436 let _ = conn.close();
437 }
438 #[cfg(feature = "async")]
439 PoolSlot::Async(_conn) => {
440 }
443 }
444 }
445 let (_, cvar) = &self.inner.release_pair;
447 cvar.notify_all();
448 }
449
450 pub fn is_closed(&self) -> bool {
452 self.inner.closed.load(Ordering::Acquire)
453 }
454
455 #[cfg(feature = "async")]
465 pub async fn acquire_async(&self) -> Result<PoolGuard, DriverError> {
466 if self.inner.closed.load(Ordering::Acquire) {
467 return Err(DriverError::Pool("pool is closed".into()));
468 }
469
470 if let Some(guard) = self.try_pop_idle()? {
472 return Ok(guard);
473 }
474
475 loop {
477 let current = self.inner.open_count.load(Ordering::Acquire);
478 if current >= self.inner.max_size {
479 if let Some(timeout) = self.inner.acquire_timeout {
480 let (lock, cvar) = &self.inner.release_pair;
481 let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
482 let (_guard, result) = cvar
483 .wait_timeout(guard, timeout)
484 .unwrap_or_else(|e| e.into_inner());
485 if result.timed_out() {
486 return Err(DriverError::Pool(
487 "pool exhausted: acquire timeout expired".into(),
488 ));
489 }
490 if let Some(guard) = self.try_pop_idle()? {
491 return Ok(guard);
492 }
493 continue;
494 }
495 return Err(DriverError::Pool(
496 "pool exhausted: all connections in use".into(),
497 ));
498 }
499 if self
500 .inner
501 .open_count
502 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
503 .is_ok()
504 {
505 break;
506 }
507 }
508
509 if self.inner.config.host_is_uds() {
511 let conn_result = Connection::connect_arc(self.inner.config.clone());
513 match conn_result {
514 Ok(mut conn) => {
515 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
516 self.warmup_conn(&mut conn);
517 Ok(PoolGuard {
518 conn: Some(PoolSlot::Sync(conn)),
519 pool: self.inner.clone(),
520 discard: false,
521 #[cfg(feature = "detect-n-plus-one")]
522 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
523 })
524 }
525 Err(e) => {
526 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
527 Err(e)
528 }
529 }
530 } else {
531 let conn_result = AsyncConnection::connect_arc(self.inner.config.clone()).await;
533 match conn_result {
534 Ok(mut conn) => {
535 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
536 Ok(PoolGuard {
537 conn: Some(PoolSlot::Async(conn)),
538 pool: self.inner.clone(),
539 discard: false,
540 #[cfg(feature = "detect-n-plus-one")]
541 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
542 })
543 }
544 Err(e) => {
545 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
546 Err(e)
547 }
548 }
549 }
550 }
551}
552
553impl Clone for Pool {
554 fn clone(&self) -> Self {
555 Pool {
556 inner: self.inner.clone(),
557 }
558 }
559}
560
561#[derive(Debug, Clone, Copy)]
565pub struct PoolStatus {
566 pub idle: usize,
568 pub active: usize,
570 pub open: usize,
572 pub max_size: usize,
574}
575
576pub struct PoolBuilder {
580 url: Option<String>,
581 max_size: usize,
582 max_lifetime: Option<Duration>,
584 acquire_timeout: Option<Duration>,
586 min_idle: usize,
588 max_stmt_cache_size: usize,
590 stale_timeout: Duration,
592 #[cfg(feature = "detect-n-plus-one")]
594 n_plus_one_threshold: Option<u16>,
595}
596
597impl PoolBuilder {
598 fn new() -> Self {
599 Self {
600 url: None,
601 max_size: 10,
602 max_lifetime: Some(Duration::from_secs(30 * 60)), acquire_timeout: Some(Duration::from_secs(5)), min_idle: 0, max_stmt_cache_size: 256, stale_timeout: Duration::from_secs(30), #[cfg(feature = "detect-n-plus-one")]
608 n_plus_one_threshold: None,
609 }
610 }
611
612 pub fn url(mut self, url: &str) -> Self {
614 self.url = Some(url.to_owned());
615 self
616 }
617
618 pub fn max_size(mut self, size: usize) -> Self {
622 self.max_size = size;
623 self
624 }
625
626 pub fn max_lifetime(mut self, lifetime: Option<Duration>) -> Self {
629 self.max_lifetime = lifetime;
630 self
631 }
632
633 pub fn acquire_timeout(mut self, timeout: Option<Duration>) -> Self {
636 self.acquire_timeout = timeout;
637 self
638 }
639
640 pub fn min_idle(mut self, count: usize) -> Self {
643 self.min_idle = count;
644 self
645 }
646
647 pub fn max_stmt_cache_size(mut self, size: usize) -> Self {
651 self.max_stmt_cache_size = size;
652 self
653 }
654
655 pub fn stale_timeout(mut self, timeout: Duration) -> Self {
659 self.stale_timeout = timeout;
660 self
661 }
662
663 #[cfg(feature = "detect-n-plus-one")]
668 pub fn n_plus_one_threshold(mut self, n: u16) -> Self {
669 self.n_plus_one_threshold = Some(n);
670 self
671 }
672
673 pub fn build(self) -> Result<Pool, DriverError> {
675 let url = self
676 .url
677 .ok_or_else(|| DriverError::Pool("pool builder requires a URL".into()))?;
678
679 let config = Arc::new(Config::from_url(&url)?);
680
681 let pool = Pool {
682 inner: Arc::new(PoolInner {
683 stack: std::sync::Mutex::new(Vec::with_capacity(self.max_size)),
684 max_size: self.max_size,
685 open_count: AtomicUsize::new(0),
686 config,
687 closed: AtomicBool::new(false),
688 release_pair: (std::sync::Mutex::new(()), std::sync::Condvar::new()),
689 max_lifetime: self.max_lifetime,
690 acquire_timeout: self.acquire_timeout,
691 min_idle: self.min_idle,
692 warmup_sqls: std::sync::Mutex::new(Arc::new(Vec::new())),
693 max_stmt_cache_size: self.max_stmt_cache_size,
694 stale_timeout: self.stale_timeout,
695 #[cfg(feature = "detect-n-plus-one")]
696 n_plus_one_threshold: self.n_plus_one_threshold.unwrap_or(10),
697 }),
698 };
699
700 if self.min_idle > 0 {
701 let inner = pool.inner.clone();
702 std::thread::spawn(move || {
703 maintain_min_idle(inner);
704 });
705 }
706
707 Ok(pool)
708 }
709}
710
711fn maintain_min_idle(inner: Arc<PoolInner>) {
713 loop {
714 if inner.closed.load(Ordering::Acquire) {
715 return;
716 }
717
718 let idle_count = inner.stack.lock().unwrap_or_else(|e| e.into_inner()).len();
719 let needed = inner.min_idle.saturating_sub(idle_count);
720
721 for _ in 0..needed {
722 if inner.closed.load(Ordering::Acquire) {
723 return;
724 }
725 let current = inner.open_count.load(Ordering::Acquire);
726 if current >= inner.max_size {
727 break;
728 }
729 if inner
730 .open_count
731 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
732 .is_err()
733 {
734 continue;
735 }
736
737 match Connection::connect_arc(inner.config.clone()) {
738 Ok(conn) => {
739 let mut stack = inner.stack.lock().unwrap_or_else(|e| e.into_inner());
740 stack.push(PoolSlot::Sync(conn));
741 let (_, cvar) = &inner.release_pair;
742 cvar.notify_one();
743 }
744 Err(_) => {
745 inner.open_count.fetch_sub(1, Ordering::AcqRel);
746 }
747 }
748 }
749
750 std::thread::sleep(Duration::from_secs(1));
753 }
754}
755
756pub struct PoolGuard {
763 conn: Option<PoolSlot>,
764 pool: Arc<PoolInner>,
765 discard: bool,
767 #[cfg(feature = "detect-n-plus-one")]
769 detector: NPlusOneDetector,
770}
771
772impl PoolGuard {
773 #[inline]
776 fn sync_conn(&self) -> Result<&Connection, DriverError> {
777 match self.conn.as_ref() {
778 Some(PoolSlot::Sync(conn)) => Ok(conn),
779 #[cfg(feature = "async")]
780 Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
781 "expected sync connection, got async; use async methods".into(),
782 )),
783 None => Err(DriverError::Pool("connection already taken".into())),
784 }
785 }
786
787 #[inline]
789 fn sync_conn_mut(&mut self) -> Result<&mut Connection, DriverError> {
790 match self.conn.as_mut() {
791 Some(PoolSlot::Sync(conn)) => Ok(conn),
792 #[cfg(feature = "async")]
793 Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
794 "expected sync connection, got async; use async methods".into(),
795 )),
796 None => Err(DriverError::Pool("connection already taken".into())),
797 }
798 }
799
800 pub fn mark_discard(&mut self) {
803 self.discard = true;
804 }
805
806 pub fn cancel(&self) -> Result<(), DriverError> {
811 self.sync_conn()?.cancel()
812 }
813
814 pub fn pid(&self) -> i32 {
823 match self.conn.as_ref().expect("connection returned to pool") {
824 PoolSlot::Sync(conn) => conn.pid(),
825 #[cfg(feature = "async")]
826 PoolSlot::Async(conn) => conn.pid(),
827 }
828 }
829
830 pub fn is_idle(&self) -> bool {
837 match self.conn.as_ref().expect("connection returned to pool") {
838 PoolSlot::Sync(conn) => conn.is_idle(),
839 #[cfg(feature = "async")]
840 PoolSlot::Async(conn) => conn.is_idle(),
841 }
842 }
843
844 pub fn is_in_transaction(&self) -> bool {
851 match self.conn.as_ref().expect("connection returned to pool") {
852 PoolSlot::Sync(conn) => conn.is_in_transaction(),
853 #[cfg(feature = "async")]
854 PoolSlot::Async(conn) => conn.is_in_transaction(),
855 }
856 }
857
858 #[inline]
862 pub fn query(
863 &mut self,
864 sql: &str,
865 sql_hash: u64,
866 params: &[&(dyn Encode + Sync)],
867 ) -> Result<QueryResult, DriverError> {
868 #[cfg(feature = "detect-n-plus-one")]
869 self.detector.track(sql_hash);
870 self.sync_conn_mut()?.query(sql, sql_hash, params)
871 }
872
873 #[inline]
875 pub fn execute(
876 &mut self,
877 sql: &str,
878 sql_hash: u64,
879 params: &[&(dyn Encode + Sync)],
880 ) -> Result<u64, DriverError> {
881 #[cfg(feature = "detect-n-plus-one")]
882 self.detector.track(sql_hash);
883 self.sync_conn_mut()?.execute(sql, sql_hash, params)
884 }
885
886 pub fn execute_pipeline(
891 &mut self,
892 sql: &str,
893 sql_hash: u64,
894 param_sets: &[&[&(dyn Encode + Sync)]],
895 ) -> Result<Vec<u64>, DriverError> {
896 #[cfg(feature = "detect-n-plus-one")]
897 self.detector.track(sql_hash);
898 self.sync_conn_mut()?
899 .execute_pipeline(sql, sql_hash, param_sets)
900 }
901
902 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
904 self.sync_conn_mut()?.simple_query(sql)
905 }
906
907 pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
911 self.sync_conn_mut()?.simple_query_rows(sql)
912 }
913
914 pub fn for_each<F>(
916 &mut self,
917 sql: &str,
918 sql_hash: u64,
919 params: &[&(dyn Encode + Sync)],
920 f: F,
921 ) -> Result<(), DriverError>
922 where
923 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
924 {
925 #[cfg(feature = "detect-n-plus-one")]
926 self.detector.track(sql_hash);
927 self.sync_conn_mut()?.for_each(sql, sql_hash, params, f)
928 }
929
930 pub fn for_each_raw<F>(
932 &mut self,
933 sql: &str,
934 sql_hash: u64,
935 params: &[&(dyn Encode + Sync)],
936 f: F,
937 ) -> Result<(), DriverError>
938 where
939 F: FnMut(&[u8]) -> Result<(), DriverError>,
940 {
941 #[cfg(feature = "detect-n-plus-one")]
942 self.detector.track(sql_hash);
943 self.sync_conn_mut()?.for_each_raw(sql, sql_hash, params, f)
944 }
945
946 pub fn query_streaming_start(
950 &mut self,
951 sql: &str,
952 sql_hash: u64,
953 params: &[&(dyn Encode + Sync)],
954 chunk_size: i32,
955 ) -> Result<(std::sync::Arc<[crate::types::ColumnDesc]>, bool), DriverError> {
956 #[cfg(feature = "detect-n-plus-one")]
957 self.detector.track(sql_hash);
958 self.sync_conn_mut()?
959 .query_streaming_start(sql, sql_hash, params, chunk_size)
960 }
961
962 pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
964 self.sync_conn_mut()?.streaming_send_execute(chunk_size)
965 }
966
967 pub fn streaming_next_chunk(
969 &mut self,
970 arena: &mut Arena,
971 all_col_offsets: &mut Vec<(usize, i32)>,
972 ) -> Result<bool, DriverError> {
973 self.sync_conn_mut()?
974 .streaming_next_chunk(arena, all_col_offsets)
975 }
976
977 pub fn copy_in<'a, I>(
983 &mut self,
984 table: &str,
985 columns: &[&str],
986 rows: I,
987 ) -> Result<u64, DriverError>
988 where
989 I: IntoIterator<Item = &'a str>,
990 {
991 self.sync_conn_mut()?.copy_in(table, columns, rows)
992 }
993
994 pub fn copy_out<W: std::io::Write>(
998 &mut self,
999 query: &str,
1000 writer: &mut W,
1001 ) -> Result<u64, DriverError> {
1002 self.sync_conn_mut()?.copy_out(query, writer)
1003 }
1004
1005 pub fn is_sync(&self) -> bool {
1007 matches!(self.conn.as_ref(), Some(PoolSlot::Sync(_)))
1008 }
1009
1010 #[cfg(feature = "async")]
1012 pub fn is_async(&self) -> bool {
1013 matches!(self.conn.as_ref(), Some(PoolSlot::Async(_)))
1014 }
1015
1016 #[cfg(feature = "async")]
1024 pub async fn query_async(
1025 &mut self,
1026 sql: &str,
1027 sql_hash: u64,
1028 params: &[&(dyn Encode + Sync)],
1029 ) -> Result<QueryResult, DriverError> {
1030 #[cfg(feature = "detect-n-plus-one")]
1031 self.detector.track(sql_hash);
1032 match self.conn.as_mut() {
1033 Some(PoolSlot::Sync(conn)) => conn.query(sql, sql_hash, params),
1034 Some(PoolSlot::Async(conn)) => conn.query(sql, sql_hash, params).await,
1035 None => Err(DriverError::Pool("connection already taken".into())),
1036 }
1037 }
1038
1039 #[cfg(feature = "async")]
1041 pub async fn execute_async(
1042 &mut self,
1043 sql: &str,
1044 sql_hash: u64,
1045 params: &[&(dyn Encode + Sync)],
1046 ) -> Result<u64, DriverError> {
1047 #[cfg(feature = "detect-n-plus-one")]
1048 self.detector.track(sql_hash);
1049 match self.conn.as_mut() {
1050 Some(PoolSlot::Sync(conn)) => conn.execute(sql, sql_hash, params),
1051 Some(PoolSlot::Async(conn)) => conn.execute(sql, sql_hash, params).await,
1052 None => Err(DriverError::Pool("connection already taken".into())),
1053 }
1054 }
1055
1056 #[cfg(feature = "async")]
1058 pub async fn simple_query_async(&mut self, sql: &str) -> Result<(), DriverError> {
1059 match self.conn.as_mut() {
1060 Some(PoolSlot::Sync(conn)) => conn.simple_query(sql),
1061 Some(PoolSlot::Async(conn)) => conn.simple_query(sql).await,
1062 None => Err(DriverError::Pool("connection already taken".into())),
1063 }
1064 }
1065
1066 pub(crate) fn ensure_stmt_prepared(
1070 &mut self,
1071 sql: &str,
1072 sql_hash: u64,
1073 params: &[&(dyn Encode + Sync)],
1074 ) -> Result<[u8; 18], DriverError> {
1075 self.sync_conn_mut()?
1076 .ensure_stmt_prepared(sql, sql_hash, params)
1077 }
1078
1079 pub(crate) fn write_deferred_bind_execute(
1081 &self,
1082 sql: &str,
1083 sql_hash: u64,
1084 params: &[&(dyn Encode + Sync)],
1085 buf: &mut Vec<u8>,
1086 ) -> Result<(), DriverError> {
1087 let conn = self.sync_conn()?;
1088 conn.write_deferred_bind_execute(sql, sql_hash, params, buf)
1089 }
1090
1091 pub(crate) fn flush_deferred_pipeline(
1093 &mut self,
1094 buf: &mut Vec<u8>,
1095 count: usize,
1096 ) -> Result<Vec<u64>, DriverError> {
1097 self.sync_conn_mut()?.flush_deferred_pipeline(buf, count)
1098 }
1099}
1100
1101impl Drop for PoolGuard {
1102 fn drop(&mut self) {
1103 #[cfg(feature = "detect-n-plus-one")]
1104 self.detector.emit_final_warning();
1105
1106 if let Some(slot) = self.conn.take() {
1107 let should_discard = self.discard
1109 || self.pool.closed.load(Ordering::Acquire)
1110 || match &slot {
1111 PoolSlot::Sync(conn) => {
1112 conn.is_in_failed_transaction()
1113 || conn.is_in_transaction()
1114 || conn.is_streaming()
1115 }
1116 #[cfg(feature = "async")]
1117 PoolSlot::Async(conn) => {
1118 conn.is_in_failed_transaction() || conn.is_in_transaction()
1119 }
1120 };
1121
1122 if should_discard {
1123 self.pool.open_count.fetch_sub(1, Ordering::AcqRel);
1124 return;
1125 }
1126
1127 let mut slot = slot;
1130 match &mut slot {
1131 PoolSlot::Sync(conn) => {
1132 if conn.query_counter() & 63 == 0 {
1133 conn.touch();
1134 }
1135 }
1136 #[cfg(feature = "async")]
1137 PoolSlot::Async(conn) => {
1138 if conn.query_counter() & 63 == 0 {
1139 conn.touch();
1140 }
1141 }
1142 }
1143
1144 {
1146 let mut stack = self.pool.stack.lock().unwrap_or_else(|e| e.into_inner());
1147 stack.push(slot);
1148 }
1149
1150 if self.pool.open_count.load(Ordering::Relaxed) >= self.pool.max_size {
1152 let (_, cvar) = &self.pool.release_pair;
1153 cvar.notify_one();
1154 }
1155 }
1156 }
1157}
1158
1159pub struct Transaction {
1175 guard: PoolGuard,
1176 committed: bool,
1177 deferred_buf: Vec<u8>,
1179 deferred_count: usize,
1181}
1182
1183impl Transaction {
1184 pub fn commit(mut self) -> Result<(), DriverError> {
1188 if self.deferred_count > 0 {
1189 self.flush_deferred()?;
1190 }
1191 self.guard.simple_query("COMMIT")?;
1192 self.committed = true;
1193 Ok(())
1194 }
1195
1196 pub fn rollback(mut self) -> Result<(), DriverError> {
1200 self.deferred_buf.clear();
1201 self.deferred_count = 0;
1202 self.guard.simple_query("ROLLBACK")?;
1203 self.committed = true; Ok(())
1205 }
1206
1207 pub fn query(
1212 &mut self,
1213 sql: &str,
1214 sql_hash: u64,
1215 params: &[&(dyn Encode + Sync)],
1216 ) -> Result<QueryResult, DriverError> {
1217 if self.deferred_count > 0 {
1218 self.flush_deferred()?;
1219 }
1220 self.guard.query(sql, sql_hash, params)
1221 }
1222
1223 pub fn execute(
1225 &mut self,
1226 sql: &str,
1227 sql_hash: u64,
1228 params: &[&(dyn Encode + Sync)],
1229 ) -> Result<u64, DriverError> {
1230 self.guard.execute(sql, sql_hash, params)
1231 }
1232
1233 pub fn execute_pipeline(
1235 &mut self,
1236 sql: &str,
1237 sql_hash: u64,
1238 param_sets: &[&[&(dyn Encode + Sync)]],
1239 ) -> Result<Vec<u64>, DriverError> {
1240 self.guard.execute_pipeline(sql, sql_hash, param_sets)
1241 }
1242
1243 pub fn for_each<F>(
1247 &mut self,
1248 sql: &str,
1249 sql_hash: u64,
1250 params: &[&(dyn Encode + Sync)],
1251 f: F,
1252 ) -> Result<(), DriverError>
1253 where
1254 F: FnMut(crate::types::PgDataRow<'_>) -> Result<(), DriverError>,
1255 {
1256 if self.deferred_count > 0 {
1257 self.flush_deferred()?;
1258 }
1259 self.guard.for_each(sql, sql_hash, params, f)
1260 }
1261
1262 pub fn for_each_raw<F>(
1266 &mut self,
1267 sql: &str,
1268 sql_hash: u64,
1269 params: &[&(dyn Encode + Sync)],
1270 f: F,
1271 ) -> Result<(), DriverError>
1272 where
1273 F: FnMut(&[u8]) -> Result<(), DriverError>,
1274 {
1275 if self.deferred_count > 0 {
1276 self.flush_deferred()?;
1277 }
1278 self.guard.for_each_raw(sql, sql_hash, params, f)
1279 }
1280
1281 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
1285 if self.deferred_count > 0 {
1286 self.flush_deferred()?;
1287 }
1288 self.guard.simple_query(sql)
1289 }
1290
1291 pub fn defer_execute(
1320 &mut self,
1321 sql: &str,
1322 sql_hash: u64,
1323 params: &[&(dyn Encode + Sync)],
1324 ) -> Result<(), DriverError> {
1325 if params.len() > i16::MAX as usize {
1326 return Err(DriverError::Protocol(format!(
1327 "parameter count {} exceeds maximum {}",
1328 params.len(),
1329 i16::MAX
1330 )));
1331 }
1332
1333 self.guard.ensure_stmt_prepared(sql, sql_hash, params)?;
1335
1336 self.guard
1338 .write_deferred_bind_execute(sql, sql_hash, params, &mut self.deferred_buf)?;
1339 self.deferred_count += 1;
1340 Ok(())
1341 }
1342
1343 pub fn flush_deferred(&mut self) -> Result<Vec<u64>, DriverError> {
1348 let count = self.deferred_count;
1349 self.deferred_count = 0;
1350 self.guard
1351 .flush_deferred_pipeline(&mut self.deferred_buf, count)
1352 }
1353
1354 pub fn deferred_count(&self) -> usize {
1356 self.deferred_count
1357 }
1358}
1359
1360impl Drop for Transaction {
1361 fn drop(&mut self) {
1362 if !self.committed {
1363 if let Some(_slot) = self.guard.conn.take() {
1366 self.guard.pool.open_count.fetch_sub(1, Ordering::AcqRel);
1367 }
1369 }
1370 }
1371}
1372
1373#[cfg(test)]
1374mod tests {
1375 use super::*;
1376
1377 #[test]
1378 fn pool_builder_requires_url() {
1379 let result = PoolBuilder::new().build();
1380 assert!(result.is_err());
1381 }
1382
1383 #[test]
1384 fn pool_builder_validates_url() {
1385 let result = PoolBuilder::new().url("not_a_url").build();
1386 assert!(result.is_err());
1387 }
1388
1389 #[test]
1390 fn pool_builder_accepts_valid_url() {
1391 let pool = PoolBuilder::new()
1392 .url("postgres://user:pass@localhost/db")
1393 .max_size(5)
1394 .build()
1395 .unwrap();
1396 assert_eq!(pool.max_size(), 5);
1397 assert_eq!(pool.open_count(), 0);
1398 }
1399
1400 #[test]
1401 fn pool_connect_validates_url() {
1402 let result = Pool::connect("not_a_url");
1403 assert!(result.is_err());
1404 }
1405
1406 #[test]
1407 fn pool_max_size_zero() {
1408 let pool = PoolBuilder::new()
1409 .url("postgres://user:pass@localhost/db")
1410 .max_size(0)
1411 .build()
1412 .unwrap();
1413
1414 let result = pool.acquire();
1415 assert!(result.is_err());
1416 match result {
1417 Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
1418 Err(e) => panic!("expected Pool error, got: {e:?}"),
1419 Ok(_) => panic!("expected error, got Ok"),
1420 }
1421 }
1422
1423 #[test]
1424 fn pool_clone_shares_state() {
1425 let pool = PoolBuilder::new()
1426 .url("postgres://user:pass@localhost/db")
1427 .max_size(5)
1428 .build()
1429 .unwrap();
1430
1431 let pool2 = pool.clone();
1432 assert_eq!(pool.max_size(), pool2.max_size());
1433 }
1434
1435 #[test]
1439 fn pool_builder_max_lifetime() {
1440 let pool = PoolBuilder::new()
1441 .url("postgres://user:pass@localhost/db")
1442 .max_lifetime(Some(Duration::from_secs(60)))
1443 .build()
1444 .unwrap();
1445 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(60)));
1446 }
1447
1448 #[test]
1450 fn pool_builder_max_lifetime_none() {
1451 let pool = PoolBuilder::new()
1452 .url("postgres://user:pass@localhost/db")
1453 .max_lifetime(None)
1454 .build()
1455 .unwrap();
1456 assert_eq!(pool.inner.max_lifetime, None);
1457 }
1458
1459 #[test]
1461 fn pool_builder_acquire_timeout_none() {
1462 let pool = PoolBuilder::new()
1463 .url("postgres://user:pass@localhost/db")
1464 .acquire_timeout(None)
1465 .build()
1466 .unwrap();
1467 assert_eq!(pool.inner.acquire_timeout, None);
1468 }
1469
1470 #[test]
1472 fn pool_builder_acquire_timeout_custom() {
1473 let pool = PoolBuilder::new()
1474 .url("postgres://user:pass@localhost/db")
1475 .acquire_timeout(Some(Duration::from_secs(10)))
1476 .build()
1477 .unwrap();
1478 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(10)));
1479 }
1480
1481 #[test]
1483 fn pool_builder_min_idle() {
1484 let pool = PoolBuilder::new()
1485 .url("postgres://user:pass@localhost/db")
1486 .min_idle(2)
1487 .build()
1488 .unwrap();
1489 assert_eq!(pool.inner.min_idle, 2);
1490 }
1491
1492 #[test]
1494 fn pool_close_marks_closed() {
1495 let pool = PoolBuilder::new()
1496 .url("postgres://user:pass@localhost/db")
1497 .max_size(5)
1498 .build()
1499 .unwrap();
1500
1501 assert!(!pool.is_closed());
1502 pool.close();
1503 assert!(pool.is_closed());
1504
1505 let result = pool.acquire();
1507 assert!(result.is_err());
1508 match result {
1509 Err(DriverError::Pool(msg)) => assert!(msg.contains("closed")),
1510 Err(e) => panic!("expected Pool(closed) error, got: {e:?}"),
1511 Ok(_) => panic!("expected error, got Ok"),
1512 }
1513 }
1514
1515 #[test]
1517 fn pool_status_initial() {
1518 let pool = PoolBuilder::new()
1519 .url("postgres://user:pass@localhost/db")
1520 .max_size(10)
1521 .build()
1522 .unwrap();
1523
1524 let status = pool.status();
1525 assert_eq!(status.idle, 0);
1526 assert_eq!(status.active, 0);
1527 assert_eq!(status.open, 0);
1528 assert_eq!(status.max_size, 10);
1529 }
1530
1531 #[test]
1533 fn pool_builder_defaults() {
1534 let pool = PoolBuilder::new()
1535 .url("postgres://user:pass@localhost/db")
1536 .build()
1537 .unwrap();
1538
1539 assert_eq!(pool.max_size(), 10);
1540 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(30 * 60)));
1541 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
1542 assert_eq!(pool.inner.min_idle, 0);
1543 }
1544
1545 #[test]
1547 fn pool_open_count_initial() {
1548 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1549 assert_eq!(pool.open_count(), 0);
1550 }
1551
1552 #[test]
1555 fn pool_builder_max_stmt_cache_size_default() {
1556 let pool = PoolBuilder::new()
1557 .url("postgres://user:pass@localhost/db")
1558 .build()
1559 .unwrap();
1560 assert_eq!(pool.inner.max_stmt_cache_size, 256);
1561 }
1562
1563 #[test]
1564 fn pool_builder_max_stmt_cache_size_custom() {
1565 let pool = PoolBuilder::new()
1566 .url("postgres://user:pass@localhost/db")
1567 .max_stmt_cache_size(512)
1568 .build()
1569 .unwrap();
1570 assert_eq!(pool.inner.max_stmt_cache_size, 512);
1571 }
1572
1573 #[test]
1576 fn pool_is_uds_false_for_tcp() {
1577 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1578 assert!(!pool.is_uds());
1579 }
1580
1581 #[cfg(unix)]
1582 #[test]
1583 fn pool_is_uds_true_for_unix_socket() {
1584 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
1585 assert!(pool.is_uds());
1586 }
1587
1588 #[cfg(unix)]
1589 #[test]
1590 fn pool_is_uds_true_for_var_run_socket() {
1591 let pool = Pool::connect("postgres://user@localhost/db?host=/var/run/postgresql").unwrap();
1592 assert!(pool.is_uds());
1593 }
1594
1595 #[test]
1596 fn pool_is_uds_false_for_ip_address() {
1597 let pool = Pool::connect("postgres://user:pass@127.0.0.1/db").unwrap();
1598 assert!(!pool.is_uds());
1599 }
1600
1601 #[cfg(unix)]
1602 #[test]
1603 fn pool_slot_sync_created_for_uds_config() {
1604 let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
1605 assert!(config.host_is_uds());
1606 }
1607
1608 #[test]
1609 fn pool_slot_tcp_config() {
1610 let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
1611 assert!(!config.host_is_uds());
1612 }
1613
1614 #[test]
1619 fn pool_is_uds_false_for_hostname() {
1620 let pool = Pool::connect("postgres://user:pass@db.example.com/db").unwrap();
1621 assert!(!pool.is_uds());
1622 }
1623
1624 #[cfg(unix)]
1625 #[test]
1626 fn pool_is_uds_true_for_tmp() {
1627 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
1628 assert!(pool.is_uds());
1629 }
1630
1631 #[test]
1636 fn pool_close_then_acquire_fails() {
1637 let pool = PoolBuilder::new()
1638 .url("postgres://user:pass@localhost/db")
1639 .max_size(5)
1640 .build()
1641 .unwrap();
1642 pool.close();
1643 let result = pool.acquire();
1644 assert!(result.is_err());
1645 match result {
1646 Err(DriverError::Pool(msg)) => {
1647 assert!(msg.contains("closed"), "should say closed: {msg}")
1648 }
1649 Err(e) => panic!("expected Pool error, got: {e:?}"),
1650 Ok(_) => panic!("expected error"),
1651 }
1652 }
1653
1654 #[test]
1655 fn pool_is_closed_before_and_after() {
1656 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1657 assert!(!pool.is_closed());
1658 pool.close();
1659 assert!(pool.is_closed());
1660 }
1661
1662 #[test]
1667 fn pool_exhausted_no_timeout() {
1668 let pool = PoolBuilder::new()
1669 .url("postgres://user:pass@localhost/db")
1670 .max_size(0)
1671 .acquire_timeout(None) .build()
1673 .unwrap();
1674 let result = pool.acquire();
1675 assert!(result.is_err());
1676 match result {
1677 Err(DriverError::Pool(msg)) => {
1678 assert!(msg.contains("exhausted"), "should say exhausted: {msg}")
1679 }
1680 Err(e) => panic!("expected Pool error, got: {e:?}"),
1681 Ok(_) => panic!("expected error"),
1682 }
1683 }
1684
1685 #[test]
1690 fn pool_builder_no_url_error() {
1691 let result = PoolBuilder::new().max_size(5).build();
1692 assert!(result.is_err());
1693 match result {
1694 Err(DriverError::Pool(msg)) => {
1695 assert!(msg.contains("URL"), "should mention URL: {msg}")
1696 }
1697 Err(e) => panic!("expected Pool error, got: {e:?}"),
1698 Ok(_) => panic!("expected error"),
1699 }
1700 }
1701
1702 #[test]
1703 fn pool_builder_invalid_url_error() {
1704 let result = PoolBuilder::new().url("ftp://something").build();
1705 assert!(result.is_err());
1706 }
1707
1708 #[test]
1709 fn pool_builder_stmt_cache_size_zero() {
1710 let pool = PoolBuilder::new()
1711 .url("postgres://user:pass@localhost/db")
1712 .max_stmt_cache_size(0)
1713 .build()
1714 .unwrap();
1715 assert_eq!(pool.inner.max_stmt_cache_size, 0);
1716 }
1717
1718 #[test]
1721 fn pool_builder_stale_timeout_default() {
1722 let pool = PoolBuilder::new()
1723 .url("postgres://user:pass@localhost/db")
1724 .build()
1725 .unwrap();
1726 assert_eq!(pool.inner.stale_timeout, Duration::from_secs(30));
1727 }
1728
1729 #[test]
1730 fn pool_builder_stale_timeout_custom() {
1731 let pool = PoolBuilder::new()
1732 .url("postgres://user:pass@localhost/db")
1733 .stale_timeout(Duration::from_secs(60))
1734 .build()
1735 .unwrap();
1736 assert_eq!(pool.inner.stale_timeout, Duration::from_secs(60));
1737 }
1738
1739 #[test]
1740 fn pool_builder_stale_timeout_zero() {
1741 let pool = PoolBuilder::new()
1742 .url("postgres://user:pass@localhost/db")
1743 .stale_timeout(Duration::from_secs(0))
1744 .build()
1745 .unwrap();
1746 assert_eq!(pool.inner.stale_timeout, Duration::from_secs(0));
1747 }
1748
1749 #[test]
1754 fn pool_status_reflects_max_size() {
1755 let pool = PoolBuilder::new()
1756 .url("postgres://user:pass@localhost/db")
1757 .max_size(20)
1758 .build()
1759 .unwrap();
1760 let status = pool.status();
1761 assert_eq!(status.max_size, 20);
1762 assert_eq!(status.idle, 0);
1763 assert_eq!(status.active, 0);
1764 assert_eq!(status.open, 0);
1765 }
1766
1767 #[test]
1772 fn pool_clone_shares_config() {
1773 let pool = PoolBuilder::new()
1774 .url("postgres://user:pass@localhost/db")
1775 .max_size(7)
1776 .build()
1777 .unwrap();
1778 let p2 = pool.clone();
1779 assert_eq!(pool.max_size(), 7);
1780 assert_eq!(p2.max_size(), 7);
1781 assert_eq!(pool.open_count(), p2.open_count());
1782 }
1783
1784 #[test]
1789 fn pool_set_warmup_sqls_empty() {
1790 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1791 pool.set_warmup_sqls(&[]);
1792 let sqls = pool
1793 .inner
1794 .warmup_sqls
1795 .lock()
1796 .unwrap_or_else(|e| e.into_inner())
1797 .clone();
1798 assert!(sqls.is_empty());
1799 }
1800
1801 #[test]
1802 fn pool_set_warmup_sqls_multiple() {
1803 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1804 pool.set_warmup_sqls(&["SELECT 1", "SELECT 2", "SELECT 3"]);
1805 let sqls = pool
1806 .inner
1807 .warmup_sqls
1808 .lock()
1809 .unwrap_or_else(|e| e.into_inner())
1810 .clone();
1811 assert_eq!(sqls.len(), 3);
1812 assert_eq!(&*sqls[0], "SELECT 1");
1813 assert_eq!(&*sqls[1], "SELECT 2");
1814 assert_eq!(&*sqls[2], "SELECT 3");
1815 }
1816
1817 #[test]
1818 fn pool_set_warmup_sqls_overwrite() {
1819 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1820 pool.set_warmup_sqls(&["SELECT 1"]);
1821 pool.set_warmup_sqls(&["SELECT 99"]);
1822 let sqls = pool
1823 .inner
1824 .warmup_sqls
1825 .lock()
1826 .unwrap_or_else(|e| e.into_inner())
1827 .clone();
1828 assert_eq!(sqls.len(), 1);
1829 assert_eq!(&*sqls[0], "SELECT 99");
1830 }
1831
1832 #[test]
1837 fn pool_status_debug() {
1838 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1839 let status = pool.status();
1840 let dbg = format!("{status:?}");
1841 assert!(dbg.contains("PoolStatus"));
1842 assert!(dbg.contains("idle"));
1843 assert!(dbg.contains("active"));
1844 assert!(dbg.contains("open"));
1845 assert!(dbg.contains("max_size"));
1846 }
1847
1848 #[test]
1853 fn config_host_is_uds_returns_true_for_slash() {
1854 let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
1855 assert!(config.host_is_uds());
1856 }
1857
1858 #[test]
1859 fn config_host_is_uds_returns_false_for_tcp() {
1860 let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
1861 assert!(!config.host_is_uds());
1862 }
1863
1864 #[test]
1865 fn config_host_is_uds_returns_false_for_ip() {
1866 let config = Config::from_url("postgres://user:pass@192.168.1.1/db").unwrap();
1867 assert!(!config.host_is_uds());
1868 }
1869
1870 #[test]
1875 fn pool_builder_full_chain() {
1876 let pool = PoolBuilder::new()
1877 .url("postgres://user:pass@localhost/db")
1878 .max_size(3)
1879 .max_lifetime(Some(Duration::from_secs(600)))
1880 .acquire_timeout(Some(Duration::from_secs(5)))
1881 .min_idle(1)
1882 .max_stmt_cache_size(128)
1883 .build()
1884 .unwrap();
1885 assert_eq!(pool.max_size(), 3);
1886 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(600)));
1887 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
1888 assert_eq!(pool.inner.min_idle, 1);
1889 assert_eq!(pool.inner.max_stmt_cache_size, 128);
1890 }
1891
1892 #[test]
1895 fn pool_max_size_zero_rejects_all_acquires() {
1896 let pool = PoolBuilder::new()
1897 .url("postgres://user:pass@localhost/db")
1898 .max_size(0)
1899 .build()
1900 .unwrap();
1901 let result = pool.acquire();
1902 assert!(result.is_err());
1903 match &result {
1904 Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
1905 _ => panic!("expected pool exhausted error"),
1906 }
1907 }
1908
1909 #[test]
1912 fn url_parse_unknown_sslmode_returns_error() {
1913 let result = Config::from_url("postgres://u:p@h/d?sslmode=bogus");
1914 assert!(result.is_err());
1915 let msg = format!("{}", result.unwrap_err());
1916 assert!(msg.contains("unknown sslmode"));
1917 }
1918
1919 #[test]
1920 fn url_parse_invalid_port_returns_error() {
1921 let result = Config::from_url("postgres://u:p@h:abc/d");
1922 assert!(result.is_err());
1923 let msg = format!("{}", result.unwrap_err());
1924 assert!(msg.contains("invalid port"));
1925 }
1926
1927 #[test]
1928 fn url_parse_missing_at_sign_returns_error() {
1929 let result = Config::from_url("postgres://u:plocalhost/d");
1930 assert!(result.is_err());
1931 let msg = format!("{}", result.unwrap_err());
1932 assert!(msg.contains("missing @"));
1933 }
1934
1935 #[test]
1936 fn url_parse_empty_host_returns_error() {
1937 let result = Config::from_url("postgres://u:p@/d");
1938 assert!(result.is_err());
1939 }
1940
1941 #[test]
1942 fn url_parse_empty_user_returns_error() {
1943 let result = Config::from_url("postgres://:p@h/d");
1944 assert!(result.is_err());
1945 }
1946
1947 #[test]
1948 fn url_parse_statement_timeout_invalid_uses_default() {
1949 let config = Config::from_url("postgres://u:p@h/d?statement_timeout=notnum").unwrap();
1950 assert_eq!(config.statement_timeout_secs, 30);
1951 }
1952
1953 #[test]
1954 fn url_parse_malformed_percent_encoding() {
1955 let result = Config::from_url("postgres://u%:p@h/d");
1956 assert!(result.is_err());
1957 }
1958
1959 #[test]
1960 fn url_parse_invalid_hex_in_percent_encoding() {
1961 let result = Config::from_url("postgres://u%ZZ:p@h/d");
1962 assert!(result.is_err());
1963 }
1964}
1965
1966#[cfg(all(test, feature = "detect-n-plus-one"))]
1969mod n_plus_one_tests {
1970 use super::NPlusOneDetector;
1971
1972 #[test]
1973 fn below_threshold_no_warning() {
1974 let mut d = NPlusOneDetector::new(10);
1975 for _ in 0..10 {
1976 d.track(42);
1977 }
1978 assert!(d.check_final().is_none());
1979 }
1980
1981 #[test]
1982 fn above_threshold_warns() {
1983 let mut d = NPlusOneDetector::new(10);
1984 for _ in 0..11 {
1985 d.track(42);
1986 }
1987 let w = d.check_final().unwrap();
1988 assert_eq!(w, (42, 11));
1989 }
1990
1991 #[test]
1992 fn exact_threshold_no_warning() {
1993 let mut d = NPlusOneDetector::new(5);
1994 for _ in 0..5 {
1995 d.track(99);
1996 }
1997 assert!(d.check_final().is_none(), "> not >=");
1998 }
1999
2000 #[test]
2001 fn threshold_plus_one_warns() {
2002 let mut d = NPlusOneDetector::new(5);
2003 for _ in 0..6 {
2004 d.track(99);
2005 }
2006 assert_eq!(d.check_final(), Some((99, 6)));
2007 }
2008
2009 #[test]
2010 fn alternating_hashes_no_warning() {
2011 let mut d = NPlusOneDetector::new(2);
2012 for i in 0..100 {
2013 d.track(if i % 2 == 0 { 1 } else { 2 });
2014 }
2015 assert!(d.check_final().is_none());
2016 }
2017
2018 #[test]
2019 fn single_query_no_warning() {
2020 let mut d = NPlusOneDetector::new(10);
2021 d.track(42);
2022 assert!(d.check_final().is_none());
2023 }
2024
2025 #[test]
2026 fn no_queries_no_warning() {
2027 let d = NPlusOneDetector::new(10);
2028 assert!(d.check_final().is_none());
2029 }
2030
2031 #[test]
2032 fn threshold_zero_warns_on_second() {
2033 let mut d = NPlusOneDetector::new(0);
2034 d.track(42);
2035 assert_eq!(d.check_final(), Some((42, 1)));
2037 }
2038
2039 #[test]
2040 fn threshold_max_never_warns() {
2041 let mut d = NPlusOneDetector::new(u16::MAX);
2042 for _ in 0..1000 {
2043 d.track(42);
2044 }
2045 assert!(d.check_final().is_none());
2046 }
2047
2048 #[test]
2049 fn saturating_add_no_overflow() {
2050 let mut d = NPlusOneDetector::new(10);
2051 d.last_query_hash = 42;
2052 d.repeat_count = u16::MAX - 1;
2053 d.track(42); d.track(42); assert_eq!(d.repeat_count, u16::MAX);
2056 }
2057
2058 #[test]
2059 fn different_hash_resets() {
2060 let mut d = NPlusOneDetector::new(100);
2061 for _ in 0..50 {
2062 d.track(1);
2063 }
2064 d.track(2); assert_eq!(d.repeat_count, 1);
2066 assert_eq!(d.last_query_hash, 2);
2067 }
2068
2069 #[test]
2070 fn multiple_n_plus_one_sequences() {
2071 let mut d = NPlusOneDetector::new(3);
2072 for _ in 0..5 {
2074 d.track(1);
2075 }
2076 for _ in 0..4 {
2079 d.track(2);
2080 }
2081 assert_eq!(d.check_final(), Some((2, 4)));
2083 }
2084
2085 #[test]
2086 fn warning_emitted_on_hash_switch() {
2087 let mut d = NPlusOneDetector::new(2);
2088 d.track(10);
2089 d.track(10);
2090 d.track(10); d.track(20);
2093 assert_eq!(d.last_query_hash, 20);
2095 assert_eq!(d.repeat_count, 1);
2096 }
2097
2098 #[test]
2099 fn hash_zero_treated_normally() {
2100 let mut d = NPlusOneDetector::new(2);
2101 d.track(0);
2102 d.track(0);
2103 d.track(0);
2104 assert!(d.check_final().is_none());
2106 }
2107
2108 #[test]
2109 fn long_sequence_correct_count() {
2110 let mut d = NPlusOneDetector::new(10);
2111 for _ in 0..500 {
2112 d.track(42);
2113 }
2114 assert_eq!(d.check_final(), Some((42, 500)));
2115 }
2116
2117 #[test]
2118 fn two_queries_below_threshold() {
2119 let mut d = NPlusOneDetector::new(10);
2120 d.track(1);
2121 d.track(1);
2122 assert!(d.check_final().is_none());
2123 }
2124
2125 #[test]
2126 fn interleaved_then_burst() {
2127 let mut d = NPlusOneDetector::new(3);
2128 d.track(1);
2130 d.track(2);
2131 d.track(1);
2132 d.track(2);
2133 for _ in 0..5 {
2135 d.track(5);
2136 }
2137 assert_eq!(d.check_final(), Some((5, 5)));
2138 }
2139
2140 #[test]
2143 fn pool_builder_n_plus_one_threshold_default() {
2144 let pool = super::PoolBuilder::new()
2145 .url("postgres://user:pass@localhost/db")
2146 .build()
2147 .unwrap();
2148 assert_eq!(pool.inner.n_plus_one_threshold, 10);
2149 }
2150
2151 #[test]
2152 fn pool_builder_n_plus_one_threshold_custom() {
2153 let pool = super::PoolBuilder::new()
2154 .url("postgres://user:pass@localhost/db")
2155 .n_plus_one_threshold(5)
2156 .build()
2157 .unwrap();
2158 assert_eq!(pool.inner.n_plus_one_threshold, 5);
2159 }
2160
2161 #[test]
2162 fn pool_builder_n_plus_one_threshold_zero() {
2163 let pool = super::PoolBuilder::new()
2164 .url("postgres://user:pass@localhost/db")
2165 .n_plus_one_threshold(0)
2166 .build()
2167 .unwrap();
2168 assert_eq!(pool.inner.n_plus_one_threshold, 0);
2169 }
2170
2171 #[test]
2172 fn pool_builder_n_plus_one_threshold_max() {
2173 let pool = super::PoolBuilder::new()
2174 .url("postgres://user:pass@localhost/db")
2175 .n_plus_one_threshold(u16::MAX)
2176 .build()
2177 .unwrap();
2178 assert_eq!(pool.inner.n_plus_one_threshold, u16::MAX);
2179 }
2180
2181 #[test]
2182 fn one_then_different_no_warning() {
2183 let mut d = NPlusOneDetector::new(10);
2184 d.track(1);
2185 d.track(2);
2186 assert!(d.check_final().is_none());
2188 }
2189
2190 #[test]
2191 fn nonzero_hash_after_zero_init() {
2192 let mut d = NPlusOneDetector::new(0);
2196 d.track(42);
2197 let w = d.check_final().unwrap();
2198 assert_eq!(w, (42, 1));
2199 }
2200
2201 #[test]
2202 fn independent_detectors_dont_interfere() {
2203 let mut d1 = NPlusOneDetector::new(5);
2205 let mut d2 = NPlusOneDetector::new(5);
2206
2207 for _ in 0..10 {
2209 d1.track(42);
2210 }
2211 d2.track(1);
2213 d2.track(2);
2214 d2.track(3);
2215
2216 assert!(d1.check_final().is_some());
2218 assert!(d2.check_final().is_none());
2219 }
2220
2221 #[test]
2222 fn rapid_hash_changes_dont_false_positive() {
2223 let mut d = NPlusOneDetector::new(2);
2225 for i in 0u64..1000 {
2226 d.track(i);
2227 }
2228 assert!(d.check_final().is_none());
2230 }
2231
2232 #[test]
2233 fn detector_reset_state_after_warning() {
2234 let mut d = NPlusOneDetector::new(2);
2236 d.track(1);
2237 d.track(1);
2238 d.track(1); d.track(2); d.track(2); assert!(d.check_final().is_none()); }
2243
2244 #[test]
2245 fn detector_with_realistic_orm_pattern() {
2246 let mut d = NPlusOneDetector::new(5);
2248 d.track(100); for _ in 0..20 {
2251 d.track(200); }
2253 assert_eq!(d.check_final(), Some((200, 20)));
2255 }
2256
2257 #[test]
2258 fn detector_with_legitimate_batch_pattern() {
2259 let mut d = NPlusOneDetector::new(10);
2262 for _ in 0..15 {
2263 d.track(300); }
2265 assert!(d.check_final().is_some());
2266 }
2267
2268 #[test]
2269 fn detector_exactly_at_boundaries() {
2270 for threshold in [0u16, 1, 2, 5, 10, 100] {
2271 let mut d = NPlusOneDetector::new(threshold);
2272 for _ in 0..=threshold {
2273 d.track(42);
2274 }
2275 assert!(
2277 d.check_final().is_some(),
2278 "threshold={threshold} should warn at count={}",
2279 threshold + 1
2280 );
2281 }
2282 }
2283
2284 #[test]
2285 fn detector_with_deterministic_random_sequences() {
2286 let mut d = NPlusOneDetector::new(5);
2288 let hashes: Vec<u64> = (0..100).map(|i| ((i * 7 + 3) % 4) as u64).collect();
2289 for &h in &hashes {
2290 d.track(h);
2291 }
2292 let _ = d.check_final();
2294 }
2295
2296 mod proptest_fuzz {
2297 use super::*;
2298 use proptest::prelude::*;
2299
2300 proptest! {
2301 #[test]
2302 fn detector_never_panics(
2303 hashes in proptest::collection::vec(0u64..100, 0..500),
2304 threshold in 0u16..100,
2305 ) {
2306 let mut d = NPlusOneDetector::new(threshold);
2307 for h in &hashes {
2308 d.track(*h);
2309 }
2310 let _ = d.check_final();
2311 }
2312
2313 #[test]
2314 fn sequential_repeats_always_detected(
2315 hash in 1u64..u64::MAX,
2316 count in 2u16..1000,
2317 threshold in 0u16..100,
2318 ) {
2319 let mut d = NPlusOneDetector::new(threshold);
2320 for _ in 0..count {
2321 d.track(hash);
2322 }
2323 if count > threshold {
2324 assert!(d.check_final().is_some(),
2325 "count={count} > threshold={threshold} should trigger");
2326 }
2327 }
2328 }
2329 }
2330}