1use std::sync::Mutex;
7use std::time::Duration;
8
9use bsql_driver_postgres::arena::acquire_arena;
10use bsql_driver_postgres::codec::Encode;
11
12use crate::error::{BsqlError, BsqlResult};
13use crate::stream::QueryStream;
14use crate::transaction::Transaction;
15
16#[derive(Debug, Clone)]
21pub struct RawRow(Vec<Option<String>>);
22
23impl RawRow {
24 pub fn get(&self, idx: usize) -> Option<&str> {
26 self.0.get(idx)?.as_deref()
27 }
28
29 pub fn len(&self) -> usize {
31 self.0.len()
32 }
33
34 pub fn is_empty(&self) -> bool {
36 self.0.is_empty()
37 }
38
39 pub fn iter(&self) -> impl Iterator<Item = Option<&str>> {
41 self.0.iter().map(|v| v.as_deref())
42 }
43}
44
45pub struct Pool {
66 pub(crate) inner: bsql_driver_postgres::Pool,
67 pub(crate) read_pool: Option<bsql_driver_postgres::Pool>,
69}
70
71pub struct PoolBuilder {
87 url: Option<String>,
88 max_size: usize,
89 max_lifetime: Option<Option<Duration>>,
90 acquire_timeout: Option<Option<Duration>>,
91 min_idle: Option<usize>,
92 replica_url: Option<String>,
95 replica_max_size: Option<usize>,
97}
98
99impl PoolBuilder {
100 pub fn url(mut self, url: &str) -> Self {
104 self.url = Some(url.into());
105 self
106 }
107
108 pub fn max_size(mut self, size: usize) -> Self {
109 self.max_size = size;
110 self
111 }
112
113 pub fn max_lifetime(mut self, d: Option<Duration>) -> Self {
118 self.max_lifetime = Some(d);
119 self
120 }
121
122 pub fn max_lifetime_secs(self, secs: u64) -> Self {
125 self.max_lifetime(Some(Duration::from_secs(secs)))
126 }
127
128 pub fn lifetime_secs(self, secs: u64) -> Self {
130 self.max_lifetime_secs(secs)
131 }
132
133 pub fn acquire_timeout(mut self, d: Option<Duration>) -> Self {
138 self.acquire_timeout = Some(d);
139 self
140 }
141
142 pub fn acquire_timeout_secs(self, secs: u64) -> Self {
145 self.acquire_timeout(Some(Duration::from_secs(secs)))
146 }
147
148 pub fn timeout_secs(self, secs: u64) -> Self {
150 self.acquire_timeout_secs(secs)
151 }
152
153 pub fn min_idle(mut self, n: usize) -> Self {
158 self.min_idle = Some(n);
159 self
160 }
161
162 pub fn replica_url(mut self, url: &str) -> Self {
168 self.replica_url = Some(url.into());
169 self
170 }
171
172 pub fn replica_max_size(mut self, size: usize) -> Self {
175 self.replica_max_size = Some(size);
176 self
177 }
178
179 pub async fn build(self) -> BsqlResult<Pool> {
180 let url = self.url.ok_or_else(|| {
181 BsqlError::from(bsql_driver_postgres::DriverError::Pool(
182 "pool builder requires a URL".into(),
183 ))
184 })?;
185 let mut builder = bsql_driver_postgres::Pool::builder()
186 .url(&url)
187 .max_size(self.max_size);
188
189 if let Some(lt) = self.max_lifetime {
190 builder = builder.max_lifetime(lt);
191 }
192 if let Some(at) = self.acquire_timeout {
193 builder = builder.acquire_timeout(at);
194 }
195 if let Some(mi) = self.min_idle {
196 builder = builder.min_idle(mi);
197 }
198
199 let inner = builder.build().map_err(BsqlError::from)?;
200
201 let read_pool = if let Some(replica_url) = &self.replica_url {
203 let replica_size = self.replica_max_size.unwrap_or(self.max_size);
204 let mut rbuilder = bsql_driver_postgres::Pool::builder()
205 .url(replica_url)
206 .max_size(replica_size);
207 if let Some(lt) = self.max_lifetime {
208 rbuilder = rbuilder.max_lifetime(lt);
209 }
210 if let Some(at) = self.acquire_timeout {
211 rbuilder = rbuilder.acquire_timeout(at);
212 }
213 Some(rbuilder.build().map_err(BsqlError::from)?)
214 } else {
215 None
216 };
217
218 Ok(Pool { inner, read_pool })
219 }
220}
221
222impl Pool {
223 pub async fn connect(url: &str) -> BsqlResult<Self> {
230 let inner = bsql_driver_postgres::Pool::connect(url).map_err(BsqlError::from)?;
231 Ok(Pool {
232 inner,
233 read_pool: None,
234 })
235 }
236
237 pub fn builder() -> PoolBuilder {
239 PoolBuilder {
240 url: None,
241 max_size: 10,
242 max_lifetime: None,
243 acquire_timeout: None,
244 min_idle: None,
245 replica_url: None,
246 replica_max_size: None,
247 }
248 }
249
250 pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
255 let guard = self.inner.acquire().map_err(BsqlError::from)?;
256 Ok(PoolConnection {
257 inner: Mutex::new(guard),
258 })
259 }
260
261 pub async fn begin(&self) -> BsqlResult<Transaction> {
265 let tx = self.inner.begin().map_err(BsqlError::from)?;
266 Ok(Transaction::from_driver(tx))
267 }
268
269 pub async fn query_stream(
278 &self,
279 sql: &str,
280 sql_hash: u64,
281 params: &[&(dyn Encode + Sync)],
282 ) -> BsqlResult<QueryStream> {
283 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
284 let mut arena = acquire_arena();
285
286 const CHUNK_SIZE: i32 = 64;
288
289 let (columns, _) = guard
290 .query_streaming_start(sql, sql_hash, params, CHUNK_SIZE)
291 .map_err(BsqlError::from)?;
292
293 let num_cols = columns.len();
294 let mut all_col_offsets: Vec<(usize, i32)> =
295 Vec::with_capacity(num_cols * CHUNK_SIZE as usize);
296
297 let more = guard
298 .streaming_next_chunk(&mut arena, &mut all_col_offsets)
299 .map_err(BsqlError::from)?;
300
301 let first_result = bsql_driver_postgres::QueryResult::from_parts(
302 all_col_offsets,
303 num_cols,
304 columns.clone(),
305 0,
306 );
307
308 Ok(QueryStream::new(guard, arena, first_result, columns, !more))
309 }
310
311 pub fn set_warmup_sqls(&self, sqls: &[&str]) {
319 self.inner.set_warmup_sqls(sqls);
320 }
321
322 pub async fn raw_query(&self, sql: &str) -> BsqlResult<Vec<RawRow>> {
330 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
331 let rows = guard
332 .simple_query_rows(sql)
333 .map_err(BsqlError::from_driver_query)?;
334 Ok(rows.into_iter().map(RawRow).collect())
335 }
336
337 pub async fn raw_execute(&self, sql: &str) -> BsqlResult<()> {
342 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
343 guard
344 .simple_query(sql)
345 .map_err(BsqlError::from_driver_query)?;
346 Ok(())
347 }
348
349 pub async fn copy_in<'a, I>(&self, table: &str, columns: &[&str], rows: I) -> BsqlResult<u64>
363 where
364 I: IntoIterator<Item = &'a str>,
365 {
366 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
367 guard
368 .copy_in(table, columns, rows)
369 .map_err(BsqlError::from_driver_query)
370 }
371
372 pub async fn copy_out<W: std::io::Write>(
384 &self,
385 query: &str,
386 writer: &mut W,
387 ) -> BsqlResult<u64> {
388 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
389 guard
390 .copy_out(query, writer)
391 .map_err(BsqlError::from_driver_query)
392 }
393
394 pub fn status(&self) -> PoolStatus {
398 let driver_status = self.inner.status();
399 PoolStatus {
400 idle: driver_status.idle,
401 active: driver_status.active,
402 open: driver_status.open,
403 max_size: driver_status.max_size,
404 }
405 }
406
407 pub fn close(&self) {
413 self.inner.close();
414 if let Some(ref rp) = self.read_pool {
415 rp.close();
416 }
417 }
418
419 pub fn is_closed(&self) -> bool {
421 self.inner.is_closed()
422 }
423
424 pub fn has_replica(&self) -> bool {
426 self.read_pool.is_some()
427 }
428
429 pub fn is_uds(&self) -> bool {
435 self.inner.is_uds()
436 }
437
438 pub async fn for_each_raw<F>(
447 &self,
448 sql: &str,
449 sql_hash: u64,
450 params: &[&(dyn Encode + Sync)],
451 readonly: bool,
452 mut f: F,
453 ) -> BsqlResult<()>
454 where
455 F: FnMut(bsql_driver_postgres::PgDataRow<'_>) -> BsqlResult<()>,
456 {
457 let pool = if readonly {
458 self.read_pool.as_ref().unwrap_or(&self.inner)
459 } else {
460 &self.inner
461 };
462 let mut guard = pool.acquire().map_err(BsqlError::from)?;
463 let mut user_err: Option<BsqlError> = None;
467 let driver_result = guard.for_each(sql, sql_hash, params, |row| match f(row) {
468 Ok(()) => Ok(()),
469 Err(e) => {
470 user_err = Some(e);
471 Err(bsql_driver_postgres::DriverError::Protocol(
472 "for_each closure error".into(),
473 ))
474 }
475 });
476 if let Some(e) = user_err {
478 return Err(e);
479 }
480 driver_result.map_err(BsqlError::from_driver_query)
481 }
482
483 #[doc(hidden)]
490 pub async fn __for_each_raw_bytes<F>(
491 &self,
492 sql: &str,
493 sql_hash: u64,
494 params: &[&(dyn Encode + Sync)],
495 readonly: bool,
496 mut f: F,
497 ) -> BsqlResult<()>
498 where
499 F: FnMut(&[u8]) -> BsqlResult<()>,
500 {
501 let pool = if readonly {
502 self.read_pool.as_ref().unwrap_or(&self.inner)
503 } else {
504 &self.inner
505 };
506 let mut guard = pool.acquire().map_err(BsqlError::from)?;
507 let mut user_err: Option<BsqlError> = None;
508 let driver_result = guard.for_each_raw(sql, sql_hash, params, |data| match f(data) {
509 Ok(()) => Ok(()),
510 Err(e) => {
511 user_err = Some(e);
512 Err(bsql_driver_postgres::DriverError::Protocol(
513 "for_each closure error".into(),
514 ))
515 }
516 });
517 if let Some(e) = user_err {
518 return Err(e);
519 }
520 driver_result.map_err(BsqlError::from_driver_query)
521 }
522}
523
524impl Clone for Pool {
525 fn clone(&self) -> Self {
526 Pool {
527 inner: self.inner.clone(),
528 read_pool: self.read_pool.clone(),
529 }
530 }
531}
532
533impl std::fmt::Debug for Pool {
534 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535 f.debug_struct("Pool")
536 .field("status", &self.status())
537 .finish()
538 }
539}
540
541pub struct PoolConnection {
550 pub(crate) inner: Mutex<bsql_driver_postgres::PoolGuard>,
551}
552
553impl std::fmt::Debug for PoolConnection {
554 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555 f.debug_struct("PoolConnection").finish()
556 }
557}
558
559#[derive(Debug, Clone, Copy)]
561pub struct PoolStatus {
562 pub idle: usize,
564 pub active: usize,
566 pub open: usize,
568 pub max_size: usize,
570}
571
572impl std::fmt::Display for PoolStatus {
573 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
574 write!(
575 f,
576 "idle={}, active={}, open={}, max={}",
577 self.idle, self.active, self.open, self.max_size
578 )
579 }
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585
586 #[test]
587 fn builder_defaults() {
588 let b = Pool::builder();
589 assert_eq!(b.max_size, 10);
590 assert!(b.max_lifetime.is_none());
591 assert!(b.acquire_timeout.is_none());
592 assert!(b.min_idle.is_none());
593 }
594
595 #[test]
596 fn builder_max_lifetime() {
597 let b = Pool::builder().max_lifetime(Some(Duration::from_secs(60)));
598 assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(60))));
599 }
600
601 #[test]
602 fn builder_max_lifetime_none_disables() {
603 let b = Pool::builder().max_lifetime(None);
604 assert_eq!(b.max_lifetime, Some(None));
605 }
606
607 #[test]
608 fn builder_acquire_timeout() {
609 let b = Pool::builder().acquire_timeout(Some(Duration::from_secs(3)));
610 assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(3))));
611 }
612
613 #[test]
614 fn builder_acquire_timeout_none_disables() {
615 let b = Pool::builder().acquire_timeout(None);
616 assert_eq!(b.acquire_timeout, Some(None));
617 }
618
619 #[test]
620 fn builder_min_idle() {
621 let b = Pool::builder().min_idle(5);
622 assert_eq!(b.min_idle, Some(5));
623 }
624
625 #[test]
628 fn builder_max_lifetime_secs() {
629 let b = Pool::builder().max_lifetime_secs(1800);
630 assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(1800))));
631 }
632
633 #[test]
634 fn builder_acquire_timeout_secs() {
635 let b = Pool::builder().acquire_timeout_secs(5);
636 assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(5))));
637 }
638
639 #[test]
642 fn builder_lifetime_secs_shorthand() {
643 let b = Pool::builder().lifetime_secs(900);
644 assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(900))));
645 }
646
647 #[test]
648 fn builder_timeout_secs_shorthand() {
649 let b = Pool::builder().timeout_secs(3);
650 assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(3))));
651 }
652
653 #[test]
656 fn builder_defaults_no_replica() {
657 let b = Pool::builder();
658 assert!(b.replica_url.is_none());
659 assert!(b.replica_max_size.is_none());
660 }
661
662 #[test]
663 fn builder_replica_url() {
664 let b = Pool::builder().replica_url("postgres://replica:5432/db");
665 assert_eq!(b.replica_url.as_deref(), Some("postgres://replica:5432/db"));
666 }
667
668 #[test]
669 fn builder_replica_max_size() {
670 let b = Pool::builder().replica_max_size(20);
671 assert_eq!(b.replica_max_size, Some(20));
672 }
673
674 #[tokio::test]
675 async fn pool_connect_has_no_replica() {
676 let pool = Pool::connect("postgres://user:pass@localhost/db")
677 .await
678 .unwrap();
679 assert!(!pool.has_replica());
680 }
681
682 #[tokio::test]
685 async fn pool_is_uds_false_for_tcp() {
686 let pool = Pool::connect("postgres://user:pass@localhost/db")
687 .await
688 .unwrap();
689 assert!(!pool.is_uds());
690 }
691
692 #[cfg(unix)]
693 #[tokio::test]
694 async fn pool_is_uds_true_for_unix_socket() {
695 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp")
696 .await
697 .unwrap();
698 assert!(pool.is_uds());
699 }
700
701 #[tokio::test]
702 async fn pool_is_uds_false_for_ip() {
703 let pool = Pool::connect("postgres://user:pass@127.0.0.1/db")
704 .await
705 .unwrap();
706 assert!(!pool.is_uds());
707 }
708
709 #[test]
712 fn pool_status_display() {
713 let status = PoolStatus {
714 idle: 3,
715 active: 2,
716 open: 5,
717 max_size: 10,
718 };
719 assert_eq!(status.to_string(), "idle=3, active=2, open=5, max=10");
720 }
721
722 #[test]
723 fn pool_status_display_zeros() {
724 let status = PoolStatus {
725 idle: 0,
726 active: 0,
727 open: 0,
728 max_size: 0,
729 };
730 assert_eq!(status.to_string(), "idle=0, active=0, open=0, max=0");
731 }
732
733 #[test]
736 fn pool_connection_debug() {
737 let dbg_str = "PoolConnection";
739 assert!(!dbg_str.is_empty());
740 fn _assert_debug<T: std::fmt::Debug>() {}
743 _assert_debug::<PoolConnection>();
744 }
745
746 #[tokio::test]
749 async fn pool_debug() {
750 let pool = Pool::connect("postgres://user:pass@localhost/db")
751 .await
752 .unwrap();
753 let dbg = format!("{pool:?}");
754 assert!(dbg.contains("Pool"), "Debug should show Pool: {dbg}");
755 }
756
757 #[tokio::test]
760 async fn pool_clone_is_cheap() {
761 let pool = Pool::connect("postgres://user:pass@localhost/db")
762 .await
763 .unwrap();
764 let pool2 = pool.clone();
765 assert_eq!(pool.status().max_size, pool2.status().max_size);
766 assert!(!pool.has_replica());
767 assert!(!pool2.has_replica());
768 }
769
770 fn _assert_send<T: Send>() {}
773 fn _assert_sync<T: Sync>() {}
774
775 #[test]
776 fn pool_is_send_and_sync() {
777 _assert_send::<Pool>();
778 _assert_sync::<Pool>();
779 }
780
781 #[test]
782 fn pool_connection_is_send_and_sync() {
783 _assert_send::<PoolConnection>();
784 _assert_sync::<PoolConnection>();
785 }
786
787 #[test]
788 fn pool_status_is_send_and_sync() {
789 _assert_send::<PoolStatus>();
790 _assert_sync::<PoolStatus>();
791 }
792
793 #[tokio::test]
796 async fn builder_build_without_url_errors() {
797 let result = Pool::builder().build().await;
798 assert!(result.is_err());
799 let err = result.unwrap_err().to_string();
800 assert!(err.contains("URL"), "error should mention URL: {err}");
801 }
802
803 #[test]
806 fn builder_chaining() {
807 let b = Pool::builder()
808 .url("postgres://u@localhost/db")
809 .max_size(20)
810 .lifetime_secs(600)
811 .timeout_secs(3)
812 .min_idle(2)
813 .replica_url("postgres://u@replica/db")
814 .replica_max_size(10);
815 assert_eq!(b.max_size, 20);
816 assert_eq!(b.min_idle, Some(2));
817 assert_eq!(b.replica_max_size, Some(10));
818 }
819
820 #[test]
823 fn raw_row_get() {
824 let row = RawRow(vec![Some("hello".into()), None, Some("42".into())]);
825 assert_eq!(row.get(0), Some("hello"));
826 assert_eq!(row.get(1), None);
827 assert_eq!(row.get(2), Some("42"));
828 assert_eq!(row.get(99), None);
829 assert_eq!(row.len(), 3);
830 }
831
832 #[test]
833 fn raw_row_is_empty() {
834 let empty = RawRow(vec![]);
835 assert!(empty.is_empty());
836 assert_eq!(empty.len(), 0);
837
838 let non_empty = RawRow(vec![Some("x".into())]);
839 assert!(!non_empty.is_empty());
840 }
841
842 #[test]
843 fn raw_row_iter() {
844 let row = RawRow(vec![Some("a".into()), None, Some("b".into())]);
845 let vals: Vec<_> = row.iter().collect();
846 assert_eq!(vals, vec![Some("a"), None, Some("b")]);
847 }
848
849 #[test]
850 fn raw_row_clone() {
851 let row = RawRow(vec![Some("hello".into()), None]);
852 let cloned = row.clone();
853 assert_eq!(cloned.get(0), Some("hello"));
854 assert_eq!(cloned.get(1), None);
855 assert_eq!(cloned.len(), 2);
856 }
857
858 #[test]
859 fn raw_row_debug() {
860 let row = RawRow(vec![Some("x".into())]);
861 let dbg = format!("{row:?}");
862 assert!(dbg.contains("RawRow"), "Debug should show RawRow: {dbg}");
863 }
864
865 #[test]
868 fn raw_row_all_null_values() {
869 let row = RawRow(vec![None, None, None]);
870 assert_eq!(row.len(), 3);
871 assert!(!row.is_empty());
872 assert_eq!(row.get(0), None);
873 assert_eq!(row.get(1), None);
874 assert_eq!(row.get(2), None);
875 let vals: Vec<_> = row.iter().collect();
877 assert_eq!(vals, vec![None, None, None]);
878 }
879
880 #[test]
881 fn raw_row_empty_string_values() {
882 let row = RawRow(vec![Some(String::new()), Some("".into())]);
883 assert_eq!(row.len(), 2);
884 assert_eq!(row.get(0), Some(""));
886 assert_eq!(row.get(1), Some(""));
887 }
888
889 #[test]
890 fn raw_row_get_out_of_bounds() {
891 let row = RawRow(vec![Some("only".into())]);
892 assert_eq!(row.get(0), Some("only"));
893 assert_eq!(row.get(1), None);
894 assert_eq!(row.get(100), None);
895 assert_eq!(row.get(usize::MAX), None);
896 }
897
898 #[test]
899 fn raw_row_iter_empty() {
900 let row = RawRow(vec![]);
901 let vals: Vec<_> = row.iter().collect();
902 assert!(vals.is_empty());
903 }
904
905 #[test]
906 fn raw_row_iter_mixed() {
907 let row = RawRow(vec![
908 Some("hello".into()),
909 None,
910 Some("world".into()),
911 None,
912 Some("".into()),
913 ]);
914 let vals: Vec<_> = row.iter().collect();
915 assert_eq!(
916 vals,
917 vec![Some("hello"), None, Some("world"), None, Some("")]
918 );
919 }
920
921 #[test]
922 fn raw_row_single_null() {
923 let row = RawRow(vec![None]);
924 assert_eq!(row.len(), 1);
925 assert!(!row.is_empty());
926 assert_eq!(row.get(0), None);
927 }
928}