1use std::sync::Mutex;
7use std::time::Duration;
8
9use bsql_driver_postgres::arena::acquire_arena;
10use bsql_driver_postgres::codec::Encode;
11
12use crate::error::{BsqlError, BsqlResult};
13use crate::stream::QueryStream;
14use crate::transaction::Transaction;
15
16#[derive(Debug, Clone)]
21pub struct RawRow(Vec<Option<String>>);
22
23impl RawRow {
24 pub fn get(&self, idx: usize) -> Option<&str> {
26 self.0.get(idx)?.as_deref()
27 }
28
29 pub fn len(&self) -> usize {
31 self.0.len()
32 }
33
34 pub fn is_empty(&self) -> bool {
36 self.0.is_empty()
37 }
38
39 pub fn iter(&self) -> impl Iterator<Item = Option<&str>> {
41 self.0.iter().map(|v| v.as_deref())
42 }
43}
44
45pub struct Pool {
66 pub(crate) inner: bsql_driver_postgres::Pool,
67 pub(crate) read_pool: Option<bsql_driver_postgres::Pool>,
69}
70
71pub struct PoolBuilder {
87 url: Option<String>,
88 max_size: usize,
89 max_lifetime: Option<Option<Duration>>,
90 acquire_timeout: Option<Option<Duration>>,
91 min_idle: Option<usize>,
92 replica_url: Option<String>,
95 replica_max_size: Option<usize>,
97 stale_timeout: Option<Duration>,
99 max_stmt_cache_size: Option<usize>,
101}
102
103impl PoolBuilder {
104 pub fn url(mut self, url: &str) -> Self {
108 self.url = Some(url.into());
109 self
110 }
111
112 pub fn max_size(mut self, size: usize) -> Self {
113 self.max_size = size;
114 self
115 }
116
117 pub fn max_lifetime(mut self, d: Option<Duration>) -> Self {
122 self.max_lifetime = Some(d);
123 self
124 }
125
126 pub fn max_lifetime_secs(self, secs: u64) -> Self {
129 self.max_lifetime(Some(Duration::from_secs(secs)))
130 }
131
132 pub fn lifetime_secs(self, secs: u64) -> Self {
134 self.max_lifetime_secs(secs)
135 }
136
137 pub fn acquire_timeout(mut self, d: Option<Duration>) -> Self {
142 self.acquire_timeout = Some(d);
143 self
144 }
145
146 pub fn acquire_timeout_secs(self, secs: u64) -> Self {
149 self.acquire_timeout(Some(Duration::from_secs(secs)))
150 }
151
152 pub fn timeout_secs(self, secs: u64) -> Self {
154 self.acquire_timeout_secs(secs)
155 }
156
157 pub fn min_idle(mut self, n: usize) -> Self {
162 self.min_idle = Some(n);
163 self
164 }
165
166 pub fn replica_url(mut self, url: &str) -> Self {
172 self.replica_url = Some(url.into());
173 self
174 }
175
176 pub fn replica_max_size(mut self, size: usize) -> Self {
179 self.replica_max_size = Some(size);
180 self
181 }
182
183 pub fn stale_timeout(mut self, timeout: Duration) -> Self {
187 self.stale_timeout = Some(timeout);
188 self
189 }
190
191 pub fn max_stmt_cache_size(mut self, size: usize) -> Self {
195 self.max_stmt_cache_size = Some(size);
196 self
197 }
198
199 pub async fn build(self) -> BsqlResult<Pool> {
200 let url = self.url.ok_or_else(|| {
201 BsqlError::from(bsql_driver_postgres::DriverError::Pool(
202 "pool builder requires a URL".into(),
203 ))
204 })?;
205 let mut builder = bsql_driver_postgres::Pool::builder()
206 .url(&url)
207 .max_size(self.max_size);
208
209 if let Some(lt) = self.max_lifetime {
210 builder = builder.max_lifetime(lt);
211 }
212 if let Some(at) = self.acquire_timeout {
213 builder = builder.acquire_timeout(at);
214 }
215 if let Some(mi) = self.min_idle {
216 builder = builder.min_idle(mi);
217 }
218 if let Some(st) = self.stale_timeout {
219 builder = builder.stale_timeout(st);
220 }
221 if let Some(msc) = self.max_stmt_cache_size {
222 builder = builder.max_stmt_cache_size(msc);
223 }
224
225 let inner = builder.build().map_err(BsqlError::from)?;
226
227 let read_pool = if let Some(replica_url) = &self.replica_url {
229 let replica_size = self.replica_max_size.unwrap_or(self.max_size);
230 let mut rbuilder = bsql_driver_postgres::Pool::builder()
231 .url(replica_url)
232 .max_size(replica_size);
233 if let Some(lt) = self.max_lifetime {
234 rbuilder = rbuilder.max_lifetime(lt);
235 }
236 if let Some(at) = self.acquire_timeout {
237 rbuilder = rbuilder.acquire_timeout(at);
238 }
239 Some(rbuilder.build().map_err(BsqlError::from)?)
240 } else {
241 None
242 };
243
244 Ok(Pool { inner, read_pool })
245 }
246}
247
248impl Pool {
249 pub async fn connect(url: &str) -> BsqlResult<Self> {
256 let inner = bsql_driver_postgres::Pool::connect(url).map_err(BsqlError::from)?;
257 Ok(Pool {
258 inner,
259 read_pool: None,
260 })
261 }
262
263 pub fn builder() -> PoolBuilder {
265 PoolBuilder {
266 url: None,
267 max_size: 10,
268 max_lifetime: None,
269 acquire_timeout: None,
270 min_idle: None,
271 replica_url: None,
272 replica_max_size: None,
273 stale_timeout: None,
274 max_stmt_cache_size: None,
275 }
276 }
277
278 pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
283 let guard = self.inner.acquire().map_err(BsqlError::from)?;
284 Ok(PoolConnection {
285 inner: Mutex::new(guard),
286 })
287 }
288
289 pub async fn begin(&self) -> BsqlResult<Transaction> {
293 let tx = self.inner.begin().map_err(BsqlError::from)?;
294 Ok(Transaction::from_driver(tx))
295 }
296
297 pub async fn query_stream(
306 &self,
307 sql: &str,
308 sql_hash: u64,
309 params: &[&(dyn Encode + Sync)],
310 ) -> BsqlResult<QueryStream> {
311 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
312 let mut arena = acquire_arena();
313
314 const CHUNK_SIZE: i32 = 64;
316
317 let (columns, _) = guard
318 .query_streaming_start(sql, sql_hash, params, CHUNK_SIZE)
319 .map_err(BsqlError::from)?;
320
321 let num_cols = columns.len();
322 let mut all_col_offsets: Vec<(usize, i32)> =
323 Vec::with_capacity(num_cols * CHUNK_SIZE as usize);
324
325 let more = guard
326 .streaming_next_chunk(&mut arena, &mut all_col_offsets)
327 .map_err(BsqlError::from)?;
328
329 let first_result = bsql_driver_postgres::QueryResult::from_parts(
330 all_col_offsets,
331 num_cols,
332 columns.clone(),
333 0,
334 );
335
336 Ok(QueryStream::new(guard, arena, first_result, columns, !more))
337 }
338
339 pub fn set_warmup_sqls(&self, sqls: &[&str]) {
347 self.inner.set_warmup_sqls(sqls);
348 }
349
350 pub async fn raw_query(&self, sql: &str) -> BsqlResult<Vec<RawRow>> {
358 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
359 let rows = guard
360 .simple_query_rows(sql)
361 .map_err(BsqlError::from_driver_query)?;
362 Ok(rows.into_iter().map(RawRow).collect())
363 }
364
365 pub async fn raw_execute(&self, sql: &str) -> BsqlResult<()> {
370 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
371 guard
372 .simple_query(sql)
373 .map_err(BsqlError::from_driver_query)?;
374 Ok(())
375 }
376
377 pub async fn copy_in<'a, I>(&self, table: &str, columns: &[&str], rows: I) -> BsqlResult<u64>
391 where
392 I: IntoIterator<Item = &'a str>,
393 {
394 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
395 guard
396 .copy_in(table, columns, rows)
397 .map_err(BsqlError::from_driver_query)
398 }
399
400 pub async fn copy_out<W: std::io::Write>(
412 &self,
413 query: &str,
414 writer: &mut W,
415 ) -> BsqlResult<u64> {
416 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
417 guard
418 .copy_out(query, writer)
419 .map_err(BsqlError::from_driver_query)
420 }
421
422 pub fn status(&self) -> PoolStatus {
426 let driver_status = self.inner.status();
427 PoolStatus {
428 idle: driver_status.idle,
429 active: driver_status.active,
430 open: driver_status.open,
431 max_size: driver_status.max_size,
432 }
433 }
434
435 pub fn close(&self) {
441 self.inner.close();
442 if let Some(ref rp) = self.read_pool {
443 rp.close();
444 }
445 }
446
447 pub fn is_closed(&self) -> bool {
449 self.inner.is_closed()
450 }
451
452 pub fn has_replica(&self) -> bool {
454 self.read_pool.is_some()
455 }
456
457 pub fn is_uds(&self) -> bool {
463 self.inner.is_uds()
464 }
465
466 pub async fn for_each_raw<F>(
475 &self,
476 sql: &str,
477 sql_hash: u64,
478 params: &[&(dyn Encode + Sync)],
479 readonly: bool,
480 mut f: F,
481 ) -> BsqlResult<()>
482 where
483 F: FnMut(bsql_driver_postgres::PgDataRow<'_>) -> BsqlResult<()>,
484 {
485 let pool = if readonly {
486 self.read_pool.as_ref().unwrap_or(&self.inner)
487 } else {
488 &self.inner
489 };
490 let mut guard = pool.acquire().map_err(BsqlError::from)?;
491 let mut user_err: Option<BsqlError> = None;
495 let driver_result = guard.for_each(sql, sql_hash, params, |row| match f(row) {
496 Ok(()) => Ok(()),
497 Err(e) => {
498 user_err = Some(e);
499 Err(bsql_driver_postgres::DriverError::Protocol(
500 "for_each closure error".into(),
501 ))
502 }
503 });
504 if let Some(e) = user_err {
506 return Err(e);
507 }
508 driver_result.map_err(BsqlError::from_driver_query)
509 }
510
511 #[doc(hidden)]
518 pub async fn __for_each_raw_bytes<F>(
519 &self,
520 sql: &str,
521 sql_hash: u64,
522 params: &[&(dyn Encode + Sync)],
523 readonly: bool,
524 mut f: F,
525 ) -> BsqlResult<()>
526 where
527 F: FnMut(&[u8]) -> BsqlResult<()>,
528 {
529 let pool = if readonly {
530 self.read_pool.as_ref().unwrap_or(&self.inner)
531 } else {
532 &self.inner
533 };
534 let mut guard = pool.acquire().map_err(BsqlError::from)?;
535 let mut user_err: Option<BsqlError> = None;
536 let driver_result = guard.for_each_raw(sql, sql_hash, params, |data| match f(data) {
537 Ok(()) => Ok(()),
538 Err(e) => {
539 user_err = Some(e);
540 Err(bsql_driver_postgres::DriverError::Protocol(
541 "for_each closure error".into(),
542 ))
543 }
544 });
545 if let Some(e) = user_err {
546 return Err(e);
547 }
548 driver_result.map_err(BsqlError::from_driver_query)
549 }
550}
551
552impl Clone for Pool {
553 fn clone(&self) -> Self {
554 Pool {
555 inner: self.inner.clone(),
556 read_pool: self.read_pool.clone(),
557 }
558 }
559}
560
561impl std::fmt::Debug for Pool {
562 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
563 f.debug_struct("Pool")
564 .field("status", &self.status())
565 .finish()
566 }
567}
568
569pub struct PoolConnection {
578 pub(crate) inner: Mutex<bsql_driver_postgres::PoolGuard>,
579}
580
581impl std::fmt::Debug for PoolConnection {
582 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
583 f.debug_struct("PoolConnection").finish()
584 }
585}
586
587#[derive(Debug, Clone, Copy)]
589pub struct PoolStatus {
590 pub idle: usize,
592 pub active: usize,
594 pub open: usize,
596 pub max_size: usize,
598}
599
600impl std::fmt::Display for PoolStatus {
601 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
602 write!(
603 f,
604 "idle={}, active={}, open={}, max={}",
605 self.idle, self.active, self.open, self.max_size
606 )
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613
614 #[test]
615 fn builder_defaults() {
616 let b = Pool::builder();
617 assert_eq!(b.max_size, 10);
618 assert!(b.max_lifetime.is_none());
619 assert!(b.acquire_timeout.is_none());
620 assert!(b.min_idle.is_none());
621 }
622
623 #[test]
624 fn builder_max_lifetime() {
625 let b = Pool::builder().max_lifetime(Some(Duration::from_secs(60)));
626 assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(60))));
627 }
628
629 #[test]
630 fn builder_max_lifetime_none_disables() {
631 let b = Pool::builder().max_lifetime(None);
632 assert_eq!(b.max_lifetime, Some(None));
633 }
634
635 #[test]
636 fn builder_acquire_timeout() {
637 let b = Pool::builder().acquire_timeout(Some(Duration::from_secs(3)));
638 assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(3))));
639 }
640
641 #[test]
642 fn builder_acquire_timeout_none_disables() {
643 let b = Pool::builder().acquire_timeout(None);
644 assert_eq!(b.acquire_timeout, Some(None));
645 }
646
647 #[test]
648 fn builder_min_idle() {
649 let b = Pool::builder().min_idle(5);
650 assert_eq!(b.min_idle, Some(5));
651 }
652
653 #[test]
656 fn builder_max_lifetime_secs() {
657 let b = Pool::builder().max_lifetime_secs(1800);
658 assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(1800))));
659 }
660
661 #[test]
662 fn builder_acquire_timeout_secs() {
663 let b = Pool::builder().acquire_timeout_secs(5);
664 assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(5))));
665 }
666
667 #[test]
670 fn builder_lifetime_secs_shorthand() {
671 let b = Pool::builder().lifetime_secs(900);
672 assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(900))));
673 }
674
675 #[test]
676 fn builder_timeout_secs_shorthand() {
677 let b = Pool::builder().timeout_secs(3);
678 assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(3))));
679 }
680
681 #[test]
684 fn builder_defaults_no_replica() {
685 let b = Pool::builder();
686 assert!(b.replica_url.is_none());
687 assert!(b.replica_max_size.is_none());
688 }
689
690 #[test]
691 fn builder_replica_url() {
692 let b = Pool::builder().replica_url("postgres://replica:5432/db");
693 assert_eq!(b.replica_url.as_deref(), Some("postgres://replica:5432/db"));
694 }
695
696 #[test]
697 fn builder_replica_max_size() {
698 let b = Pool::builder().replica_max_size(20);
699 assert_eq!(b.replica_max_size, Some(20));
700 }
701
702 #[tokio::test]
703 async fn pool_connect_has_no_replica() {
704 let pool = Pool::connect("postgres://user:pass@localhost/db")
705 .await
706 .unwrap();
707 assert!(!pool.has_replica());
708 }
709
710 #[tokio::test]
713 async fn pool_is_uds_false_for_tcp() {
714 let pool = Pool::connect("postgres://user:pass@localhost/db")
715 .await
716 .unwrap();
717 assert!(!pool.is_uds());
718 }
719
720 #[cfg(unix)]
721 #[tokio::test]
722 async fn pool_is_uds_true_for_unix_socket() {
723 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp")
724 .await
725 .unwrap();
726 assert!(pool.is_uds());
727 }
728
729 #[tokio::test]
730 async fn pool_is_uds_false_for_ip() {
731 let pool = Pool::connect("postgres://user:pass@127.0.0.1/db")
732 .await
733 .unwrap();
734 assert!(!pool.is_uds());
735 }
736
737 #[test]
740 fn pool_status_display() {
741 let status = PoolStatus {
742 idle: 3,
743 active: 2,
744 open: 5,
745 max_size: 10,
746 };
747 assert_eq!(status.to_string(), "idle=3, active=2, open=5, max=10");
748 }
749
750 #[test]
751 fn pool_status_display_zeros() {
752 let status = PoolStatus {
753 idle: 0,
754 active: 0,
755 open: 0,
756 max_size: 0,
757 };
758 assert_eq!(status.to_string(), "idle=0, active=0, open=0, max=0");
759 }
760
761 #[test]
764 fn pool_connection_debug() {
765 let dbg_str = "PoolConnection";
767 assert!(!dbg_str.is_empty());
768 fn _assert_debug<T: std::fmt::Debug>() {}
771 _assert_debug::<PoolConnection>();
772 }
773
774 #[tokio::test]
777 async fn pool_debug() {
778 let pool = Pool::connect("postgres://user:pass@localhost/db")
779 .await
780 .unwrap();
781 let dbg = format!("{pool:?}");
782 assert!(dbg.contains("Pool"), "Debug should show Pool: {dbg}");
783 }
784
785 #[tokio::test]
788 async fn pool_clone_is_cheap() {
789 let pool = Pool::connect("postgres://user:pass@localhost/db")
790 .await
791 .unwrap();
792 let pool2 = pool.clone();
793 assert_eq!(pool.status().max_size, pool2.status().max_size);
794 assert!(!pool.has_replica());
795 assert!(!pool2.has_replica());
796 }
797
798 fn _assert_send<T: Send>() {}
801 fn _assert_sync<T: Sync>() {}
802
803 #[test]
804 fn pool_is_send_and_sync() {
805 _assert_send::<Pool>();
806 _assert_sync::<Pool>();
807 }
808
809 #[test]
810 fn pool_connection_is_send_and_sync() {
811 _assert_send::<PoolConnection>();
812 _assert_sync::<PoolConnection>();
813 }
814
815 #[test]
816 fn pool_status_is_send_and_sync() {
817 _assert_send::<PoolStatus>();
818 _assert_sync::<PoolStatus>();
819 }
820
821 #[tokio::test]
824 async fn builder_build_without_url_errors() {
825 let result = Pool::builder().build().await;
826 assert!(result.is_err());
827 let err = result.unwrap_err().to_string();
828 assert!(err.contains("URL"), "error should mention URL: {err}");
829 }
830
831 #[test]
834 fn builder_chaining() {
835 let b = Pool::builder()
836 .url("postgres://u@localhost/db")
837 .max_size(20)
838 .lifetime_secs(600)
839 .timeout_secs(3)
840 .min_idle(2)
841 .replica_url("postgres://u@replica/db")
842 .replica_max_size(10);
843 assert_eq!(b.max_size, 20);
844 assert_eq!(b.min_idle, Some(2));
845 assert_eq!(b.replica_max_size, Some(10));
846 }
847
848 #[test]
851 fn raw_row_get() {
852 let row = RawRow(vec![Some("hello".into()), None, Some("42".into())]);
853 assert_eq!(row.get(0), Some("hello"));
854 assert_eq!(row.get(1), None);
855 assert_eq!(row.get(2), Some("42"));
856 assert_eq!(row.get(99), None);
857 assert_eq!(row.len(), 3);
858 }
859
860 #[test]
861 fn raw_row_is_empty() {
862 let empty = RawRow(vec![]);
863 assert!(empty.is_empty());
864 assert_eq!(empty.len(), 0);
865
866 let non_empty = RawRow(vec![Some("x".into())]);
867 assert!(!non_empty.is_empty());
868 }
869
870 #[test]
871 fn raw_row_iter() {
872 let row = RawRow(vec![Some("a".into()), None, Some("b".into())]);
873 let vals: Vec<_> = row.iter().collect();
874 assert_eq!(vals, vec![Some("a"), None, Some("b")]);
875 }
876
877 #[test]
878 fn raw_row_clone() {
879 let row = RawRow(vec![Some("hello".into()), None]);
880 let cloned = row.clone();
881 assert_eq!(cloned.get(0), Some("hello"));
882 assert_eq!(cloned.get(1), None);
883 assert_eq!(cloned.len(), 2);
884 }
885
886 #[test]
887 fn raw_row_debug() {
888 let row = RawRow(vec![Some("x".into())]);
889 let dbg = format!("{row:?}");
890 assert!(dbg.contains("RawRow"), "Debug should show RawRow: {dbg}");
891 }
892
893 #[test]
896 fn raw_row_all_null_values() {
897 let row = RawRow(vec![None, None, None]);
898 assert_eq!(row.len(), 3);
899 assert!(!row.is_empty());
900 assert_eq!(row.get(0), None);
901 assert_eq!(row.get(1), None);
902 assert_eq!(row.get(2), None);
903 let vals: Vec<_> = row.iter().collect();
905 assert_eq!(vals, vec![None, None, None]);
906 }
907
908 #[test]
909 fn raw_row_empty_string_values() {
910 let row = RawRow(vec![Some(String::new()), Some("".into())]);
911 assert_eq!(row.len(), 2);
912 assert_eq!(row.get(0), Some(""));
914 assert_eq!(row.get(1), Some(""));
915 }
916
917 #[test]
918 fn raw_row_get_out_of_bounds() {
919 let row = RawRow(vec![Some("only".into())]);
920 assert_eq!(row.get(0), Some("only"));
921 assert_eq!(row.get(1), None);
922 assert_eq!(row.get(100), None);
923 assert_eq!(row.get(usize::MAX), None);
924 }
925
926 #[test]
927 fn raw_row_iter_empty() {
928 let row = RawRow(vec![]);
929 let vals: Vec<_> = row.iter().collect();
930 assert!(vals.is_empty());
931 }
932
933 #[test]
934 fn raw_row_iter_mixed() {
935 let row = RawRow(vec![
936 Some("hello".into()),
937 None,
938 Some("world".into()),
939 None,
940 Some("".into()),
941 ]);
942 let vals: Vec<_> = row.iter().collect();
943 assert_eq!(
944 vals,
945 vec![Some("hello"), None, Some("world"), None, Some("")]
946 );
947 }
948
949 #[test]
950 fn raw_row_single_null() {
951 let row = RawRow(vec![None]);
952 assert_eq!(row.len(), 1);
953 assert!(!row.is_empty());
954 assert_eq!(row.get(0), None);
955 }
956}