1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use crate::arena::Arena;
14use crate::codec::Encode;
15use crate::conn::Connection;
16use crate::types::{Config, PgDataRow, QueryResult, SimpleRow};
17use crate::DriverError;
18
19#[cfg(feature = "async")]
20use crate::async_conn::AsyncConnection;
21
22pub(crate) enum PoolSlot {
30 Sync(Connection),
32 #[cfg(feature = "async")]
34 Async(AsyncConnection),
35}
36
37#[cfg(feature = "detect-n-plus-one")]
43pub(crate) struct NPlusOneDetector {
44 last_query_hash: u64,
45 repeat_count: u16,
46 threshold: u16,
47}
48
49#[cfg(feature = "detect-n-plus-one")]
50impl NPlusOneDetector {
51 pub(crate) fn new(threshold: u16) -> Self {
53 Self {
54 last_query_hash: 0,
55 repeat_count: 0,
56 threshold,
57 }
58 }
59
60 #[inline]
62 pub(crate) fn track(&mut self, sql_hash: u64) {
63 if sql_hash == self.last_query_hash {
64 self.repeat_count = self.repeat_count.saturating_add(1);
65 } else {
66 self.emit_warning();
68 self.last_query_hash = sql_hash;
69 self.repeat_count = 1;
70 }
71 }
72
73 pub(crate) fn check_final(&self) -> Option<(u64, u16)> {
76 if self.repeat_count > self.threshold && self.last_query_hash != 0 {
77 Some((self.last_query_hash, self.repeat_count))
78 } else {
79 None
80 }
81 }
82
83 #[cold]
85 #[inline(never)]
86 fn emit_warning(&self) {
87 if let Some((hash, count)) = self.check_final() {
88 log::warn!(
89 "[bsql] potential N+1 detected: sql_hash={:#018x} repeated {} times (threshold: {})",
90 hash,
91 count,
92 self.threshold,
93 );
94 }
95 }
96
97 #[cold]
99 #[inline(never)]
100 pub(crate) fn emit_final_warning(&self) {
101 self.emit_warning();
102 }
103}
104
105pub struct Pool {
121 inner: Arc<PoolInner>,
122}
123
124struct PoolInner {
125 stack: std::sync::Mutex<Vec<PoolSlot>>,
129 max_size: usize,
130 open_count: AtomicUsize,
131 config: Arc<Config>,
132 closed: AtomicBool,
134 release_pair: (std::sync::Mutex<()>, std::sync::Condvar),
137 max_lifetime: Option<Duration>,
140 acquire_timeout: Option<Duration>,
142 min_idle: usize,
144 warmup_sqls: std::sync::Mutex<Arc<Vec<Box<str>>>>,
146 max_stmt_cache_size: usize,
148 stale_timeout: Duration,
151 #[cfg(feature = "detect-n-plus-one")]
154 n_plus_one_threshold: u16,
155}
156
157impl Pool {
158 pub fn connect(url: &str) -> Result<Self, DriverError> {
162 PoolBuilder::new().url(url).build()
163 }
164
165 pub fn builder() -> PoolBuilder {
167 PoolBuilder::new()
168 }
169
170 #[inline]
178 pub fn acquire(&self) -> Result<PoolGuard, DriverError> {
179 if self.inner.closed.load(Ordering::Acquire) {
180 return Err(DriverError::Pool("pool is closed".into()));
181 }
182
183 if let Some(guard) = self.try_pop_idle()? {
185 return Ok(guard);
186 }
187
188 loop {
190 let current = self.inner.open_count.load(Ordering::Acquire);
191 if current >= self.inner.max_size {
192 if let Some(timeout) = self.inner.acquire_timeout {
193 let (lock, cvar) = &self.inner.release_pair;
194 let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
195 let (_guard, result) = cvar
196 .wait_timeout(guard, timeout)
197 .unwrap_or_else(|e| e.into_inner());
198 if result.timed_out() {
199 return Err(DriverError::Pool(
200 "pool exhausted: acquire timeout expired".into(),
201 ));
202 }
203 if let Some(guard) = self.try_pop_idle()? {
205 return Ok(guard);
206 }
207 continue;
209 }
210 return Err(DriverError::Pool(
211 "pool exhausted: all connections in use".into(),
212 ));
213 }
214 if self
215 .inner
216 .open_count
217 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
218 .is_ok()
219 {
220 break;
221 }
222 }
224
225 let conn_result = Connection::connect_arc(self.inner.config.clone());
227 match conn_result {
228 Ok(mut conn) => {
229 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
231 self.warmup_conn(&mut conn);
233
234 Ok(PoolGuard {
235 conn: Some(PoolSlot::Sync(conn)),
236 pool: self.inner.clone(),
237 discard: false,
238 #[cfg(feature = "detect-n-plus-one")]
239 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
240 })
241 }
242 Err(e) => {
243 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
245 Err(e)
246 }
247 }
248 }
249
250 #[inline]
256 fn try_pop_idle(&self) -> Result<Option<PoolGuard>, DriverError> {
257 loop {
261 let (mut slot, needs_health_check) = {
262 let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
263 loop {
264 let Some(slot) = stack.pop() else {
265 return Ok(None);
266 };
267 let (created_at, idle_dur) = match &slot {
268 PoolSlot::Sync(conn) => (conn.created_at(), conn.idle_duration()),
269 #[cfg(feature = "async")]
270 PoolSlot::Async(conn) => (conn.created_at(), conn.idle_duration()),
271 };
272 if let Some(max_lifetime) = self.inner.max_lifetime {
273 if created_at.elapsed() >= max_lifetime {
274 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
275 continue;
276 }
277 }
278 if idle_dur >= self.inner.stale_timeout {
279 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
281 continue;
282 }
283 break (slot, idle_dur > Duration::from_secs(5));
284 }
285 };
286 if needs_health_check {
290 let alive = match &mut slot {
291 PoolSlot::Sync(conn) => conn.simple_query("").is_ok(),
292 #[cfg(feature = "async")]
293 PoolSlot::Async(_) => true, };
295 if !alive {
296 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
297 continue; }
299 }
300 return Ok(Some(PoolGuard {
301 conn: Some(slot),
302 pool: self.inner.clone(),
303 discard: false,
304 #[cfg(feature = "detect-n-plus-one")]
305 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
306 }));
307 }
308 }
309
310 pub fn is_uds(&self) -> bool {
315 #[cfg(unix)]
316 {
317 self.inner.config.host_is_uds()
318 }
319 #[cfg(not(unix))]
320 {
321 false
322 }
323 }
324
325 pub fn begin(&self) -> Result<Transaction, DriverError> {
327 let mut guard = self.acquire()?;
328 guard.simple_query("BEGIN")?;
329 Ok(Transaction {
330 guard,
331 committed: false,
332 deferred_buf: Vec::new(),
333 deferred_count: 0,
334 })
335 }
336
337 pub fn open_count(&self) -> usize {
339 self.inner.open_count.load(Ordering::Relaxed)
340 }
341
342 pub fn max_size(&self) -> usize {
344 self.inner.max_size
345 }
346
347 pub fn status(&self) -> PoolStatus {
349 let idle = self
350 .inner
351 .stack
352 .lock()
353 .unwrap_or_else(|e| e.into_inner())
354 .len();
355 let open = self.inner.open_count.load(Ordering::Relaxed);
356 let active = open.saturating_sub(idle);
357 PoolStatus {
358 idle,
359 active,
360 open,
361 max_size: self.inner.max_size,
362 }
363 }
364
365 fn warmup_conn(&self, conn: &mut Connection) {
373 let sqls = self
374 .inner
375 .warmup_sqls
376 .lock()
377 .unwrap_or_else(|e| e.into_inner())
378 .clone();
379
380 if sqls.is_empty() {
381 return;
382 }
383
384 let batch: Vec<(&str, u64)> = sqls
385 .iter()
386 .map(|sql| (sql.as_ref(), crate::types::hash_sql(sql)))
387 .collect();
388
389 let _ = conn.prepare_batch(&batch);
390 }
391
392 pub fn set_warmup_sqls<S: Into<Box<str>>>(&self, sqls: impl IntoIterator<Item = S>) {
419 let boxed: Arc<Vec<Box<str>>> = Arc::new(sqls.into_iter().map(Into::into).collect());
420 *self
421 .inner
422 .warmup_sqls
423 .lock()
424 .unwrap_or_else(|e| e.into_inner()) = boxed;
425 }
426
427 pub fn close(&self) {
430 self.inner.closed.store(true, Ordering::Release);
431 let slots: Vec<PoolSlot> = {
433 let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
434 std::mem::take(&mut *stack)
435 };
436 for slot in slots {
437 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
438 match slot {
439 PoolSlot::Sync(conn) => {
440 let _ = conn.close();
441 }
442 #[cfg(feature = "async")]
443 PoolSlot::Async(_conn) => {
444 }
447 }
448 }
449 let (_, cvar) = &self.inner.release_pair;
451 cvar.notify_all();
452 }
453
454 pub fn is_closed(&self) -> bool {
456 self.inner.closed.load(Ordering::Acquire)
457 }
458
459 #[cfg(feature = "async")]
469 pub async fn acquire_async(&self) -> Result<PoolGuard, DriverError> {
470 if self.inner.closed.load(Ordering::Acquire) {
471 return Err(DriverError::Pool("pool is closed".into()));
472 }
473
474 if let Some(guard) = self.try_pop_idle()? {
476 return Ok(guard);
477 }
478
479 loop {
481 let current = self.inner.open_count.load(Ordering::Acquire);
482 if current >= self.inner.max_size {
483 if let Some(timeout) = self.inner.acquire_timeout {
484 let (lock, cvar) = &self.inner.release_pair;
485 let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
486 let (_guard, result) = cvar
487 .wait_timeout(guard, timeout)
488 .unwrap_or_else(|e| e.into_inner());
489 if result.timed_out() {
490 return Err(DriverError::Pool(
491 "pool exhausted: acquire timeout expired".into(),
492 ));
493 }
494 if let Some(guard) = self.try_pop_idle()? {
495 return Ok(guard);
496 }
497 continue;
498 }
499 return Err(DriverError::Pool(
500 "pool exhausted: all connections in use".into(),
501 ));
502 }
503 if self
504 .inner
505 .open_count
506 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
507 .is_ok()
508 {
509 break;
510 }
511 }
512
513 if self.inner.config.host_is_uds() {
515 let conn_result = Connection::connect_arc(self.inner.config.clone());
517 match conn_result {
518 Ok(mut conn) => {
519 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
520 self.warmup_conn(&mut conn);
521 Ok(PoolGuard {
522 conn: Some(PoolSlot::Sync(conn)),
523 pool: self.inner.clone(),
524 discard: false,
525 #[cfg(feature = "detect-n-plus-one")]
526 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
527 })
528 }
529 Err(e) => {
530 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
531 Err(e)
532 }
533 }
534 } else {
535 let conn_result = AsyncConnection::connect_arc(self.inner.config.clone()).await;
537 match conn_result {
538 Ok(mut conn) => {
539 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
540 Ok(PoolGuard {
541 conn: Some(PoolSlot::Async(conn)),
542 pool: self.inner.clone(),
543 discard: false,
544 #[cfg(feature = "detect-n-plus-one")]
545 detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
546 })
547 }
548 Err(e) => {
549 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
550 Err(e)
551 }
552 }
553 }
554 }
555}
556
557impl Clone for Pool {
558 fn clone(&self) -> Self {
559 Pool {
560 inner: self.inner.clone(),
561 }
562 }
563}
564
565#[derive(Debug, Clone, Copy)]
569pub struct PoolStatus {
570 pub idle: usize,
572 pub active: usize,
574 pub open: usize,
576 pub max_size: usize,
578}
579
580pub struct PoolBuilder {
584 url: Option<String>,
585 max_size: usize,
586 max_lifetime: Option<Duration>,
588 acquire_timeout: Option<Duration>,
590 min_idle: usize,
592 max_stmt_cache_size: usize,
594 stale_timeout: Duration,
596 #[cfg(feature = "detect-n-plus-one")]
598 n_plus_one_threshold: Option<u16>,
599}
600
601impl PoolBuilder {
602 fn new() -> Self {
603 Self {
604 url: None,
605 max_size: 10,
606 max_lifetime: Some(Duration::from_secs(30 * 60)), acquire_timeout: Some(Duration::from_secs(5)), min_idle: 0, max_stmt_cache_size: 256, stale_timeout: Duration::from_secs(30), #[cfg(feature = "detect-n-plus-one")]
612 n_plus_one_threshold: None,
613 }
614 }
615
616 pub fn url(mut self, url: &str) -> Self {
618 self.url = Some(url.to_owned());
619 self
620 }
621
622 pub fn max_size(mut self, size: usize) -> Self {
626 self.max_size = size;
627 self
628 }
629
630 pub fn max_lifetime(mut self, lifetime: Option<Duration>) -> Self {
633 self.max_lifetime = lifetime;
634 self
635 }
636
637 pub fn acquire_timeout(mut self, timeout: Option<Duration>) -> Self {
640 self.acquire_timeout = timeout;
641 self
642 }
643
644 pub fn min_idle(mut self, count: usize) -> Self {
647 self.min_idle = count;
648 self
649 }
650
651 pub fn max_stmt_cache_size(mut self, size: usize) -> Self {
655 self.max_stmt_cache_size = size;
656 self
657 }
658
659 pub fn stale_timeout(mut self, timeout: Duration) -> Self {
663 self.stale_timeout = timeout;
664 self
665 }
666
667 #[cfg(feature = "detect-n-plus-one")]
672 pub fn n_plus_one_threshold(mut self, n: u16) -> Self {
673 self.n_plus_one_threshold = Some(n);
674 self
675 }
676
677 pub fn build(self) -> Result<Pool, DriverError> {
679 let url = self
680 .url
681 .ok_or_else(|| DriverError::Pool("pool builder requires a URL".into()))?;
682
683 let config = Arc::new(Config::from_url(&url)?);
684
685 let pool = Pool {
686 inner: Arc::new(PoolInner {
687 stack: std::sync::Mutex::new(Vec::with_capacity(self.max_size)),
688 max_size: self.max_size,
689 open_count: AtomicUsize::new(0),
690 config,
691 closed: AtomicBool::new(false),
692 release_pair: (std::sync::Mutex::new(()), std::sync::Condvar::new()),
693 max_lifetime: self.max_lifetime,
694 acquire_timeout: self.acquire_timeout,
695 min_idle: self.min_idle,
696 warmup_sqls: std::sync::Mutex::new(Arc::new(Vec::new())),
697 max_stmt_cache_size: self.max_stmt_cache_size,
698 stale_timeout: self.stale_timeout,
699 #[cfg(feature = "detect-n-plus-one")]
700 n_plus_one_threshold: self.n_plus_one_threshold.unwrap_or(10),
701 }),
702 };
703
704 if self.min_idle > 0 {
705 let inner = pool.inner.clone();
706 std::thread::spawn(move || {
707 maintain_min_idle(inner);
708 });
709 }
710
711 Ok(pool)
712 }
713}
714
715fn maintain_min_idle(inner: Arc<PoolInner>) {
717 loop {
718 if inner.closed.load(Ordering::Acquire) {
719 return;
720 }
721
722 let idle_count = inner.stack.lock().unwrap_or_else(|e| e.into_inner()).len();
723 let needed = inner.min_idle.saturating_sub(idle_count);
724
725 for _ in 0..needed {
726 if inner.closed.load(Ordering::Acquire) {
727 return;
728 }
729 let current = inner.open_count.load(Ordering::Acquire);
730 if current >= inner.max_size {
731 break;
732 }
733 if inner
734 .open_count
735 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
736 .is_err()
737 {
738 continue;
739 }
740
741 match Connection::connect_arc(inner.config.clone()) {
742 Ok(conn) => {
743 let mut stack = inner.stack.lock().unwrap_or_else(|e| e.into_inner());
744 stack.push(PoolSlot::Sync(conn));
745 let (_, cvar) = &inner.release_pair;
746 cvar.notify_one();
747 }
748 Err(_) => {
749 inner.open_count.fetch_sub(1, Ordering::AcqRel);
750 }
751 }
752 }
753
754 std::thread::sleep(Duration::from_secs(1));
757 }
758}
759
760pub struct PoolGuard {
767 conn: Option<PoolSlot>,
768 pool: Arc<PoolInner>,
769 discard: bool,
771 #[cfg(feature = "detect-n-plus-one")]
773 detector: NPlusOneDetector,
774}
775
776impl PoolGuard {
777 #[inline]
780 fn sync_conn(&self) -> Result<&Connection, DriverError> {
781 match self.conn.as_ref() {
782 Some(PoolSlot::Sync(conn)) => Ok(conn),
783 #[cfg(feature = "async")]
784 Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
785 "expected sync connection, got async; use async methods".into(),
786 )),
787 None => Err(DriverError::Pool("connection already taken".into())),
788 }
789 }
790
791 #[inline]
793 fn sync_conn_mut(&mut self) -> Result<&mut Connection, DriverError> {
794 match self.conn.as_mut() {
795 Some(PoolSlot::Sync(conn)) => Ok(conn),
796 #[cfg(feature = "async")]
797 Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
798 "expected sync connection, got async; use async methods".into(),
799 )),
800 None => Err(DriverError::Pool("connection already taken".into())),
801 }
802 }
803
804 pub fn mark_discard(&mut self) {
807 self.discard = true;
808 }
809
810 pub fn cancel(&self) -> Result<(), DriverError> {
815 self.sync_conn()?.cancel()
816 }
817
818 pub fn pid(&self) -> i32 {
827 match self.conn.as_ref().expect("connection returned to pool") {
828 PoolSlot::Sync(conn) => conn.pid(),
829 #[cfg(feature = "async")]
830 PoolSlot::Async(conn) => conn.pid(),
831 }
832 }
833
834 pub fn is_idle(&self) -> bool {
841 match self.conn.as_ref().expect("connection returned to pool") {
842 PoolSlot::Sync(conn) => conn.is_idle(),
843 #[cfg(feature = "async")]
844 PoolSlot::Async(conn) => conn.is_idle(),
845 }
846 }
847
848 pub fn is_in_transaction(&self) -> bool {
855 match self.conn.as_ref().expect("connection returned to pool") {
856 PoolSlot::Sync(conn) => conn.is_in_transaction(),
857 #[cfg(feature = "async")]
858 PoolSlot::Async(conn) => conn.is_in_transaction(),
859 }
860 }
861
862 #[inline]
866 pub fn query(
867 &mut self,
868 sql: &str,
869 sql_hash: u64,
870 params: &[&(dyn Encode + Sync)],
871 ) -> Result<QueryResult, DriverError> {
872 #[cfg(feature = "detect-n-plus-one")]
873 self.detector.track(sql_hash);
874 self.sync_conn_mut()?.query(sql, sql_hash, params)
875 }
876
877 #[inline]
879 pub fn execute(
880 &mut self,
881 sql: &str,
882 sql_hash: u64,
883 params: &[&(dyn Encode + Sync)],
884 ) -> Result<u64, DriverError> {
885 #[cfg(feature = "detect-n-plus-one")]
886 self.detector.track(sql_hash);
887 self.sync_conn_mut()?.execute(sql, sql_hash, params)
888 }
889
890 pub fn execute_pipeline(
895 &mut self,
896 sql: &str,
897 sql_hash: u64,
898 param_sets: &[&[&(dyn Encode + Sync)]],
899 ) -> Result<Vec<u64>, DriverError> {
900 #[cfg(feature = "detect-n-plus-one")]
901 self.detector.track(sql_hash);
902 self.sync_conn_mut()?
903 .execute_pipeline(sql, sql_hash, param_sets)
904 }
905
906 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
908 self.sync_conn_mut()?.simple_query(sql)
909 }
910
911 pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
915 self.sync_conn_mut()?.simple_query_rows(sql)
916 }
917
918 pub fn for_each<F>(
920 &mut self,
921 sql: &str,
922 sql_hash: u64,
923 params: &[&(dyn Encode + Sync)],
924 f: F,
925 ) -> Result<(), DriverError>
926 where
927 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
928 {
929 #[cfg(feature = "detect-n-plus-one")]
930 self.detector.track(sql_hash);
931 self.sync_conn_mut()?.for_each(sql, sql_hash, params, f)
932 }
933
934 pub fn for_each_raw<F>(
936 &mut self,
937 sql: &str,
938 sql_hash: u64,
939 params: &[&(dyn Encode + Sync)],
940 f: F,
941 ) -> Result<(), DriverError>
942 where
943 F: FnMut(&[u8]) -> Result<(), DriverError>,
944 {
945 #[cfg(feature = "detect-n-plus-one")]
946 self.detector.track(sql_hash);
947 self.sync_conn_mut()?.for_each_raw(sql, sql_hash, params, f)
948 }
949
950 pub fn query_streaming_start(
954 &mut self,
955 sql: &str,
956 sql_hash: u64,
957 params: &[&(dyn Encode + Sync)],
958 chunk_size: i32,
959 ) -> Result<(std::sync::Arc<[crate::types::ColumnDesc]>, bool), DriverError> {
960 #[cfg(feature = "detect-n-plus-one")]
961 self.detector.track(sql_hash);
962 self.sync_conn_mut()?
963 .query_streaming_start(sql, sql_hash, params, chunk_size)
964 }
965
966 pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
968 self.sync_conn_mut()?.streaming_send_execute(chunk_size)
969 }
970
971 pub fn streaming_next_chunk(
973 &mut self,
974 arena: &mut Arena,
975 all_col_offsets: &mut Vec<(usize, i32)>,
976 ) -> Result<bool, DriverError> {
977 self.sync_conn_mut()?
978 .streaming_next_chunk(arena, all_col_offsets)
979 }
980
981 pub fn copy_in<'a, I>(
987 &mut self,
988 table: &str,
989 columns: &[&str],
990 rows: I,
991 ) -> Result<u64, DriverError>
992 where
993 I: IntoIterator<Item = &'a str>,
994 {
995 self.sync_conn_mut()?.copy_in(table, columns, rows)
996 }
997
998 pub fn copy_out<W: std::io::Write>(
1002 &mut self,
1003 query: &str,
1004 writer: &mut W,
1005 ) -> Result<u64, DriverError> {
1006 self.sync_conn_mut()?.copy_out(query, writer)
1007 }
1008
1009 pub fn is_sync(&self) -> bool {
1011 matches!(self.conn.as_ref(), Some(PoolSlot::Sync(_)))
1012 }
1013
1014 #[cfg(feature = "async")]
1016 pub fn is_async(&self) -> bool {
1017 matches!(self.conn.as_ref(), Some(PoolSlot::Async(_)))
1018 }
1019
1020 #[cfg(feature = "async")]
1028 pub async fn query_async(
1029 &mut self,
1030 sql: &str,
1031 sql_hash: u64,
1032 params: &[&(dyn Encode + Sync)],
1033 ) -> Result<QueryResult, DriverError> {
1034 #[cfg(feature = "detect-n-plus-one")]
1035 self.detector.track(sql_hash);
1036 match self.conn.as_mut() {
1037 Some(PoolSlot::Sync(conn)) => conn.query(sql, sql_hash, params),
1038 Some(PoolSlot::Async(conn)) => conn.query(sql, sql_hash, params).await,
1039 None => Err(DriverError::Pool("connection already taken".into())),
1040 }
1041 }
1042
1043 #[cfg(feature = "async")]
1045 pub async fn execute_async(
1046 &mut self,
1047 sql: &str,
1048 sql_hash: u64,
1049 params: &[&(dyn Encode + Sync)],
1050 ) -> Result<u64, DriverError> {
1051 #[cfg(feature = "detect-n-plus-one")]
1052 self.detector.track(sql_hash);
1053 match self.conn.as_mut() {
1054 Some(PoolSlot::Sync(conn)) => conn.execute(sql, sql_hash, params),
1055 Some(PoolSlot::Async(conn)) => conn.execute(sql, sql_hash, params).await,
1056 None => Err(DriverError::Pool("connection already taken".into())),
1057 }
1058 }
1059
1060 #[cfg(feature = "async")]
1062 pub async fn simple_query_async(&mut self, sql: &str) -> Result<(), DriverError> {
1063 match self.conn.as_mut() {
1064 Some(PoolSlot::Sync(conn)) => conn.simple_query(sql),
1065 Some(PoolSlot::Async(conn)) => conn.simple_query(sql).await,
1066 None => Err(DriverError::Pool("connection already taken".into())),
1067 }
1068 }
1069
1070 pub(crate) fn ensure_stmt_prepared(
1074 &mut self,
1075 sql: &str,
1076 sql_hash: u64,
1077 params: &[&(dyn Encode + Sync)],
1078 ) -> Result<[u8; 18], DriverError> {
1079 self.sync_conn_mut()?
1080 .ensure_stmt_prepared(sql, sql_hash, params)
1081 }
1082
1083 pub(crate) fn write_deferred_bind_execute(
1085 &self,
1086 sql: &str,
1087 sql_hash: u64,
1088 params: &[&(dyn Encode + Sync)],
1089 buf: &mut Vec<u8>,
1090 ) -> Result<(), DriverError> {
1091 let conn = self.sync_conn()?;
1092 conn.write_deferred_bind_execute(sql, sql_hash, params, buf)
1093 }
1094
1095 pub(crate) fn flush_deferred_pipeline(
1097 &mut self,
1098 buf: &mut Vec<u8>,
1099 count: usize,
1100 ) -> Result<Vec<u64>, DriverError> {
1101 self.sync_conn_mut()?.flush_deferred_pipeline(buf, count)
1102 }
1103}
1104
1105impl Drop for PoolGuard {
1106 fn drop(&mut self) {
1107 #[cfg(feature = "detect-n-plus-one")]
1108 self.detector.emit_final_warning();
1109
1110 if let Some(slot) = self.conn.take() {
1111 let should_discard = self.discard
1113 || self.pool.closed.load(Ordering::Acquire)
1114 || match &slot {
1115 PoolSlot::Sync(conn) => {
1116 conn.is_in_failed_transaction()
1117 || conn.is_in_transaction()
1118 || conn.is_streaming()
1119 }
1120 #[cfg(feature = "async")]
1121 PoolSlot::Async(conn) => {
1122 conn.is_in_failed_transaction() || conn.is_in_transaction()
1123 }
1124 };
1125
1126 if should_discard {
1127 self.pool.open_count.fetch_sub(1, Ordering::AcqRel);
1128 return;
1129 }
1130
1131 let mut slot = slot;
1134 match &mut slot {
1135 PoolSlot::Sync(conn) => {
1136 if conn.query_counter() & 63 == 0 {
1137 conn.touch();
1138 }
1139 }
1140 #[cfg(feature = "async")]
1141 PoolSlot::Async(conn) => {
1142 if conn.query_counter() & 63 == 0 {
1143 conn.touch();
1144 }
1145 }
1146 }
1147
1148 {
1150 let mut stack = self.pool.stack.lock().unwrap_or_else(|e| e.into_inner());
1151 stack.push(slot);
1152 }
1153
1154 if self.pool.open_count.load(Ordering::Relaxed) >= self.pool.max_size {
1156 let (_, cvar) = &self.pool.release_pair;
1157 cvar.notify_one();
1158 }
1159 }
1160 }
1161}
1162
1163pub struct Transaction {
1179 guard: PoolGuard,
1180 committed: bool,
1181 deferred_buf: Vec<u8>,
1183 deferred_count: usize,
1185}
1186
1187impl Transaction {
1188 pub fn commit(mut self) -> Result<(), DriverError> {
1192 if self.deferred_count > 0 {
1193 self.flush_deferred()?;
1194 }
1195 self.guard.simple_query("COMMIT")?;
1196 self.committed = true;
1197 Ok(())
1198 }
1199
1200 pub fn rollback(mut self) -> Result<(), DriverError> {
1204 self.deferred_buf.clear();
1205 self.deferred_count = 0;
1206 self.guard.simple_query("ROLLBACK")?;
1207 self.committed = true; Ok(())
1209 }
1210
1211 pub fn query(
1216 &mut self,
1217 sql: &str,
1218 sql_hash: u64,
1219 params: &[&(dyn Encode + Sync)],
1220 ) -> Result<QueryResult, DriverError> {
1221 if self.deferred_count > 0 {
1222 self.flush_deferred()?;
1223 }
1224 self.guard.query(sql, sql_hash, params)
1225 }
1226
1227 pub fn execute(
1229 &mut self,
1230 sql: &str,
1231 sql_hash: u64,
1232 params: &[&(dyn Encode + Sync)],
1233 ) -> Result<u64, DriverError> {
1234 self.guard.execute(sql, sql_hash, params)
1235 }
1236
1237 pub fn execute_pipeline(
1239 &mut self,
1240 sql: &str,
1241 sql_hash: u64,
1242 param_sets: &[&[&(dyn Encode + Sync)]],
1243 ) -> Result<Vec<u64>, DriverError> {
1244 self.guard.execute_pipeline(sql, sql_hash, param_sets)
1245 }
1246
1247 pub fn for_each<F>(
1251 &mut self,
1252 sql: &str,
1253 sql_hash: u64,
1254 params: &[&(dyn Encode + Sync)],
1255 f: F,
1256 ) -> Result<(), DriverError>
1257 where
1258 F: FnMut(crate::types::PgDataRow<'_>) -> Result<(), DriverError>,
1259 {
1260 if self.deferred_count > 0 {
1261 self.flush_deferred()?;
1262 }
1263 self.guard.for_each(sql, sql_hash, params, f)
1264 }
1265
1266 pub fn for_each_raw<F>(
1270 &mut self,
1271 sql: &str,
1272 sql_hash: u64,
1273 params: &[&(dyn Encode + Sync)],
1274 f: F,
1275 ) -> Result<(), DriverError>
1276 where
1277 F: FnMut(&[u8]) -> Result<(), DriverError>,
1278 {
1279 if self.deferred_count > 0 {
1280 self.flush_deferred()?;
1281 }
1282 self.guard.for_each_raw(sql, sql_hash, params, f)
1283 }
1284
1285 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
1289 if self.deferred_count > 0 {
1290 self.flush_deferred()?;
1291 }
1292 self.guard.simple_query(sql)
1293 }
1294
1295 pub fn defer_execute(
1324 &mut self,
1325 sql: &str,
1326 sql_hash: u64,
1327 params: &[&(dyn Encode + Sync)],
1328 ) -> Result<(), DriverError> {
1329 if params.len() > i16::MAX as usize {
1330 return Err(DriverError::Protocol(format!(
1331 "parameter count {} exceeds maximum {}",
1332 params.len(),
1333 i16::MAX
1334 )));
1335 }
1336
1337 self.guard.ensure_stmt_prepared(sql, sql_hash, params)?;
1339
1340 self.guard
1342 .write_deferred_bind_execute(sql, sql_hash, params, &mut self.deferred_buf)?;
1343 self.deferred_count += 1;
1344 Ok(())
1345 }
1346
1347 pub fn flush_deferred(&mut self) -> Result<Vec<u64>, DriverError> {
1352 let count = self.deferred_count;
1353 self.deferred_count = 0;
1354 self.guard
1355 .flush_deferred_pipeline(&mut self.deferred_buf, count)
1356 }
1357
1358 pub fn deferred_count(&self) -> usize {
1360 self.deferred_count
1361 }
1362}
1363
1364impl Drop for Transaction {
1365 fn drop(&mut self) {
1366 if !self.committed {
1367 if let Some(_slot) = self.guard.conn.take() {
1370 self.guard.pool.open_count.fetch_sub(1, Ordering::AcqRel);
1371 }
1373 }
1374 }
1375}
1376
1377#[cfg(test)]
1378mod tests {
1379 use super::*;
1380
1381 #[test]
1382 fn pool_builder_requires_url() {
1383 let result = PoolBuilder::new().build();
1384 assert!(result.is_err());
1385 }
1386
1387 #[test]
1388 fn pool_builder_validates_url() {
1389 let result = PoolBuilder::new().url("not_a_url").build();
1390 assert!(result.is_err());
1391 }
1392
1393 #[test]
1394 fn pool_builder_accepts_valid_url() {
1395 let pool = PoolBuilder::new()
1396 .url("postgres://user:pass@localhost/db")
1397 .max_size(5)
1398 .build()
1399 .unwrap();
1400 assert_eq!(pool.max_size(), 5);
1401 assert_eq!(pool.open_count(), 0);
1402 }
1403
1404 #[test]
1405 fn pool_connect_validates_url() {
1406 let result = Pool::connect("not_a_url");
1407 assert!(result.is_err());
1408 }
1409
1410 #[test]
1411 fn pool_max_size_zero() {
1412 let pool = PoolBuilder::new()
1413 .url("postgres://user:pass@localhost/db")
1414 .max_size(0)
1415 .build()
1416 .unwrap();
1417
1418 let result = pool.acquire();
1419 assert!(result.is_err());
1420 match result {
1421 Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
1422 Err(e) => panic!("expected Pool error, got: {e:?}"),
1423 Ok(_) => panic!("expected error, got Ok"),
1424 }
1425 }
1426
1427 #[test]
1428 fn pool_clone_shares_state() {
1429 let pool = PoolBuilder::new()
1430 .url("postgres://user:pass@localhost/db")
1431 .max_size(5)
1432 .build()
1433 .unwrap();
1434
1435 let pool2 = pool.clone();
1436 assert_eq!(pool.max_size(), pool2.max_size());
1437 }
1438
1439 #[test]
1443 fn pool_builder_max_lifetime() {
1444 let pool = PoolBuilder::new()
1445 .url("postgres://user:pass@localhost/db")
1446 .max_lifetime(Some(Duration::from_secs(60)))
1447 .build()
1448 .unwrap();
1449 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(60)));
1450 }
1451
1452 #[test]
1454 fn pool_builder_max_lifetime_none() {
1455 let pool = PoolBuilder::new()
1456 .url("postgres://user:pass@localhost/db")
1457 .max_lifetime(None)
1458 .build()
1459 .unwrap();
1460 assert_eq!(pool.inner.max_lifetime, None);
1461 }
1462
1463 #[test]
1465 fn pool_builder_acquire_timeout_none() {
1466 let pool = PoolBuilder::new()
1467 .url("postgres://user:pass@localhost/db")
1468 .acquire_timeout(None)
1469 .build()
1470 .unwrap();
1471 assert_eq!(pool.inner.acquire_timeout, None);
1472 }
1473
1474 #[test]
1476 fn pool_builder_acquire_timeout_custom() {
1477 let pool = PoolBuilder::new()
1478 .url("postgres://user:pass@localhost/db")
1479 .acquire_timeout(Some(Duration::from_secs(10)))
1480 .build()
1481 .unwrap();
1482 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(10)));
1483 }
1484
1485 #[test]
1487 fn pool_builder_min_idle() {
1488 let pool = PoolBuilder::new()
1489 .url("postgres://user:pass@localhost/db")
1490 .min_idle(2)
1491 .build()
1492 .unwrap();
1493 assert_eq!(pool.inner.min_idle, 2);
1494 }
1495
1496 #[test]
1498 fn pool_close_marks_closed() {
1499 let pool = PoolBuilder::new()
1500 .url("postgres://user:pass@localhost/db")
1501 .max_size(5)
1502 .build()
1503 .unwrap();
1504
1505 assert!(!pool.is_closed());
1506 pool.close();
1507 assert!(pool.is_closed());
1508
1509 let result = pool.acquire();
1511 assert!(result.is_err());
1512 match result {
1513 Err(DriverError::Pool(msg)) => assert!(msg.contains("closed")),
1514 Err(e) => panic!("expected Pool(closed) error, got: {e:?}"),
1515 Ok(_) => panic!("expected error, got Ok"),
1516 }
1517 }
1518
1519 #[test]
1521 fn pool_status_initial() {
1522 let pool = PoolBuilder::new()
1523 .url("postgres://user:pass@localhost/db")
1524 .max_size(10)
1525 .build()
1526 .unwrap();
1527
1528 let status = pool.status();
1529 assert_eq!(status.idle, 0);
1530 assert_eq!(status.active, 0);
1531 assert_eq!(status.open, 0);
1532 assert_eq!(status.max_size, 10);
1533 }
1534
1535 #[test]
1537 fn pool_builder_defaults() {
1538 let pool = PoolBuilder::new()
1539 .url("postgres://user:pass@localhost/db")
1540 .build()
1541 .unwrap();
1542
1543 assert_eq!(pool.max_size(), 10);
1544 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(30 * 60)));
1545 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
1546 assert_eq!(pool.inner.min_idle, 0);
1547 }
1548
1549 #[test]
1551 fn pool_open_count_initial() {
1552 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1553 assert_eq!(pool.open_count(), 0);
1554 }
1555
1556 #[test]
1559 fn pool_builder_max_stmt_cache_size_default() {
1560 let pool = PoolBuilder::new()
1561 .url("postgres://user:pass@localhost/db")
1562 .build()
1563 .unwrap();
1564 assert_eq!(pool.inner.max_stmt_cache_size, 256);
1565 }
1566
1567 #[test]
1568 fn pool_builder_max_stmt_cache_size_custom() {
1569 let pool = PoolBuilder::new()
1570 .url("postgres://user:pass@localhost/db")
1571 .max_stmt_cache_size(512)
1572 .build()
1573 .unwrap();
1574 assert_eq!(pool.inner.max_stmt_cache_size, 512);
1575 }
1576
1577 #[test]
1580 fn pool_is_uds_false_for_tcp() {
1581 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1582 assert!(!pool.is_uds());
1583 }
1584
1585 #[cfg(unix)]
1586 #[test]
1587 fn pool_is_uds_true_for_unix_socket() {
1588 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
1589 assert!(pool.is_uds());
1590 }
1591
1592 #[cfg(unix)]
1593 #[test]
1594 fn pool_is_uds_true_for_var_run_socket() {
1595 let pool = Pool::connect("postgres://user@localhost/db?host=/var/run/postgresql").unwrap();
1596 assert!(pool.is_uds());
1597 }
1598
1599 #[test]
1600 fn pool_is_uds_false_for_ip_address() {
1601 let pool = Pool::connect("postgres://user:pass@127.0.0.1/db").unwrap();
1602 assert!(!pool.is_uds());
1603 }
1604
1605 #[cfg(unix)]
1606 #[test]
1607 fn pool_slot_sync_created_for_uds_config() {
1608 let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
1609 assert!(config.host_is_uds());
1610 }
1611
1612 #[test]
1613 fn pool_slot_tcp_config() {
1614 let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
1615 assert!(!config.host_is_uds());
1616 }
1617
1618 #[test]
1623 fn pool_is_uds_false_for_hostname() {
1624 let pool = Pool::connect("postgres://user:pass@db.example.com/db").unwrap();
1625 assert!(!pool.is_uds());
1626 }
1627
1628 #[cfg(unix)]
1629 #[test]
1630 fn pool_is_uds_true_for_tmp() {
1631 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
1632 assert!(pool.is_uds());
1633 }
1634
1635 #[test]
1640 fn pool_close_then_acquire_fails() {
1641 let pool = PoolBuilder::new()
1642 .url("postgres://user:pass@localhost/db")
1643 .max_size(5)
1644 .build()
1645 .unwrap();
1646 pool.close();
1647 let result = pool.acquire();
1648 assert!(result.is_err());
1649 match result {
1650 Err(DriverError::Pool(msg)) => {
1651 assert!(msg.contains("closed"), "should say closed: {msg}")
1652 }
1653 Err(e) => panic!("expected Pool error, got: {e:?}"),
1654 Ok(_) => panic!("expected error"),
1655 }
1656 }
1657
1658 #[test]
1659 fn pool_is_closed_before_and_after() {
1660 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1661 assert!(!pool.is_closed());
1662 pool.close();
1663 assert!(pool.is_closed());
1664 }
1665
1666 #[test]
1671 fn pool_exhausted_no_timeout() {
1672 let pool = PoolBuilder::new()
1673 .url("postgres://user:pass@localhost/db")
1674 .max_size(0)
1675 .acquire_timeout(None) .build()
1677 .unwrap();
1678 let result = pool.acquire();
1679 assert!(result.is_err());
1680 match result {
1681 Err(DriverError::Pool(msg)) => {
1682 assert!(msg.contains("exhausted"), "should say exhausted: {msg}")
1683 }
1684 Err(e) => panic!("expected Pool error, got: {e:?}"),
1685 Ok(_) => panic!("expected error"),
1686 }
1687 }
1688
1689 #[test]
1694 fn pool_builder_no_url_error() {
1695 let result = PoolBuilder::new().max_size(5).build();
1696 assert!(result.is_err());
1697 match result {
1698 Err(DriverError::Pool(msg)) => {
1699 assert!(msg.contains("URL"), "should mention URL: {msg}")
1700 }
1701 Err(e) => panic!("expected Pool error, got: {e:?}"),
1702 Ok(_) => panic!("expected error"),
1703 }
1704 }
1705
1706 #[test]
1707 fn pool_builder_invalid_url_error() {
1708 let result = PoolBuilder::new().url("ftp://something").build();
1709 assert!(result.is_err());
1710 }
1711
1712 #[test]
1713 fn pool_builder_stmt_cache_size_zero() {
1714 let pool = PoolBuilder::new()
1715 .url("postgres://user:pass@localhost/db")
1716 .max_stmt_cache_size(0)
1717 .build()
1718 .unwrap();
1719 assert_eq!(pool.inner.max_stmt_cache_size, 0);
1720 }
1721
1722 #[test]
1725 fn pool_builder_stale_timeout_default() {
1726 let pool = PoolBuilder::new()
1727 .url("postgres://user:pass@localhost/db")
1728 .build()
1729 .unwrap();
1730 assert_eq!(pool.inner.stale_timeout, Duration::from_secs(30));
1731 }
1732
1733 #[test]
1734 fn pool_builder_stale_timeout_custom() {
1735 let pool = PoolBuilder::new()
1736 .url("postgres://user:pass@localhost/db")
1737 .stale_timeout(Duration::from_secs(60))
1738 .build()
1739 .unwrap();
1740 assert_eq!(pool.inner.stale_timeout, Duration::from_secs(60));
1741 }
1742
1743 #[test]
1744 fn pool_builder_stale_timeout_zero() {
1745 let pool = PoolBuilder::new()
1746 .url("postgres://user:pass@localhost/db")
1747 .stale_timeout(Duration::from_secs(0))
1748 .build()
1749 .unwrap();
1750 assert_eq!(pool.inner.stale_timeout, Duration::from_secs(0));
1751 }
1752
1753 #[test]
1758 fn pool_status_reflects_max_size() {
1759 let pool = PoolBuilder::new()
1760 .url("postgres://user:pass@localhost/db")
1761 .max_size(20)
1762 .build()
1763 .unwrap();
1764 let status = pool.status();
1765 assert_eq!(status.max_size, 20);
1766 assert_eq!(status.idle, 0);
1767 assert_eq!(status.active, 0);
1768 assert_eq!(status.open, 0);
1769 }
1770
1771 #[test]
1776 fn pool_clone_shares_config() {
1777 let pool = PoolBuilder::new()
1778 .url("postgres://user:pass@localhost/db")
1779 .max_size(7)
1780 .build()
1781 .unwrap();
1782 let p2 = pool.clone();
1783 assert_eq!(pool.max_size(), 7);
1784 assert_eq!(p2.max_size(), 7);
1785 assert_eq!(pool.open_count(), p2.open_count());
1786 }
1787
1788 #[test]
1793 fn pool_set_warmup_sqls_empty() {
1794 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1795 pool.set_warmup_sqls([] as [&str; 0]);
1796 let sqls = pool
1797 .inner
1798 .warmup_sqls
1799 .lock()
1800 .unwrap_or_else(|e| e.into_inner())
1801 .clone();
1802 assert!(sqls.is_empty());
1803 }
1804
1805 #[test]
1806 fn pool_set_warmup_sqls_multiple() {
1807 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1808 pool.set_warmup_sqls(["SELECT 1", "SELECT 2", "SELECT 3"]);
1809 let sqls = pool
1810 .inner
1811 .warmup_sqls
1812 .lock()
1813 .unwrap_or_else(|e| e.into_inner())
1814 .clone();
1815 assert_eq!(sqls.len(), 3);
1816 assert_eq!(&*sqls[0], "SELECT 1");
1817 assert_eq!(&*sqls[1], "SELECT 2");
1818 assert_eq!(&*sqls[2], "SELECT 3");
1819 }
1820
1821 #[test]
1822 fn pool_set_warmup_sqls_overwrite() {
1823 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1824 pool.set_warmup_sqls(["SELECT 1"]);
1825 pool.set_warmup_sqls(["SELECT 99"]);
1826 let sqls = pool
1827 .inner
1828 .warmup_sqls
1829 .lock()
1830 .unwrap_or_else(|e| e.into_inner())
1831 .clone();
1832 assert_eq!(sqls.len(), 1);
1833 assert_eq!(&*sqls[0], "SELECT 99");
1834 }
1835
1836 #[test]
1837 fn pool_set_warmup_sqls_with_iter_empty() {
1838 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1839 pool.set_warmup_sqls(std::iter::empty::<&str>());
1840 let sqls = pool
1841 .inner
1842 .warmup_sqls
1843 .lock()
1844 .unwrap_or_else(|e| e.into_inner())
1845 .clone();
1846 assert!(sqls.is_empty());
1847 }
1848
1849 #[test]
1850 fn pool_set_warmup_sqls_with_owned_string() {
1851 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1852 let dynamic = format!("SET search_path TO test_{}", 42);
1853 pool.set_warmup_sqls([dynamic]);
1854 let sqls = pool
1855 .inner
1856 .warmup_sqls
1857 .lock()
1858 .unwrap_or_else(|e| e.into_inner())
1859 .clone();
1860 assert_eq!(sqls.len(), 1);
1861 assert_eq!(&*sqls[0], "SET search_path TO test_42");
1862 }
1863
1864 #[test]
1865 fn pool_set_warmup_sqls_with_vec_of_strings() {
1866 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1867 let sqls_owned: Vec<String> = vec!["SELECT 1".to_owned(), "SELECT 2".to_owned()];
1868 pool.set_warmup_sqls(sqls_owned);
1869 let sqls = pool
1870 .inner
1871 .warmup_sqls
1872 .lock()
1873 .unwrap_or_else(|e| e.into_inner())
1874 .clone();
1875 assert_eq!(sqls.len(), 2);
1876 assert_eq!(&*sqls[0], "SELECT 1");
1877 }
1878
1879 #[test]
1880 fn pool_set_warmup_sqls_with_boxed_str() {
1881 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1882 let b: Box<str> = "SELECT 1".into();
1883 pool.set_warmup_sqls([b]);
1884 let sqls = pool
1885 .inner
1886 .warmup_sqls
1887 .lock()
1888 .unwrap_or_else(|e| e.into_inner())
1889 .clone();
1890 assert_eq!(&*sqls[0], "SELECT 1");
1891 }
1892
1893 #[test]
1894 fn pool_set_warmup_sqls_single_static_str() {
1895 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1896 pool.set_warmup_sqls(["SET statement_timeout = '30s'"]);
1897 let sqls = pool
1898 .inner
1899 .warmup_sqls
1900 .lock()
1901 .unwrap_or_else(|e| e.into_inner())
1902 .clone();
1903 assert_eq!(sqls.len(), 1);
1904 }
1905
1906 #[test]
1907 fn pool_set_warmup_sqls_preserves_order() {
1908 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1909 pool.set_warmup_sqls(["first", "second", "third"]);
1910 let sqls = pool
1911 .inner
1912 .warmup_sqls
1913 .lock()
1914 .unwrap_or_else(|e| e.into_inner())
1915 .clone();
1916 assert_eq!(&*sqls[0], "first");
1917 assert_eq!(&*sqls[1], "second");
1918 assert_eq!(&*sqls[2], "third");
1919 }
1920
1921 #[test]
1922 fn pool_set_warmup_sqls_unicode() {
1923 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1924 pool.set_warmup_sqls(["SET client_encoding TO 'UTF8'", "SELECT '日本語'"]);
1925 let sqls = pool
1926 .inner
1927 .warmup_sqls
1928 .lock()
1929 .unwrap_or_else(|e| e.into_inner())
1930 .clone();
1931 assert_eq!(&*sqls[1], "SELECT '日本語'");
1932 }
1933
1934 #[test]
1935 fn pool_set_warmup_sqls_empty_string() {
1936 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1937 pool.set_warmup_sqls([""]);
1938 let sqls = pool
1939 .inner
1940 .warmup_sqls
1941 .lock()
1942 .unwrap_or_else(|e| e.into_inner())
1943 .clone();
1944 assert_eq!(sqls.len(), 1);
1945 assert_eq!(&*sqls[0], "");
1946 }
1947
1948 #[test]
1949 fn pool_set_warmup_sqls_long_sql() {
1950 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1951 let long = "SELECT ".to_owned() + &"x, ".repeat(1000) + "1";
1952 pool.set_warmup_sqls([long]);
1953 let sqls = pool
1954 .inner
1955 .warmup_sqls
1956 .lock()
1957 .unwrap_or_else(|e| e.into_inner())
1958 .clone();
1959 assert!(sqls[0].len() > 3000);
1960 }
1961
1962 #[test]
1967 fn pool_status_debug() {
1968 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1969 let status = pool.status();
1970 let dbg = format!("{status:?}");
1971 assert!(dbg.contains("PoolStatus"));
1972 assert!(dbg.contains("idle"));
1973 assert!(dbg.contains("active"));
1974 assert!(dbg.contains("open"));
1975 assert!(dbg.contains("max_size"));
1976 }
1977
1978 #[test]
1983 fn config_host_is_uds_returns_true_for_slash() {
1984 let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
1985 assert!(config.host_is_uds());
1986 }
1987
1988 #[test]
1989 fn config_host_is_uds_returns_false_for_tcp() {
1990 let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
1991 assert!(!config.host_is_uds());
1992 }
1993
1994 #[test]
1995 fn config_host_is_uds_returns_false_for_ip() {
1996 let config = Config::from_url("postgres://user:pass@192.168.1.1/db").unwrap();
1997 assert!(!config.host_is_uds());
1998 }
1999
2000 #[test]
2005 fn pool_builder_full_chain() {
2006 let pool = PoolBuilder::new()
2007 .url("postgres://user:pass@localhost/db")
2008 .max_size(3)
2009 .max_lifetime(Some(Duration::from_secs(600)))
2010 .acquire_timeout(Some(Duration::from_secs(5)))
2011 .min_idle(1)
2012 .max_stmt_cache_size(128)
2013 .build()
2014 .unwrap();
2015 assert_eq!(pool.max_size(), 3);
2016 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(600)));
2017 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
2018 assert_eq!(pool.inner.min_idle, 1);
2019 assert_eq!(pool.inner.max_stmt_cache_size, 128);
2020 }
2021
2022 #[test]
2025 fn pool_max_size_zero_rejects_all_acquires() {
2026 let pool = PoolBuilder::new()
2027 .url("postgres://user:pass@localhost/db")
2028 .max_size(0)
2029 .build()
2030 .unwrap();
2031 let result = pool.acquire();
2032 assert!(result.is_err());
2033 match &result {
2034 Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
2035 _ => panic!("expected pool exhausted error"),
2036 }
2037 }
2038
2039 #[test]
2042 fn url_parse_unknown_sslmode_returns_error() {
2043 let result = Config::from_url("postgres://u:p@h/d?sslmode=bogus");
2044 assert!(result.is_err());
2045 let msg = format!("{}", result.unwrap_err());
2046 assert!(msg.contains("unknown sslmode"));
2047 }
2048
2049 #[test]
2050 fn url_parse_invalid_port_returns_error() {
2051 let result = Config::from_url("postgres://u:p@h:abc/d");
2052 assert!(result.is_err());
2053 let msg = format!("{}", result.unwrap_err());
2054 assert!(msg.contains("invalid port"));
2055 }
2056
2057 #[test]
2058 fn url_parse_missing_at_sign_returns_error() {
2059 let result = Config::from_url("postgres://u:plocalhost/d");
2060 assert!(result.is_err());
2061 let msg = format!("{}", result.unwrap_err());
2062 assert!(msg.contains("missing @"));
2063 }
2064
2065 #[test]
2066 fn url_parse_empty_host_returns_error() {
2067 let result = Config::from_url("postgres://u:p@/d");
2068 assert!(result.is_err());
2069 }
2070
2071 #[test]
2072 fn url_parse_empty_user_returns_error() {
2073 let result = Config::from_url("postgres://:p@h/d");
2074 assert!(result.is_err());
2075 }
2076
2077 #[test]
2078 fn url_parse_statement_timeout_invalid_uses_default() {
2079 let config = Config::from_url("postgres://u:p@h/d?statement_timeout=notnum").unwrap();
2080 assert_eq!(config.statement_timeout_secs, 30);
2081 }
2082
2083 #[test]
2084 fn url_parse_malformed_percent_encoding() {
2085 let result = Config::from_url("postgres://u%:p@h/d");
2086 assert!(result.is_err());
2087 }
2088
2089 #[test]
2090 fn url_parse_invalid_hex_in_percent_encoding() {
2091 let result = Config::from_url("postgres://u%ZZ:p@h/d");
2092 assert!(result.is_err());
2093 }
2094}
2095
2096#[cfg(all(test, feature = "detect-n-plus-one"))]
2099mod n_plus_one_tests {
2100 use super::NPlusOneDetector;
2101
2102 #[test]
2103 fn below_threshold_no_warning() {
2104 let mut d = NPlusOneDetector::new(10);
2105 for _ in 0..10 {
2106 d.track(42);
2107 }
2108 assert!(d.check_final().is_none());
2109 }
2110
2111 #[test]
2112 fn above_threshold_warns() {
2113 let mut d = NPlusOneDetector::new(10);
2114 for _ in 0..11 {
2115 d.track(42);
2116 }
2117 let w = d.check_final().unwrap();
2118 assert_eq!(w, (42, 11));
2119 }
2120
2121 #[test]
2122 fn exact_threshold_no_warning() {
2123 let mut d = NPlusOneDetector::new(5);
2124 for _ in 0..5 {
2125 d.track(99);
2126 }
2127 assert!(d.check_final().is_none(), "> not >=");
2128 }
2129
2130 #[test]
2131 fn threshold_plus_one_warns() {
2132 let mut d = NPlusOneDetector::new(5);
2133 for _ in 0..6 {
2134 d.track(99);
2135 }
2136 assert_eq!(d.check_final(), Some((99, 6)));
2137 }
2138
2139 #[test]
2140 fn alternating_hashes_no_warning() {
2141 let mut d = NPlusOneDetector::new(2);
2142 for i in 0..100 {
2143 d.track(if i % 2 == 0 { 1 } else { 2 });
2144 }
2145 assert!(d.check_final().is_none());
2146 }
2147
2148 #[test]
2149 fn single_query_no_warning() {
2150 let mut d = NPlusOneDetector::new(10);
2151 d.track(42);
2152 assert!(d.check_final().is_none());
2153 }
2154
2155 #[test]
2156 fn no_queries_no_warning() {
2157 let d = NPlusOneDetector::new(10);
2158 assert!(d.check_final().is_none());
2159 }
2160
2161 #[test]
2162 fn threshold_zero_warns_on_second() {
2163 let mut d = NPlusOneDetector::new(0);
2164 d.track(42);
2165 assert_eq!(d.check_final(), Some((42, 1)));
2167 }
2168
2169 #[test]
2170 fn threshold_max_never_warns() {
2171 let mut d = NPlusOneDetector::new(u16::MAX);
2172 for _ in 0..1000 {
2173 d.track(42);
2174 }
2175 assert!(d.check_final().is_none());
2176 }
2177
2178 #[test]
2179 fn saturating_add_no_overflow() {
2180 let mut d = NPlusOneDetector::new(10);
2181 d.last_query_hash = 42;
2182 d.repeat_count = u16::MAX - 1;
2183 d.track(42); d.track(42); assert_eq!(d.repeat_count, u16::MAX);
2186 }
2187
2188 #[test]
2189 fn different_hash_resets() {
2190 let mut d = NPlusOneDetector::new(100);
2191 for _ in 0..50 {
2192 d.track(1);
2193 }
2194 d.track(2); assert_eq!(d.repeat_count, 1);
2196 assert_eq!(d.last_query_hash, 2);
2197 }
2198
2199 #[test]
2200 fn multiple_n_plus_one_sequences() {
2201 let mut d = NPlusOneDetector::new(3);
2202 for _ in 0..5 {
2204 d.track(1);
2205 }
2206 for _ in 0..4 {
2209 d.track(2);
2210 }
2211 assert_eq!(d.check_final(), Some((2, 4)));
2213 }
2214
2215 #[test]
2216 fn warning_emitted_on_hash_switch() {
2217 let mut d = NPlusOneDetector::new(2);
2218 d.track(10);
2219 d.track(10);
2220 d.track(10); d.track(20);
2223 assert_eq!(d.last_query_hash, 20);
2225 assert_eq!(d.repeat_count, 1);
2226 }
2227
2228 #[test]
2229 fn hash_zero_treated_normally() {
2230 let mut d = NPlusOneDetector::new(2);
2231 d.track(0);
2232 d.track(0);
2233 d.track(0);
2234 assert!(d.check_final().is_none());
2236 }
2237
2238 #[test]
2239 fn long_sequence_correct_count() {
2240 let mut d = NPlusOneDetector::new(10);
2241 for _ in 0..500 {
2242 d.track(42);
2243 }
2244 assert_eq!(d.check_final(), Some((42, 500)));
2245 }
2246
2247 #[test]
2248 fn two_queries_below_threshold() {
2249 let mut d = NPlusOneDetector::new(10);
2250 d.track(1);
2251 d.track(1);
2252 assert!(d.check_final().is_none());
2253 }
2254
2255 #[test]
2256 fn interleaved_then_burst() {
2257 let mut d = NPlusOneDetector::new(3);
2258 d.track(1);
2260 d.track(2);
2261 d.track(1);
2262 d.track(2);
2263 for _ in 0..5 {
2265 d.track(5);
2266 }
2267 assert_eq!(d.check_final(), Some((5, 5)));
2268 }
2269
2270 #[test]
2273 fn pool_builder_n_plus_one_threshold_default() {
2274 let pool = super::PoolBuilder::new()
2275 .url("postgres://user:pass@localhost/db")
2276 .build()
2277 .unwrap();
2278 assert_eq!(pool.inner.n_plus_one_threshold, 10);
2279 }
2280
2281 #[test]
2282 fn pool_builder_n_plus_one_threshold_custom() {
2283 let pool = super::PoolBuilder::new()
2284 .url("postgres://user:pass@localhost/db")
2285 .n_plus_one_threshold(5)
2286 .build()
2287 .unwrap();
2288 assert_eq!(pool.inner.n_plus_one_threshold, 5);
2289 }
2290
2291 #[test]
2292 fn pool_builder_n_plus_one_threshold_zero() {
2293 let pool = super::PoolBuilder::new()
2294 .url("postgres://user:pass@localhost/db")
2295 .n_plus_one_threshold(0)
2296 .build()
2297 .unwrap();
2298 assert_eq!(pool.inner.n_plus_one_threshold, 0);
2299 }
2300
2301 #[test]
2302 fn pool_builder_n_plus_one_threshold_max() {
2303 let pool = super::PoolBuilder::new()
2304 .url("postgres://user:pass@localhost/db")
2305 .n_plus_one_threshold(u16::MAX)
2306 .build()
2307 .unwrap();
2308 assert_eq!(pool.inner.n_plus_one_threshold, u16::MAX);
2309 }
2310
2311 #[test]
2312 fn one_then_different_no_warning() {
2313 let mut d = NPlusOneDetector::new(10);
2314 d.track(1);
2315 d.track(2);
2316 assert!(d.check_final().is_none());
2318 }
2319
2320 #[test]
2321 fn nonzero_hash_after_zero_init() {
2322 let mut d = NPlusOneDetector::new(0);
2326 d.track(42);
2327 let w = d.check_final().unwrap();
2328 assert_eq!(w, (42, 1));
2329 }
2330
2331 #[test]
2332 fn independent_detectors_dont_interfere() {
2333 let mut d1 = NPlusOneDetector::new(5);
2335 let mut d2 = NPlusOneDetector::new(5);
2336
2337 for _ in 0..10 {
2339 d1.track(42);
2340 }
2341 d2.track(1);
2343 d2.track(2);
2344 d2.track(3);
2345
2346 assert!(d1.check_final().is_some());
2348 assert!(d2.check_final().is_none());
2349 }
2350
2351 #[test]
2352 fn rapid_hash_changes_dont_false_positive() {
2353 let mut d = NPlusOneDetector::new(2);
2355 for i in 0u64..1000 {
2356 d.track(i);
2357 }
2358 assert!(d.check_final().is_none());
2360 }
2361
2362 #[test]
2363 fn detector_reset_state_after_warning() {
2364 let mut d = NPlusOneDetector::new(2);
2366 d.track(1);
2367 d.track(1);
2368 d.track(1); d.track(2); d.track(2); assert!(d.check_final().is_none()); }
2373
2374 #[test]
2375 fn detector_with_realistic_orm_pattern() {
2376 let mut d = NPlusOneDetector::new(5);
2378 d.track(100); for _ in 0..20 {
2381 d.track(200); }
2383 assert_eq!(d.check_final(), Some((200, 20)));
2385 }
2386
2387 #[test]
2388 fn detector_with_legitimate_batch_pattern() {
2389 let mut d = NPlusOneDetector::new(10);
2392 for _ in 0..15 {
2393 d.track(300); }
2395 assert!(d.check_final().is_some());
2396 }
2397
2398 #[test]
2399 fn detector_exactly_at_boundaries() {
2400 for threshold in [0u16, 1, 2, 5, 10, 100] {
2401 let mut d = NPlusOneDetector::new(threshold);
2402 for _ in 0..=threshold {
2403 d.track(42);
2404 }
2405 assert!(
2407 d.check_final().is_some(),
2408 "threshold={threshold} should warn at count={}",
2409 threshold + 1
2410 );
2411 }
2412 }
2413
2414 #[test]
2415 fn detector_with_deterministic_random_sequences() {
2416 let mut d = NPlusOneDetector::new(5);
2418 let hashes: Vec<u64> = (0..100).map(|i| ((i * 7 + 3) % 4) as u64).collect();
2419 for &h in &hashes {
2420 d.track(h);
2421 }
2422 let _ = d.check_final();
2424 }
2425
2426 mod proptest_fuzz {
2427 use super::*;
2428 use proptest::prelude::*;
2429
2430 proptest! {
2431 #[test]
2432 fn detector_never_panics(
2433 hashes in proptest::collection::vec(0u64..100, 0..500),
2434 threshold in 0u16..100,
2435 ) {
2436 let mut d = NPlusOneDetector::new(threshold);
2437 for h in &hashes {
2438 d.track(*h);
2439 }
2440 let _ = d.check_final();
2441 }
2442
2443 #[test]
2444 fn sequential_repeats_always_detected(
2445 hash in 1u64..u64::MAX,
2446 count in 2u16..1000,
2447 threshold in 0u16..100,
2448 ) {
2449 let mut d = NPlusOneDetector::new(threshold);
2450 for _ in 0..count {
2451 d.track(hash);
2452 }
2453 if count > threshold {
2454 assert!(d.check_final().is_some(),
2455 "count={count} > threshold={threshold} should trigger");
2456 }
2457 }
2458 }
2459 }
2460}