1use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use std::time::Duration;
12
13use crate::DriverError;
14use crate::arena::Arena;
15use crate::codec::Encode;
16use crate::conn::Connection;
17use crate::types::{Config, PgDataRow, QueryResult, SimpleRow};
18
19#[cfg(feature = "async")]
20use crate::async_conn::AsyncConnection;
21
22pub(crate) enum PoolSlot {
30 Sync(Connection),
32 #[cfg(feature = "async")]
34 Async(AsyncConnection),
35}
36
37pub struct Pool {
53 inner: Arc<PoolInner>,
54}
55
56struct PoolInner {
57 stack: std::sync::Mutex<Vec<PoolSlot>>,
61 max_size: usize,
62 open_count: AtomicUsize,
63 config: Arc<Config>,
64 closed: AtomicBool,
66 release_pair: (std::sync::Mutex<()>, std::sync::Condvar),
69 max_lifetime: Option<Duration>,
72 acquire_timeout: Option<Duration>,
74 min_idle: usize,
76 warmup_sqls: std::sync::Mutex<Arc<Vec<Box<str>>>>,
78 max_stmt_cache_size: usize,
80}
81
82impl Pool {
83 pub fn connect(url: &str) -> Result<Self, DriverError> {
87 PoolBuilder::new().url(url).build()
88 }
89
90 pub fn builder() -> PoolBuilder {
92 PoolBuilder::new()
93 }
94
95 #[inline]
103 pub fn acquire(&self) -> Result<PoolGuard, DriverError> {
104 if self.inner.closed.load(Ordering::Acquire) {
105 return Err(DriverError::Pool("pool is closed".into()));
106 }
107
108 if let Some(guard) = self.try_pop_idle()? {
110 return Ok(guard);
111 }
112
113 loop {
115 let current = self.inner.open_count.load(Ordering::Acquire);
116 if current >= self.inner.max_size {
117 if let Some(timeout) = self.inner.acquire_timeout {
118 let (lock, cvar) = &self.inner.release_pair;
119 let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
120 let (_guard, result) = cvar
121 .wait_timeout(guard, timeout)
122 .unwrap_or_else(|e| e.into_inner());
123 if result.timed_out() {
124 return Err(DriverError::Pool(
125 "pool exhausted: acquire timeout expired".into(),
126 ));
127 }
128 if let Some(guard) = self.try_pop_idle()? {
130 return Ok(guard);
131 }
132 continue;
134 }
135 return Err(DriverError::Pool(
136 "pool exhausted: all connections in use".into(),
137 ));
138 }
139 if self
140 .inner
141 .open_count
142 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
143 .is_ok()
144 {
145 break;
146 }
147 }
149
150 let conn_result = Connection::connect_arc(self.inner.config.clone());
152 match conn_result {
153 Ok(mut conn) => {
154 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
156 self.warmup_conn(&mut conn);
158
159 Ok(PoolGuard {
160 conn: Some(PoolSlot::Sync(conn)),
161 pool: self.inner.clone(),
162 discard: false,
163 })
164 }
165 Err(e) => {
166 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
168 Err(e)
169 }
170 }
171 }
172
173 #[inline]
175 fn try_pop_idle(&self) -> Result<Option<PoolGuard>, DriverError> {
176 let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
177 while let Some(slot) = stack.pop() {
178 let (created_at, idle_dur) = match &slot {
179 PoolSlot::Sync(conn) => (conn.created_at(), conn.idle_duration()),
180 #[cfg(feature = "async")]
181 PoolSlot::Async(conn) => (conn.created_at(), conn.idle_duration()),
182 };
183 if let Some(max_lifetime) = self.inner.max_lifetime {
184 if created_at.elapsed() >= max_lifetime {
185 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
186 continue;
187 }
188 }
189 if idle_dur < Duration::from_secs(30) {
190 return Ok(Some(PoolGuard {
191 conn: Some(slot),
192 pool: self.inner.clone(),
193 discard: false,
194 }));
195 }
196 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
198 }
199 Ok(None)
200 }
201
202 pub fn is_uds(&self) -> bool {
207 #[cfg(unix)]
208 {
209 self.inner.config.host_is_uds()
210 }
211 #[cfg(not(unix))]
212 {
213 false
214 }
215 }
216
217 pub fn begin(&self) -> Result<Transaction, DriverError> {
219 let mut guard = self.acquire()?;
220 guard.simple_query("BEGIN")?;
221 Ok(Transaction {
222 guard,
223 committed: false,
224 deferred_buf: Vec::new(),
225 deferred_count: 0,
226 })
227 }
228
229 pub fn open_count(&self) -> usize {
231 self.inner.open_count.load(Ordering::Relaxed)
232 }
233
234 pub fn max_size(&self) -> usize {
236 self.inner.max_size
237 }
238
239 pub fn status(&self) -> PoolStatus {
241 let idle = self
242 .inner
243 .stack
244 .lock()
245 .unwrap_or_else(|e| e.into_inner())
246 .len();
247 let open = self.inner.open_count.load(Ordering::Relaxed);
248 let active = open.saturating_sub(idle);
249 PoolStatus {
250 idle,
251 active,
252 open,
253 max_size: self.inner.max_size,
254 }
255 }
256
257 fn warmup_conn(&self, conn: &mut Connection) {
265 let sqls = self
266 .inner
267 .warmup_sqls
268 .lock()
269 .unwrap_or_else(|e| e.into_inner())
270 .clone();
271
272 if sqls.is_empty() {
273 return;
274 }
275
276 for sql in sqls.iter() {
277 let sql_hash = crate::types::hash_sql(sql);
278 let _ = conn.prepare_only(sql, sql_hash);
279 }
280 }
281
282 pub fn set_warmup_sqls(&self, sqls: &[&str]) {
304 let boxed: Arc<Vec<Box<str>>> =
305 Arc::new(sqls.iter().map(|s| (*s).into()).collect::<Vec<_>>());
306 *self
307 .inner
308 .warmup_sqls
309 .lock()
310 .unwrap_or_else(|e| e.into_inner()) = boxed;
311 }
312
313 pub fn close(&self) {
316 self.inner.closed.store(true, Ordering::Release);
317 let slots: Vec<PoolSlot> = {
319 let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
320 std::mem::take(&mut *stack)
321 };
322 for slot in slots {
323 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
324 match slot {
325 PoolSlot::Sync(conn) => {
326 let _ = conn.close();
327 }
328 #[cfg(feature = "async")]
329 PoolSlot::Async(_conn) => {
330 }
333 }
334 }
335 let (_, cvar) = &self.inner.release_pair;
337 cvar.notify_all();
338 }
339
340 pub fn is_closed(&self) -> bool {
342 self.inner.closed.load(Ordering::Acquire)
343 }
344
345 #[cfg(feature = "async")]
355 pub async fn acquire_async(&self) -> Result<PoolGuard, DriverError> {
356 if self.inner.closed.load(Ordering::Acquire) {
357 return Err(DriverError::Pool("pool is closed".into()));
358 }
359
360 if let Some(guard) = self.try_pop_idle()? {
362 return Ok(guard);
363 }
364
365 loop {
367 let current = self.inner.open_count.load(Ordering::Acquire);
368 if current >= self.inner.max_size {
369 if let Some(timeout) = self.inner.acquire_timeout {
370 let (lock, cvar) = &self.inner.release_pair;
371 let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
372 let (_guard, result) = cvar
373 .wait_timeout(guard, timeout)
374 .unwrap_or_else(|e| e.into_inner());
375 if result.timed_out() {
376 return Err(DriverError::Pool(
377 "pool exhausted: acquire timeout expired".into(),
378 ));
379 }
380 if let Some(guard) = self.try_pop_idle()? {
381 return Ok(guard);
382 }
383 continue;
384 }
385 return Err(DriverError::Pool(
386 "pool exhausted: all connections in use".into(),
387 ));
388 }
389 if self
390 .inner
391 .open_count
392 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
393 .is_ok()
394 {
395 break;
396 }
397 }
398
399 if self.inner.config.host_is_uds() {
401 let conn_result = Connection::connect_arc(self.inner.config.clone());
403 match conn_result {
404 Ok(mut conn) => {
405 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
406 self.warmup_conn(&mut conn);
407 Ok(PoolGuard {
408 conn: Some(PoolSlot::Sync(conn)),
409 pool: self.inner.clone(),
410 discard: false,
411 })
412 }
413 Err(e) => {
414 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
415 Err(e)
416 }
417 }
418 } else {
419 let conn_result = AsyncConnection::connect_arc(self.inner.config.clone()).await;
421 match conn_result {
422 Ok(mut conn) => {
423 conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
424 Ok(PoolGuard {
425 conn: Some(PoolSlot::Async(conn)),
426 pool: self.inner.clone(),
427 discard: false,
428 })
429 }
430 Err(e) => {
431 self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
432 Err(e)
433 }
434 }
435 }
436 }
437}
438
439impl Clone for Pool {
440 fn clone(&self) -> Self {
441 Pool {
442 inner: self.inner.clone(),
443 }
444 }
445}
446
447#[derive(Debug, Clone, Copy)]
451pub struct PoolStatus {
452 pub idle: usize,
454 pub active: usize,
456 pub open: usize,
458 pub max_size: usize,
460}
461
462pub struct PoolBuilder {
466 url: Option<String>,
467 max_size: usize,
468 max_lifetime: Option<Duration>,
470 acquire_timeout: Option<Duration>,
472 min_idle: usize,
474 max_stmt_cache_size: usize,
476}
477
478impl PoolBuilder {
479 fn new() -> Self {
480 Self {
481 url: None,
482 max_size: 10,
483 max_lifetime: Some(Duration::from_secs(30 * 60)), acquire_timeout: Some(Duration::from_secs(5)), min_idle: 0, max_stmt_cache_size: 256, }
488 }
489
490 pub fn url(mut self, url: &str) -> Self {
492 self.url = Some(url.to_owned());
493 self
494 }
495
496 pub fn max_size(mut self, size: usize) -> Self {
500 self.max_size = size;
501 self
502 }
503
504 pub fn max_lifetime(mut self, lifetime: Option<Duration>) -> Self {
507 self.max_lifetime = lifetime;
508 self
509 }
510
511 pub fn acquire_timeout(mut self, timeout: Option<Duration>) -> Self {
514 self.acquire_timeout = timeout;
515 self
516 }
517
518 pub fn min_idle(mut self, count: usize) -> Self {
521 self.min_idle = count;
522 self
523 }
524
525 pub fn max_stmt_cache_size(mut self, size: usize) -> Self {
529 self.max_stmt_cache_size = size;
530 self
531 }
532
533 pub fn build(self) -> Result<Pool, DriverError> {
535 let url = self
536 .url
537 .ok_or_else(|| DriverError::Pool("pool builder requires a URL".into()))?;
538
539 let config = Arc::new(Config::from_url(&url)?);
540
541 let pool = Pool {
542 inner: Arc::new(PoolInner {
543 stack: std::sync::Mutex::new(Vec::with_capacity(self.max_size)),
544 max_size: self.max_size,
545 open_count: AtomicUsize::new(0),
546 config,
547 closed: AtomicBool::new(false),
548 release_pair: (std::sync::Mutex::new(()), std::sync::Condvar::new()),
549 max_lifetime: self.max_lifetime,
550 acquire_timeout: self.acquire_timeout,
551 min_idle: self.min_idle,
552 warmup_sqls: std::sync::Mutex::new(Arc::new(Vec::new())),
553 max_stmt_cache_size: self.max_stmt_cache_size,
554 }),
555 };
556
557 if self.min_idle > 0 {
558 let inner = pool.inner.clone();
559 std::thread::spawn(move || {
560 maintain_min_idle(inner);
561 });
562 }
563
564 Ok(pool)
565 }
566}
567
568fn maintain_min_idle(inner: Arc<PoolInner>) {
570 loop {
571 if inner.closed.load(Ordering::Acquire) {
572 return;
573 }
574
575 let idle_count = inner.stack.lock().unwrap_or_else(|e| e.into_inner()).len();
576 let needed = inner.min_idle.saturating_sub(idle_count);
577
578 for _ in 0..needed {
579 if inner.closed.load(Ordering::Acquire) {
580 return;
581 }
582 let current = inner.open_count.load(Ordering::Acquire);
583 if current >= inner.max_size {
584 break;
585 }
586 if inner
587 .open_count
588 .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
589 .is_err()
590 {
591 continue;
592 }
593
594 match Connection::connect_arc(inner.config.clone()) {
595 Ok(conn) => {
596 let mut stack = inner.stack.lock().unwrap_or_else(|e| e.into_inner());
597 stack.push(PoolSlot::Sync(conn));
598 let (_, cvar) = &inner.release_pair;
599 cvar.notify_one();
600 }
601 Err(_) => {
602 inner.open_count.fetch_sub(1, Ordering::AcqRel);
603 }
604 }
605 }
606
607 std::thread::sleep(Duration::from_secs(1));
610 }
611}
612
613pub struct PoolGuard {
620 conn: Option<PoolSlot>,
621 pool: Arc<PoolInner>,
622 discard: bool,
624}
625
626impl PoolGuard {
627 #[inline]
630 fn sync_conn(&self) -> Result<&Connection, DriverError> {
631 match self.conn.as_ref() {
632 Some(PoolSlot::Sync(conn)) => Ok(conn),
633 #[cfg(feature = "async")]
634 Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
635 "expected sync connection, got async; use async methods".into(),
636 )),
637 None => Err(DriverError::Pool("connection already taken".into())),
638 }
639 }
640
641 #[inline]
643 fn sync_conn_mut(&mut self) -> Result<&mut Connection, DriverError> {
644 match self.conn.as_mut() {
645 Some(PoolSlot::Sync(conn)) => Ok(conn),
646 #[cfg(feature = "async")]
647 Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
648 "expected sync connection, got async; use async methods".into(),
649 )),
650 None => Err(DriverError::Pool("connection already taken".into())),
651 }
652 }
653
654 pub fn mark_discard(&mut self) {
657 self.discard = true;
658 }
659
660 pub fn cancel(&self) -> Result<(), DriverError> {
665 self.sync_conn()?.cancel()
666 }
667
668 pub fn pid(&self) -> i32 {
672 match self.conn.as_ref().expect("connection taken") {
673 PoolSlot::Sync(conn) => conn.pid(),
674 #[cfg(feature = "async")]
675 PoolSlot::Async(conn) => conn.pid(),
676 }
677 }
678
679 pub fn is_idle(&self) -> bool {
681 match self.conn.as_ref().expect("connection taken") {
682 PoolSlot::Sync(conn) => conn.is_idle(),
683 #[cfg(feature = "async")]
684 PoolSlot::Async(conn) => conn.is_idle(),
685 }
686 }
687
688 pub fn is_in_transaction(&self) -> bool {
690 match self.conn.as_ref().expect("connection taken") {
691 PoolSlot::Sync(conn) => conn.is_in_transaction(),
692 #[cfg(feature = "async")]
693 PoolSlot::Async(conn) => conn.is_in_transaction(),
694 }
695 }
696
697 #[inline]
701 pub fn query(
702 &mut self,
703 sql: &str,
704 sql_hash: u64,
705 params: &[&(dyn Encode + Sync)],
706 ) -> Result<QueryResult, DriverError> {
707 self.sync_conn_mut()?.query(sql, sql_hash, params)
708 }
709
710 #[inline]
712 pub fn execute(
713 &mut self,
714 sql: &str,
715 sql_hash: u64,
716 params: &[&(dyn Encode + Sync)],
717 ) -> Result<u64, DriverError> {
718 self.sync_conn_mut()?.execute(sql, sql_hash, params)
719 }
720
721 pub fn execute_pipeline(
726 &mut self,
727 sql: &str,
728 sql_hash: u64,
729 param_sets: &[&[&(dyn Encode + Sync)]],
730 ) -> Result<Vec<u64>, DriverError> {
731 self.sync_conn_mut()?
732 .execute_pipeline(sql, sql_hash, param_sets)
733 }
734
735 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
737 self.sync_conn_mut()?.simple_query(sql)
738 }
739
740 pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
744 self.sync_conn_mut()?.simple_query_rows(sql)
745 }
746
747 pub fn for_each<F>(
749 &mut self,
750 sql: &str,
751 sql_hash: u64,
752 params: &[&(dyn Encode + Sync)],
753 f: F,
754 ) -> Result<(), DriverError>
755 where
756 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
757 {
758 self.sync_conn_mut()?.for_each(sql, sql_hash, params, f)
759 }
760
761 pub fn for_each_raw<F>(
763 &mut self,
764 sql: &str,
765 sql_hash: u64,
766 params: &[&(dyn Encode + Sync)],
767 f: F,
768 ) -> Result<(), DriverError>
769 where
770 F: FnMut(&[u8]) -> Result<(), DriverError>,
771 {
772 self.sync_conn_mut()?.for_each_raw(sql, sql_hash, params, f)
773 }
774
775 pub fn query_streaming_start(
779 &mut self,
780 sql: &str,
781 sql_hash: u64,
782 params: &[&(dyn Encode + Sync)],
783 chunk_size: i32,
784 ) -> Result<(std::sync::Arc<[crate::types::ColumnDesc]>, bool), DriverError> {
785 self.sync_conn_mut()?
786 .query_streaming_start(sql, sql_hash, params, chunk_size)
787 }
788
789 pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
791 self.sync_conn_mut()?.streaming_send_execute(chunk_size)
792 }
793
794 pub fn streaming_next_chunk(
796 &mut self,
797 arena: &mut Arena,
798 all_col_offsets: &mut Vec<(usize, i32)>,
799 ) -> Result<bool, DriverError> {
800 self.sync_conn_mut()?
801 .streaming_next_chunk(arena, all_col_offsets)
802 }
803
804 pub fn copy_in<'a, I>(
810 &mut self,
811 table: &str,
812 columns: &[&str],
813 rows: I,
814 ) -> Result<u64, DriverError>
815 where
816 I: IntoIterator<Item = &'a str>,
817 {
818 self.sync_conn_mut()?.copy_in(table, columns, rows)
819 }
820
821 pub fn copy_out<W: std::io::Write>(
825 &mut self,
826 query: &str,
827 writer: &mut W,
828 ) -> Result<u64, DriverError> {
829 self.sync_conn_mut()?.copy_out(query, writer)
830 }
831
832 pub fn is_sync(&self) -> bool {
834 matches!(self.conn.as_ref(), Some(PoolSlot::Sync(_)))
835 }
836
837 #[cfg(feature = "async")]
839 pub fn is_async(&self) -> bool {
840 matches!(self.conn.as_ref(), Some(PoolSlot::Async(_)))
841 }
842
843 #[cfg(feature = "async")]
851 pub async fn query_async(
852 &mut self,
853 sql: &str,
854 sql_hash: u64,
855 params: &[&(dyn Encode + Sync)],
856 ) -> Result<QueryResult, DriverError> {
857 match self.conn.as_mut() {
858 Some(PoolSlot::Sync(conn)) => conn.query(sql, sql_hash, params),
859 Some(PoolSlot::Async(conn)) => conn.query(sql, sql_hash, params).await,
860 None => Err(DriverError::Pool("connection already taken".into())),
861 }
862 }
863
864 #[cfg(feature = "async")]
866 pub async fn execute_async(
867 &mut self,
868 sql: &str,
869 sql_hash: u64,
870 params: &[&(dyn Encode + Sync)],
871 ) -> Result<u64, DriverError> {
872 match self.conn.as_mut() {
873 Some(PoolSlot::Sync(conn)) => conn.execute(sql, sql_hash, params),
874 Some(PoolSlot::Async(conn)) => conn.execute(sql, sql_hash, params).await,
875 None => Err(DriverError::Pool("connection already taken".into())),
876 }
877 }
878
879 #[cfg(feature = "async")]
881 pub async fn simple_query_async(&mut self, sql: &str) -> Result<(), DriverError> {
882 match self.conn.as_mut() {
883 Some(PoolSlot::Sync(conn)) => conn.simple_query(sql),
884 Some(PoolSlot::Async(conn)) => conn.simple_query(sql).await,
885 None => Err(DriverError::Pool("connection already taken".into())),
886 }
887 }
888
889 pub(crate) fn ensure_stmt_prepared(
893 &mut self,
894 sql: &str,
895 sql_hash: u64,
896 params: &[&(dyn Encode + Sync)],
897 ) -> Result<[u8; 18], DriverError> {
898 self.sync_conn_mut()?
899 .ensure_stmt_prepared(sql, sql_hash, params)
900 }
901
902 pub(crate) fn write_deferred_bind_execute(
904 &self,
905 sql_hash: u64,
906 params: &[&(dyn Encode + Sync)],
907 buf: &mut Vec<u8>,
908 ) {
909 let conn = self
910 .sync_conn()
911 .expect("sync_conn failed in write_deferred");
912 conn.write_deferred_bind_execute(sql_hash, params, buf);
913 }
914
915 pub(crate) fn flush_deferred_pipeline(
917 &mut self,
918 buf: &mut Vec<u8>,
919 count: usize,
920 ) -> Result<Vec<u64>, DriverError> {
921 self.sync_conn_mut()?.flush_deferred_pipeline(buf, count)
922 }
923}
924
925impl Drop for PoolGuard {
926 fn drop(&mut self) {
927 if let Some(slot) = self.conn.take() {
928 let should_discard = self.discard
930 || self.pool.closed.load(Ordering::Acquire)
931 || match &slot {
932 PoolSlot::Sync(conn) => {
933 conn.is_in_failed_transaction()
934 || conn.is_in_transaction()
935 || conn.is_streaming()
936 }
937 #[cfg(feature = "async")]
938 PoolSlot::Async(conn) => {
939 conn.is_in_failed_transaction() || conn.is_in_transaction()
940 }
941 };
942
943 if should_discard {
944 self.pool.open_count.fetch_sub(1, Ordering::AcqRel);
945 return;
946 }
947
948 let mut slot = slot;
951 match &mut slot {
952 PoolSlot::Sync(conn) => {
953 if conn.query_counter() & 63 == 0 {
954 conn.touch();
955 }
956 }
957 #[cfg(feature = "async")]
958 PoolSlot::Async(conn) => {
959 if conn.query_counter() & 63 == 0 {
960 conn.touch();
961 }
962 }
963 }
964
965 {
967 let mut stack = self.pool.stack.lock().unwrap_or_else(|e| e.into_inner());
968 stack.push(slot);
969 }
970
971 if self.pool.open_count.load(Ordering::Relaxed) >= self.pool.max_size {
973 let (_, cvar) = &self.pool.release_pair;
974 cvar.notify_one();
975 }
976 }
977 }
978}
979
980pub struct Transaction {
996 guard: PoolGuard,
997 committed: bool,
998 deferred_buf: Vec<u8>,
1000 deferred_count: usize,
1002}
1003
1004impl Transaction {
1005 pub fn commit(mut self) -> Result<(), DriverError> {
1009 if self.deferred_count > 0 {
1010 self.flush_deferred()?;
1011 }
1012 self.guard.simple_query("COMMIT")?;
1013 self.committed = true;
1014 Ok(())
1015 }
1016
1017 pub fn rollback(mut self) -> Result<(), DriverError> {
1021 self.deferred_buf.clear();
1022 self.deferred_count = 0;
1023 self.guard.simple_query("ROLLBACK")?;
1024 self.committed = true; Ok(())
1026 }
1027
1028 pub fn query(
1033 &mut self,
1034 sql: &str,
1035 sql_hash: u64,
1036 params: &[&(dyn Encode + Sync)],
1037 ) -> Result<QueryResult, DriverError> {
1038 if self.deferred_count > 0 {
1039 self.flush_deferred()?;
1040 }
1041 self.guard.query(sql, sql_hash, params)
1042 }
1043
1044 pub fn execute(
1046 &mut self,
1047 sql: &str,
1048 sql_hash: u64,
1049 params: &[&(dyn Encode + Sync)],
1050 ) -> Result<u64, DriverError> {
1051 self.guard.execute(sql, sql_hash, params)
1052 }
1053
1054 pub fn execute_pipeline(
1056 &mut self,
1057 sql: &str,
1058 sql_hash: u64,
1059 param_sets: &[&[&(dyn Encode + Sync)]],
1060 ) -> Result<Vec<u64>, DriverError> {
1061 self.guard.execute_pipeline(sql, sql_hash, param_sets)
1062 }
1063
1064 pub fn for_each<F>(
1068 &mut self,
1069 sql: &str,
1070 sql_hash: u64,
1071 params: &[&(dyn Encode + Sync)],
1072 f: F,
1073 ) -> Result<(), DriverError>
1074 where
1075 F: FnMut(crate::types::PgDataRow<'_>) -> Result<(), DriverError>,
1076 {
1077 if self.deferred_count > 0 {
1078 self.flush_deferred()?;
1079 }
1080 self.guard.for_each(sql, sql_hash, params, f)
1081 }
1082
1083 pub fn for_each_raw<F>(
1087 &mut self,
1088 sql: &str,
1089 sql_hash: u64,
1090 params: &[&(dyn Encode + Sync)],
1091 f: F,
1092 ) -> Result<(), DriverError>
1093 where
1094 F: FnMut(&[u8]) -> Result<(), DriverError>,
1095 {
1096 if self.deferred_count > 0 {
1097 self.flush_deferred()?;
1098 }
1099 self.guard.for_each_raw(sql, sql_hash, params, f)
1100 }
1101
1102 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
1106 if self.deferred_count > 0 {
1107 self.flush_deferred()?;
1108 }
1109 self.guard.simple_query(sql)
1110 }
1111
1112 pub fn defer_execute(
1141 &mut self,
1142 sql: &str,
1143 sql_hash: u64,
1144 params: &[&(dyn Encode + Sync)],
1145 ) -> Result<(), DriverError> {
1146 if params.len() > i16::MAX as usize {
1147 return Err(DriverError::Protocol(format!(
1148 "parameter count {} exceeds maximum {}",
1149 params.len(),
1150 i16::MAX
1151 )));
1152 }
1153
1154 self.guard.ensure_stmt_prepared(sql, sql_hash, params)?;
1156
1157 self.guard
1159 .write_deferred_bind_execute(sql_hash, params, &mut self.deferred_buf);
1160 self.deferred_count += 1;
1161 Ok(())
1162 }
1163
1164 pub fn flush_deferred(&mut self) -> Result<Vec<u64>, DriverError> {
1169 let count = self.deferred_count;
1170 self.deferred_count = 0;
1171 self.guard
1172 .flush_deferred_pipeline(&mut self.deferred_buf, count)
1173 }
1174
1175 pub fn deferred_count(&self) -> usize {
1177 self.deferred_count
1178 }
1179}
1180
1181impl Drop for Transaction {
1182 fn drop(&mut self) {
1183 if !self.committed {
1184 if let Some(_slot) = self.guard.conn.take() {
1187 self.guard.pool.open_count.fetch_sub(1, Ordering::AcqRel);
1188 }
1190 }
1191 }
1192}
1193
1194#[cfg(test)]
1195mod tests {
1196 use super::*;
1197
1198 #[test]
1199 fn pool_builder_requires_url() {
1200 let result = PoolBuilder::new().build();
1201 assert!(result.is_err());
1202 }
1203
1204 #[test]
1205 fn pool_builder_validates_url() {
1206 let result = PoolBuilder::new().url("not_a_url").build();
1207 assert!(result.is_err());
1208 }
1209
1210 #[test]
1211 fn pool_builder_accepts_valid_url() {
1212 let pool = PoolBuilder::new()
1213 .url("postgres://user:pass@localhost/db")
1214 .max_size(5)
1215 .build()
1216 .unwrap();
1217 assert_eq!(pool.max_size(), 5);
1218 assert_eq!(pool.open_count(), 0);
1219 }
1220
1221 #[test]
1222 fn pool_connect_validates_url() {
1223 let result = Pool::connect("not_a_url");
1224 assert!(result.is_err());
1225 }
1226
1227 #[test]
1228 fn pool_max_size_zero() {
1229 let pool = PoolBuilder::new()
1230 .url("postgres://user:pass@localhost/db")
1231 .max_size(0)
1232 .build()
1233 .unwrap();
1234
1235 let result = pool.acquire();
1236 assert!(result.is_err());
1237 match result {
1238 Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
1239 Err(e) => panic!("expected Pool error, got: {e:?}"),
1240 Ok(_) => panic!("expected error, got Ok"),
1241 }
1242 }
1243
1244 #[test]
1245 fn pool_clone_shares_state() {
1246 let pool = PoolBuilder::new()
1247 .url("postgres://user:pass@localhost/db")
1248 .max_size(5)
1249 .build()
1250 .unwrap();
1251
1252 let pool2 = pool.clone();
1253 assert_eq!(pool.max_size(), pool2.max_size());
1254 }
1255
1256 #[test]
1260 fn pool_builder_max_lifetime() {
1261 let pool = PoolBuilder::new()
1262 .url("postgres://user:pass@localhost/db")
1263 .max_lifetime(Some(Duration::from_secs(60)))
1264 .build()
1265 .unwrap();
1266 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(60)));
1267 }
1268
1269 #[test]
1271 fn pool_builder_max_lifetime_none() {
1272 let pool = PoolBuilder::new()
1273 .url("postgres://user:pass@localhost/db")
1274 .max_lifetime(None)
1275 .build()
1276 .unwrap();
1277 assert_eq!(pool.inner.max_lifetime, None);
1278 }
1279
1280 #[test]
1282 fn pool_builder_acquire_timeout_none() {
1283 let pool = PoolBuilder::new()
1284 .url("postgres://user:pass@localhost/db")
1285 .acquire_timeout(None)
1286 .build()
1287 .unwrap();
1288 assert_eq!(pool.inner.acquire_timeout, None);
1289 }
1290
1291 #[test]
1293 fn pool_builder_acquire_timeout_custom() {
1294 let pool = PoolBuilder::new()
1295 .url("postgres://user:pass@localhost/db")
1296 .acquire_timeout(Some(Duration::from_secs(10)))
1297 .build()
1298 .unwrap();
1299 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(10)));
1300 }
1301
1302 #[test]
1304 fn pool_builder_min_idle() {
1305 let pool = PoolBuilder::new()
1306 .url("postgres://user:pass@localhost/db")
1307 .min_idle(2)
1308 .build()
1309 .unwrap();
1310 assert_eq!(pool.inner.min_idle, 2);
1311 }
1312
1313 #[test]
1315 fn pool_close_marks_closed() {
1316 let pool = PoolBuilder::new()
1317 .url("postgres://user:pass@localhost/db")
1318 .max_size(5)
1319 .build()
1320 .unwrap();
1321
1322 assert!(!pool.is_closed());
1323 pool.close();
1324 assert!(pool.is_closed());
1325
1326 let result = pool.acquire();
1328 assert!(result.is_err());
1329 match result {
1330 Err(DriverError::Pool(msg)) => assert!(msg.contains("closed")),
1331 Err(e) => panic!("expected Pool(closed) error, got: {e:?}"),
1332 Ok(_) => panic!("expected error, got Ok"),
1333 }
1334 }
1335
1336 #[test]
1338 fn pool_status_initial() {
1339 let pool = PoolBuilder::new()
1340 .url("postgres://user:pass@localhost/db")
1341 .max_size(10)
1342 .build()
1343 .unwrap();
1344
1345 let status = pool.status();
1346 assert_eq!(status.idle, 0);
1347 assert_eq!(status.active, 0);
1348 assert_eq!(status.open, 0);
1349 assert_eq!(status.max_size, 10);
1350 }
1351
1352 #[test]
1354 fn pool_builder_defaults() {
1355 let pool = PoolBuilder::new()
1356 .url("postgres://user:pass@localhost/db")
1357 .build()
1358 .unwrap();
1359
1360 assert_eq!(pool.max_size(), 10);
1361 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(30 * 60)));
1362 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
1363 assert_eq!(pool.inner.min_idle, 0);
1364 }
1365
1366 #[test]
1368 fn pool_open_count_initial() {
1369 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1370 assert_eq!(pool.open_count(), 0);
1371 }
1372
1373 #[test]
1376 fn pool_builder_max_stmt_cache_size_default() {
1377 let pool = PoolBuilder::new()
1378 .url("postgres://user:pass@localhost/db")
1379 .build()
1380 .unwrap();
1381 assert_eq!(pool.inner.max_stmt_cache_size, 256);
1382 }
1383
1384 #[test]
1385 fn pool_builder_max_stmt_cache_size_custom() {
1386 let pool = PoolBuilder::new()
1387 .url("postgres://user:pass@localhost/db")
1388 .max_stmt_cache_size(512)
1389 .build()
1390 .unwrap();
1391 assert_eq!(pool.inner.max_stmt_cache_size, 512);
1392 }
1393
1394 #[test]
1397 fn pool_is_uds_false_for_tcp() {
1398 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1399 assert!(!pool.is_uds());
1400 }
1401
1402 #[cfg(unix)]
1403 #[test]
1404 fn pool_is_uds_true_for_unix_socket() {
1405 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
1406 assert!(pool.is_uds());
1407 }
1408
1409 #[cfg(unix)]
1410 #[test]
1411 fn pool_is_uds_true_for_var_run_socket() {
1412 let pool = Pool::connect("postgres://user@localhost/db?host=/var/run/postgresql").unwrap();
1413 assert!(pool.is_uds());
1414 }
1415
1416 #[test]
1417 fn pool_is_uds_false_for_ip_address() {
1418 let pool = Pool::connect("postgres://user:pass@127.0.0.1/db").unwrap();
1419 assert!(!pool.is_uds());
1420 }
1421
1422 #[cfg(unix)]
1423 #[test]
1424 fn pool_slot_sync_created_for_uds_config() {
1425 let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
1426 assert!(config.host_is_uds());
1427 }
1428
1429 #[test]
1430 fn pool_slot_tcp_config() {
1431 let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
1432 assert!(!config.host_is_uds());
1433 }
1434
1435 #[test]
1440 fn pool_is_uds_false_for_hostname() {
1441 let pool = Pool::connect("postgres://user:pass@db.example.com/db").unwrap();
1442 assert!(!pool.is_uds());
1443 }
1444
1445 #[cfg(unix)]
1446 #[test]
1447 fn pool_is_uds_true_for_tmp() {
1448 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
1449 assert!(pool.is_uds());
1450 }
1451
1452 #[test]
1457 fn pool_close_then_acquire_fails() {
1458 let pool = PoolBuilder::new()
1459 .url("postgres://user:pass@localhost/db")
1460 .max_size(5)
1461 .build()
1462 .unwrap();
1463 pool.close();
1464 let result = pool.acquire();
1465 assert!(result.is_err());
1466 match result {
1467 Err(DriverError::Pool(msg)) => {
1468 assert!(msg.contains("closed"), "should say closed: {msg}")
1469 }
1470 Err(e) => panic!("expected Pool error, got: {e:?}"),
1471 Ok(_) => panic!("expected error"),
1472 }
1473 }
1474
1475 #[test]
1476 fn pool_is_closed_before_and_after() {
1477 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1478 assert!(!pool.is_closed());
1479 pool.close();
1480 assert!(pool.is_closed());
1481 }
1482
1483 #[test]
1488 fn pool_exhausted_no_timeout() {
1489 let pool = PoolBuilder::new()
1490 .url("postgres://user:pass@localhost/db")
1491 .max_size(0)
1492 .acquire_timeout(None) .build()
1494 .unwrap();
1495 let result = pool.acquire();
1496 assert!(result.is_err());
1497 match result {
1498 Err(DriverError::Pool(msg)) => {
1499 assert!(msg.contains("exhausted"), "should say exhausted: {msg}")
1500 }
1501 Err(e) => panic!("expected Pool error, got: {e:?}"),
1502 Ok(_) => panic!("expected error"),
1503 }
1504 }
1505
1506 #[test]
1511 fn pool_builder_no_url_error() {
1512 let result = PoolBuilder::new().max_size(5).build();
1513 assert!(result.is_err());
1514 match result {
1515 Err(DriverError::Pool(msg)) => {
1516 assert!(msg.contains("URL"), "should mention URL: {msg}")
1517 }
1518 Err(e) => panic!("expected Pool error, got: {e:?}"),
1519 Ok(_) => panic!("expected error"),
1520 }
1521 }
1522
1523 #[test]
1524 fn pool_builder_invalid_url_error() {
1525 let result = PoolBuilder::new().url("ftp://something").build();
1526 assert!(result.is_err());
1527 }
1528
1529 #[test]
1530 fn pool_builder_stmt_cache_size_zero() {
1531 let pool = PoolBuilder::new()
1532 .url("postgres://user:pass@localhost/db")
1533 .max_stmt_cache_size(0)
1534 .build()
1535 .unwrap();
1536 assert_eq!(pool.inner.max_stmt_cache_size, 0);
1537 }
1538
1539 #[test]
1544 fn pool_status_reflects_max_size() {
1545 let pool = PoolBuilder::new()
1546 .url("postgres://user:pass@localhost/db")
1547 .max_size(20)
1548 .build()
1549 .unwrap();
1550 let status = pool.status();
1551 assert_eq!(status.max_size, 20);
1552 assert_eq!(status.idle, 0);
1553 assert_eq!(status.active, 0);
1554 assert_eq!(status.open, 0);
1555 }
1556
1557 #[test]
1562 fn pool_clone_shares_config() {
1563 let pool = PoolBuilder::new()
1564 .url("postgres://user:pass@localhost/db")
1565 .max_size(7)
1566 .build()
1567 .unwrap();
1568 let p2 = pool.clone();
1569 assert_eq!(pool.max_size(), 7);
1570 assert_eq!(p2.max_size(), 7);
1571 assert_eq!(pool.open_count(), p2.open_count());
1572 }
1573
1574 #[test]
1579 fn pool_set_warmup_sqls_empty() {
1580 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1581 pool.set_warmup_sqls(&[]);
1582 let sqls = pool
1583 .inner
1584 .warmup_sqls
1585 .lock()
1586 .unwrap_or_else(|e| e.into_inner())
1587 .clone();
1588 assert!(sqls.is_empty());
1589 }
1590
1591 #[test]
1592 fn pool_set_warmup_sqls_multiple() {
1593 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1594 pool.set_warmup_sqls(&["SELECT 1", "SELECT 2", "SELECT 3"]);
1595 let sqls = pool
1596 .inner
1597 .warmup_sqls
1598 .lock()
1599 .unwrap_or_else(|e| e.into_inner())
1600 .clone();
1601 assert_eq!(sqls.len(), 3);
1602 assert_eq!(&*sqls[0], "SELECT 1");
1603 assert_eq!(&*sqls[1], "SELECT 2");
1604 assert_eq!(&*sqls[2], "SELECT 3");
1605 }
1606
1607 #[test]
1608 fn pool_set_warmup_sqls_overwrite() {
1609 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1610 pool.set_warmup_sqls(&["SELECT 1"]);
1611 pool.set_warmup_sqls(&["SELECT 99"]);
1612 let sqls = pool
1613 .inner
1614 .warmup_sqls
1615 .lock()
1616 .unwrap_or_else(|e| e.into_inner())
1617 .clone();
1618 assert_eq!(sqls.len(), 1);
1619 assert_eq!(&*sqls[0], "SELECT 99");
1620 }
1621
1622 #[test]
1627 fn pool_status_debug() {
1628 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1629 let status = pool.status();
1630 let dbg = format!("{status:?}");
1631 assert!(dbg.contains("PoolStatus"));
1632 assert!(dbg.contains("idle"));
1633 assert!(dbg.contains("active"));
1634 assert!(dbg.contains("open"));
1635 assert!(dbg.contains("max_size"));
1636 }
1637
1638 #[test]
1643 fn config_host_is_uds_returns_true_for_slash() {
1644 let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
1645 assert!(config.host_is_uds());
1646 }
1647
1648 #[test]
1649 fn config_host_is_uds_returns_false_for_tcp() {
1650 let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
1651 assert!(!config.host_is_uds());
1652 }
1653
1654 #[test]
1655 fn config_host_is_uds_returns_false_for_ip() {
1656 let config = Config::from_url("postgres://user:pass@192.168.1.1/db").unwrap();
1657 assert!(!config.host_is_uds());
1658 }
1659
1660 #[test]
1665 fn pool_builder_full_chain() {
1666 let pool = PoolBuilder::new()
1667 .url("postgres://user:pass@localhost/db")
1668 .max_size(3)
1669 .max_lifetime(Some(Duration::from_secs(600)))
1670 .acquire_timeout(Some(Duration::from_secs(5)))
1671 .min_idle(1)
1672 .max_stmt_cache_size(128)
1673 .build()
1674 .unwrap();
1675 assert_eq!(pool.max_size(), 3);
1676 assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(600)));
1677 assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
1678 assert_eq!(pool.inner.min_idle, 1);
1679 assert_eq!(pool.inner.max_stmt_cache_size, 128);
1680 }
1681
1682 #[test]
1685 fn pool_max_size_zero_rejects_all_acquires() {
1686 let pool = PoolBuilder::new()
1687 .url("postgres://user:pass@localhost/db")
1688 .max_size(0)
1689 .build()
1690 .unwrap();
1691 let result = pool.acquire();
1692 assert!(result.is_err());
1693 match &result {
1694 Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
1695 _ => panic!("expected pool exhausted error"),
1696 }
1697 }
1698
1699 #[test]
1702 fn url_parse_unknown_sslmode_returns_error() {
1703 let result = Config::from_url("postgres://u:p@h/d?sslmode=bogus");
1704 assert!(result.is_err());
1705 let msg = format!("{}", result.unwrap_err());
1706 assert!(msg.contains("unknown sslmode"));
1707 }
1708
1709 #[test]
1710 fn url_parse_invalid_port_returns_error() {
1711 let result = Config::from_url("postgres://u:p@h:abc/d");
1712 assert!(result.is_err());
1713 let msg = format!("{}", result.unwrap_err());
1714 assert!(msg.contains("invalid port"));
1715 }
1716
1717 #[test]
1718 fn url_parse_missing_at_sign_returns_error() {
1719 let result = Config::from_url("postgres://u:plocalhost/d");
1720 assert!(result.is_err());
1721 let msg = format!("{}", result.unwrap_err());
1722 assert!(msg.contains("missing @"));
1723 }
1724
1725 #[test]
1726 fn url_parse_empty_host_returns_error() {
1727 let result = Config::from_url("postgres://u:p@/d");
1728 assert!(result.is_err());
1729 }
1730
1731 #[test]
1732 fn url_parse_empty_user_returns_error() {
1733 let result = Config::from_url("postgres://:p@h/d");
1734 assert!(result.is_err());
1735 }
1736
1737 #[test]
1738 fn url_parse_statement_timeout_invalid_uses_default() {
1739 let config = Config::from_url("postgres://u:p@h/d?statement_timeout=notnum").unwrap();
1740 assert_eq!(config.statement_timeout_secs, 30);
1741 }
1742
1743 #[test]
1744 fn url_parse_malformed_percent_encoding() {
1745 let result = Config::from_url("postgres://u%:p@h/d");
1746 assert!(result.is_err());
1747 }
1748
1749 #[test]
1750 fn url_parse_invalid_hex_in_percent_encoding() {
1751 let result = Config::from_url("postgres://u%ZZ:p@h/d");
1752 assert!(result.is_err());
1753 }
1754}