1use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use std::time::Duration;
12
13use crate::DriverError;
14use crate::arena::Arena;
15use crate::codec::Encode;
16use crate::conn::Connection;
17use crate::types::{Config, PgDataRow, QueryResult, SimpleRow};
18
19#[cfg(feature = "async")]
20use crate::async_conn::AsyncConnection;
21
22pub(crate) enum PoolSlot {
30 Sync(Connection),
32 #[cfg(feature = "async")]
34 Async(AsyncConnection),
35}
36
37pub struct Pool {
53 inner: Arc<PoolInner>,
54}
55
56struct PoolInner {
57 stack: std::sync::Mutex<Vec<PoolSlot>>,
61 max_size: usize,
62 open_count: AtomicUsize,
63 config: Arc<Config>,
64 closed: AtomicBool,
66 release_pair: (std::sync::Mutex<()>, std::sync::Condvar),
69 max_lifetime: Option<Duration>,
72 acquire_timeout: Option<Duration>,
74 min_idle: usize,
76 warmup_sqls: std::sync::Mutex<Arc<Vec<Box<str>>>>,
78 max_stmt_cache_size: usize,
80 stale_timeout: Duration,
83}
84
85impl Pool {
86 pub fn connect(url: &str) -> Result<Self, DriverError> {
90 PoolBuilder::new().url(url).build()
91 }
92
93 pub fn builder() -> PoolBuilder {
95 PoolBuilder::new()
96 }
97
98 #[inline]
106 pub fn acquire(&self) -> Result<PoolGuard, DriverError> {
107 if self.inner.closed.load(Ordering::Acquire) {
108 return Err(DriverError::Pool("pool is closed".into()));
109 }
110
111 if let Some(guard) = self.try_pop_idle()? {
113 return Ok(guard);
114 }
115
116 loop {
118 let current = self.inner.open_count.load(Ordering::Acquire);
119 if current >= self.inner.max_size {
120 if let Some(timeout) = self.inner.acquire_timeout {
121 let (lock, cvar) = &self.inner.release_pair;
122 let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
123 let (_guard, result) = cvar
124 .wait_timeout(guard, timeout)
125 .unwrap_or_else(|e| e.into_inner());
126 if result.timed_out() {
127 return Err(DriverError::Pool(
128 "pool exhausted: acquire timeout expired".into(),
129 ));
130 }
131 if let Some(guard) = self.try_pop_idle()? {
133 return Ok(guard);
134 }
135 continue;
137 }
138 return Err(DriverError::Pool(
139 "pool exhausted: all connections in use".into(),
140 ));
141 }
142 if self
143 .inner
144 .open_count
145 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
146 .is_ok()
147 {
148 break;
149 }
150 }
152
153 let conn_result = Connection::connect_arc(self.inner.config.clone());
155 match conn_result {
156 Ok(mut conn) => {
157 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
159 self.warmup_conn(&mut conn);
161
162 Ok(PoolGuard {
163 conn: Some(PoolSlot::Sync(conn)),
164 pool: self.inner.clone(),
165 discard: false,
166 })
167 }
168 Err(e) => {
169 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
171 Err(e)
172 }
173 }
174 }
175
176 #[inline]
182 fn try_pop_idle(&self) -> Result<Option<PoolGuard>, DriverError> {
183 let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
184 while let Some(mut slot) = stack.pop() {
185 let (created_at, idle_dur) = match &slot {
186 PoolSlot::Sync(conn) => (conn.created_at(), conn.idle_duration()),
187 #[cfg(feature = "async")]
188 PoolSlot::Async(conn) => (conn.created_at(), conn.idle_duration()),
189 };
190 if let Some(max_lifetime) = self.inner.max_lifetime {
191 if created_at.elapsed() >= max_lifetime {
192 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
193 continue;
194 }
195 }
196 if idle_dur >= self.inner.stale_timeout {
197 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
199 continue;
200 }
201 if idle_dur > Duration::from_secs(5) {
205 let alive = match &mut slot {
206 PoolSlot::Sync(conn) => conn.simple_query("").is_ok(),
207 #[cfg(feature = "async")]
208 PoolSlot::Async(_) => true, };
210 if !alive {
211 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
212 continue;
213 }
214 }
215 return Ok(Some(PoolGuard {
216 conn: Some(slot),
217 pool: self.inner.clone(),
218 discard: false,
219 }));
220 }
221 Ok(None)
222 }
223
224 pub fn is_uds(&self) -> bool {
229 #[cfg(unix)]
230 {
231 self.inner.config.host_is_uds()
232 }
233 #[cfg(not(unix))]
234 {
235 false
236 }
237 }
238
239 pub fn begin(&self) -> Result<Transaction, DriverError> {
241 let mut guard = self.acquire()?;
242 guard.simple_query("BEGIN")?;
243 Ok(Transaction {
244 guard,
245 committed: false,
246 deferred_buf: Vec::new(),
247 deferred_count: 0,
248 })
249 }
250
251 pub fn open_count(&self) -> usize {
253 self.inner.open_count.load(Ordering::Relaxed)
254 }
255
256 pub fn max_size(&self) -> usize {
258 self.inner.max_size
259 }
260
261 pub fn status(&self) -> PoolStatus {
263 let idle = self
264 .inner
265 .stack
266 .lock()
267 .unwrap_or_else(|e| e.into_inner())
268 .len();
269 let open = self.inner.open_count.load(Ordering::Relaxed);
270 let active = open.saturating_sub(idle);
271 PoolStatus {
272 idle,
273 active,
274 open,
275 max_size: self.inner.max_size,
276 }
277 }
278
279 fn warmup_conn(&self, conn: &mut Connection) {
287 let sqls = self
288 .inner
289 .warmup_sqls
290 .lock()
291 .unwrap_or_else(|e| e.into_inner())
292 .clone();
293
294 if sqls.is_empty() {
295 return;
296 }
297
298 for sql in sqls.iter() {
299 let sql_hash = crate::types::hash_sql(sql);
300 let _ = conn.prepare_only(sql, sql_hash);
301 }
302 }
303
304 pub fn set_warmup_sqls(&self, sqls: &[&str]) {
326 let boxed: Arc<Vec<Box<str>>> =
327 Arc::new(sqls.iter().map(|s| (*s).into()).collect::<Vec<_>>());
328 *self
329 .inner
330 .warmup_sqls
331 .lock()
332 .unwrap_or_else(|e| e.into_inner()) = boxed;
333 }
334
335 pub fn close(&self) {
338 self.inner.closed.store(true, Ordering::Release);
339 let slots: Vec<PoolSlot> = {
341 let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
342 std::mem::take(&mut *stack)
343 };
344 for slot in slots {
345 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
346 match slot {
347 PoolSlot::Sync(conn) => {
348 let _ = conn.close();
349 }
350 #[cfg(feature = "async")]
351 PoolSlot::Async(_conn) => {
352 }
355 }
356 }
357 let (_, cvar) = &self.inner.release_pair;
359 cvar.notify_all();
360 }
361
362 pub fn is_closed(&self) -> bool {
364 self.inner.closed.load(Ordering::Acquire)
365 }
366
367 #[cfg(feature = "async")]
377 pub async fn acquire_async(&self) -> Result<PoolGuard, DriverError> {
378 if self.inner.closed.load(Ordering::Acquire) {
379 return Err(DriverError::Pool("pool is closed".into()));
380 }
381
382 if let Some(guard) = self.try_pop_idle()? {
384 return Ok(guard);
385 }
386
387 loop {
389 let current = self.inner.open_count.load(Ordering::Acquire);
390 if current >= self.inner.max_size {
391 if let Some(timeout) = self.inner.acquire_timeout {
392 let (lock, cvar) = &self.inner.release_pair;
393 let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
394 let (_guard, result) = cvar
395 .wait_timeout(guard, timeout)
396 .unwrap_or_else(|e| e.into_inner());
397 if result.timed_out() {
398 return Err(DriverError::Pool(
399 "pool exhausted: acquire timeout expired".into(),
400 ));
401 }
402 if let Some(guard) = self.try_pop_idle()? {
403 return Ok(guard);
404 }
405 continue;
406 }
407 return Err(DriverError::Pool(
408 "pool exhausted: all connections in use".into(),
409 ));
410 }
411 if self
412 .inner
413 .open_count
414 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
415 .is_ok()
416 {
417 break;
418 }
419 }
420
421 if self.inner.config.host_is_uds() {
423 let conn_result = Connection::connect_arc(self.inner.config.clone());
425 match conn_result {
426 Ok(mut conn) => {
427 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
428 self.warmup_conn(&mut conn);
429 Ok(PoolGuard {
430 conn: Some(PoolSlot::Sync(conn)),
431 pool: self.inner.clone(),
432 discard: false,
433 })
434 }
435 Err(e) => {
436 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
437 Err(e)
438 }
439 }
440 } else {
441 let conn_result = AsyncConnection::connect_arc(self.inner.config.clone()).await;
443 match conn_result {
444 Ok(mut conn) => {
445 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
446 Ok(PoolGuard {
447 conn: Some(PoolSlot::Async(conn)),
448 pool: self.inner.clone(),
449 discard: false,
450 })
451 }
452 Err(e) => {
453 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
454 Err(e)
455 }
456 }
457 }
458 }
459}
460
461impl Clone for Pool {
462 fn clone(&self) -> Self {
463 Pool {
464 inner: self.inner.clone(),
465 }
466 }
467}
468
469#[derive(Debug, Clone, Copy)]
473pub struct PoolStatus {
474 pub idle: usize,
476 pub active: usize,
478 pub open: usize,
480 pub max_size: usize,
482}
483
484pub struct PoolBuilder {
488 url: Option<String>,
489 max_size: usize,
490 max_lifetime: Option<Duration>,
492 acquire_timeout: Option<Duration>,
494 min_idle: usize,
496 max_stmt_cache_size: usize,
498 stale_timeout: Duration,
500}
501
502impl PoolBuilder {
503 fn new() -> Self {
504 Self {
505 url: None,
506 max_size: 10,
507 max_lifetime: Some(Duration::from_secs(30 * 60)), acquire_timeout: Some(Duration::from_secs(5)), min_idle: 0, max_stmt_cache_size: 256, stale_timeout: Duration::from_secs(30), }
513 }
514
515 pub fn url(mut self, url: &str) -> Self {
517 self.url = Some(url.to_owned());
518 self
519 }
520
521 pub fn max_size(mut self, size: usize) -> Self {
525 self.max_size = size;
526 self
527 }
528
529 pub fn max_lifetime(mut self, lifetime: Option<Duration>) -> Self {
532 self.max_lifetime = lifetime;
533 self
534 }
535
536 pub fn acquire_timeout(mut self, timeout: Option<Duration>) -> Self {
539 self.acquire_timeout = timeout;
540 self
541 }
542
543 pub fn min_idle(mut self, count: usize) -> Self {
546 self.min_idle = count;
547 self
548 }
549
550 pub fn max_stmt_cache_size(mut self, size: usize) -> Self {
554 self.max_stmt_cache_size = size;
555 self
556 }
557
558 pub fn stale_timeout(mut self, timeout: Duration) -> Self {
562 self.stale_timeout = timeout;
563 self
564 }
565
566 pub fn build(self) -> Result<Pool, DriverError> {
568 let url = self
569 .url
570 .ok_or_else(|| DriverError::Pool("pool builder requires a URL".into()))?;
571
572 let config = Arc::new(Config::from_url(&url)?);
573
574 let pool = Pool {
575 inner: Arc::new(PoolInner {
576 stack: std::sync::Mutex::new(Vec::with_capacity(self.max_size)),
577 max_size: self.max_size,
578 open_count: AtomicUsize::new(0),
579 config,
580 closed: AtomicBool::new(false),
581 release_pair: (std::sync::Mutex::new(()), std::sync::Condvar::new()),
582 max_lifetime: self.max_lifetime,
583 acquire_timeout: self.acquire_timeout,
584 min_idle: self.min_idle,
585 warmup_sqls: std::sync::Mutex::new(Arc::new(Vec::new())),
586 max_stmt_cache_size: self.max_stmt_cache_size,
587 stale_timeout: self.stale_timeout,
588 }),
589 };
590
591 if self.min_idle > 0 {
592 let inner = pool.inner.clone();
593 std::thread::spawn(move || {
594 maintain_min_idle(inner);
595 });
596 }
597
598 Ok(pool)
599 }
600}
601
602fn maintain_min_idle(inner: Arc<PoolInner>) {
604 loop {
605 if inner.closed.load(Ordering::Acquire) {
606 return;
607 }
608
609 let idle_count = inner.stack.lock().unwrap_or_else(|e| e.into_inner()).len();
610 let needed = inner.min_idle.saturating_sub(idle_count);
611
612 for _ in 0..needed {
613 if inner.closed.load(Ordering::Acquire) {
614 return;
615 }
616 let current = inner.open_count.load(Ordering::Acquire);
617 if current >= inner.max_size {
618 break;
619 }
620 if inner
621 .open_count
622 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
623 .is_err()
624 {
625 continue;
626 }
627
628 match Connection::connect_arc(inner.config.clone()) {
629 Ok(conn) => {
630 let mut stack = inner.stack.lock().unwrap_or_else(|e| e.into_inner());
631 stack.push(PoolSlot::Sync(conn));
632 let (_, cvar) = &inner.release_pair;
633 cvar.notify_one();
634 }
635 Err(_) => {
636 inner.open_count.fetch_sub(1, Ordering::AcqRel);
637 }
638 }
639 }
640
641 std::thread::sleep(Duration::from_secs(1));
644 }
645}
646
647pub struct PoolGuard {
654 conn: Option<PoolSlot>,
655 pool: Arc<PoolInner>,
656 discard: bool,
658}
659
660impl PoolGuard {
661 #[inline]
664 fn sync_conn(&self) -> Result<&Connection, DriverError> {
665 match self.conn.as_ref() {
666 Some(PoolSlot::Sync(conn)) => Ok(conn),
667 #[cfg(feature = "async")]
668 Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
669 "expected sync connection, got async; use async methods".into(),
670 )),
671 None => Err(DriverError::Pool("connection already taken".into())),
672 }
673 }
674
675 #[inline]
677 fn sync_conn_mut(&mut self) -> Result<&mut Connection, DriverError> {
678 match self.conn.as_mut() {
679 Some(PoolSlot::Sync(conn)) => Ok(conn),
680 #[cfg(feature = "async")]
681 Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
682 "expected sync connection, got async; use async methods".into(),
683 )),
684 None => Err(DriverError::Pool("connection already taken".into())),
685 }
686 }
687
688 pub fn mark_discard(&mut self) {
691 self.discard = true;
692 }
693
694 pub fn cancel(&self) -> Result<(), DriverError> {
699 self.sync_conn()?.cancel()
700 }
701
702 pub fn pid(&self) -> i32 {
706 match self.conn.as_ref().expect("connection taken") {
707 PoolSlot::Sync(conn) => conn.pid(),
708 #[cfg(feature = "async")]
709 PoolSlot::Async(conn) => conn.pid(),
710 }
711 }
712
713 pub fn is_idle(&self) -> bool {
715 match self.conn.as_ref().expect("connection taken") {
716 PoolSlot::Sync(conn) => conn.is_idle(),
717 #[cfg(feature = "async")]
718 PoolSlot::Async(conn) => conn.is_idle(),
719 }
720 }
721
722 pub fn is_in_transaction(&self) -> bool {
724 match self.conn.as_ref().expect("connection taken") {
725 PoolSlot::Sync(conn) => conn.is_in_transaction(),
726 #[cfg(feature = "async")]
727 PoolSlot::Async(conn) => conn.is_in_transaction(),
728 }
729 }
730
731 #[inline]
735 pub fn query(
736 &mut self,
737 sql: &str,
738 sql_hash: u64,
739 params: &[&(dyn Encode + Sync)],
740 ) -> Result<QueryResult, DriverError> {
741 self.sync_conn_mut()?.query(sql, sql_hash, params)
742 }
743
744 #[inline]
746 pub fn execute(
747 &mut self,
748 sql: &str,
749 sql_hash: u64,
750 params: &[&(dyn Encode + Sync)],
751 ) -> Result<u64, DriverError> {
752 self.sync_conn_mut()?.execute(sql, sql_hash, params)
753 }
754
755 pub fn execute_pipeline(
760 &mut self,
761 sql: &str,
762 sql_hash: u64,
763 param_sets: &[&[&(dyn Encode + Sync)]],
764 ) -> Result<Vec<u64>, DriverError> {
765 self.sync_conn_mut()?
766 .execute_pipeline(sql, sql_hash, param_sets)
767 }
768
769 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
771 self.sync_conn_mut()?.simple_query(sql)
772 }
773
774 pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
778 self.sync_conn_mut()?.simple_query_rows(sql)
779 }
780
781 pub fn for_each<F>(
783 &mut self,
784 sql: &str,
785 sql_hash: u64,
786 params: &[&(dyn Encode + Sync)],
787 f: F,
788 ) -> Result<(), DriverError>
789 where
790 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
791 {
792 self.sync_conn_mut()?.for_each(sql, sql_hash, params, f)
793 }
794
795 pub fn for_each_raw<F>(
797 &mut self,
798 sql: &str,
799 sql_hash: u64,
800 params: &[&(dyn Encode + Sync)],
801 f: F,
802 ) -> Result<(), DriverError>
803 where
804 F: FnMut(&[u8]) -> Result<(), DriverError>,
805 {
806 self.sync_conn_mut()?.for_each_raw(sql, sql_hash, params, f)
807 }
808
809 pub fn query_streaming_start(
813 &mut self,
814 sql: &str,
815 sql_hash: u64,
816 params: &[&(dyn Encode + Sync)],
817 chunk_size: i32,
818 ) -> Result<(std::sync::Arc<[crate::types::ColumnDesc]>, bool), DriverError> {
819 self.sync_conn_mut()?
820 .query_streaming_start(sql, sql_hash, params, chunk_size)
821 }
822
823 pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
825 self.sync_conn_mut()?.streaming_send_execute(chunk_size)
826 }
827
828 pub fn streaming_next_chunk(
830 &mut self,
831 arena: &mut Arena,
832 all_col_offsets: &mut Vec<(usize, i32)>,
833 ) -> Result<bool, DriverError> {
834 self.sync_conn_mut()?
835 .streaming_next_chunk(arena, all_col_offsets)
836 }
837
838 pub fn copy_in<'a, I>(
844 &mut self,
845 table: &str,
846 columns: &[&str],
847 rows: I,
848 ) -> Result<u64, DriverError>
849 where
850 I: IntoIterator<Item = &'a str>,
851 {
852 self.sync_conn_mut()?.copy_in(table, columns, rows)
853 }
854
855 pub fn copy_out<W: std::io::Write>(
859 &mut self,
860 query: &str,
861 writer: &mut W,
862 ) -> Result<u64, DriverError> {
863 self.sync_conn_mut()?.copy_out(query, writer)
864 }
865
866 pub fn is_sync(&self) -> bool {
868 matches!(self.conn.as_ref(), Some(PoolSlot::Sync(_)))
869 }
870
871 #[cfg(feature = "async")]
873 pub fn is_async(&self) -> bool {
874 matches!(self.conn.as_ref(), Some(PoolSlot::Async(_)))
875 }
876
877 #[cfg(feature = "async")]
885 pub async fn query_async(
886 &mut self,
887 sql: &str,
888 sql_hash: u64,
889 params: &[&(dyn Encode + Sync)],
890 ) -> Result<QueryResult, DriverError> {
891 match self.conn.as_mut() {
892 Some(PoolSlot::Sync(conn)) => conn.query(sql, sql_hash, params),
893 Some(PoolSlot::Async(conn)) => conn.query(sql, sql_hash, params).await,
894 None => Err(DriverError::Pool("connection already taken".into())),
895 }
896 }
897
898 #[cfg(feature = "async")]
900 pub async fn execute_async(
901 &mut self,
902 sql: &str,
903 sql_hash: u64,
904 params: &[&(dyn Encode + Sync)],
905 ) -> Result<u64, DriverError> {
906 match self.conn.as_mut() {
907 Some(PoolSlot::Sync(conn)) => conn.execute(sql, sql_hash, params),
908 Some(PoolSlot::Async(conn)) => conn.execute(sql, sql_hash, params).await,
909 None => Err(DriverError::Pool("connection already taken".into())),
910 }
911 }
912
913 #[cfg(feature = "async")]
915 pub async fn simple_query_async(&mut self, sql: &str) -> Result<(), DriverError> {
916 match self.conn.as_mut() {
917 Some(PoolSlot::Sync(conn)) => conn.simple_query(sql),
918 Some(PoolSlot::Async(conn)) => conn.simple_query(sql).await,
919 None => Err(DriverError::Pool("connection already taken".into())),
920 }
921 }
922
923 pub(crate) fn ensure_stmt_prepared(
927 &mut self,
928 sql: &str,
929 sql_hash: u64,
930 params: &[&(dyn Encode + Sync)],
931 ) -> Result<[u8; 18], DriverError> {
932 self.sync_conn_mut()?
933 .ensure_stmt_prepared(sql, sql_hash, params)
934 }
935
936 pub(crate) fn write_deferred_bind_execute(
938 &self,
939 sql: &str,
940 sql_hash: u64,
941 params: &[&(dyn Encode + Sync)],
942 buf: &mut Vec<u8>,
943 ) {
944 let conn = self
945 .sync_conn()
946 .expect("sync_conn failed in write_deferred");
947 conn.write_deferred_bind_execute(sql, sql_hash, params, buf);
948 }
949
950 pub(crate) fn flush_deferred_pipeline(
952 &mut self,
953 buf: &mut Vec<u8>,
954 count: usize,
955 ) -> Result<Vec<u64>, DriverError> {
956 self.sync_conn_mut()?.flush_deferred_pipeline(buf, count)
957 }
958}
959
960impl Drop for PoolGuard {
961 fn drop(&mut self) {
962 if let Some(slot) = self.conn.take() {
963 let should_discard = self.discard
965 || self.pool.closed.load(Ordering::Acquire)
966 || match &slot {
967 PoolSlot::Sync(conn) => {
968 conn.is_in_failed_transaction()
969 || conn.is_in_transaction()
970 || conn.is_streaming()
971 }
972 #[cfg(feature = "async")]
973 PoolSlot::Async(conn) => {
974 conn.is_in_failed_transaction() || conn.is_in_transaction()
975 }
976 };
977
978 if should_discard {
979 self.pool.open_count.fetch_sub(1, Ordering::AcqRel);
980 return;
981 }
982
983 let mut slot = slot;
986 match &mut slot {
987 PoolSlot::Sync(conn) => {
988 if conn.query_counter() & 63 == 0 {
989 conn.touch();
990 }
991 }
992 #[cfg(feature = "async")]
993 PoolSlot::Async(conn) => {
994 if conn.query_counter() & 63 == 0 {
995 conn.touch();
996 }
997 }
998 }
999
1000 {
1002 let mut stack = self.pool.stack.lock().unwrap_or_else(|e| e.into_inner());
1003 stack.push(slot);
1004 }
1005
1006 if self.pool.open_count.load(Ordering::Relaxed) >= self.pool.max_size {
1008 let (_, cvar) = &self.pool.release_pair;
1009 cvar.notify_one();
1010 }
1011 }
1012 }
1013}
1014
1015pub struct Transaction {
1031 guard: PoolGuard,
1032 committed: bool,
1033 deferred_buf: Vec<u8>,
1035 deferred_count: usize,
1037}
1038
1039impl Transaction {
1040 pub fn commit(mut self) -> Result<(), DriverError> {
1044 if self.deferred_count > 0 {
1045 self.flush_deferred()?;
1046 }
1047 self.guard.simple_query("COMMIT")?;
1048 self.committed = true;
1049 Ok(())
1050 }
1051
1052 pub fn rollback(mut self) -> Result<(), DriverError> {
1056 self.deferred_buf.clear();
1057 self.deferred_count = 0;
1058 self.guard.simple_query("ROLLBACK")?;
1059 self.committed = true; Ok(())
1061 }
1062
1063 pub fn query(
1068 &mut self,
1069 sql: &str,
1070 sql_hash: u64,
1071 params: &[&(dyn Encode + Sync)],
1072 ) -> Result<QueryResult, DriverError> {
1073 if self.deferred_count > 0 {
1074 self.flush_deferred()?;
1075 }
1076 self.guard.query(sql, sql_hash, params)
1077 }
1078
1079 pub fn execute(
1081 &mut self,
1082 sql: &str,
1083 sql_hash: u64,
1084 params: &[&(dyn Encode + Sync)],
1085 ) -> Result<u64, DriverError> {
1086 self.guard.execute(sql, sql_hash, params)
1087 }
1088
1089 pub fn execute_pipeline(
1091 &mut self,
1092 sql: &str,
1093 sql_hash: u64,
1094 param_sets: &[&[&(dyn Encode + Sync)]],
1095 ) -> Result<Vec<u64>, DriverError> {
1096 self.guard.execute_pipeline(sql, sql_hash, param_sets)
1097 }
1098
1099 pub fn for_each<F>(
1103 &mut self,
1104 sql: &str,
1105 sql_hash: u64,
1106 params: &[&(dyn Encode + Sync)],
1107 f: F,
1108 ) -> Result<(), DriverError>
1109 where
1110 F: FnMut(crate::types::PgDataRow<'_>) -> Result<(), DriverError>,
1111 {
1112 if self.deferred_count > 0 {
1113 self.flush_deferred()?;
1114 }
1115 self.guard.for_each(sql, sql_hash, params, f)
1116 }
1117
1118 pub fn for_each_raw<F>(
1122 &mut self,
1123 sql: &str,
1124 sql_hash: u64,
1125 params: &[&(dyn Encode + Sync)],
1126 f: F,
1127 ) -> Result<(), DriverError>
1128 where
1129 F: FnMut(&[u8]) -> Result<(), DriverError>,
1130 {
1131 if self.deferred_count > 0 {
1132 self.flush_deferred()?;
1133 }
1134 self.guard.for_each_raw(sql, sql_hash, params, f)
1135 }
1136
1137 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
1141 if self.deferred_count > 0 {
1142 self.flush_deferred()?;
1143 }
1144 self.guard.simple_query(sql)
1145 }
1146
1147 pub fn defer_execute(
1176 &mut self,
1177 sql: &str,
1178 sql_hash: u64,
1179 params: &[&(dyn Encode + Sync)],
1180 ) -> Result<(), DriverError> {
1181 if params.len() > i16::MAX as usize {
1182 return Err(DriverError::Protocol(format!(
1183 "parameter count {} exceeds maximum {}",
1184 params.len(),
1185 i16::MAX
1186 )));
1187 }
1188
1189 self.guard.ensure_stmt_prepared(sql, sql_hash, params)?;
1191
1192 self.guard
1194 .write_deferred_bind_execute(sql, sql_hash, params, &mut self.deferred_buf);
1195 self.deferred_count += 1;
1196 Ok(())
1197 }
1198
1199 pub fn flush_deferred(&mut self) -> Result<Vec<u64>, DriverError> {
1204 let count = self.deferred_count;
1205 self.deferred_count = 0;
1206 self.guard
1207 .flush_deferred_pipeline(&mut self.deferred_buf, count)
1208 }
1209
1210 pub fn deferred_count(&self) -> usize {
1212 self.deferred_count
1213 }
1214}
1215
1216impl Drop for Transaction {
1217 fn drop(&mut self) {
1218 if !self.committed {
1219 if let Some(_slot) = self.guard.conn.take() {
1222 self.guard.pool.open_count.fetch_sub(1, Ordering::AcqRel);
1223 }
1225 }
1226 }
1227}
1228
1229#[cfg(test)]
1230mod tests {
1231 use super::*;
1232
1233 #[test]
1234 fn pool_builder_requires_url() {
1235 let result = PoolBuilder::new().build();
1236 assert!(result.is_err());
1237 }
1238
1239 #[test]
1240 fn pool_builder_validates_url() {
1241 let result = PoolBuilder::new().url("not_a_url").build();
1242 assert!(result.is_err());
1243 }
1244
1245 #[test]
1246 fn pool_builder_accepts_valid_url() {
1247 let pool = PoolBuilder::new()
1248 .url("postgres://user:pass@localhost/db")
1249 .max_size(5)
1250 .build()
1251 .unwrap();
1252 assert_eq!(pool.max_size(), 5);
1253 assert_eq!(pool.open_count(), 0);
1254 }
1255
1256 #[test]
1257 fn pool_connect_validates_url() {
1258 let result = Pool::connect("not_a_url");
1259 assert!(result.is_err());
1260 }
1261
1262 #[test]
1263 fn pool_max_size_zero() {
1264 let pool = PoolBuilder::new()
1265 .url("postgres://user:pass@localhost/db")
1266 .max_size(0)
1267 .build()
1268 .unwrap();
1269
1270 let result = pool.acquire();
1271 assert!(result.is_err());
1272 match result {
1273 Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
1274 Err(e) => panic!("expected Pool error, got: {e:?}"),
1275 Ok(_) => panic!("expected error, got Ok"),
1276 }
1277 }
1278
1279 #[test]
1280 fn pool_clone_shares_state() {
1281 let pool = PoolBuilder::new()
1282 .url("postgres://user:pass@localhost/db")
1283 .max_size(5)
1284 .build()
1285 .unwrap();
1286
1287 let pool2 = pool.clone();
1288 assert_eq!(pool.max_size(), pool2.max_size());
1289 }
1290
1291 #[test]
1295 fn pool_builder_max_lifetime() {
1296 let pool = PoolBuilder::new()
1297 .url("postgres://user:pass@localhost/db")
1298 .max_lifetime(Some(Duration::from_secs(60)))
1299 .build()
1300 .unwrap();
1301 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(60)));
1302 }
1303
1304 #[test]
1306 fn pool_builder_max_lifetime_none() {
1307 let pool = PoolBuilder::new()
1308 .url("postgres://user:pass@localhost/db")
1309 .max_lifetime(None)
1310 .build()
1311 .unwrap();
1312 assert_eq!(pool.inner.max_lifetime, None);
1313 }
1314
1315 #[test]
1317 fn pool_builder_acquire_timeout_none() {
1318 let pool = PoolBuilder::new()
1319 .url("postgres://user:pass@localhost/db")
1320 .acquire_timeout(None)
1321 .build()
1322 .unwrap();
1323 assert_eq!(pool.inner.acquire_timeout, None);
1324 }
1325
1326 #[test]
1328 fn pool_builder_acquire_timeout_custom() {
1329 let pool = PoolBuilder::new()
1330 .url("postgres://user:pass@localhost/db")
1331 .acquire_timeout(Some(Duration::from_secs(10)))
1332 .build()
1333 .unwrap();
1334 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(10)));
1335 }
1336
1337 #[test]
1339 fn pool_builder_min_idle() {
1340 let pool = PoolBuilder::new()
1341 .url("postgres://user:pass@localhost/db")
1342 .min_idle(2)
1343 .build()
1344 .unwrap();
1345 assert_eq!(pool.inner.min_idle, 2);
1346 }
1347
1348 #[test]
1350 fn pool_close_marks_closed() {
1351 let pool = PoolBuilder::new()
1352 .url("postgres://user:pass@localhost/db")
1353 .max_size(5)
1354 .build()
1355 .unwrap();
1356
1357 assert!(!pool.is_closed());
1358 pool.close();
1359 assert!(pool.is_closed());
1360
1361 let result = pool.acquire();
1363 assert!(result.is_err());
1364 match result {
1365 Err(DriverError::Pool(msg)) => assert!(msg.contains("closed")),
1366 Err(e) => panic!("expected Pool(closed) error, got: {e:?}"),
1367 Ok(_) => panic!("expected error, got Ok"),
1368 }
1369 }
1370
1371 #[test]
1373 fn pool_status_initial() {
1374 let pool = PoolBuilder::new()
1375 .url("postgres://user:pass@localhost/db")
1376 .max_size(10)
1377 .build()
1378 .unwrap();
1379
1380 let status = pool.status();
1381 assert_eq!(status.idle, 0);
1382 assert_eq!(status.active, 0);
1383 assert_eq!(status.open, 0);
1384 assert_eq!(status.max_size, 10);
1385 }
1386
1387 #[test]
1389 fn pool_builder_defaults() {
1390 let pool = PoolBuilder::new()
1391 .url("postgres://user:pass@localhost/db")
1392 .build()
1393 .unwrap();
1394
1395 assert_eq!(pool.max_size(), 10);
1396 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(30 * 60)));
1397 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
1398 assert_eq!(pool.inner.min_idle, 0);
1399 }
1400
1401 #[test]
1403 fn pool_open_count_initial() {
1404 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1405 assert_eq!(pool.open_count(), 0);
1406 }
1407
1408 #[test]
1411 fn pool_builder_max_stmt_cache_size_default() {
1412 let pool = PoolBuilder::new()
1413 .url("postgres://user:pass@localhost/db")
1414 .build()
1415 .unwrap();
1416 assert_eq!(pool.inner.max_stmt_cache_size, 256);
1417 }
1418
1419 #[test]
1420 fn pool_builder_max_stmt_cache_size_custom() {
1421 let pool = PoolBuilder::new()
1422 .url("postgres://user:pass@localhost/db")
1423 .max_stmt_cache_size(512)
1424 .build()
1425 .unwrap();
1426 assert_eq!(pool.inner.max_stmt_cache_size, 512);
1427 }
1428
1429 #[test]
1432 fn pool_is_uds_false_for_tcp() {
1433 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1434 assert!(!pool.is_uds());
1435 }
1436
1437 #[cfg(unix)]
1438 #[test]
1439 fn pool_is_uds_true_for_unix_socket() {
1440 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
1441 assert!(pool.is_uds());
1442 }
1443
1444 #[cfg(unix)]
1445 #[test]
1446 fn pool_is_uds_true_for_var_run_socket() {
1447 let pool = Pool::connect("postgres://user@localhost/db?host=/var/run/postgresql").unwrap();
1448 assert!(pool.is_uds());
1449 }
1450
1451 #[test]
1452 fn pool_is_uds_false_for_ip_address() {
1453 let pool = Pool::connect("postgres://user:pass@127.0.0.1/db").unwrap();
1454 assert!(!pool.is_uds());
1455 }
1456
1457 #[cfg(unix)]
1458 #[test]
1459 fn pool_slot_sync_created_for_uds_config() {
1460 let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
1461 assert!(config.host_is_uds());
1462 }
1463
1464 #[test]
1465 fn pool_slot_tcp_config() {
1466 let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
1467 assert!(!config.host_is_uds());
1468 }
1469
1470 #[test]
1475 fn pool_is_uds_false_for_hostname() {
1476 let pool = Pool::connect("postgres://user:pass@db.example.com/db").unwrap();
1477 assert!(!pool.is_uds());
1478 }
1479
1480 #[cfg(unix)]
1481 #[test]
1482 fn pool_is_uds_true_for_tmp() {
1483 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
1484 assert!(pool.is_uds());
1485 }
1486
1487 #[test]
1492 fn pool_close_then_acquire_fails() {
1493 let pool = PoolBuilder::new()
1494 .url("postgres://user:pass@localhost/db")
1495 .max_size(5)
1496 .build()
1497 .unwrap();
1498 pool.close();
1499 let result = pool.acquire();
1500 assert!(result.is_err());
1501 match result {
1502 Err(DriverError::Pool(msg)) => {
1503 assert!(msg.contains("closed"), "should say closed: {msg}")
1504 }
1505 Err(e) => panic!("expected Pool error, got: {e:?}"),
1506 Ok(_) => panic!("expected error"),
1507 }
1508 }
1509
1510 #[test]
1511 fn pool_is_closed_before_and_after() {
1512 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1513 assert!(!pool.is_closed());
1514 pool.close();
1515 assert!(pool.is_closed());
1516 }
1517
1518 #[test]
1523 fn pool_exhausted_no_timeout() {
1524 let pool = PoolBuilder::new()
1525 .url("postgres://user:pass@localhost/db")
1526 .max_size(0)
1527 .acquire_timeout(None) .build()
1529 .unwrap();
1530 let result = pool.acquire();
1531 assert!(result.is_err());
1532 match result {
1533 Err(DriverError::Pool(msg)) => {
1534 assert!(msg.contains("exhausted"), "should say exhausted: {msg}")
1535 }
1536 Err(e) => panic!("expected Pool error, got: {e:?}"),
1537 Ok(_) => panic!("expected error"),
1538 }
1539 }
1540
1541 #[test]
1546 fn pool_builder_no_url_error() {
1547 let result = PoolBuilder::new().max_size(5).build();
1548 assert!(result.is_err());
1549 match result {
1550 Err(DriverError::Pool(msg)) => {
1551 assert!(msg.contains("URL"), "should mention URL: {msg}")
1552 }
1553 Err(e) => panic!("expected Pool error, got: {e:?}"),
1554 Ok(_) => panic!("expected error"),
1555 }
1556 }
1557
1558 #[test]
1559 fn pool_builder_invalid_url_error() {
1560 let result = PoolBuilder::new().url("ftp://something").build();
1561 assert!(result.is_err());
1562 }
1563
1564 #[test]
1565 fn pool_builder_stmt_cache_size_zero() {
1566 let pool = PoolBuilder::new()
1567 .url("postgres://user:pass@localhost/db")
1568 .max_stmt_cache_size(0)
1569 .build()
1570 .unwrap();
1571 assert_eq!(pool.inner.max_stmt_cache_size, 0);
1572 }
1573
1574 #[test]
1579 fn pool_status_reflects_max_size() {
1580 let pool = PoolBuilder::new()
1581 .url("postgres://user:pass@localhost/db")
1582 .max_size(20)
1583 .build()
1584 .unwrap();
1585 let status = pool.status();
1586 assert_eq!(status.max_size, 20);
1587 assert_eq!(status.idle, 0);
1588 assert_eq!(status.active, 0);
1589 assert_eq!(status.open, 0);
1590 }
1591
1592 #[test]
1597 fn pool_clone_shares_config() {
1598 let pool = PoolBuilder::new()
1599 .url("postgres://user:pass@localhost/db")
1600 .max_size(7)
1601 .build()
1602 .unwrap();
1603 let p2 = pool.clone();
1604 assert_eq!(pool.max_size(), 7);
1605 assert_eq!(p2.max_size(), 7);
1606 assert_eq!(pool.open_count(), p2.open_count());
1607 }
1608
1609 #[test]
1614 fn pool_set_warmup_sqls_empty() {
1615 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1616 pool.set_warmup_sqls(&[]);
1617 let sqls = pool
1618 .inner
1619 .warmup_sqls
1620 .lock()
1621 .unwrap_or_else(|e| e.into_inner())
1622 .clone();
1623 assert!(sqls.is_empty());
1624 }
1625
1626 #[test]
1627 fn pool_set_warmup_sqls_multiple() {
1628 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1629 pool.set_warmup_sqls(&["SELECT 1", "SELECT 2", "SELECT 3"]);
1630 let sqls = pool
1631 .inner
1632 .warmup_sqls
1633 .lock()
1634 .unwrap_or_else(|e| e.into_inner())
1635 .clone();
1636 assert_eq!(sqls.len(), 3);
1637 assert_eq!(&*sqls[0], "SELECT 1");
1638 assert_eq!(&*sqls[1], "SELECT 2");
1639 assert_eq!(&*sqls[2], "SELECT 3");
1640 }
1641
1642 #[test]
1643 fn pool_set_warmup_sqls_overwrite() {
1644 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1645 pool.set_warmup_sqls(&["SELECT 1"]);
1646 pool.set_warmup_sqls(&["SELECT 99"]);
1647 let sqls = pool
1648 .inner
1649 .warmup_sqls
1650 .lock()
1651 .unwrap_or_else(|e| e.into_inner())
1652 .clone();
1653 assert_eq!(sqls.len(), 1);
1654 assert_eq!(&*sqls[0], "SELECT 99");
1655 }
1656
1657 #[test]
1662 fn pool_status_debug() {
1663 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1664 let status = pool.status();
1665 let dbg = format!("{status:?}");
1666 assert!(dbg.contains("PoolStatus"));
1667 assert!(dbg.contains("idle"));
1668 assert!(dbg.contains("active"));
1669 assert!(dbg.contains("open"));
1670 assert!(dbg.contains("max_size"));
1671 }
1672
1673 #[test]
1678 fn config_host_is_uds_returns_true_for_slash() {
1679 let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
1680 assert!(config.host_is_uds());
1681 }
1682
1683 #[test]
1684 fn config_host_is_uds_returns_false_for_tcp() {
1685 let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
1686 assert!(!config.host_is_uds());
1687 }
1688
1689 #[test]
1690 fn config_host_is_uds_returns_false_for_ip() {
1691 let config = Config::from_url("postgres://user:pass@192.168.1.1/db").unwrap();
1692 assert!(!config.host_is_uds());
1693 }
1694
1695 #[test]
1700 fn pool_builder_full_chain() {
1701 let pool = PoolBuilder::new()
1702 .url("postgres://user:pass@localhost/db")
1703 .max_size(3)
1704 .max_lifetime(Some(Duration::from_secs(600)))
1705 .acquire_timeout(Some(Duration::from_secs(5)))
1706 .min_idle(1)
1707 .max_stmt_cache_size(128)
1708 .build()
1709 .unwrap();
1710 assert_eq!(pool.max_size(), 3);
1711 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(600)));
1712 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
1713 assert_eq!(pool.inner.min_idle, 1);
1714 assert_eq!(pool.inner.max_stmt_cache_size, 128);
1715 }
1716
1717 #[test]
1720 fn pool_max_size_zero_rejects_all_acquires() {
1721 let pool = PoolBuilder::new()
1722 .url("postgres://user:pass@localhost/db")
1723 .max_size(0)
1724 .build()
1725 .unwrap();
1726 let result = pool.acquire();
1727 assert!(result.is_err());
1728 match &result {
1729 Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
1730 _ => panic!("expected pool exhausted error"),
1731 }
1732 }
1733
1734 #[test]
1737 fn url_parse_unknown_sslmode_returns_error() {
1738 let result = Config::from_url("postgres://u:p@h/d?sslmode=bogus");
1739 assert!(result.is_err());
1740 let msg = format!("{}", result.unwrap_err());
1741 assert!(msg.contains("unknown sslmode"));
1742 }
1743
1744 #[test]
1745 fn url_parse_invalid_port_returns_error() {
1746 let result = Config::from_url("postgres://u:p@h:abc/d");
1747 assert!(result.is_err());
1748 let msg = format!("{}", result.unwrap_err());
1749 assert!(msg.contains("invalid port"));
1750 }
1751
1752 #[test]
1753 fn url_parse_missing_at_sign_returns_error() {
1754 let result = Config::from_url("postgres://u:plocalhost/d");
1755 assert!(result.is_err());
1756 let msg = format!("{}", result.unwrap_err());
1757 assert!(msg.contains("missing @"));
1758 }
1759
1760 #[test]
1761 fn url_parse_empty_host_returns_error() {
1762 let result = Config::from_url("postgres://u:p@/d");
1763 assert!(result.is_err());
1764 }
1765
1766 #[test]
1767 fn url_parse_empty_user_returns_error() {
1768 let result = Config::from_url("postgres://:p@h/d");
1769 assert!(result.is_err());
1770 }
1771
1772 #[test]
1773 fn url_parse_statement_timeout_invalid_uses_default() {
1774 let config = Config::from_url("postgres://u:p@h/d?statement_timeout=notnum").unwrap();
1775 assert_eq!(config.statement_timeout_secs, 30);
1776 }
1777
1778 #[test]
1779 fn url_parse_malformed_percent_encoding() {
1780 let result = Config::from_url("postgres://u%:p@h/d");
1781 assert!(result.is_err());
1782 }
1783
1784 #[test]
1785 fn url_parse_invalid_hex_in_percent_encoding() {
1786 let result = Config::from_url("postgres://u%ZZ:p@h/d");
1787 assert!(result.is_err());
1788 }
1789}