1use std::sync::Arc;
15use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
16use std::time::Duration;
17
18use tokio::sync::Notify;
19
20use crate::DriverError;
21use crate::arena::Arena;
22use crate::codec::Encode;
23use crate::conn::{Config, Connection, PgDataRow, QueryResult};
24#[cfg(unix)]
25use crate::sync_conn::SyncConnection;
26
27enum PoolSlot {
38 Async(Connection),
39 #[cfg(unix)]
40 Sync(SyncConnection),
41}
42
43impl PoolSlot {
44 fn created_at(&self) -> std::time::Instant {
45 match self {
46 PoolSlot::Async(c) => c.created_at(),
47 #[cfg(unix)]
48 PoolSlot::Sync(c) => c.created_at(),
49 }
50 }
51
52 fn idle_duration(&self) -> Duration {
53 match self {
54 PoolSlot::Async(c) => c.idle_duration(),
55 #[cfg(unix)]
56 PoolSlot::Sync(c) => c.idle_duration(),
57 }
58 }
59
60 fn is_in_failed_transaction(&self) -> bool {
61 match self {
62 PoolSlot::Async(c) => c.is_in_failed_transaction(),
63 #[cfg(unix)]
64 PoolSlot::Sync(c) => c.is_in_failed_transaction(),
65 }
66 }
67
68 fn is_in_transaction(&self) -> bool {
69 match self {
70 PoolSlot::Async(c) => c.is_in_transaction(),
71 #[cfg(unix)]
72 PoolSlot::Sync(c) => c.is_in_transaction(),
73 }
74 }
75
76 fn is_streaming(&self) -> bool {
77 match self {
78 PoolSlot::Async(c) => c.is_streaming(),
79 #[cfg(unix)]
81 PoolSlot::Sync(_) => false,
82 }
83 }
84
85 fn set_max_stmt_cache_size(&mut self, size: usize) {
86 match self {
87 PoolSlot::Async(c) => c.set_max_stmt_cache_size(size),
88 #[cfg(unix)]
89 PoolSlot::Sync(c) => c.set_max_stmt_cache_size(size),
90 }
91 }
92
93 async fn close(self) -> Result<(), DriverError> {
94 match self {
95 PoolSlot::Async(c) => c.close().await,
96 #[cfg(unix)]
97 PoolSlot::Sync(c) => c.close(),
98 }
99 }
100
101 #[cfg(unix)]
103 fn is_sync(&self) -> bool {
104 matches!(self, PoolSlot::Sync(_))
105 }
106
107 fn touch(&mut self) {
111 match self {
112 PoolSlot::Async(c) => c.touch(),
113 #[cfg(unix)]
114 PoolSlot::Sync(c) => c.touch(),
115 }
116 }
117}
118
119pub struct Pool {
135 inner: Arc<PoolInner>,
136}
137
138struct PoolInner {
139 stack: std::sync::Mutex<Vec<PoolSlot>>,
143 max_size: usize,
144 open_count: AtomicUsize,
145 config: Config,
146 connecting: Notify,
147 release_notify: Notify,
149 closed: AtomicBool,
151 max_lifetime: Option<Duration>,
154 acquire_timeout: Option<Duration>,
156 min_idle: usize,
158 warmup_sqls: std::sync::Mutex<Arc<[Box<str>]>>,
167 max_stmt_cache_size: usize,
169}
170
171impl Pool {
172 pub async fn connect(url: &str) -> Result<Self, DriverError> {
176 PoolBuilder::new().url(url).build().await
177 }
178
179 pub fn builder() -> PoolBuilder {
181 PoolBuilder::new()
182 }
183
184 pub async fn acquire(&self) -> Result<PoolGuard, DriverError> {
191 if self.inner.closed.load(Ordering::Acquire) {
192 return Err(DriverError::Pool("pool is closed".into()));
193 }
194
195 if let Some(guard) = self.try_pop_idle()? {
203 return Ok(guard);
204 }
205
206 loop {
209 let current = self.inner.open_count.load(Ordering::Acquire);
210 if current >= self.inner.max_size {
211 if let Some(timeout) = self.inner.acquire_timeout {
212 let result =
213 tokio::time::timeout(timeout, self.inner.release_notify.notified()).await;
214 if result.is_err() {
215 return Err(DriverError::Pool(
216 "pool exhausted: acquire timeout expired".into(),
217 ));
218 }
219 if let Some(guard) = self.try_pop_idle()? {
221 return Ok(guard);
222 }
223 continue;
225 }
226 return Err(DriverError::Pool(
227 "pool exhausted: all connections in use".into(),
228 ));
229 }
230 if self
231 .inner
232 .open_count
233 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
234 .is_ok()
235 {
236 break;
237 }
238 }
240
241 let slot_result = self.open_new_connection().await;
243 match slot_result {
244 Ok(mut slot) => {
245 slot.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
247 self.warmup_slot(&mut slot).await;
249
250 self.inner.connecting.notify_waiters();
251 Ok(PoolGuard {
252 conn: Some(slot),
253 pool: self.inner.clone(),
254 discard: false,
255 })
256 }
257 Err(e) => {
258 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
260 self.inner.connecting.notify_waiters();
261 Err(e)
262 }
263 }
264 }
265
266 fn try_pop_idle(&self) -> Result<Option<PoolGuard>, DriverError> {
268 let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
269 while let Some(slot) = stack.pop() {
270 if let Some(max_lifetime) = self.inner.max_lifetime {
271 if slot.created_at().elapsed() >= max_lifetime {
272 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
273 continue;
274 }
275 }
276 if slot.idle_duration() < Duration::from_secs(30) {
277 return Ok(Some(PoolGuard {
278 conn: Some(slot),
279 pool: self.inner.clone(),
280 discard: false,
281 }));
282 }
283 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
285 }
286 Ok(None)
287 }
288
289 async fn open_new_connection(&self) -> Result<PoolSlot, DriverError> {
295 open_new_connection_inner(&self.inner.config).await
296 }
297
298 pub fn is_uds(&self) -> bool {
303 #[cfg(unix)]
304 {
305 self.inner.config.host_is_uds()
306 }
307 #[cfg(not(unix))]
308 {
309 false
310 }
311 }
312
313 pub async fn begin(&self) -> Result<Transaction, DriverError> {
315 let mut guard = self.acquire().await?;
316 guard.simple_query("BEGIN").await?;
317 Ok(Transaction {
318 guard,
319 committed: false,
320 deferred_buf: Vec::new(),
321 deferred_count: 0,
322 })
323 }
324
325 pub fn open_count(&self) -> usize {
327 self.inner.open_count.load(Ordering::Relaxed)
328 }
329
330 pub fn max_size(&self) -> usize {
332 self.inner.max_size
333 }
334
335 pub fn status(&self) -> PoolStatus {
337 let idle = self
338 .inner
339 .stack
340 .lock()
341 .unwrap_or_else(|e| e.into_inner())
342 .len();
343 let open = self.inner.open_count.load(Ordering::Relaxed);
344 let active = open.saturating_sub(idle);
345 PoolStatus {
346 idle,
347 active,
348 open,
349 max_size: self.inner.max_size,
350 }
351 }
352
353 async fn warmup_slot(&self, slot: &mut PoolSlot) {
361 let sqls = self
362 .inner
363 .warmup_sqls
364 .lock()
365 .unwrap_or_else(|e| e.into_inner())
366 .clone();
367
368 if sqls.is_empty() {
369 return;
370 }
371
372 match slot {
373 PoolSlot::Async(conn) => {
374 for sql in sqls.iter() {
375 let sql_hash = crate::conn::hash_sql(sql);
376 let _ = tokio::time::timeout(
377 std::time::Duration::from_secs(5),
378 conn.prepare_only(sql, sql_hash),
379 )
380 .await;
381 }
382 }
383 #[cfg(unix)]
384 PoolSlot::Sync(conn) => {
385 tokio::task::block_in_place(|| {
386 for sql in sqls.iter() {
387 let sql_hash = crate::conn::hash_sql(sql);
388 let _ = conn.prepare_only(sql, sql_hash);
389 }
390 });
391 }
392 }
393 }
394
395 pub async fn close(&self) {
419 self.inner.closed.store(true, Ordering::Release);
420 let slots: Vec<PoolSlot> = {
422 let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
423 std::mem::take(&mut *stack)
424 };
425 for slot in slots {
426 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
427 let _ = slot.close().await;
428 }
429 self.inner.release_notify.notify_waiters();
431 }
432
433 pub fn is_closed(&self) -> bool {
435 self.inner.closed.load(Ordering::Acquire)
436 }
437
438 pub fn set_warmup_sqls(&self, sqls: &[&str]) {
439 let boxed: Arc<[Box<str>]> = sqls.iter().map(|s| (*s).into()).collect::<Vec<_>>().into();
440 *self
441 .inner
442 .warmup_sqls
443 .lock()
444 .unwrap_or_else(|e| e.into_inner()) = boxed;
445 }
446}
447
448impl Clone for Pool {
449 fn clone(&self) -> Self {
450 Pool {
451 inner: self.inner.clone(),
452 }
453 }
454}
455
456#[derive(Debug, Clone, Copy)]
460pub struct PoolStatus {
461 pub idle: usize,
463 pub active: usize,
465 pub open: usize,
467 pub max_size: usize,
469}
470
471pub struct PoolBuilder {
475 url: Option<String>,
476 max_size: usize,
477 max_lifetime: Option<Duration>,
479 acquire_timeout: Option<Duration>,
481 min_idle: usize,
483 max_stmt_cache_size: usize,
485}
486
487impl PoolBuilder {
488 fn new() -> Self {
489 Self {
490 url: None,
491 max_size: 10,
492 max_lifetime: Some(Duration::from_secs(30 * 60)), acquire_timeout: None, min_idle: 0, max_stmt_cache_size: 256, }
497 }
498
499 pub fn url(mut self, url: &str) -> Self {
501 self.url = Some(url.to_owned());
502 self
503 }
504
505 pub fn max_size(mut self, size: usize) -> Self {
509 self.max_size = size;
510 self
511 }
512
513 pub fn max_lifetime(mut self, lifetime: Option<Duration>) -> Self {
516 self.max_lifetime = lifetime;
517 self
518 }
519
520 pub fn acquire_timeout(mut self, timeout: Option<Duration>) -> Self {
523 self.acquire_timeout = timeout;
524 self
525 }
526
527 pub fn min_idle(mut self, count: usize) -> Self {
530 self.min_idle = count;
531 self
532 }
533
534 pub fn max_stmt_cache_size(mut self, size: usize) -> Self {
538 self.max_stmt_cache_size = size;
539 self
540 }
541
542 pub async fn build(self) -> Result<Pool, DriverError> {
544 let url = self
545 .url
546 .ok_or_else(|| DriverError::Pool("pool builder requires a URL".into()))?;
547
548 let config = Config::from_url(&url)?;
549
550 let pool = Pool {
551 inner: Arc::new(PoolInner {
552 stack: std::sync::Mutex::new(Vec::with_capacity(self.max_size)),
553 max_size: self.max_size,
554 open_count: AtomicUsize::new(0),
555 config,
556 connecting: Notify::new(),
557 release_notify: Notify::new(),
558 closed: AtomicBool::new(false),
559 max_lifetime: self.max_lifetime,
560 acquire_timeout: self.acquire_timeout,
561 min_idle: self.min_idle,
562 warmup_sqls: std::sync::Mutex::new(Arc::from(Vec::<Box<str>>::new())),
563 max_stmt_cache_size: self.max_stmt_cache_size,
564 }),
565 };
566
567 if self.min_idle > 0 {
568 let inner = pool.inner.clone();
569 tokio::spawn(async move {
570 maintain_min_idle(inner).await;
571 });
572 }
573
574 Ok(pool)
575 }
576}
577
578async fn maintain_min_idle(inner: Arc<PoolInner>) {
580 loop {
581 if inner.closed.load(Ordering::Acquire) {
582 return;
583 }
584
585 let idle_count = inner.stack.lock().unwrap_or_else(|e| e.into_inner()).len();
586 let needed = inner.min_idle.saturating_sub(idle_count);
587
588 for _ in 0..needed {
589 if inner.closed.load(Ordering::Acquire) {
590 return;
591 }
592 let current = inner.open_count.load(Ordering::Acquire);
593 if current >= inner.max_size {
594 break;
595 }
596 if inner
597 .open_count
598 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
599 .is_err()
600 {
601 continue;
602 }
603
604 let slot_result = open_new_connection_inner(&inner.config).await;
605 match slot_result {
606 Ok(slot) => {
607 let mut stack = inner.stack.lock().unwrap_or_else(|e| e.into_inner());
608 stack.push(slot);
609 inner.release_notify.notify_one();
610 }
611 Err(_) => {
612 inner.open_count.fetch_sub(1, Ordering::AcqRel);
613 }
614 }
615 }
616
617 tokio::time::sleep(Duration::from_secs(5)).await;
619 }
620}
621
622async fn open_new_connection_inner(config: &Config) -> Result<PoolSlot, DriverError> {
625 #[cfg(unix)]
626 if config.host_is_uds() {
627 let config = config.clone();
628 return tokio::task::block_in_place(|| {
629 SyncConnection::connect(&config).map(PoolSlot::Sync)
630 });
631 }
632
633 Connection::connect(config).await.map(PoolSlot::Async)
634}
635
636pub struct PoolGuard {
647 conn: Option<PoolSlot>,
648 pool: Arc<PoolInner>,
649 discard: bool,
653}
654
655impl PoolGuard {
656 pub fn mark_discard(&mut self) {
663 self.discard = true;
664 }
665
666 pub async fn cancel(&self) -> Result<(), DriverError> {
671 let slot = self
672 .conn
673 .as_ref()
674 .ok_or_else(|| DriverError::Pool("connection already taken".into()))?;
675 match slot {
676 PoolSlot::Async(conn) => conn.cancel(&self.pool.config).await,
677 #[cfg(unix)]
680 PoolSlot::Sync(_) => Err(DriverError::Pool(
681 "cancel not supported on sync UDS connections".into(),
682 )),
683 }
684 }
685
686 pub fn pid(&self) -> i32 {
690 match self.conn.as_ref().expect("connection taken") {
691 PoolSlot::Async(conn) => conn.pid(),
692 #[cfg(unix)]
693 PoolSlot::Sync(conn) => conn.pid(),
694 }
695 }
696
697 pub fn is_idle(&self) -> bool {
699 match self.conn.as_ref().expect("connection taken") {
700 PoolSlot::Async(conn) => conn.is_idle(),
701 #[cfg(unix)]
702 PoolSlot::Sync(conn) => conn.is_idle(),
703 }
704 }
705
706 pub fn is_in_transaction(&self) -> bool {
708 match self.conn.as_ref().expect("connection taken") {
709 PoolSlot::Async(conn) => conn.is_in_transaction(),
710 #[cfg(unix)]
711 PoolSlot::Sync(conn) => conn.is_in_transaction(),
712 }
713 }
714
715 pub async fn query(
719 &mut self,
720 sql: &str,
721 sql_hash: u64,
722 params: &[&(dyn Encode + Sync)],
723 arena: &mut Arena,
724 ) -> Result<QueryResult, DriverError> {
725 let slot = self
726 .conn
727 .as_mut()
728 .ok_or_else(|| DriverError::Pool("connection already taken".into()))?;
729 match slot {
730 PoolSlot::Async(conn) => conn.query(sql, sql_hash, params, arena).await,
731 #[cfg(unix)]
732 PoolSlot::Sync(conn) => {
733 tokio::task::block_in_place(|| conn.query(sql, sql_hash, params, arena))
734 }
735 }
736 }
737
738 pub async fn execute(
743 &mut self,
744 sql: &str,
745 sql_hash: u64,
746 params: &[&(dyn Encode + Sync)],
747 ) -> Result<u64, DriverError> {
748 let slot = self
749 .conn
750 .as_mut()
751 .ok_or_else(|| DriverError::Pool("connection already taken".into()))?;
752 match slot {
753 PoolSlot::Async(conn) => conn.execute(sql, sql_hash, params).await,
754 #[cfg(unix)]
755 PoolSlot::Sync(conn) => {
756 tokio::task::block_in_place(|| conn.execute_monolithic(sql, sql_hash, params))
757 }
758 }
759 }
760
761 pub async fn execute_pipeline(
766 &mut self,
767 sql: &str,
768 sql_hash: u64,
769 param_sets: &[&[&(dyn Encode + Sync)]],
770 ) -> Result<Vec<u64>, DriverError> {
771 let slot = self
772 .conn
773 .as_mut()
774 .ok_or_else(|| DriverError::Pool("connection already taken".into()))?;
775 match slot {
776 PoolSlot::Async(conn) => conn.execute_pipeline(sql, sql_hash, param_sets).await,
777 #[cfg(unix)]
778 PoolSlot::Sync(conn) => {
779 tokio::task::block_in_place(|| conn.execute_pipeline(sql, sql_hash, param_sets))
780 }
781 }
782 }
783
784 pub async fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
786 let slot = self
787 .conn
788 .as_mut()
789 .ok_or_else(|| DriverError::Pool("connection already taken".into()))?;
790 match slot {
791 PoolSlot::Async(conn) => conn.simple_query(sql).await,
792 #[cfg(unix)]
793 PoolSlot::Sync(conn) => tokio::task::block_in_place(|| conn.simple_query(sql)),
794 }
795 }
796
797 pub async fn for_each<F>(
799 &mut self,
800 sql: &str,
801 sql_hash: u64,
802 params: &[&(dyn Encode + Sync)],
803 f: F,
804 ) -> Result<(), DriverError>
805 where
806 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
807 {
808 let slot = self
809 .conn
810 .as_mut()
811 .ok_or_else(|| DriverError::Pool("connection already taken".into()))?;
812 match slot {
813 PoolSlot::Async(conn) => conn.for_each(sql, sql_hash, params, f).await,
814 #[cfg(unix)]
815 PoolSlot::Sync(conn) => {
816 tokio::task::block_in_place(|| conn.for_each(sql, sql_hash, params, f))
817 }
818 }
819 }
820
821 pub async fn for_each_raw<F>(
826 &mut self,
827 sql: &str,
828 sql_hash: u64,
829 params: &[&(dyn Encode + Sync)],
830 f: F,
831 ) -> Result<(), DriverError>
832 where
833 F: FnMut(&[u8]) -> Result<(), DriverError>,
834 {
835 let slot = self
836 .conn
837 .as_mut()
838 .ok_or_else(|| DriverError::Pool("connection already taken".into()))?;
839 match slot {
840 PoolSlot::Async(conn) => conn.for_each_raw(sql, sql_hash, params, f).await,
841 #[cfg(unix)]
842 PoolSlot::Sync(conn) => tokio::task::block_in_place(|| {
843 conn.for_each_raw_monolithic(sql, sql_hash, params, f)
844 }),
845 }
846 }
847
848 pub async fn query_streaming_start(
855 &mut self,
856 sql: &str,
857 sql_hash: u64,
858 params: &[&(dyn Encode + Sync)],
859 chunk_size: i32,
860 ) -> Result<(std::sync::Arc<[crate::conn::ColumnDesc]>, bool), DriverError> {
861 let slot = self
862 .conn
863 .as_mut()
864 .ok_or_else(|| DriverError::Pool("connection already taken".into()))?;
865 match slot {
866 PoolSlot::Async(conn) => {
867 conn.query_streaming_start(sql, sql_hash, params, chunk_size)
868 .await
869 }
870 #[cfg(unix)]
871 PoolSlot::Sync(_) => Err(DriverError::Pool(
872 "streaming queries not supported on sync UDS connections".into(),
873 )),
874 }
875 }
876
877 pub async fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
879 let slot = self
880 .conn
881 .as_mut()
882 .ok_or_else(|| DriverError::Pool("connection already taken".into()))?;
883 match slot {
884 PoolSlot::Async(conn) => conn.streaming_send_execute(chunk_size).await,
885 #[cfg(unix)]
886 PoolSlot::Sync(_) => Err(DriverError::Pool(
887 "streaming queries not supported on sync UDS connections".into(),
888 )),
889 }
890 }
891
892 pub async fn streaming_next_chunk(
894 &mut self,
895 arena: &mut Arena,
896 all_col_offsets: &mut Vec<(usize, i32)>,
897 ) -> Result<bool, DriverError> {
898 let slot = self
899 .conn
900 .as_mut()
901 .ok_or_else(|| DriverError::Pool("connection already taken".into()))?;
902 match slot {
903 PoolSlot::Async(conn) => conn.streaming_next_chunk(arena, all_col_offsets).await,
904 #[cfg(unix)]
905 PoolSlot::Sync(_) => Err(DriverError::Pool(
906 "streaming queries not supported on sync UDS connections".into(),
907 )),
908 }
909 }
910
911 pub fn is_sync(&self) -> bool {
916 #[cfg(unix)]
917 if let Some(slot) = &self.conn {
918 return slot.is_sync();
919 }
920 false
921 }
922
923 pub(crate) async fn ensure_stmt_prepared(
929 &mut self,
930 sql: &str,
931 sql_hash: u64,
932 params: &[&(dyn Encode + Sync)],
933 ) -> Result<Box<str>, DriverError> {
934 let slot = self
935 .conn
936 .as_mut()
937 .ok_or_else(|| DriverError::Pool("connection already taken".into()))?;
938 match slot {
939 PoolSlot::Async(conn) => conn.ensure_stmt_prepared(sql, sql_hash, params).await,
940 #[cfg(unix)]
941 PoolSlot::Sync(conn) => {
942 tokio::task::block_in_place(|| conn.ensure_stmt_prepared(sql, sql_hash, params))
943 }
944 }
945 }
946
947 pub(crate) fn write_deferred_bind_execute(
951 &self,
952 sql_hash: u64,
953 params: &[&(dyn Encode + Sync)],
954 buf: &mut Vec<u8>,
955 ) {
956 let slot = self.conn.as_ref().expect("connection taken");
957 match slot {
958 PoolSlot::Async(conn) => conn.write_deferred_bind_execute(sql_hash, params, buf),
959 #[cfg(unix)]
960 PoolSlot::Sync(conn) => conn.write_deferred_bind_execute(sql_hash, params, buf),
961 }
962 }
963
964 pub(crate) async fn flush_deferred_pipeline(
966 &mut self,
967 buf: &mut Vec<u8>,
968 count: usize,
969 ) -> Result<Vec<u64>, DriverError> {
970 let slot = self
971 .conn
972 .as_mut()
973 .ok_or_else(|| DriverError::Pool("connection already taken".into()))?;
974 match slot {
975 PoolSlot::Async(conn) => conn.flush_deferred_pipeline(buf, count).await,
976 #[cfg(unix)]
977 PoolSlot::Sync(conn) => {
978 tokio::task::block_in_place(|| conn.flush_deferred_pipeline(buf, count))
979 }
980 }
981 }
982}
983
984impl Drop for PoolGuard {
985 fn drop(&mut self) {
986 if let Some(mut slot) = self.conn.take() {
987 if self.discard
994 || slot.is_in_failed_transaction()
995 || slot.is_in_transaction()
996 || slot.is_streaming()
997 || self.pool.closed.load(Ordering::Acquire)
998 {
999 self.pool.open_count.fetch_sub(1, Ordering::AcqRel);
1000 return;
1001 }
1002
1003 slot.touch();
1007
1008 {
1012 let mut stack = self.pool.stack.lock().unwrap_or_else(|e| e.into_inner());
1013 stack.push(slot);
1014 }
1015
1016 self.pool.release_notify.notify_one();
1017 }
1018 }
1019}
1020
1021pub struct Transaction {
1037 guard: PoolGuard,
1038 committed: bool,
1039 deferred_buf: Vec<u8>,
1041 deferred_count: usize,
1043}
1044
1045impl Transaction {
1046 pub async fn commit(mut self) -> Result<(), DriverError> {
1050 if self.deferred_count > 0 {
1051 self.flush_deferred().await?;
1052 }
1053 self.guard.simple_query("COMMIT").await?;
1054 self.committed = true;
1055 Ok(())
1056 }
1057
1058 pub async fn rollback(mut self) -> Result<(), DriverError> {
1062 self.deferred_buf.clear();
1063 self.deferred_count = 0;
1064 self.guard.simple_query("ROLLBACK").await?;
1065 self.committed = true; Ok(())
1067 }
1068
1069 pub async fn query(
1074 &mut self,
1075 sql: &str,
1076 sql_hash: u64,
1077 params: &[&(dyn Encode + Sync)],
1078 arena: &mut Arena,
1079 ) -> Result<QueryResult, DriverError> {
1080 if self.deferred_count > 0 {
1081 self.flush_deferred().await?;
1082 }
1083 self.guard.query(sql, sql_hash, params, arena).await
1084 }
1085
1086 pub async fn execute(
1088 &mut self,
1089 sql: &str,
1090 sql_hash: u64,
1091 params: &[&(dyn Encode + Sync)],
1092 ) -> Result<u64, DriverError> {
1093 self.guard.execute(sql, sql_hash, params).await
1094 }
1095
1096 pub async fn execute_pipeline(
1101 &mut self,
1102 sql: &str,
1103 sql_hash: u64,
1104 param_sets: &[&[&(dyn Encode + Sync)]],
1105 ) -> Result<Vec<u64>, DriverError> {
1106 self.guard.execute_pipeline(sql, sql_hash, param_sets).await
1107 }
1108
1109 pub async fn for_each<F>(
1113 &mut self,
1114 sql: &str,
1115 sql_hash: u64,
1116 params: &[&(dyn Encode + Sync)],
1117 f: F,
1118 ) -> Result<(), DriverError>
1119 where
1120 F: FnMut(crate::conn::PgDataRow<'_>) -> Result<(), DriverError>,
1121 {
1122 if self.deferred_count > 0 {
1123 self.flush_deferred().await?;
1124 }
1125 self.guard.for_each(sql, sql_hash, params, f).await
1126 }
1127
1128 pub async fn for_each_raw<F>(
1135 &mut self,
1136 sql: &str,
1137 sql_hash: u64,
1138 params: &[&(dyn Encode + Sync)],
1139 f: F,
1140 ) -> Result<(), DriverError>
1141 where
1142 F: FnMut(&[u8]) -> Result<(), DriverError>,
1143 {
1144 if self.deferred_count > 0 {
1145 self.flush_deferred().await?;
1146 }
1147 self.guard.for_each_raw(sql, sql_hash, params, f).await
1148 }
1149
1150 pub async fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
1154 if self.deferred_count > 0 {
1155 self.flush_deferred().await?;
1156 }
1157 self.guard.simple_query(sql).await
1158 }
1159
1160 pub async fn defer_execute(
1201 &mut self,
1202 sql: &str,
1203 sql_hash: u64,
1204 params: &[&(dyn Encode + Sync)],
1205 ) -> Result<(), DriverError> {
1206 if params.len() > i16::MAX as usize {
1207 return Err(DriverError::Protocol(format!(
1208 "parameter count {} exceeds maximum {}",
1209 params.len(),
1210 i16::MAX
1211 )));
1212 }
1213
1214 self.guard
1216 .ensure_stmt_prepared(sql, sql_hash, params)
1217 .await?;
1218
1219 self.guard
1221 .write_deferred_bind_execute(sql_hash, params, &mut self.deferred_buf);
1222 self.deferred_count += 1;
1223 Ok(())
1224 }
1225
1226 pub async fn flush_deferred(&mut self) -> Result<Vec<u64>, DriverError> {
1234 let count = self.deferred_count;
1235 self.deferred_count = 0;
1236 self.guard
1237 .flush_deferred_pipeline(&mut self.deferred_buf, count)
1238 .await
1239 }
1240
1241 pub fn deferred_count(&self) -> usize {
1243 self.deferred_count
1244 }
1245}
1246
1247impl Drop for Transaction {
1248 fn drop(&mut self) {
1249 if !self.committed {
1250 if let Some(_conn) = self.guard.conn.take() {
1257 self.guard.pool.open_count.fetch_sub(1, Ordering::AcqRel);
1258 }
1260 }
1261 }
1262}
1263
1264#[cfg(test)]
1265mod tests {
1266 use super::*;
1267
1268 #[tokio::test]
1269 async fn pool_builder_requires_url() {
1270 let result = PoolBuilder::new().build().await;
1271 assert!(result.is_err());
1272 }
1273
1274 #[tokio::test]
1275 async fn pool_builder_validates_url() {
1276 let result = PoolBuilder::new().url("not_a_url").build().await;
1277 assert!(result.is_err());
1278 }
1279
1280 #[tokio::test]
1281 async fn pool_builder_accepts_valid_url() {
1282 let pool = PoolBuilder::new()
1283 .url("postgres://user:pass@localhost/db")
1284 .max_size(5)
1285 .build()
1286 .await
1287 .unwrap();
1288 assert_eq!(pool.max_size(), 5);
1289 assert_eq!(pool.open_count(), 0);
1290 }
1291
1292 #[tokio::test]
1293 async fn pool_connect_validates_url() {
1294 let result = Pool::connect("not_a_url").await;
1295 assert!(result.is_err());
1296 }
1297
1298 #[tokio::test]
1299 async fn pool_max_size_zero() {
1300 let pool = PoolBuilder::new()
1301 .url("postgres://user:pass@localhost/db")
1302 .max_size(0)
1303 .build()
1304 .await
1305 .unwrap();
1306
1307 let result = pool.acquire().await;
1308 assert!(result.is_err());
1309 match result {
1310 Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
1311 Err(e) => panic!("expected Pool error, got: {e:?}"),
1312 Ok(_) => panic!("expected error, got Ok"),
1313 }
1314 }
1315
1316 #[tokio::test]
1317 async fn pool_clone_shares_state() {
1318 let pool = PoolBuilder::new()
1319 .url("postgres://user:pass@localhost/db")
1320 .max_size(5)
1321 .build()
1322 .await
1323 .unwrap();
1324
1325 let pool2 = pool.clone();
1326 assert_eq!(pool.max_size(), pool2.max_size());
1327 }
1328
1329 #[tokio::test]
1333 async fn pool_builder_max_lifetime() {
1334 let pool = PoolBuilder::new()
1335 .url("postgres://user:pass@localhost/db")
1336 .max_lifetime(Some(Duration::from_secs(60)))
1337 .build()
1338 .await
1339 .unwrap();
1340 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(60)));
1341 }
1342
1343 #[tokio::test]
1345 async fn pool_builder_max_lifetime_none() {
1346 let pool = PoolBuilder::new()
1347 .url("postgres://user:pass@localhost/db")
1348 .max_lifetime(None)
1349 .build()
1350 .await
1351 .unwrap();
1352 assert_eq!(pool.inner.max_lifetime, None);
1353 }
1354
1355 #[tokio::test]
1357 async fn pool_builder_acquire_timeout_none() {
1358 let pool = PoolBuilder::new()
1359 .url("postgres://user:pass@localhost/db")
1360 .acquire_timeout(None)
1361 .build()
1362 .await
1363 .unwrap();
1364 assert_eq!(pool.inner.acquire_timeout, None);
1365 }
1366
1367 #[tokio::test]
1369 async fn pool_builder_acquire_timeout_custom() {
1370 let pool = PoolBuilder::new()
1371 .url("postgres://user:pass@localhost/db")
1372 .acquire_timeout(Some(Duration::from_secs(10)))
1373 .build()
1374 .await
1375 .unwrap();
1376 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(10)));
1377 }
1378
1379 #[tokio::test]
1381 async fn pool_builder_min_idle() {
1382 let pool = PoolBuilder::new()
1383 .url("postgres://user:pass@localhost/db")
1384 .min_idle(2)
1385 .build()
1386 .await
1387 .unwrap();
1388 assert_eq!(pool.inner.min_idle, 2);
1389 }
1390
1391 #[tokio::test]
1393 async fn pool_close_marks_closed() {
1394 let pool = PoolBuilder::new()
1395 .url("postgres://user:pass@localhost/db")
1396 .max_size(5)
1397 .build()
1398 .await
1399 .unwrap();
1400
1401 assert!(!pool.is_closed());
1402 pool.close().await;
1403 assert!(pool.is_closed());
1404
1405 let result = pool.acquire().await;
1407 assert!(result.is_err());
1408 match result {
1409 Err(DriverError::Pool(msg)) => assert!(msg.contains("closed")),
1410 Err(e) => panic!("expected Pool(closed) error, got: {e:?}"),
1411 Ok(_) => panic!("expected error, got Ok"),
1412 }
1413 }
1414
1415 #[tokio::test]
1417 async fn pool_status_initial() {
1418 let pool = PoolBuilder::new()
1419 .url("postgres://user:pass@localhost/db")
1420 .max_size(10)
1421 .build()
1422 .await
1423 .unwrap();
1424
1425 let status = pool.status();
1426 assert_eq!(status.idle, 0);
1427 assert_eq!(status.active, 0);
1428 assert_eq!(status.open, 0);
1429 assert_eq!(status.max_size, 10);
1430 }
1431
1432 #[tokio::test]
1434 async fn pool_builder_defaults() {
1435 let pool = PoolBuilder::new()
1436 .url("postgres://user:pass@localhost/db")
1437 .build()
1438 .await
1439 .unwrap();
1440
1441 assert_eq!(pool.max_size(), 10);
1442 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(30 * 60)));
1443 assert_eq!(pool.inner.acquire_timeout, None); assert_eq!(pool.inner.min_idle, 0);
1445 }
1446
1447 #[tokio::test]
1449 async fn pool_open_count_initial() {
1450 let pool = Pool::connect("postgres://user:pass@localhost/db")
1451 .await
1452 .unwrap();
1453 assert_eq!(pool.open_count(), 0);
1454 }
1455
1456 #[tokio::test]
1459 async fn pool_builder_max_stmt_cache_size_default() {
1460 let pool = PoolBuilder::new()
1461 .url("postgres://user:pass@localhost/db")
1462 .build()
1463 .await
1464 .unwrap();
1465 assert_eq!(pool.inner.max_stmt_cache_size, 256);
1466 }
1467
1468 #[tokio::test]
1469 async fn pool_builder_max_stmt_cache_size_custom() {
1470 let pool = PoolBuilder::new()
1471 .url("postgres://user:pass@localhost/db")
1472 .max_stmt_cache_size(512)
1473 .build()
1474 .await
1475 .unwrap();
1476 assert_eq!(pool.inner.max_stmt_cache_size, 512);
1477 }
1478
1479 #[tokio::test]
1482 async fn pool_is_uds_false_for_tcp() {
1483 let pool = Pool::connect("postgres://user:pass@localhost/db")
1484 .await
1485 .unwrap();
1486 assert!(!pool.is_uds());
1487 }
1488
1489 #[cfg(unix)]
1490 #[tokio::test]
1491 async fn pool_is_uds_true_for_unix_socket() {
1492 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp")
1493 .await
1494 .unwrap();
1495 assert!(pool.is_uds());
1496 }
1497
1498 #[cfg(unix)]
1499 #[tokio::test]
1500 async fn pool_is_uds_true_for_var_run_socket() {
1501 let pool = Pool::connect("postgres://user@localhost/db?host=/var/run/postgresql")
1502 .await
1503 .unwrap();
1504 assert!(pool.is_uds());
1505 }
1506
1507 #[tokio::test]
1508 async fn pool_is_uds_false_for_ip_address() {
1509 let pool = Pool::connect("postgres://user:pass@127.0.0.1/db")
1510 .await
1511 .unwrap();
1512 assert!(!pool.is_uds());
1513 }
1514
1515 #[cfg(unix)]
1516 #[tokio::test]
1517 async fn pool_slot_sync_created_for_uds_config() {
1518 let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
1522 assert!(config.host_is_uds());
1523 }
1524
1525 #[test]
1526 fn pool_slot_async_created_for_tcp_config() {
1527 let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
1528 assert!(!config.host_is_uds());
1529 }
1530
1531 #[tokio::test]
1536 async fn pool_is_uds_false_for_hostname() {
1537 let pool = Pool::connect("postgres://user:pass@db.example.com/db")
1538 .await
1539 .unwrap();
1540 assert!(!pool.is_uds());
1541 }
1542
1543 #[cfg(unix)]
1544 #[tokio::test]
1545 async fn pool_is_uds_true_for_tmp() {
1546 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp")
1547 .await
1548 .unwrap();
1549 assert!(pool.is_uds());
1550 }
1551
1552 #[tokio::test]
1557 async fn pool_close_then_acquire_fails() {
1558 let pool = PoolBuilder::new()
1559 .url("postgres://user:pass@localhost/db")
1560 .max_size(5)
1561 .build()
1562 .await
1563 .unwrap();
1564 pool.close().await;
1565 let result = pool.acquire().await;
1566 assert!(result.is_err());
1567 match result {
1568 Err(DriverError::Pool(msg)) => {
1569 assert!(msg.contains("closed"), "should say closed: {msg}")
1570 }
1571 Err(e) => panic!("expected Pool error, got: {e:?}"),
1572 Ok(_) => panic!("expected error"),
1573 }
1574 }
1575
1576 #[tokio::test]
1577 async fn pool_is_closed_before_and_after() {
1578 let pool = Pool::connect("postgres://user:pass@localhost/db")
1579 .await
1580 .unwrap();
1581 assert!(!pool.is_closed());
1582 pool.close().await;
1583 assert!(pool.is_closed());
1584 }
1585
1586 #[tokio::test]
1591 async fn pool_exhausted_no_timeout() {
1592 let pool = PoolBuilder::new()
1593 .url("postgres://user:pass@localhost/db")
1594 .max_size(0)
1595 .acquire_timeout(None) .build()
1597 .await
1598 .unwrap();
1599 let result = pool.acquire().await;
1600 assert!(result.is_err());
1601 match result {
1602 Err(DriverError::Pool(msg)) => {
1603 assert!(msg.contains("exhausted"), "should say exhausted: {msg}")
1604 }
1605 Err(e) => panic!("expected Pool error, got: {e:?}"),
1606 Ok(_) => panic!("expected error"),
1607 }
1608 }
1609
1610 #[tokio::test]
1615 async fn pool_builder_no_url_error() {
1616 let result = PoolBuilder::new().max_size(5).build().await;
1617 assert!(result.is_err());
1618 match result {
1619 Err(DriverError::Pool(msg)) => {
1620 assert!(msg.contains("URL"), "should mention URL: {msg}")
1621 }
1622 Err(e) => panic!("expected Pool error, got: {e:?}"),
1623 Ok(_) => panic!("expected error"),
1624 }
1625 }
1626
1627 #[tokio::test]
1628 async fn pool_builder_invalid_url_error() {
1629 let result = PoolBuilder::new().url("ftp://something").build().await;
1630 assert!(result.is_err());
1631 }
1632
1633 #[tokio::test]
1634 async fn pool_builder_stmt_cache_size_zero() {
1635 let pool = PoolBuilder::new()
1636 .url("postgres://user:pass@localhost/db")
1637 .max_stmt_cache_size(0)
1638 .build()
1639 .await
1640 .unwrap();
1641 assert_eq!(pool.inner.max_stmt_cache_size, 0);
1642 }
1643
1644 #[tokio::test]
1649 async fn pool_status_reflects_max_size() {
1650 let pool = PoolBuilder::new()
1651 .url("postgres://user:pass@localhost/db")
1652 .max_size(20)
1653 .build()
1654 .await
1655 .unwrap();
1656 let status = pool.status();
1657 assert_eq!(status.max_size, 20);
1658 assert_eq!(status.idle, 0);
1659 assert_eq!(status.active, 0);
1660 assert_eq!(status.open, 0);
1661 }
1662
1663 #[tokio::test]
1668 async fn pool_clone_shares_config() {
1669 let pool = PoolBuilder::new()
1670 .url("postgres://user:pass@localhost/db")
1671 .max_size(7)
1672 .build()
1673 .await
1674 .unwrap();
1675 let p2 = pool.clone();
1676 assert_eq!(pool.max_size(), 7);
1677 assert_eq!(p2.max_size(), 7);
1678 assert_eq!(pool.open_count(), p2.open_count());
1679 }
1680
1681 #[tokio::test]
1686 async fn pool_set_warmup_sqls_empty() {
1687 let pool = Pool::connect("postgres://user:pass@localhost/db")
1688 .await
1689 .unwrap();
1690 pool.set_warmup_sqls(&[]);
1691 let sqls = pool
1692 .inner
1693 .warmup_sqls
1694 .lock()
1695 .unwrap_or_else(|e| e.into_inner())
1696 .clone();
1697 assert!(sqls.is_empty());
1698 }
1699
1700 #[tokio::test]
1701 async fn pool_set_warmup_sqls_multiple() {
1702 let pool = Pool::connect("postgres://user:pass@localhost/db")
1703 .await
1704 .unwrap();
1705 pool.set_warmup_sqls(&["SELECT 1", "SELECT 2", "SELECT 3"]);
1706 let sqls = pool
1707 .inner
1708 .warmup_sqls
1709 .lock()
1710 .unwrap_or_else(|e| e.into_inner())
1711 .clone();
1712 assert_eq!(sqls.len(), 3);
1713 assert_eq!(&*sqls[0], "SELECT 1");
1714 assert_eq!(&*sqls[1], "SELECT 2");
1715 assert_eq!(&*sqls[2], "SELECT 3");
1716 }
1717
1718 #[tokio::test]
1719 async fn pool_set_warmup_sqls_overwrite() {
1720 let pool = Pool::connect("postgres://user:pass@localhost/db")
1721 .await
1722 .unwrap();
1723 pool.set_warmup_sqls(&["SELECT 1"]);
1724 pool.set_warmup_sqls(&["SELECT 99"]);
1725 let sqls = pool
1726 .inner
1727 .warmup_sqls
1728 .lock()
1729 .unwrap_or_else(|e| e.into_inner())
1730 .clone();
1731 assert_eq!(sqls.len(), 1);
1732 assert_eq!(&*sqls[0], "SELECT 99");
1733 }
1734
1735 #[tokio::test]
1740 async fn pool_status_debug() {
1741 let pool = Pool::connect("postgres://user:pass@localhost/db")
1742 .await
1743 .unwrap();
1744 let status = pool.status();
1745 let dbg = format!("{status:?}");
1746 assert!(dbg.contains("PoolStatus"));
1747 assert!(dbg.contains("idle"));
1748 assert!(dbg.contains("active"));
1749 assert!(dbg.contains("open"));
1750 assert!(dbg.contains("max_size"));
1751 }
1752
1753 #[test]
1758 fn config_host_is_uds_returns_true_for_slash() {
1759 let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
1760 assert!(config.host_is_uds());
1761 }
1762
1763 #[test]
1764 fn config_host_is_uds_returns_false_for_tcp() {
1765 let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
1766 assert!(!config.host_is_uds());
1767 }
1768
1769 #[test]
1770 fn config_host_is_uds_returns_false_for_ip() {
1771 let config = Config::from_url("postgres://user:pass@192.168.1.1/db").unwrap();
1772 assert!(!config.host_is_uds());
1773 }
1774
1775 #[tokio::test]
1780 async fn pool_builder_full_chain() {
1781 let pool = PoolBuilder::new()
1782 .url("postgres://user:pass@localhost/db")
1783 .max_size(3)
1784 .max_lifetime(Some(Duration::from_secs(600)))
1785 .acquire_timeout(Some(Duration::from_secs(5)))
1786 .min_idle(1)
1787 .max_stmt_cache_size(128)
1788 .build()
1789 .await
1790 .unwrap();
1791 assert_eq!(pool.max_size(), 3);
1792 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(600)));
1793 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
1794 assert_eq!(pool.inner.min_idle, 1);
1795 assert_eq!(pool.inner.max_stmt_cache_size, 128);
1796 }
1797
1798 #[tokio::test]
1801 async fn pool_max_size_zero_rejects_all_acquires() {
1802 let pool = PoolBuilder::new()
1803 .url("postgres://user:pass@localhost/db")
1804 .max_size(0)
1805 .build()
1806 .await
1807 .unwrap();
1808 let result = pool.acquire().await;
1809 assert!(result.is_err());
1810 match &result {
1811 Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
1812 _ => panic!("expected pool exhausted error"),
1813 }
1814 }
1815
1816 #[test]
1819 fn url_parse_unknown_sslmode_returns_error() {
1820 let result = Config::from_url("postgres://u:p@h/d?sslmode=bogus");
1821 assert!(result.is_err());
1822 let msg = format!("{}", result.unwrap_err());
1823 assert!(msg.contains("unknown sslmode"));
1824 }
1825
1826 #[test]
1827 fn url_parse_invalid_port_returns_error() {
1828 let result = Config::from_url("postgres://u:p@h:abc/d");
1829 assert!(result.is_err());
1830 let msg = format!("{}", result.unwrap_err());
1831 assert!(msg.contains("invalid port"));
1832 }
1833
1834 #[test]
1835 fn url_parse_missing_at_sign_returns_error() {
1836 let result = Config::from_url("postgres://u:plocalhost/d");
1837 assert!(result.is_err());
1838 let msg = format!("{}", result.unwrap_err());
1839 assert!(msg.contains("missing @"));
1840 }
1841
1842 #[test]
1843 fn url_parse_empty_host_returns_error() {
1844 let result = Config::from_url("postgres://u:p@/d");
1845 assert!(result.is_err());
1846 }
1847
1848 #[test]
1849 fn url_parse_empty_user_returns_error() {
1850 let result = Config::from_url("postgres://:p@h/d");
1851 assert!(result.is_err());
1852 }
1853
1854 #[test]
1855 fn url_parse_statement_timeout_invalid_uses_default() {
1856 let config = Config::from_url("postgres://u:p@h/d?statement_timeout=notnum").unwrap();
1857 assert_eq!(config.statement_timeout_secs, 30);
1858 }
1859
1860 #[test]
1861 fn url_parse_malformed_percent_encoding() {
1862 let result = Config::from_url("postgres://u%:p@h/d");
1863 assert!(result.is_err());
1864 }
1865
1866 #[test]
1867 fn url_parse_invalid_hex_in_percent_encoding() {
1868 let result = Config::from_url("postgres://u%ZZ:p@h/d");
1869 assert!(result.is_err());
1870 }
1871}