1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use crate::arena::Arena;
14use crate::codec::Encode;
15use crate::conn::Connection;
16use crate::types::{Config, PgDataRow, QueryResult, SimpleRow};
17use crate::DriverError;
18
19#[cfg(feature = "async")]
20use crate::async_conn::AsyncConnection;
21
22pub(crate) enum PoolSlot {
30 Sync(Connection),
32 #[cfg(feature = "async")]
34 Async(AsyncConnection),
35}
36
37#[cfg(feature = "detect-n-plus-one")]
43pub(crate) struct NPlusOneDetector {
44 last_query_hash: u64,
45 repeat_count: u16,
46 threshold: u16,
47}
48
49#[cfg(feature = "detect-n-plus-one")]
50impl NPlusOneDetector {
51 pub(crate) fn new(threshold: u16) -> Self {
53 Self {
54 last_query_hash: 0,
55 repeat_count: 0,
56 threshold,
57 }
58 }
59
60 #[inline]
62 pub(crate) fn track(&mut self, sql_hash: u64) {
63 if sql_hash == self.last_query_hash {
64 self.repeat_count = self.repeat_count.saturating_add(1);
65 } else {
66 self.emit_warning();
68 self.last_query_hash = sql_hash;
69 self.repeat_count = 1;
70 }
71 }
72
73 pub(crate) fn check_final(&self) -> Option<(u64, u16)> {
76 if self.repeat_count > self.threshold && self.last_query_hash != 0 {
77 Some((self.last_query_hash, self.repeat_count))
78 } else {
79 None
80 }
81 }
82
83 #[cold]
85 #[inline(never)]
86 fn emit_warning(&self) {
87 if let Some((hash, count)) = self.check_final() {
88 log::warn!(
89 "[bsql] potential N+1 detected: sql_hash={:#018x} repeated {} times (threshold: {})",
90 hash,
91 count,
92 self.threshold,
93 );
94 }
95 }
96
97 #[cold]
99 #[inline(never)]
100 pub(crate) fn emit_final_warning(&self) {
101 self.emit_warning();
102 }
103}
104
105pub struct Pool {
121 inner: Arc<PoolInner>,
122}
123
124struct PoolInner {
125 stack: std::sync::Mutex<Vec<PoolSlot>>,
129 max_size: usize,
130 open_count: AtomicUsize,
131 config: Arc<Config>,
132 closed: AtomicBool,
134 release_pair: (std::sync::Mutex<()>, std::sync::Condvar),
137 max_lifetime: Option<Duration>,
140 acquire_timeout: Option<Duration>,
142 min_idle: usize,
144 warmup_sqls: std::sync::Mutex<Arc<Vec<Box<str>>>>,
146 max_stmt_cache_size: usize,
148 stale_timeout: Duration,
151 #[cfg(feature = "detect-n-plus-one")]
154 n_plus_one_threshold: u16,
155}
156
157impl Pool {
158 pub fn connect(url: &str) -> Result<Self, DriverError> {
162 PoolBuilder::new().url(url).build()
163 }
164
165 pub fn builder() -> PoolBuilder {
167 PoolBuilder::new()
168 }
169
170 #[inline]
178 pub fn acquire(&self) -> Result<PoolGuard, DriverError> {
179 if self.inner.closed.load(Ordering::Acquire) {
180 return Err(DriverError::Pool("pool is closed".into()));
181 }
182
183 if let Some(guard) = self.try_pop_idle()? {
185 return Ok(guard);
186 }
187
188 loop {
190 let current = self.inner.open_count.load(Ordering::Acquire);
191 if current >= self.inner.max_size {
192 if let Some(timeout) = self.inner.acquire_timeout {
193 let (lock, cvar) = &self.inner.release_pair;
194 let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
195 let (_guard, result) = cvar
196 .wait_timeout(guard, timeout)
197 .unwrap_or_else(|e| e.into_inner());
198 if result.timed_out() {
199 return Err(DriverError::Pool(
200 "pool exhausted: acquire timeout expired".into(),
201 ));
202 }
203 if let Some(guard) = self.try_pop_idle()? {
205 return Ok(guard);
206 }
207 continue;
209 }
210 return Err(DriverError::Pool(
211 "pool exhausted: all connections in use".into(),
212 ));
213 }
214 if self
215 .inner
216 .open_count
217 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
218 .is_ok()
219 {
220 break;
221 }
222 }
224
225 let conn_result = Connection::connect_arc(self.inner.config.clone());
227 match conn_result {
228 Ok(mut conn) => {
229 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
231 self.warmup_conn(&mut conn);
233
234 Ok(PoolGuard {
235 conn: Some(PoolSlot::Sync(conn)),
236 pool: self.inner.clone(),
237 discard: false,
238 #[cfg(feature = "detect-n-plus-one")]
239 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
240 })
241 }
242 Err(e) => {
243 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
245 Err(e)
246 }
247 }
248 }
249
250 #[inline]
256 fn try_pop_idle(&self) -> Result<Option<PoolGuard>, DriverError> {
257 loop {
261 let (mut slot, needs_health_check) = {
262 let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
263 loop {
264 let Some(slot) = stack.pop() else {
265 return Ok(None);
266 };
267 let (created_at, idle_dur) = match &slot {
268 PoolSlot::Sync(conn) => (conn.created_at(), conn.idle_duration()),
269 #[cfg(feature = "async")]
270 PoolSlot::Async(conn) => (conn.created_at(), conn.idle_duration()),
271 };
272 if let Some(max_lifetime) = self.inner.max_lifetime {
273 if created_at.elapsed() >= max_lifetime {
274 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
275 continue;
276 }
277 }
278 if idle_dur >= self.inner.stale_timeout {
279 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
281 continue;
282 }
283 break (slot, idle_dur > Duration::from_secs(5));
284 }
285 };
286 if needs_health_check {
290 let alive = match &mut slot {
291 PoolSlot::Sync(conn) => conn.simple_query("").is_ok(),
292 #[cfg(feature = "async")]
293 PoolSlot::Async(_) => true, };
295 if !alive {
296 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
297 continue; }
299 }
300 return Ok(Some(PoolGuard {
301 conn: Some(slot),
302 pool: self.inner.clone(),
303 discard: false,
304 #[cfg(feature = "detect-n-plus-one")]
305 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
306 }));
307 }
308 }
309
310 pub fn is_uds(&self) -> bool {
315 #[cfg(unix)]
316 {
317 self.inner.config.host_is_uds()
318 }
319 #[cfg(not(unix))]
320 {
321 false
322 }
323 }
324
325 pub fn begin(&self) -> Result<Transaction, DriverError> {
327 let mut guard = self.acquire()?;
328 guard.simple_query("BEGIN")?;
329 Ok(Transaction {
330 guard,
331 committed: false,
332 deferred_buf: Vec::new(),
333 deferred_count: 0,
334 })
335 }
336
337 pub fn open_count(&self) -> usize {
339 self.inner.open_count.load(Ordering::Relaxed)
340 }
341
342 pub fn max_size(&self) -> usize {
344 self.inner.max_size
345 }
346
347 pub fn status(&self) -> PoolStatus {
349 let idle = self
350 .inner
351 .stack
352 .lock()
353 .unwrap_or_else(|e| e.into_inner())
354 .len();
355 let open = self.inner.open_count.load(Ordering::Relaxed);
356 let active = open.saturating_sub(idle);
357 PoolStatus {
358 idle,
359 active,
360 open,
361 max_size: self.inner.max_size,
362 }
363 }
364
365 fn warmup_conn(&self, conn: &mut Connection) {
373 let sqls = self
374 .inner
375 .warmup_sqls
376 .lock()
377 .unwrap_or_else(|e| e.into_inner())
378 .clone();
379
380 if sqls.is_empty() {
381 return;
382 }
383
384 let batch: Vec<(&str, u64)> = sqls
385 .iter()
386 .map(|sql| (sql.as_ref(), crate::types::hash_sql(sql)))
387 .collect();
388
389 let _ = conn.prepare_batch(&batch);
390 }
391
392 pub fn set_warmup_sqls(&self, sqls: &[&str]) {
414 let boxed: Arc<Vec<Box<str>>> =
415 Arc::new(sqls.iter().map(|s| (*s).into()).collect::<Vec<_>>());
416 *self
417 .inner
418 .warmup_sqls
419 .lock()
420 .unwrap_or_else(|e| e.into_inner()) = boxed;
421 }
422
423 pub fn close(&self) {
426 self.inner.closed.store(true, Ordering::Release);
427 let slots: Vec<PoolSlot> = {
429 let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
430 std::mem::take(&mut *stack)
431 };
432 for slot in slots {
433 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
434 match slot {
435 PoolSlot::Sync(conn) => {
436 let _ = conn.close();
437 }
438 #[cfg(feature = "async")]
439 PoolSlot::Async(_conn) => {
440 }
443 }
444 }
445 let (_, cvar) = &self.inner.release_pair;
447 cvar.notify_all();
448 }
449
450 pub fn is_closed(&self) -> bool {
452 self.inner.closed.load(Ordering::Acquire)
453 }
454
455 #[cfg(feature = "async")]
465 pub async fn acquire_async(&self) -> Result<PoolGuard, DriverError> {
466 if self.inner.closed.load(Ordering::Acquire) {
467 return Err(DriverError::Pool("pool is closed".into()));
468 }
469
470 if let Some(guard) = self.try_pop_idle()? {
472 return Ok(guard);
473 }
474
475 loop {
477 let current = self.inner.open_count.load(Ordering::Acquire);
478 if current >= self.inner.max_size {
479 if let Some(timeout) = self.inner.acquire_timeout {
480 let (lock, cvar) = &self.inner.release_pair;
481 let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
482 let (_guard, result) = cvar
483 .wait_timeout(guard, timeout)
484 .unwrap_or_else(|e| e.into_inner());
485 if result.timed_out() {
486 return Err(DriverError::Pool(
487 "pool exhausted: acquire timeout expired".into(),
488 ));
489 }
490 if let Some(guard) = self.try_pop_idle()? {
491 return Ok(guard);
492 }
493 continue;
494 }
495 return Err(DriverError::Pool(
496 "pool exhausted: all connections in use".into(),
497 ));
498 }
499 if self
500 .inner
501 .open_count
502 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
503 .is_ok()
504 {
505 break;
506 }
507 }
508
509 if self.inner.config.host_is_uds() {
511 let conn_result = Connection::connect_arc(self.inner.config.clone());
513 match conn_result {
514 Ok(mut conn) => {
515 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
516 self.warmup_conn(&mut conn);
517 Ok(PoolGuard {
518 conn: Some(PoolSlot::Sync(conn)),
519 pool: self.inner.clone(),
520 discard: false,
521 #[cfg(feature = "detect-n-plus-one")]
522 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
523 })
524 }
525 Err(e) => {
526 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
527 Err(e)
528 }
529 }
530 } else {
531 let conn_result = AsyncConnection::connect_arc(self.inner.config.clone()).await;
533 match conn_result {
534 Ok(mut conn) => {
535 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
536 Ok(PoolGuard {
537 conn: Some(PoolSlot::Async(conn)),
538 pool: self.inner.clone(),
539 discard: false,
540 #[cfg(feature = "detect-n-plus-one")]
541 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
542 })
543 }
544 Err(e) => {
545 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
546 Err(e)
547 }
548 }
549 }
550 }
551}
552
553impl Clone for Pool {
554 fn clone(&self) -> Self {
555 Pool {
556 inner: self.inner.clone(),
557 }
558 }
559}
560
561#[derive(Debug, Clone, Copy)]
565pub struct PoolStatus {
566 pub idle: usize,
568 pub active: usize,
570 pub open: usize,
572 pub max_size: usize,
574}
575
576pub struct PoolBuilder {
580 url: Option<String>,
581 max_size: usize,
582 max_lifetime: Option<Duration>,
584 acquire_timeout: Option<Duration>,
586 min_idle: usize,
588 max_stmt_cache_size: usize,
590 stale_timeout: Duration,
592 #[cfg(feature = "detect-n-plus-one")]
594 n_plus_one_threshold: Option<u16>,
595}
596
597impl PoolBuilder {
598 fn new() -> Self {
599 Self {
600 url: None,
601 max_size: 10,
602 max_lifetime: Some(Duration::from_secs(30 * 60)), acquire_timeout: Some(Duration::from_secs(5)), min_idle: 0, max_stmt_cache_size: 256, stale_timeout: Duration::from_secs(30), #[cfg(feature = "detect-n-plus-one")]
608 n_plus_one_threshold: None,
609 }
610 }
611
612 pub fn url(mut self, url: &str) -> Self {
614 self.url = Some(url.to_owned());
615 self
616 }
617
618 pub fn max_size(mut self, size: usize) -> Self {
622 self.max_size = size;
623 self
624 }
625
626 pub fn max_lifetime(mut self, lifetime: Option<Duration>) -> Self {
629 self.max_lifetime = lifetime;
630 self
631 }
632
633 pub fn acquire_timeout(mut self, timeout: Option<Duration>) -> Self {
636 self.acquire_timeout = timeout;
637 self
638 }
639
640 pub fn min_idle(mut self, count: usize) -> Self {
643 self.min_idle = count;
644 self
645 }
646
647 pub fn max_stmt_cache_size(mut self, size: usize) -> Self {
651 self.max_stmt_cache_size = size;
652 self
653 }
654
655 pub fn stale_timeout(mut self, timeout: Duration) -> Self {
659 self.stale_timeout = timeout;
660 self
661 }
662
663 #[cfg(feature = "detect-n-plus-one")]
668 pub fn n_plus_one_threshold(mut self, n: u16) -> Self {
669 self.n_plus_one_threshold = Some(n);
670 self
671 }
672
673 pub fn build(self) -> Result<Pool, DriverError> {
675 let url = self
676 .url
677 .ok_or_else(|| DriverError::Pool("pool builder requires a URL".into()))?;
678
679 let config = Arc::new(Config::from_url(&url)?);
680
681 let pool = Pool {
682 inner: Arc::new(PoolInner {
683 stack: std::sync::Mutex::new(Vec::with_capacity(self.max_size)),
684 max_size: self.max_size,
685 open_count: AtomicUsize::new(0),
686 config,
687 closed: AtomicBool::new(false),
688 release_pair: (std::sync::Mutex::new(()), std::sync::Condvar::new()),
689 max_lifetime: self.max_lifetime,
690 acquire_timeout: self.acquire_timeout,
691 min_idle: self.min_idle,
692 warmup_sqls: std::sync::Mutex::new(Arc::new(Vec::new())),
693 max_stmt_cache_size: self.max_stmt_cache_size,
694 stale_timeout: self.stale_timeout,
695 #[cfg(feature = "detect-n-plus-one")]
696 n_plus_one_threshold: self.n_plus_one_threshold.unwrap_or(10),
697 }),
698 };
699
700 if self.min_idle > 0 {
701 let inner = pool.inner.clone();
702 std::thread::spawn(move || {
703 maintain_min_idle(inner);
704 });
705 }
706
707 Ok(pool)
708 }
709}
710
711fn maintain_min_idle(inner: Arc<PoolInner>) {
713 loop {
714 if inner.closed.load(Ordering::Acquire) {
715 return;
716 }
717
718 let idle_count = inner.stack.lock().unwrap_or_else(|e| e.into_inner()).len();
719 let needed = inner.min_idle.saturating_sub(idle_count);
720
721 for _ in 0..needed {
722 if inner.closed.load(Ordering::Acquire) {
723 return;
724 }
725 let current = inner.open_count.load(Ordering::Acquire);
726 if current >= inner.max_size {
727 break;
728 }
729 if inner
730 .open_count
731 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
732 .is_err()
733 {
734 continue;
735 }
736
737 match Connection::connect_arc(inner.config.clone()) {
738 Ok(conn) => {
739 let mut stack = inner.stack.lock().unwrap_or_else(|e| e.into_inner());
740 stack.push(PoolSlot::Sync(conn));
741 let (_, cvar) = &inner.release_pair;
742 cvar.notify_one();
743 }
744 Err(_) => {
745 inner.open_count.fetch_sub(1, Ordering::AcqRel);
746 }
747 }
748 }
749
750 std::thread::sleep(Duration::from_secs(1));
753 }
754}
755
756pub struct PoolGuard {
763 conn: Option<PoolSlot>,
764 pool: Arc<PoolInner>,
765 discard: bool,
767 #[cfg(feature = "detect-n-plus-one")]
769 detector: NPlusOneDetector,
770}
771
772impl PoolGuard {
773 #[inline]
776 fn sync_conn(&self) -> Result<&Connection, DriverError> {
777 match self.conn.as_ref() {
778 Some(PoolSlot::Sync(conn)) => Ok(conn),
779 #[cfg(feature = "async")]
780 Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
781 "expected sync connection, got async; use async methods".into(),
782 )),
783 None => Err(DriverError::Pool("connection already taken".into())),
784 }
785 }
786
787 #[inline]
789 fn sync_conn_mut(&mut self) -> Result<&mut Connection, DriverError> {
790 match self.conn.as_mut() {
791 Some(PoolSlot::Sync(conn)) => Ok(conn),
792 #[cfg(feature = "async")]
793 Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
794 "expected sync connection, got async; use async methods".into(),
795 )),
796 None => Err(DriverError::Pool("connection already taken".into())),
797 }
798 }
799
800 pub fn mark_discard(&mut self) {
803 self.discard = true;
804 }
805
806 pub fn cancel(&self) -> Result<(), DriverError> {
811 self.sync_conn()?.cancel()
812 }
813
814 pub fn pid(&self) -> i32 {
818 match self.conn.as_ref().expect("connection taken") {
819 PoolSlot::Sync(conn) => conn.pid(),
820 #[cfg(feature = "async")]
821 PoolSlot::Async(conn) => conn.pid(),
822 }
823 }
824
825 pub fn is_idle(&self) -> bool {
827 match self.conn.as_ref().expect("connection taken") {
828 PoolSlot::Sync(conn) => conn.is_idle(),
829 #[cfg(feature = "async")]
830 PoolSlot::Async(conn) => conn.is_idle(),
831 }
832 }
833
834 pub fn is_in_transaction(&self) -> bool {
836 match self.conn.as_ref().expect("connection taken") {
837 PoolSlot::Sync(conn) => conn.is_in_transaction(),
838 #[cfg(feature = "async")]
839 PoolSlot::Async(conn) => conn.is_in_transaction(),
840 }
841 }
842
843 #[inline]
847 pub fn query(
848 &mut self,
849 sql: &str,
850 sql_hash: u64,
851 params: &[&(dyn Encode + Sync)],
852 ) -> Result<QueryResult, DriverError> {
853 #[cfg(feature = "detect-n-plus-one")]
854 self.detector.track(sql_hash);
855 self.sync_conn_mut()?.query(sql, sql_hash, params)
856 }
857
858 #[inline]
860 pub fn execute(
861 &mut self,
862 sql: &str,
863 sql_hash: u64,
864 params: &[&(dyn Encode + Sync)],
865 ) -> Result<u64, DriverError> {
866 #[cfg(feature = "detect-n-plus-one")]
867 self.detector.track(sql_hash);
868 self.sync_conn_mut()?.execute(sql, sql_hash, params)
869 }
870
871 pub fn execute_pipeline(
876 &mut self,
877 sql: &str,
878 sql_hash: u64,
879 param_sets: &[&[&(dyn Encode + Sync)]],
880 ) -> Result<Vec<u64>, DriverError> {
881 #[cfg(feature = "detect-n-plus-one")]
882 self.detector.track(sql_hash);
883 self.sync_conn_mut()?
884 .execute_pipeline(sql, sql_hash, param_sets)
885 }
886
887 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
889 self.sync_conn_mut()?.simple_query(sql)
890 }
891
892 pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
896 self.sync_conn_mut()?.simple_query_rows(sql)
897 }
898
899 pub fn for_each<F>(
901 &mut self,
902 sql: &str,
903 sql_hash: u64,
904 params: &[&(dyn Encode + Sync)],
905 f: F,
906 ) -> Result<(), DriverError>
907 where
908 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
909 {
910 #[cfg(feature = "detect-n-plus-one")]
911 self.detector.track(sql_hash);
912 self.sync_conn_mut()?.for_each(sql, sql_hash, params, f)
913 }
914
915 pub fn for_each_raw<F>(
917 &mut self,
918 sql: &str,
919 sql_hash: u64,
920 params: &[&(dyn Encode + Sync)],
921 f: F,
922 ) -> Result<(), DriverError>
923 where
924 F: FnMut(&[u8]) -> Result<(), DriverError>,
925 {
926 #[cfg(feature = "detect-n-plus-one")]
927 self.detector.track(sql_hash);
928 self.sync_conn_mut()?.for_each_raw(sql, sql_hash, params, f)
929 }
930
931 pub fn query_streaming_start(
935 &mut self,
936 sql: &str,
937 sql_hash: u64,
938 params: &[&(dyn Encode + Sync)],
939 chunk_size: i32,
940 ) -> Result<(std::sync::Arc<[crate::types::ColumnDesc]>, bool), DriverError> {
941 #[cfg(feature = "detect-n-plus-one")]
942 self.detector.track(sql_hash);
943 self.sync_conn_mut()?
944 .query_streaming_start(sql, sql_hash, params, chunk_size)
945 }
946
947 pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
949 self.sync_conn_mut()?.streaming_send_execute(chunk_size)
950 }
951
952 pub fn streaming_next_chunk(
954 &mut self,
955 arena: &mut Arena,
956 all_col_offsets: &mut Vec<(usize, i32)>,
957 ) -> Result<bool, DriverError> {
958 self.sync_conn_mut()?
959 .streaming_next_chunk(arena, all_col_offsets)
960 }
961
962 pub fn copy_in<'a, I>(
968 &mut self,
969 table: &str,
970 columns: &[&str],
971 rows: I,
972 ) -> Result<u64, DriverError>
973 where
974 I: IntoIterator<Item = &'a str>,
975 {
976 self.sync_conn_mut()?.copy_in(table, columns, rows)
977 }
978
979 pub fn copy_out<W: std::io::Write>(
983 &mut self,
984 query: &str,
985 writer: &mut W,
986 ) -> Result<u64, DriverError> {
987 self.sync_conn_mut()?.copy_out(query, writer)
988 }
989
990 pub fn is_sync(&self) -> bool {
992 matches!(self.conn.as_ref(), Some(PoolSlot::Sync(_)))
993 }
994
995 #[cfg(feature = "async")]
997 pub fn is_async(&self) -> bool {
998 matches!(self.conn.as_ref(), Some(PoolSlot::Async(_)))
999 }
1000
1001 #[cfg(feature = "async")]
1009 pub async fn query_async(
1010 &mut self,
1011 sql: &str,
1012 sql_hash: u64,
1013 params: &[&(dyn Encode + Sync)],
1014 ) -> Result<QueryResult, DriverError> {
1015 #[cfg(feature = "detect-n-plus-one")]
1016 self.detector.track(sql_hash);
1017 match self.conn.as_mut() {
1018 Some(PoolSlot::Sync(conn)) => conn.query(sql, sql_hash, params),
1019 Some(PoolSlot::Async(conn)) => conn.query(sql, sql_hash, params).await,
1020 None => Err(DriverError::Pool("connection already taken".into())),
1021 }
1022 }
1023
1024 #[cfg(feature = "async")]
1026 pub async fn execute_async(
1027 &mut self,
1028 sql: &str,
1029 sql_hash: u64,
1030 params: &[&(dyn Encode + Sync)],
1031 ) -> Result<u64, DriverError> {
1032 #[cfg(feature = "detect-n-plus-one")]
1033 self.detector.track(sql_hash);
1034 match self.conn.as_mut() {
1035 Some(PoolSlot::Sync(conn)) => conn.execute(sql, sql_hash, params),
1036 Some(PoolSlot::Async(conn)) => conn.execute(sql, sql_hash, params).await,
1037 None => Err(DriverError::Pool("connection already taken".into())),
1038 }
1039 }
1040
1041 #[cfg(feature = "async")]
1043 pub async fn simple_query_async(&mut self, sql: &str) -> Result<(), DriverError> {
1044 match self.conn.as_mut() {
1045 Some(PoolSlot::Sync(conn)) => conn.simple_query(sql),
1046 Some(PoolSlot::Async(conn)) => conn.simple_query(sql).await,
1047 None => Err(DriverError::Pool("connection already taken".into())),
1048 }
1049 }
1050
1051 pub(crate) fn ensure_stmt_prepared(
1055 &mut self,
1056 sql: &str,
1057 sql_hash: u64,
1058 params: &[&(dyn Encode + Sync)],
1059 ) -> Result<[u8; 18], DriverError> {
1060 self.sync_conn_mut()?
1061 .ensure_stmt_prepared(sql, sql_hash, params)
1062 }
1063
1064 pub(crate) fn write_deferred_bind_execute(
1066 &self,
1067 sql: &str,
1068 sql_hash: u64,
1069 params: &[&(dyn Encode + Sync)],
1070 buf: &mut Vec<u8>,
1071 ) {
1072 let conn = self
1073 .sync_conn()
1074 .expect("sync_conn failed in write_deferred");
1075 conn.write_deferred_bind_execute(sql, sql_hash, params, buf);
1076 }
1077
1078 pub(crate) fn flush_deferred_pipeline(
1080 &mut self,
1081 buf: &mut Vec<u8>,
1082 count: usize,
1083 ) -> Result<Vec<u64>, DriverError> {
1084 self.sync_conn_mut()?.flush_deferred_pipeline(buf, count)
1085 }
1086}
1087
1088impl Drop for PoolGuard {
1089 fn drop(&mut self) {
1090 #[cfg(feature = "detect-n-plus-one")]
1091 self.detector.emit_final_warning();
1092
1093 if let Some(slot) = self.conn.take() {
1094 let should_discard = self.discard
1096 || self.pool.closed.load(Ordering::Acquire)
1097 || match &slot {
1098 PoolSlot::Sync(conn) => {
1099 conn.is_in_failed_transaction()
1100 || conn.is_in_transaction()
1101 || conn.is_streaming()
1102 }
1103 #[cfg(feature = "async")]
1104 PoolSlot::Async(conn) => {
1105 conn.is_in_failed_transaction() || conn.is_in_transaction()
1106 }
1107 };
1108
1109 if should_discard {
1110 self.pool.open_count.fetch_sub(1, Ordering::AcqRel);
1111 return;
1112 }
1113
1114 let mut slot = slot;
1117 match &mut slot {
1118 PoolSlot::Sync(conn) => {
1119 if conn.query_counter() & 63 == 0 {
1120 conn.touch();
1121 }
1122 }
1123 #[cfg(feature = "async")]
1124 PoolSlot::Async(conn) => {
1125 if conn.query_counter() & 63 == 0 {
1126 conn.touch();
1127 }
1128 }
1129 }
1130
1131 {
1133 let mut stack = self.pool.stack.lock().unwrap_or_else(|e| e.into_inner());
1134 stack.push(slot);
1135 }
1136
1137 if self.pool.open_count.load(Ordering::Relaxed) >= self.pool.max_size {
1139 let (_, cvar) = &self.pool.release_pair;
1140 cvar.notify_one();
1141 }
1142 }
1143 }
1144}
1145
1146pub struct Transaction {
1162 guard: PoolGuard,
1163 committed: bool,
1164 deferred_buf: Vec<u8>,
1166 deferred_count: usize,
1168}
1169
1170impl Transaction {
1171 pub fn commit(mut self) -> Result<(), DriverError> {
1175 if self.deferred_count > 0 {
1176 self.flush_deferred()?;
1177 }
1178 self.guard.simple_query("COMMIT")?;
1179 self.committed = true;
1180 Ok(())
1181 }
1182
1183 pub fn rollback(mut self) -> Result<(), DriverError> {
1187 self.deferred_buf.clear();
1188 self.deferred_count = 0;
1189 self.guard.simple_query("ROLLBACK")?;
1190 self.committed = true; Ok(())
1192 }
1193
1194 pub fn query(
1199 &mut self,
1200 sql: &str,
1201 sql_hash: u64,
1202 params: &[&(dyn Encode + Sync)],
1203 ) -> Result<QueryResult, DriverError> {
1204 if self.deferred_count > 0 {
1205 self.flush_deferred()?;
1206 }
1207 self.guard.query(sql, sql_hash, params)
1208 }
1209
1210 pub fn execute(
1212 &mut self,
1213 sql: &str,
1214 sql_hash: u64,
1215 params: &[&(dyn Encode + Sync)],
1216 ) -> Result<u64, DriverError> {
1217 self.guard.execute(sql, sql_hash, params)
1218 }
1219
1220 pub fn execute_pipeline(
1222 &mut self,
1223 sql: &str,
1224 sql_hash: u64,
1225 param_sets: &[&[&(dyn Encode + Sync)]],
1226 ) -> Result<Vec<u64>, DriverError> {
1227 self.guard.execute_pipeline(sql, sql_hash, param_sets)
1228 }
1229
1230 pub fn for_each<F>(
1234 &mut self,
1235 sql: &str,
1236 sql_hash: u64,
1237 params: &[&(dyn Encode + Sync)],
1238 f: F,
1239 ) -> Result<(), DriverError>
1240 where
1241 F: FnMut(crate::types::PgDataRow<'_>) -> Result<(), DriverError>,
1242 {
1243 if self.deferred_count > 0 {
1244 self.flush_deferred()?;
1245 }
1246 self.guard.for_each(sql, sql_hash, params, f)
1247 }
1248
1249 pub fn for_each_raw<F>(
1253 &mut self,
1254 sql: &str,
1255 sql_hash: u64,
1256 params: &[&(dyn Encode + Sync)],
1257 f: F,
1258 ) -> Result<(), DriverError>
1259 where
1260 F: FnMut(&[u8]) -> Result<(), DriverError>,
1261 {
1262 if self.deferred_count > 0 {
1263 self.flush_deferred()?;
1264 }
1265 self.guard.for_each_raw(sql, sql_hash, params, f)
1266 }
1267
1268 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
1272 if self.deferred_count > 0 {
1273 self.flush_deferred()?;
1274 }
1275 self.guard.simple_query(sql)
1276 }
1277
1278 pub fn defer_execute(
1307 &mut self,
1308 sql: &str,
1309 sql_hash: u64,
1310 params: &[&(dyn Encode + Sync)],
1311 ) -> Result<(), DriverError> {
1312 if params.len() > i16::MAX as usize {
1313 return Err(DriverError::Protocol(format!(
1314 "parameter count {} exceeds maximum {}",
1315 params.len(),
1316 i16::MAX
1317 )));
1318 }
1319
1320 self.guard.ensure_stmt_prepared(sql, sql_hash, params)?;
1322
1323 self.guard
1325 .write_deferred_bind_execute(sql, sql_hash, params, &mut self.deferred_buf);
1326 self.deferred_count += 1;
1327 Ok(())
1328 }
1329
1330 pub fn flush_deferred(&mut self) -> Result<Vec<u64>, DriverError> {
1335 let count = self.deferred_count;
1336 self.deferred_count = 0;
1337 self.guard
1338 .flush_deferred_pipeline(&mut self.deferred_buf, count)
1339 }
1340
1341 pub fn deferred_count(&self) -> usize {
1343 self.deferred_count
1344 }
1345}
1346
1347impl Drop for Transaction {
1348 fn drop(&mut self) {
1349 if !self.committed {
1350 if let Some(_slot) = self.guard.conn.take() {
1353 self.guard.pool.open_count.fetch_sub(1, Ordering::AcqRel);
1354 }
1356 }
1357 }
1358}
1359
1360#[cfg(test)]
1361mod tests {
1362 use super::*;
1363
1364 #[test]
1365 fn pool_builder_requires_url() {
1366 let result = PoolBuilder::new().build();
1367 assert!(result.is_err());
1368 }
1369
1370 #[test]
1371 fn pool_builder_validates_url() {
1372 let result = PoolBuilder::new().url("not_a_url").build();
1373 assert!(result.is_err());
1374 }
1375
1376 #[test]
1377 fn pool_builder_accepts_valid_url() {
1378 let pool = PoolBuilder::new()
1379 .url("postgres://user:pass@localhost/db")
1380 .max_size(5)
1381 .build()
1382 .unwrap();
1383 assert_eq!(pool.max_size(), 5);
1384 assert_eq!(pool.open_count(), 0);
1385 }
1386
1387 #[test]
1388 fn pool_connect_validates_url() {
1389 let result = Pool::connect("not_a_url");
1390 assert!(result.is_err());
1391 }
1392
1393 #[test]
1394 fn pool_max_size_zero() {
1395 let pool = PoolBuilder::new()
1396 .url("postgres://user:pass@localhost/db")
1397 .max_size(0)
1398 .build()
1399 .unwrap();
1400
1401 let result = pool.acquire();
1402 assert!(result.is_err());
1403 match result {
1404 Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
1405 Err(e) => panic!("expected Pool error, got: {e:?}"),
1406 Ok(_) => panic!("expected error, got Ok"),
1407 }
1408 }
1409
1410 #[test]
1411 fn pool_clone_shares_state() {
1412 let pool = PoolBuilder::new()
1413 .url("postgres://user:pass@localhost/db")
1414 .max_size(5)
1415 .build()
1416 .unwrap();
1417
1418 let pool2 = pool.clone();
1419 assert_eq!(pool.max_size(), pool2.max_size());
1420 }
1421
1422 #[test]
1426 fn pool_builder_max_lifetime() {
1427 let pool = PoolBuilder::new()
1428 .url("postgres://user:pass@localhost/db")
1429 .max_lifetime(Some(Duration::from_secs(60)))
1430 .build()
1431 .unwrap();
1432 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(60)));
1433 }
1434
1435 #[test]
1437 fn pool_builder_max_lifetime_none() {
1438 let pool = PoolBuilder::new()
1439 .url("postgres://user:pass@localhost/db")
1440 .max_lifetime(None)
1441 .build()
1442 .unwrap();
1443 assert_eq!(pool.inner.max_lifetime, None);
1444 }
1445
1446 #[test]
1448 fn pool_builder_acquire_timeout_none() {
1449 let pool = PoolBuilder::new()
1450 .url("postgres://user:pass@localhost/db")
1451 .acquire_timeout(None)
1452 .build()
1453 .unwrap();
1454 assert_eq!(pool.inner.acquire_timeout, None);
1455 }
1456
1457 #[test]
1459 fn pool_builder_acquire_timeout_custom() {
1460 let pool = PoolBuilder::new()
1461 .url("postgres://user:pass@localhost/db")
1462 .acquire_timeout(Some(Duration::from_secs(10)))
1463 .build()
1464 .unwrap();
1465 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(10)));
1466 }
1467
1468 #[test]
1470 fn pool_builder_min_idle() {
1471 let pool = PoolBuilder::new()
1472 .url("postgres://user:pass@localhost/db")
1473 .min_idle(2)
1474 .build()
1475 .unwrap();
1476 assert_eq!(pool.inner.min_idle, 2);
1477 }
1478
1479 #[test]
1481 fn pool_close_marks_closed() {
1482 let pool = PoolBuilder::new()
1483 .url("postgres://user:pass@localhost/db")
1484 .max_size(5)
1485 .build()
1486 .unwrap();
1487
1488 assert!(!pool.is_closed());
1489 pool.close();
1490 assert!(pool.is_closed());
1491
1492 let result = pool.acquire();
1494 assert!(result.is_err());
1495 match result {
1496 Err(DriverError::Pool(msg)) => assert!(msg.contains("closed")),
1497 Err(e) => panic!("expected Pool(closed) error, got: {e:?}"),
1498 Ok(_) => panic!("expected error, got Ok"),
1499 }
1500 }
1501
1502 #[test]
1504 fn pool_status_initial() {
1505 let pool = PoolBuilder::new()
1506 .url("postgres://user:pass@localhost/db")
1507 .max_size(10)
1508 .build()
1509 .unwrap();
1510
1511 let status = pool.status();
1512 assert_eq!(status.idle, 0);
1513 assert_eq!(status.active, 0);
1514 assert_eq!(status.open, 0);
1515 assert_eq!(status.max_size, 10);
1516 }
1517
1518 #[test]
1520 fn pool_builder_defaults() {
1521 let pool = PoolBuilder::new()
1522 .url("postgres://user:pass@localhost/db")
1523 .build()
1524 .unwrap();
1525
1526 assert_eq!(pool.max_size(), 10);
1527 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(30 * 60)));
1528 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
1529 assert_eq!(pool.inner.min_idle, 0);
1530 }
1531
1532 #[test]
1534 fn pool_open_count_initial() {
1535 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1536 assert_eq!(pool.open_count(), 0);
1537 }
1538
1539 #[test]
1542 fn pool_builder_max_stmt_cache_size_default() {
1543 let pool = PoolBuilder::new()
1544 .url("postgres://user:pass@localhost/db")
1545 .build()
1546 .unwrap();
1547 assert_eq!(pool.inner.max_stmt_cache_size, 256);
1548 }
1549
1550 #[test]
1551 fn pool_builder_max_stmt_cache_size_custom() {
1552 let pool = PoolBuilder::new()
1553 .url("postgres://user:pass@localhost/db")
1554 .max_stmt_cache_size(512)
1555 .build()
1556 .unwrap();
1557 assert_eq!(pool.inner.max_stmt_cache_size, 512);
1558 }
1559
1560 #[test]
1563 fn pool_is_uds_false_for_tcp() {
1564 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1565 assert!(!pool.is_uds());
1566 }
1567
1568 #[cfg(unix)]
1569 #[test]
1570 fn pool_is_uds_true_for_unix_socket() {
1571 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
1572 assert!(pool.is_uds());
1573 }
1574
1575 #[cfg(unix)]
1576 #[test]
1577 fn pool_is_uds_true_for_var_run_socket() {
1578 let pool = Pool::connect("postgres://user@localhost/db?host=/var/run/postgresql").unwrap();
1579 assert!(pool.is_uds());
1580 }
1581
1582 #[test]
1583 fn pool_is_uds_false_for_ip_address() {
1584 let pool = Pool::connect("postgres://user:pass@127.0.0.1/db").unwrap();
1585 assert!(!pool.is_uds());
1586 }
1587
1588 #[cfg(unix)]
1589 #[test]
1590 fn pool_slot_sync_created_for_uds_config() {
1591 let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
1592 assert!(config.host_is_uds());
1593 }
1594
1595 #[test]
1596 fn pool_slot_tcp_config() {
1597 let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
1598 assert!(!config.host_is_uds());
1599 }
1600
1601 #[test]
1606 fn pool_is_uds_false_for_hostname() {
1607 let pool = Pool::connect("postgres://user:pass@db.example.com/db").unwrap();
1608 assert!(!pool.is_uds());
1609 }
1610
1611 #[cfg(unix)]
1612 #[test]
1613 fn pool_is_uds_true_for_tmp() {
1614 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
1615 assert!(pool.is_uds());
1616 }
1617
1618 #[test]
1623 fn pool_close_then_acquire_fails() {
1624 let pool = PoolBuilder::new()
1625 .url("postgres://user:pass@localhost/db")
1626 .max_size(5)
1627 .build()
1628 .unwrap();
1629 pool.close();
1630 let result = pool.acquire();
1631 assert!(result.is_err());
1632 match result {
1633 Err(DriverError::Pool(msg)) => {
1634 assert!(msg.contains("closed"), "should say closed: {msg}")
1635 }
1636 Err(e) => panic!("expected Pool error, got: {e:?}"),
1637 Ok(_) => panic!("expected error"),
1638 }
1639 }
1640
1641 #[test]
1642 fn pool_is_closed_before_and_after() {
1643 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1644 assert!(!pool.is_closed());
1645 pool.close();
1646 assert!(pool.is_closed());
1647 }
1648
1649 #[test]
1654 fn pool_exhausted_no_timeout() {
1655 let pool = PoolBuilder::new()
1656 .url("postgres://user:pass@localhost/db")
1657 .max_size(0)
1658 .acquire_timeout(None) .build()
1660 .unwrap();
1661 let result = pool.acquire();
1662 assert!(result.is_err());
1663 match result {
1664 Err(DriverError::Pool(msg)) => {
1665 assert!(msg.contains("exhausted"), "should say exhausted: {msg}")
1666 }
1667 Err(e) => panic!("expected Pool error, got: {e:?}"),
1668 Ok(_) => panic!("expected error"),
1669 }
1670 }
1671
1672 #[test]
1677 fn pool_builder_no_url_error() {
1678 let result = PoolBuilder::new().max_size(5).build();
1679 assert!(result.is_err());
1680 match result {
1681 Err(DriverError::Pool(msg)) => {
1682 assert!(msg.contains("URL"), "should mention URL: {msg}")
1683 }
1684 Err(e) => panic!("expected Pool error, got: {e:?}"),
1685 Ok(_) => panic!("expected error"),
1686 }
1687 }
1688
1689 #[test]
1690 fn pool_builder_invalid_url_error() {
1691 let result = PoolBuilder::new().url("ftp://something").build();
1692 assert!(result.is_err());
1693 }
1694
1695 #[test]
1696 fn pool_builder_stmt_cache_size_zero() {
1697 let pool = PoolBuilder::new()
1698 .url("postgres://user:pass@localhost/db")
1699 .max_stmt_cache_size(0)
1700 .build()
1701 .unwrap();
1702 assert_eq!(pool.inner.max_stmt_cache_size, 0);
1703 }
1704
1705 #[test]
1708 fn pool_builder_stale_timeout_default() {
1709 let pool = PoolBuilder::new()
1710 .url("postgres://user:pass@localhost/db")
1711 .build()
1712 .unwrap();
1713 assert_eq!(pool.inner.stale_timeout, Duration::from_secs(30));
1714 }
1715
1716 #[test]
1717 fn pool_builder_stale_timeout_custom() {
1718 let pool = PoolBuilder::new()
1719 .url("postgres://user:pass@localhost/db")
1720 .stale_timeout(Duration::from_secs(60))
1721 .build()
1722 .unwrap();
1723 assert_eq!(pool.inner.stale_timeout, Duration::from_secs(60));
1724 }
1725
1726 #[test]
1727 fn pool_builder_stale_timeout_zero() {
1728 let pool = PoolBuilder::new()
1729 .url("postgres://user:pass@localhost/db")
1730 .stale_timeout(Duration::from_secs(0))
1731 .build()
1732 .unwrap();
1733 assert_eq!(pool.inner.stale_timeout, Duration::from_secs(0));
1734 }
1735
1736 #[test]
1741 fn pool_status_reflects_max_size() {
1742 let pool = PoolBuilder::new()
1743 .url("postgres://user:pass@localhost/db")
1744 .max_size(20)
1745 .build()
1746 .unwrap();
1747 let status = pool.status();
1748 assert_eq!(status.max_size, 20);
1749 assert_eq!(status.idle, 0);
1750 assert_eq!(status.active, 0);
1751 assert_eq!(status.open, 0);
1752 }
1753
1754 #[test]
1759 fn pool_clone_shares_config() {
1760 let pool = PoolBuilder::new()
1761 .url("postgres://user:pass@localhost/db")
1762 .max_size(7)
1763 .build()
1764 .unwrap();
1765 let p2 = pool.clone();
1766 assert_eq!(pool.max_size(), 7);
1767 assert_eq!(p2.max_size(), 7);
1768 assert_eq!(pool.open_count(), p2.open_count());
1769 }
1770
1771 #[test]
1776 fn pool_set_warmup_sqls_empty() {
1777 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1778 pool.set_warmup_sqls(&[]);
1779 let sqls = pool
1780 .inner
1781 .warmup_sqls
1782 .lock()
1783 .unwrap_or_else(|e| e.into_inner())
1784 .clone();
1785 assert!(sqls.is_empty());
1786 }
1787
1788 #[test]
1789 fn pool_set_warmup_sqls_multiple() {
1790 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1791 pool.set_warmup_sqls(&["SELECT 1", "SELECT 2", "SELECT 3"]);
1792 let sqls = pool
1793 .inner
1794 .warmup_sqls
1795 .lock()
1796 .unwrap_or_else(|e| e.into_inner())
1797 .clone();
1798 assert_eq!(sqls.len(), 3);
1799 assert_eq!(&*sqls[0], "SELECT 1");
1800 assert_eq!(&*sqls[1], "SELECT 2");
1801 assert_eq!(&*sqls[2], "SELECT 3");
1802 }
1803
1804 #[test]
1805 fn pool_set_warmup_sqls_overwrite() {
1806 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1807 pool.set_warmup_sqls(&["SELECT 1"]);
1808 pool.set_warmup_sqls(&["SELECT 99"]);
1809 let sqls = pool
1810 .inner
1811 .warmup_sqls
1812 .lock()
1813 .unwrap_or_else(|e| e.into_inner())
1814 .clone();
1815 assert_eq!(sqls.len(), 1);
1816 assert_eq!(&*sqls[0], "SELECT 99");
1817 }
1818
1819 #[test]
1824 fn pool_status_debug() {
1825 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1826 let status = pool.status();
1827 let dbg = format!("{status:?}");
1828 assert!(dbg.contains("PoolStatus"));
1829 assert!(dbg.contains("idle"));
1830 assert!(dbg.contains("active"));
1831 assert!(dbg.contains("open"));
1832 assert!(dbg.contains("max_size"));
1833 }
1834
1835 #[test]
1840 fn config_host_is_uds_returns_true_for_slash() {
1841 let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
1842 assert!(config.host_is_uds());
1843 }
1844
1845 #[test]
1846 fn config_host_is_uds_returns_false_for_tcp() {
1847 let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
1848 assert!(!config.host_is_uds());
1849 }
1850
1851 #[test]
1852 fn config_host_is_uds_returns_false_for_ip() {
1853 let config = Config::from_url("postgres://user:pass@192.168.1.1/db").unwrap();
1854 assert!(!config.host_is_uds());
1855 }
1856
1857 #[test]
1862 fn pool_builder_full_chain() {
1863 let pool = PoolBuilder::new()
1864 .url("postgres://user:pass@localhost/db")
1865 .max_size(3)
1866 .max_lifetime(Some(Duration::from_secs(600)))
1867 .acquire_timeout(Some(Duration::from_secs(5)))
1868 .min_idle(1)
1869 .max_stmt_cache_size(128)
1870 .build()
1871 .unwrap();
1872 assert_eq!(pool.max_size(), 3);
1873 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(600)));
1874 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
1875 assert_eq!(pool.inner.min_idle, 1);
1876 assert_eq!(pool.inner.max_stmt_cache_size, 128);
1877 }
1878
1879 #[test]
1882 fn pool_max_size_zero_rejects_all_acquires() {
1883 let pool = PoolBuilder::new()
1884 .url("postgres://user:pass@localhost/db")
1885 .max_size(0)
1886 .build()
1887 .unwrap();
1888 let result = pool.acquire();
1889 assert!(result.is_err());
1890 match &result {
1891 Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
1892 _ => panic!("expected pool exhausted error"),
1893 }
1894 }
1895
1896 #[test]
1899 fn url_parse_unknown_sslmode_returns_error() {
1900 let result = Config::from_url("postgres://u:p@h/d?sslmode=bogus");
1901 assert!(result.is_err());
1902 let msg = format!("{}", result.unwrap_err());
1903 assert!(msg.contains("unknown sslmode"));
1904 }
1905
1906 #[test]
1907 fn url_parse_invalid_port_returns_error() {
1908 let result = Config::from_url("postgres://u:p@h:abc/d");
1909 assert!(result.is_err());
1910 let msg = format!("{}", result.unwrap_err());
1911 assert!(msg.contains("invalid port"));
1912 }
1913
1914 #[test]
1915 fn url_parse_missing_at_sign_returns_error() {
1916 let result = Config::from_url("postgres://u:plocalhost/d");
1917 assert!(result.is_err());
1918 let msg = format!("{}", result.unwrap_err());
1919 assert!(msg.contains("missing @"));
1920 }
1921
1922 #[test]
1923 fn url_parse_empty_host_returns_error() {
1924 let result = Config::from_url("postgres://u:p@/d");
1925 assert!(result.is_err());
1926 }
1927
1928 #[test]
1929 fn url_parse_empty_user_returns_error() {
1930 let result = Config::from_url("postgres://:p@h/d");
1931 assert!(result.is_err());
1932 }
1933
1934 #[test]
1935 fn url_parse_statement_timeout_invalid_uses_default() {
1936 let config = Config::from_url("postgres://u:p@h/d?statement_timeout=notnum").unwrap();
1937 assert_eq!(config.statement_timeout_secs, 30);
1938 }
1939
1940 #[test]
1941 fn url_parse_malformed_percent_encoding() {
1942 let result = Config::from_url("postgres://u%:p@h/d");
1943 assert!(result.is_err());
1944 }
1945
1946 #[test]
1947 fn url_parse_invalid_hex_in_percent_encoding() {
1948 let result = Config::from_url("postgres://u%ZZ:p@h/d");
1949 assert!(result.is_err());
1950 }
1951}
1952
1953#[cfg(all(test, feature = "detect-n-plus-one"))]
1956mod n_plus_one_tests {
1957 use super::NPlusOneDetector;
1958
1959 #[test]
1960 fn below_threshold_no_warning() {
1961 let mut d = NPlusOneDetector::new(10);
1962 for _ in 0..10 {
1963 d.track(42);
1964 }
1965 assert!(d.check_final().is_none());
1966 }
1967
1968 #[test]
1969 fn above_threshold_warns() {
1970 let mut d = NPlusOneDetector::new(10);
1971 for _ in 0..11 {
1972 d.track(42);
1973 }
1974 let w = d.check_final().unwrap();
1975 assert_eq!(w, (42, 11));
1976 }
1977
1978 #[test]
1979 fn exact_threshold_no_warning() {
1980 let mut d = NPlusOneDetector::new(5);
1981 for _ in 0..5 {
1982 d.track(99);
1983 }
1984 assert!(d.check_final().is_none(), "> not >=");
1985 }
1986
1987 #[test]
1988 fn threshold_plus_one_warns() {
1989 let mut d = NPlusOneDetector::new(5);
1990 for _ in 0..6 {
1991 d.track(99);
1992 }
1993 assert_eq!(d.check_final(), Some((99, 6)));
1994 }
1995
1996 #[test]
1997 fn alternating_hashes_no_warning() {
1998 let mut d = NPlusOneDetector::new(2);
1999 for i in 0..100 {
2000 d.track(if i % 2 == 0 { 1 } else { 2 });
2001 }
2002 assert!(d.check_final().is_none());
2003 }
2004
2005 #[test]
2006 fn single_query_no_warning() {
2007 let mut d = NPlusOneDetector::new(10);
2008 d.track(42);
2009 assert!(d.check_final().is_none());
2010 }
2011
2012 #[test]
2013 fn no_queries_no_warning() {
2014 let d = NPlusOneDetector::new(10);
2015 assert!(d.check_final().is_none());
2016 }
2017
2018 #[test]
2019 fn threshold_zero_warns_on_second() {
2020 let mut d = NPlusOneDetector::new(0);
2021 d.track(42);
2022 assert_eq!(d.check_final(), Some((42, 1)));
2024 }
2025
2026 #[test]
2027 fn threshold_max_never_warns() {
2028 let mut d = NPlusOneDetector::new(u16::MAX);
2029 for _ in 0..1000 {
2030 d.track(42);
2031 }
2032 assert!(d.check_final().is_none());
2033 }
2034
2035 #[test]
2036 fn saturating_add_no_overflow() {
2037 let mut d = NPlusOneDetector::new(10);
2038 d.last_query_hash = 42;
2039 d.repeat_count = u16::MAX - 1;
2040 d.track(42); d.track(42); assert_eq!(d.repeat_count, u16::MAX);
2043 }
2044
2045 #[test]
2046 fn different_hash_resets() {
2047 let mut d = NPlusOneDetector::new(100);
2048 for _ in 0..50 {
2049 d.track(1);
2050 }
2051 d.track(2); assert_eq!(d.repeat_count, 1);
2053 assert_eq!(d.last_query_hash, 2);
2054 }
2055
2056 #[test]
2057 fn multiple_n_plus_one_sequences() {
2058 let mut d = NPlusOneDetector::new(3);
2059 for _ in 0..5 {
2061 d.track(1);
2062 }
2063 for _ in 0..4 {
2066 d.track(2);
2067 }
2068 assert_eq!(d.check_final(), Some((2, 4)));
2070 }
2071
2072 #[test]
2073 fn warning_emitted_on_hash_switch() {
2074 let mut d = NPlusOneDetector::new(2);
2075 d.track(10);
2076 d.track(10);
2077 d.track(10); d.track(20);
2080 assert_eq!(d.last_query_hash, 20);
2082 assert_eq!(d.repeat_count, 1);
2083 }
2084
2085 #[test]
2086 fn hash_zero_treated_normally() {
2087 let mut d = NPlusOneDetector::new(2);
2088 d.track(0);
2089 d.track(0);
2090 d.track(0);
2091 assert!(d.check_final().is_none());
2093 }
2094
2095 #[test]
2096 fn long_sequence_correct_count() {
2097 let mut d = NPlusOneDetector::new(10);
2098 for _ in 0..500 {
2099 d.track(42);
2100 }
2101 assert_eq!(d.check_final(), Some((42, 500)));
2102 }
2103
2104 #[test]
2105 fn two_queries_below_threshold() {
2106 let mut d = NPlusOneDetector::new(10);
2107 d.track(1);
2108 d.track(1);
2109 assert!(d.check_final().is_none());
2110 }
2111
2112 #[test]
2113 fn interleaved_then_burst() {
2114 let mut d = NPlusOneDetector::new(3);
2115 d.track(1);
2117 d.track(2);
2118 d.track(1);
2119 d.track(2);
2120 for _ in 0..5 {
2122 d.track(5);
2123 }
2124 assert_eq!(d.check_final(), Some((5, 5)));
2125 }
2126
2127 #[test]
2130 fn pool_builder_n_plus_one_threshold_default() {
2131 let pool = super::PoolBuilder::new()
2132 .url("postgres://user:pass@localhost/db")
2133 .build()
2134 .unwrap();
2135 assert_eq!(pool.inner.n_plus_one_threshold, 10);
2136 }
2137
2138 #[test]
2139 fn pool_builder_n_plus_one_threshold_custom() {
2140 let pool = super::PoolBuilder::new()
2141 .url("postgres://user:pass@localhost/db")
2142 .n_plus_one_threshold(5)
2143 .build()
2144 .unwrap();
2145 assert_eq!(pool.inner.n_plus_one_threshold, 5);
2146 }
2147
2148 #[test]
2149 fn pool_builder_n_plus_one_threshold_zero() {
2150 let pool = super::PoolBuilder::new()
2151 .url("postgres://user:pass@localhost/db")
2152 .n_plus_one_threshold(0)
2153 .build()
2154 .unwrap();
2155 assert_eq!(pool.inner.n_plus_one_threshold, 0);
2156 }
2157
2158 #[test]
2159 fn pool_builder_n_plus_one_threshold_max() {
2160 let pool = super::PoolBuilder::new()
2161 .url("postgres://user:pass@localhost/db")
2162 .n_plus_one_threshold(u16::MAX)
2163 .build()
2164 .unwrap();
2165 assert_eq!(pool.inner.n_plus_one_threshold, u16::MAX);
2166 }
2167
2168 #[test]
2169 fn one_then_different_no_warning() {
2170 let mut d = NPlusOneDetector::new(10);
2171 d.track(1);
2172 d.track(2);
2173 assert!(d.check_final().is_none());
2175 }
2176
2177 #[test]
2178 fn nonzero_hash_after_zero_init() {
2179 let mut d = NPlusOneDetector::new(0);
2183 d.track(42);
2184 let w = d.check_final().unwrap();
2185 assert_eq!(w, (42, 1));
2186 }
2187
2188 #[test]
2189 fn independent_detectors_dont_interfere() {
2190 let mut d1 = NPlusOneDetector::new(5);
2192 let mut d2 = NPlusOneDetector::new(5);
2193
2194 for _ in 0..10 {
2196 d1.track(42);
2197 }
2198 d2.track(1);
2200 d2.track(2);
2201 d2.track(3);
2202
2203 assert!(d1.check_final().is_some());
2205 assert!(d2.check_final().is_none());
2206 }
2207
2208 #[test]
2209 fn rapid_hash_changes_dont_false_positive() {
2210 let mut d = NPlusOneDetector::new(2);
2212 for i in 0u64..1000 {
2213 d.track(i);
2214 }
2215 assert!(d.check_final().is_none());
2217 }
2218
2219 #[test]
2220 fn detector_reset_state_after_warning() {
2221 let mut d = NPlusOneDetector::new(2);
2223 d.track(1);
2224 d.track(1);
2225 d.track(1); d.track(2); d.track(2); assert!(d.check_final().is_none()); }
2230
2231 #[test]
2232 fn detector_with_realistic_orm_pattern() {
2233 let mut d = NPlusOneDetector::new(5);
2235 d.track(100); for _ in 0..20 {
2238 d.track(200); }
2240 assert_eq!(d.check_final(), Some((200, 20)));
2242 }
2243
2244 #[test]
2245 fn detector_with_legitimate_batch_pattern() {
2246 let mut d = NPlusOneDetector::new(10);
2249 for _ in 0..15 {
2250 d.track(300); }
2252 assert!(d.check_final().is_some());
2253 }
2254
2255 #[test]
2256 fn detector_exactly_at_boundaries() {
2257 for threshold in [0u16, 1, 2, 5, 10, 100] {
2258 let mut d = NPlusOneDetector::new(threshold);
2259 for _ in 0..=threshold {
2260 d.track(42);
2261 }
2262 assert!(
2264 d.check_final().is_some(),
2265 "threshold={threshold} should warn at count={}",
2266 threshold + 1
2267 );
2268 }
2269 }
2270
2271 #[test]
2272 fn detector_with_deterministic_random_sequences() {
2273 let mut d = NPlusOneDetector::new(5);
2275 let hashes: Vec<u64> = (0..100).map(|i| ((i * 7 + 3) % 4) as u64).collect();
2276 for &h in &hashes {
2277 d.track(h);
2278 }
2279 let _ = d.check_final();
2281 }
2282
2283 mod proptest_fuzz {
2284 use super::*;
2285 use proptest::prelude::*;
2286
2287 proptest! {
2288 #[test]
2289 fn detector_never_panics(
2290 hashes in proptest::collection::vec(0u64..100, 0..500),
2291 threshold in 0u16..100,
2292 ) {
2293 let mut d = NPlusOneDetector::new(threshold);
2294 for h in &hashes {
2295 d.track(*h);
2296 }
2297 let _ = d.check_final();
2298 }
2299
2300 #[test]
2301 fn sequential_repeats_always_detected(
2302 hash in 1u64..u64::MAX,
2303 count in 2u16..1000,
2304 threshold in 0u16..100,
2305 ) {
2306 let mut d = NPlusOneDetector::new(threshold);
2307 for _ in 0..count {
2308 d.track(hash);
2309 }
2310 if count > threshold {
2311 assert!(d.check_final().is_some(),
2312 "count={count} > threshold={threshold} should trigger");
2313 }
2314 }
2315 }
2316 }
2317}