1use std::sync::Mutex;
7use std::time::Duration;
8
9use bsql_driver_postgres::arena::acquire_arena;
10use bsql_driver_postgres::codec::Encode;
11
12use crate::error::{BsqlError, BsqlResult};
13use crate::stream::QueryStream;
14use crate::transaction::Transaction;
15
16#[derive(Debug, Clone)]
21pub struct RawRow(Vec<Option<String>>);
22
23impl RawRow {
24 pub fn get(&self, idx: usize) -> Option<&str> {
26 self.0.get(idx)?.as_deref()
27 }
28
29 pub fn len(&self) -> usize {
31 self.0.len()
32 }
33
34 pub fn is_empty(&self) -> bool {
36 self.0.is_empty()
37 }
38
39 pub fn iter(&self) -> impl Iterator<Item = Option<&str>> {
41 self.0.iter().map(|v| v.as_deref())
42 }
43}
44
45pub struct Pool {
66 pub(crate) inner: bsql_driver_postgres::Pool,
67 pub(crate) read_pool: Option<bsql_driver_postgres::Pool>,
69}
70
71pub struct PoolBuilder {
87 url: Option<String>,
88 max_size: usize,
89 max_lifetime: Option<Option<Duration>>,
90 acquire_timeout: Option<Option<Duration>>,
91 min_idle: Option<usize>,
92 replica_url: Option<String>,
95 replica_max_size: Option<usize>,
97}
98
99impl PoolBuilder {
100 pub fn url(mut self, url: &str) -> Self {
104 self.url = Some(url.into());
105 self
106 }
107
108 pub fn max_size(mut self, size: usize) -> Self {
109 self.max_size = size;
110 self
111 }
112
113 pub fn max_lifetime(mut self, d: Option<Duration>) -> Self {
118 self.max_lifetime = Some(d);
119 self
120 }
121
122 pub fn max_lifetime_secs(self, secs: u64) -> Self {
125 self.max_lifetime(Some(Duration::from_secs(secs)))
126 }
127
128 pub fn lifetime_secs(self, secs: u64) -> Self {
130 self.max_lifetime_secs(secs)
131 }
132
133 pub fn acquire_timeout(mut self, d: Option<Duration>) -> Self {
138 self.acquire_timeout = Some(d);
139 self
140 }
141
142 pub fn acquire_timeout_secs(self, secs: u64) -> Self {
145 self.acquire_timeout(Some(Duration::from_secs(secs)))
146 }
147
148 pub fn timeout_secs(self, secs: u64) -> Self {
150 self.acquire_timeout_secs(secs)
151 }
152
153 pub fn min_idle(mut self, n: usize) -> Self {
158 self.min_idle = Some(n);
159 self
160 }
161
162 pub fn replica_url(mut self, url: &str) -> Self {
168 self.replica_url = Some(url.into());
169 self
170 }
171
172 pub fn replica_max_size(mut self, size: usize) -> Self {
175 self.replica_max_size = Some(size);
176 self
177 }
178
179 pub async fn build(self) -> BsqlResult<Pool> {
180 let url = self.url.ok_or_else(|| {
181 BsqlError::from(bsql_driver_postgres::DriverError::Pool(
182 "pool builder requires a URL".into(),
183 ))
184 })?;
185 let mut builder = bsql_driver_postgres::Pool::builder()
186 .url(&url)
187 .max_size(self.max_size);
188
189 if let Some(lt) = self.max_lifetime {
190 builder = builder.max_lifetime(lt);
191 }
192 if let Some(at) = self.acquire_timeout {
193 builder = builder.acquire_timeout(at);
194 }
195 if let Some(mi) = self.min_idle {
196 builder = builder.min_idle(mi);
197 }
198
199 let inner = builder.build().map_err(BsqlError::from)?;
200
201 let read_pool = if let Some(replica_url) = &self.replica_url {
203 let replica_size = self.replica_max_size.unwrap_or(self.max_size);
204 let mut rbuilder = bsql_driver_postgres::Pool::builder()
205 .url(replica_url)
206 .max_size(replica_size);
207 if let Some(lt) = self.max_lifetime {
208 rbuilder = rbuilder.max_lifetime(lt);
209 }
210 if let Some(at) = self.acquire_timeout {
211 rbuilder = rbuilder.acquire_timeout(at);
212 }
213 Some(rbuilder.build().map_err(BsqlError::from)?)
214 } else {
215 None
216 };
217
218 Ok(Pool { inner, read_pool })
219 }
220}
221
222impl Pool {
223 pub async fn connect(url: &str) -> BsqlResult<Self> {
230 let inner = bsql_driver_postgres::Pool::connect(url).map_err(BsqlError::from)?;
231 Ok(Pool {
232 inner,
233 read_pool: None,
234 })
235 }
236
237 pub fn builder() -> PoolBuilder {
239 PoolBuilder {
240 url: None,
241 max_size: 10,
242 max_lifetime: None,
243 acquire_timeout: None,
244 min_idle: None,
245 replica_url: None,
246 replica_max_size: None,
247 }
248 }
249
250 pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
255 let guard = self.inner.acquire().map_err(BsqlError::from)?;
256 Ok(PoolConnection {
257 inner: Mutex::new(guard),
258 })
259 }
260
261 pub async fn begin(&self) -> BsqlResult<Transaction> {
265 let tx = self.inner.begin().map_err(BsqlError::from)?;
266 Ok(Transaction::from_driver(tx))
267 }
268
269 pub async fn query_stream(
278 &self,
279 sql: &str,
280 sql_hash: u64,
281 params: &[&(dyn Encode + Sync)],
282 ) -> BsqlResult<QueryStream> {
283 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
284 let mut arena = acquire_arena();
285
286 const CHUNK_SIZE: i32 = 64;
288
289 let (columns, _) = guard
290 .query_streaming_start(sql, sql_hash, params, CHUNK_SIZE)
291 .map_err(BsqlError::from)?;
292
293 let num_cols = columns.len();
294 let mut all_col_offsets: Vec<(usize, i32)> =
295 Vec::with_capacity(num_cols * CHUNK_SIZE as usize);
296
297 let more = guard
298 .streaming_next_chunk(&mut arena, &mut all_col_offsets)
299 .map_err(BsqlError::from)?;
300
301 let first_result = bsql_driver_postgres::QueryResult::from_parts(
302 all_col_offsets,
303 num_cols,
304 columns.clone(),
305 0,
306 );
307
308 Ok(QueryStream::new(guard, arena, first_result, columns, !more))
309 }
310
311 pub fn set_warmup_sqls(&self, sqls: &[&str]) {
319 self.inner.set_warmup_sqls(sqls);
320 }
321
322 pub async fn raw_query(&self, sql: &str) -> BsqlResult<Vec<RawRow>> {
330 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
331 let rows = guard
332 .simple_query_rows(sql)
333 .map_err(BsqlError::from_driver_query)?;
334 Ok(rows.into_iter().map(RawRow).collect())
335 }
336
337 pub async fn raw_execute(&self, sql: &str) -> BsqlResult<()> {
342 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
343 guard
344 .simple_query(sql)
345 .map_err(BsqlError::from_driver_query)?;
346 Ok(())
347 }
348
349 pub fn status(&self) -> PoolStatus {
353 let driver_status = self.inner.status();
354 PoolStatus {
355 idle: driver_status.idle,
356 active: driver_status.active,
357 open: driver_status.open,
358 max_size: driver_status.max_size,
359 }
360 }
361
362 pub fn close(&self) {
368 self.inner.close();
369 if let Some(ref rp) = self.read_pool {
370 rp.close();
371 }
372 }
373
374 pub fn is_closed(&self) -> bool {
376 self.inner.is_closed()
377 }
378
379 pub fn has_replica(&self) -> bool {
381 self.read_pool.is_some()
382 }
383
384 pub fn is_uds(&self) -> bool {
390 self.inner.is_uds()
391 }
392
393 pub async fn for_each_raw<F>(
402 &self,
403 sql: &str,
404 sql_hash: u64,
405 params: &[&(dyn Encode + Sync)],
406 readonly: bool,
407 mut f: F,
408 ) -> BsqlResult<()>
409 where
410 F: FnMut(bsql_driver_postgres::PgDataRow<'_>) -> BsqlResult<()>,
411 {
412 let pool = if readonly {
413 self.read_pool.as_ref().unwrap_or(&self.inner)
414 } else {
415 &self.inner
416 };
417 let mut guard = pool.acquire().map_err(BsqlError::from)?;
418 let mut user_err: Option<BsqlError> = None;
422 let driver_result = guard.for_each(sql, sql_hash, params, |row| match f(row) {
423 Ok(()) => Ok(()),
424 Err(e) => {
425 user_err = Some(e);
426 Err(bsql_driver_postgres::DriverError::Protocol(
427 "for_each closure error".into(),
428 ))
429 }
430 });
431 if let Some(e) = user_err {
433 return Err(e);
434 }
435 driver_result.map_err(BsqlError::from_driver_query)
436 }
437
438 #[doc(hidden)]
445 pub async fn __for_each_raw_bytes<F>(
446 &self,
447 sql: &str,
448 sql_hash: u64,
449 params: &[&(dyn Encode + Sync)],
450 readonly: bool,
451 mut f: F,
452 ) -> BsqlResult<()>
453 where
454 F: FnMut(&[u8]) -> BsqlResult<()>,
455 {
456 let pool = if readonly {
457 self.read_pool.as_ref().unwrap_or(&self.inner)
458 } else {
459 &self.inner
460 };
461 let mut guard = pool.acquire().map_err(BsqlError::from)?;
462 let mut user_err: Option<BsqlError> = None;
463 let driver_result = guard.for_each_raw(sql, sql_hash, params, |data| match f(data) {
464 Ok(()) => Ok(()),
465 Err(e) => {
466 user_err = Some(e);
467 Err(bsql_driver_postgres::DriverError::Protocol(
468 "for_each closure error".into(),
469 ))
470 }
471 });
472 if let Some(e) = user_err {
473 return Err(e);
474 }
475 driver_result.map_err(BsqlError::from_driver_query)
476 }
477}
478
479impl Clone for Pool {
480 fn clone(&self) -> Self {
481 Pool {
482 inner: self.inner.clone(),
483 read_pool: self.read_pool.clone(),
484 }
485 }
486}
487
488impl std::fmt::Debug for Pool {
489 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490 f.debug_struct("Pool")
491 .field("status", &self.status())
492 .finish()
493 }
494}
495
496pub struct PoolConnection {
505 pub(crate) inner: Mutex<bsql_driver_postgres::PoolGuard>,
506}
507
508impl std::fmt::Debug for PoolConnection {
509 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
510 f.debug_struct("PoolConnection").finish()
511 }
512}
513
514#[derive(Debug, Clone, Copy)]
516pub struct PoolStatus {
517 pub idle: usize,
519 pub active: usize,
521 pub open: usize,
523 pub max_size: usize,
525}
526
527impl std::fmt::Display for PoolStatus {
528 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
529 write!(
530 f,
531 "idle={}, active={}, open={}, max={}",
532 self.idle, self.active, self.open, self.max_size
533 )
534 }
535}
536
537#[cfg(test)]
538mod tests {
539 use super::*;
540
541 #[test]
542 fn builder_defaults() {
543 let b = Pool::builder();
544 assert_eq!(b.max_size, 10);
545 assert!(b.max_lifetime.is_none());
546 assert!(b.acquire_timeout.is_none());
547 assert!(b.min_idle.is_none());
548 }
549
550 #[test]
551 fn builder_max_lifetime() {
552 let b = Pool::builder().max_lifetime(Some(Duration::from_secs(60)));
553 assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(60))));
554 }
555
556 #[test]
557 fn builder_max_lifetime_none_disables() {
558 let b = Pool::builder().max_lifetime(None);
559 assert_eq!(b.max_lifetime, Some(None));
560 }
561
562 #[test]
563 fn builder_acquire_timeout() {
564 let b = Pool::builder().acquire_timeout(Some(Duration::from_secs(3)));
565 assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(3))));
566 }
567
568 #[test]
569 fn builder_acquire_timeout_none_disables() {
570 let b = Pool::builder().acquire_timeout(None);
571 assert_eq!(b.acquire_timeout, Some(None));
572 }
573
574 #[test]
575 fn builder_min_idle() {
576 let b = Pool::builder().min_idle(5);
577 assert_eq!(b.min_idle, Some(5));
578 }
579
580 #[test]
583 fn builder_max_lifetime_secs() {
584 let b = Pool::builder().max_lifetime_secs(1800);
585 assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(1800))));
586 }
587
588 #[test]
589 fn builder_acquire_timeout_secs() {
590 let b = Pool::builder().acquire_timeout_secs(5);
591 assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(5))));
592 }
593
594 #[test]
597 fn builder_lifetime_secs_shorthand() {
598 let b = Pool::builder().lifetime_secs(900);
599 assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(900))));
600 }
601
602 #[test]
603 fn builder_timeout_secs_shorthand() {
604 let b = Pool::builder().timeout_secs(3);
605 assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(3))));
606 }
607
608 #[test]
611 fn builder_defaults_no_replica() {
612 let b = Pool::builder();
613 assert!(b.replica_url.is_none());
614 assert!(b.replica_max_size.is_none());
615 }
616
617 #[test]
618 fn builder_replica_url() {
619 let b = Pool::builder().replica_url("postgres://replica:5432/db");
620 assert_eq!(b.replica_url.as_deref(), Some("postgres://replica:5432/db"));
621 }
622
623 #[test]
624 fn builder_replica_max_size() {
625 let b = Pool::builder().replica_max_size(20);
626 assert_eq!(b.replica_max_size, Some(20));
627 }
628
629 #[tokio::test]
630 async fn pool_connect_has_no_replica() {
631 let pool = Pool::connect("postgres://user:pass@localhost/db")
632 .await
633 .unwrap();
634 assert!(!pool.has_replica());
635 }
636
637 #[tokio::test]
640 async fn pool_is_uds_false_for_tcp() {
641 let pool = Pool::connect("postgres://user:pass@localhost/db")
642 .await
643 .unwrap();
644 assert!(!pool.is_uds());
645 }
646
647 #[cfg(unix)]
648 #[tokio::test]
649 async fn pool_is_uds_true_for_unix_socket() {
650 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp")
651 .await
652 .unwrap();
653 assert!(pool.is_uds());
654 }
655
656 #[tokio::test]
657 async fn pool_is_uds_false_for_ip() {
658 let pool = Pool::connect("postgres://user:pass@127.0.0.1/db")
659 .await
660 .unwrap();
661 assert!(!pool.is_uds());
662 }
663
664 #[test]
667 fn pool_status_display() {
668 let status = PoolStatus {
669 idle: 3,
670 active: 2,
671 open: 5,
672 max_size: 10,
673 };
674 assert_eq!(status.to_string(), "idle=3, active=2, open=5, max=10");
675 }
676
677 #[test]
678 fn pool_status_display_zeros() {
679 let status = PoolStatus {
680 idle: 0,
681 active: 0,
682 open: 0,
683 max_size: 0,
684 };
685 assert_eq!(status.to_string(), "idle=0, active=0, open=0, max=0");
686 }
687
688 #[test]
691 fn pool_connection_debug() {
692 let dbg_str = "PoolConnection";
694 assert!(!dbg_str.is_empty());
695 fn _assert_debug<T: std::fmt::Debug>() {}
698 _assert_debug::<PoolConnection>();
699 }
700
701 #[tokio::test]
704 async fn pool_debug() {
705 let pool = Pool::connect("postgres://user:pass@localhost/db")
706 .await
707 .unwrap();
708 let dbg = format!("{pool:?}");
709 assert!(dbg.contains("Pool"), "Debug should show Pool: {dbg}");
710 }
711
712 #[tokio::test]
715 async fn pool_clone_is_cheap() {
716 let pool = Pool::connect("postgres://user:pass@localhost/db")
717 .await
718 .unwrap();
719 let pool2 = pool.clone();
720 assert_eq!(pool.status().max_size, pool2.status().max_size);
721 assert!(!pool.has_replica());
722 assert!(!pool2.has_replica());
723 }
724
725 fn _assert_send<T: Send>() {}
728 fn _assert_sync<T: Sync>() {}
729
730 #[test]
731 fn pool_is_send_and_sync() {
732 _assert_send::<Pool>();
733 _assert_sync::<Pool>();
734 }
735
736 #[test]
737 fn pool_connection_is_send_and_sync() {
738 _assert_send::<PoolConnection>();
739 _assert_sync::<PoolConnection>();
740 }
741
742 #[test]
743 fn pool_status_is_send_and_sync() {
744 _assert_send::<PoolStatus>();
745 _assert_sync::<PoolStatus>();
746 }
747
748 #[tokio::test]
751 async fn builder_build_without_url_errors() {
752 let result = Pool::builder().build().await;
753 assert!(result.is_err());
754 let err = result.unwrap_err().to_string();
755 assert!(err.contains("URL"), "error should mention URL: {err}");
756 }
757
758 #[test]
761 fn builder_chaining() {
762 let b = Pool::builder()
763 .url("postgres://u@localhost/db")
764 .max_size(20)
765 .lifetime_secs(600)
766 .timeout_secs(3)
767 .min_idle(2)
768 .replica_url("postgres://u@replica/db")
769 .replica_max_size(10);
770 assert_eq!(b.max_size, 20);
771 assert_eq!(b.min_idle, Some(2));
772 assert_eq!(b.replica_max_size, Some(10));
773 }
774
775 #[test]
778 fn raw_row_get() {
779 let row = RawRow(vec![Some("hello".into()), None, Some("42".into())]);
780 assert_eq!(row.get(0), Some("hello"));
781 assert_eq!(row.get(1), None);
782 assert_eq!(row.get(2), Some("42"));
783 assert_eq!(row.get(99), None);
784 assert_eq!(row.len(), 3);
785 }
786
787 #[test]
788 fn raw_row_is_empty() {
789 let empty = RawRow(vec![]);
790 assert!(empty.is_empty());
791 assert_eq!(empty.len(), 0);
792
793 let non_empty = RawRow(vec![Some("x".into())]);
794 assert!(!non_empty.is_empty());
795 }
796
797 #[test]
798 fn raw_row_iter() {
799 let row = RawRow(vec![Some("a".into()), None, Some("b".into())]);
800 let vals: Vec<_> = row.iter().collect();
801 assert_eq!(vals, vec![Some("a"), None, Some("b")]);
802 }
803
804 #[test]
805 fn raw_row_clone() {
806 let row = RawRow(vec![Some("hello".into()), None]);
807 let cloned = row.clone();
808 assert_eq!(cloned.get(0), Some("hello"));
809 assert_eq!(cloned.get(1), None);
810 assert_eq!(cloned.len(), 2);
811 }
812
813 #[test]
814 fn raw_row_debug() {
815 let row = RawRow(vec![Some("x".into())]);
816 let dbg = format!("{row:?}");
817 assert!(dbg.contains("RawRow"), "Debug should show RawRow: {dbg}");
818 }
819
820 #[test]
823 fn raw_row_all_null_values() {
824 let row = RawRow(vec![None, None, None]);
825 assert_eq!(row.len(), 3);
826 assert!(!row.is_empty());
827 assert_eq!(row.get(0), None);
828 assert_eq!(row.get(1), None);
829 assert_eq!(row.get(2), None);
830 let vals: Vec<_> = row.iter().collect();
832 assert_eq!(vals, vec![None, None, None]);
833 }
834
835 #[test]
836 fn raw_row_empty_string_values() {
837 let row = RawRow(vec![Some(String::new()), Some("".into())]);
838 assert_eq!(row.len(), 2);
839 assert_eq!(row.get(0), Some(""));
841 assert_eq!(row.get(1), Some(""));
842 }
843
844 #[test]
845 fn raw_row_get_out_of_bounds() {
846 let row = RawRow(vec![Some("only".into())]);
847 assert_eq!(row.get(0), Some("only"));
848 assert_eq!(row.get(1), None);
849 assert_eq!(row.get(100), None);
850 assert_eq!(row.get(usize::MAX), None);
851 }
852
853 #[test]
854 fn raw_row_iter_empty() {
855 let row = RawRow(vec![]);
856 let vals: Vec<_> = row.iter().collect();
857 assert!(vals.is_empty());
858 }
859
860 #[test]
861 fn raw_row_iter_mixed() {
862 let row = RawRow(vec![
863 Some("hello".into()),
864 None,
865 Some("world".into()),
866 None,
867 Some("".into()),
868 ]);
869 let vals: Vec<_> = row.iter().collect();
870 assert_eq!(
871 vals,
872 vec![Some("hello"), None, Some("world"), None, Some("")]
873 );
874 }
875
876 #[test]
877 fn raw_row_single_null() {
878 let row = RawRow(vec![None]);
879 assert_eq!(row.len(), 1);
880 assert!(!row.is_empty());
881 assert_eq!(row.get(0), None);
882 }
883}