1use std::sync::Mutex;
7use std::time::Duration;
8
9use bsql_driver_postgres::arena::acquire_arena;
10use bsql_driver_postgres::codec::Encode;
11
12use crate::error::{BsqlError, BsqlResult};
13use crate::stream::QueryStream;
14use crate::transaction::Transaction;
15
16#[derive(Debug, Clone)]
21pub struct RawRow(Vec<Option<String>>);
22
23impl RawRow {
24 pub fn get(&self, idx: usize) -> Option<&str> {
26 self.0.get(idx)?.as_deref()
27 }
28
29 pub fn len(&self) -> usize {
31 self.0.len()
32 }
33
34 pub fn is_empty(&self) -> bool {
36 self.0.is_empty()
37 }
38
39 pub fn iter(&self) -> impl Iterator<Item = Option<&str>> {
41 self.0.iter().map(|v| v.as_deref())
42 }
43}
44
45pub struct Pool {
66 pub(crate) inner: bsql_driver_postgres::Pool,
67 pub(crate) read_pool: Option<bsql_driver_postgres::Pool>,
69}
70
71pub struct PoolBuilder {
87 url: Option<String>,
88 max_size: usize,
89 max_lifetime: Option<Option<Duration>>,
90 acquire_timeout: Option<Option<Duration>>,
91 min_idle: Option<usize>,
92 replica_url: Option<String>,
95 replica_max_size: Option<usize>,
97}
98
99impl PoolBuilder {
100 pub fn url(mut self, url: &str) -> Self {
104 self.url = Some(url.into());
105 self
106 }
107
108 pub fn max_size(mut self, size: usize) -> Self {
109 self.max_size = size;
110 self
111 }
112
113 pub fn max_lifetime(mut self, d: Option<Duration>) -> Self {
118 self.max_lifetime = Some(d);
119 self
120 }
121
122 pub fn max_lifetime_secs(self, secs: u64) -> Self {
125 self.max_lifetime(Some(Duration::from_secs(secs)))
126 }
127
128 pub fn lifetime_secs(self, secs: u64) -> Self {
130 self.max_lifetime_secs(secs)
131 }
132
133 pub fn acquire_timeout(mut self, d: Option<Duration>) -> Self {
138 self.acquire_timeout = Some(d);
139 self
140 }
141
142 pub fn acquire_timeout_secs(self, secs: u64) -> Self {
145 self.acquire_timeout(Some(Duration::from_secs(secs)))
146 }
147
148 pub fn timeout_secs(self, secs: u64) -> Self {
150 self.acquire_timeout_secs(secs)
151 }
152
153 pub fn min_idle(mut self, n: usize) -> Self {
158 self.min_idle = Some(n);
159 self
160 }
161
162 pub fn replica_url(mut self, url: &str) -> Self {
168 self.replica_url = Some(url.into());
169 self
170 }
171
172 pub fn replica_max_size(mut self, size: usize) -> Self {
175 self.replica_max_size = Some(size);
176 self
177 }
178
179 pub fn build(self) -> BsqlResult<Pool> {
180 let url = self.url.ok_or_else(|| {
181 BsqlError::from(bsql_driver_postgres::DriverError::Pool(
182 "pool builder requires a URL".into(),
183 ))
184 })?;
185 let mut builder = bsql_driver_postgres::Pool::builder()
186 .url(&url)
187 .max_size(self.max_size);
188
189 if let Some(lt) = self.max_lifetime {
190 builder = builder.max_lifetime(lt);
191 }
192 if let Some(at) = self.acquire_timeout {
193 builder = builder.acquire_timeout(at);
194 }
195 if let Some(mi) = self.min_idle {
196 builder = builder.min_idle(mi);
197 }
198
199 let inner = builder.build().map_err(BsqlError::from)?;
200
201 let read_pool = if let Some(replica_url) = &self.replica_url {
203 let replica_size = self.replica_max_size.unwrap_or(self.max_size);
204 let mut rbuilder = bsql_driver_postgres::Pool::builder()
205 .url(replica_url)
206 .max_size(replica_size);
207 if let Some(lt) = self.max_lifetime {
208 rbuilder = rbuilder.max_lifetime(lt);
209 }
210 if let Some(at) = self.acquire_timeout {
211 rbuilder = rbuilder.acquire_timeout(at);
212 }
213 Some(rbuilder.build().map_err(BsqlError::from)?)
214 } else {
215 None
216 };
217
218 Ok(Pool { inner, read_pool })
219 }
220}
221
222impl Pool {
223 pub fn connect(url: &str) -> BsqlResult<Self> {
227 let inner = bsql_driver_postgres::Pool::connect(url).map_err(BsqlError::from)?;
228 Ok(Pool {
229 inner,
230 read_pool: None,
231 })
232 }
233
234 pub fn builder() -> PoolBuilder {
236 PoolBuilder {
237 url: None,
238 max_size: 10,
239 max_lifetime: None,
240 acquire_timeout: None,
241 min_idle: None,
242 replica_url: None,
243 replica_max_size: None,
244 }
245 }
246
247 pub fn acquire(&self) -> BsqlResult<PoolConnection> {
252 let guard = self.inner.acquire().map_err(BsqlError::from)?;
253 Ok(PoolConnection {
254 inner: Mutex::new(guard),
255 })
256 }
257
258 pub fn begin(&self) -> BsqlResult<Transaction> {
262 let tx = self.inner.begin().map_err(BsqlError::from)?;
263 Ok(Transaction::from_driver(tx))
264 }
265
266 pub fn query_stream(
275 &self,
276 sql: &str,
277 sql_hash: u64,
278 params: &[&(dyn Encode + Sync)],
279 ) -> BsqlResult<QueryStream> {
280 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
281 let mut arena = acquire_arena();
282
283 const CHUNK_SIZE: i32 = 64;
285
286 let (columns, _) = guard
287 .query_streaming_start(sql, sql_hash, params, CHUNK_SIZE)
288 .map_err(BsqlError::from)?;
289
290 let num_cols = columns.len();
291 let mut all_col_offsets: Vec<(usize, i32)> =
292 Vec::with_capacity(num_cols * CHUNK_SIZE as usize);
293
294 let more = guard
295 .streaming_next_chunk(&mut arena, &mut all_col_offsets)
296 .map_err(BsqlError::from)?;
297
298 let first_result = bsql_driver_postgres::QueryResult::from_parts(
299 all_col_offsets,
300 num_cols,
301 columns.clone(),
302 0,
303 );
304
305 Ok(QueryStream::new(guard, arena, first_result, columns, !more))
306 }
307
308 pub fn set_warmup_sqls(&self, sqls: &[&str]) {
316 self.inner.set_warmup_sqls(sqls);
317 }
318
319 pub fn raw_query(&self, sql: &str) -> BsqlResult<Vec<RawRow>> {
327 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
328 let rows = guard
329 .simple_query_rows(sql)
330 .map_err(BsqlError::from_driver_query)?;
331 Ok(rows.into_iter().map(RawRow).collect())
332 }
333
334 pub fn raw_execute(&self, sql: &str) -> BsqlResult<()> {
339 let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
340 guard
341 .simple_query(sql)
342 .map_err(BsqlError::from_driver_query)?;
343 Ok(())
344 }
345
346 pub fn status(&self) -> PoolStatus {
350 let driver_status = self.inner.status();
351 PoolStatus {
352 idle: driver_status.idle,
353 active: driver_status.active,
354 open: driver_status.open,
355 max_size: driver_status.max_size,
356 }
357 }
358
359 pub fn close(&self) {
365 self.inner.close();
366 if let Some(ref rp) = self.read_pool {
367 rp.close();
368 }
369 }
370
371 pub fn is_closed(&self) -> bool {
373 self.inner.is_closed()
374 }
375
376 pub fn has_replica(&self) -> bool {
378 self.read_pool.is_some()
379 }
380
381 pub fn is_uds(&self) -> bool {
387 self.inner.is_uds()
388 }
389
390 pub fn for_each_raw<F>(
399 &self,
400 sql: &str,
401 sql_hash: u64,
402 params: &[&(dyn Encode + Sync)],
403 readonly: bool,
404 mut f: F,
405 ) -> BsqlResult<()>
406 where
407 F: FnMut(bsql_driver_postgres::PgDataRow<'_>) -> BsqlResult<()>,
408 {
409 let pool = if readonly {
410 self.read_pool.as_ref().unwrap_or(&self.inner)
411 } else {
412 &self.inner
413 };
414 let mut guard = pool.acquire().map_err(BsqlError::from)?;
415 let mut user_err: Option<BsqlError> = None;
419 let driver_result = guard.for_each(sql, sql_hash, params, |row| match f(row) {
420 Ok(()) => Ok(()),
421 Err(e) => {
422 user_err = Some(e);
423 Err(bsql_driver_postgres::DriverError::Protocol(
424 "for_each closure error".into(),
425 ))
426 }
427 });
428 if let Some(e) = user_err {
430 return Err(e);
431 }
432 driver_result.map_err(BsqlError::from_driver_query)
433 }
434
435 #[doc(hidden)]
442 pub fn __for_each_raw_bytes<F>(
443 &self,
444 sql: &str,
445 sql_hash: u64,
446 params: &[&(dyn Encode + Sync)],
447 readonly: bool,
448 mut f: F,
449 ) -> BsqlResult<()>
450 where
451 F: FnMut(&[u8]) -> BsqlResult<()>,
452 {
453 let pool = if readonly {
454 self.read_pool.as_ref().unwrap_or(&self.inner)
455 } else {
456 &self.inner
457 };
458 let mut guard = pool.acquire().map_err(BsqlError::from)?;
459 let mut user_err: Option<BsqlError> = None;
460 let driver_result = guard.for_each_raw(sql, sql_hash, params, |data| match f(data) {
461 Ok(()) => Ok(()),
462 Err(e) => {
463 user_err = Some(e);
464 Err(bsql_driver_postgres::DriverError::Protocol(
465 "for_each closure error".into(),
466 ))
467 }
468 });
469 if let Some(e) = user_err {
470 return Err(e);
471 }
472 driver_result.map_err(BsqlError::from_driver_query)
473 }
474}
475
476impl Clone for Pool {
477 fn clone(&self) -> Self {
478 Pool {
479 inner: self.inner.clone(),
480 read_pool: self.read_pool.clone(),
481 }
482 }
483}
484
485impl std::fmt::Debug for Pool {
486 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
487 f.debug_struct("Pool")
488 .field("status", &self.status())
489 .finish()
490 }
491}
492
493pub struct PoolConnection {
502 pub(crate) inner: Mutex<bsql_driver_postgres::PoolGuard>,
503}
504
505impl std::fmt::Debug for PoolConnection {
506 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
507 f.debug_struct("PoolConnection").finish()
508 }
509}
510
511#[derive(Debug, Clone, Copy)]
513pub struct PoolStatus {
514 pub idle: usize,
516 pub active: usize,
518 pub open: usize,
520 pub max_size: usize,
522}
523
524impl std::fmt::Display for PoolStatus {
525 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526 write!(
527 f,
528 "idle={}, active={}, open={}, max={}",
529 self.idle, self.active, self.open, self.max_size
530 )
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537
538 #[test]
539 fn builder_defaults() {
540 let b = Pool::builder();
541 assert_eq!(b.max_size, 10);
542 assert!(b.max_lifetime.is_none());
543 assert!(b.acquire_timeout.is_none());
544 assert!(b.min_idle.is_none());
545 }
546
547 #[test]
548 fn builder_max_lifetime() {
549 let b = Pool::builder().max_lifetime(Some(Duration::from_secs(60)));
550 assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(60))));
551 }
552
553 #[test]
554 fn builder_max_lifetime_none_disables() {
555 let b = Pool::builder().max_lifetime(None);
556 assert_eq!(b.max_lifetime, Some(None));
557 }
558
559 #[test]
560 fn builder_acquire_timeout() {
561 let b = Pool::builder().acquire_timeout(Some(Duration::from_secs(3)));
562 assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(3))));
563 }
564
565 #[test]
566 fn builder_acquire_timeout_none_disables() {
567 let b = Pool::builder().acquire_timeout(None);
568 assert_eq!(b.acquire_timeout, Some(None));
569 }
570
571 #[test]
572 fn builder_min_idle() {
573 let b = Pool::builder().min_idle(5);
574 assert_eq!(b.min_idle, Some(5));
575 }
576
577 #[test]
580 fn builder_max_lifetime_secs() {
581 let b = Pool::builder().max_lifetime_secs(1800);
582 assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(1800))));
583 }
584
585 #[test]
586 fn builder_acquire_timeout_secs() {
587 let b = Pool::builder().acquire_timeout_secs(5);
588 assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(5))));
589 }
590
591 #[test]
594 fn builder_lifetime_secs_shorthand() {
595 let b = Pool::builder().lifetime_secs(900);
596 assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(900))));
597 }
598
599 #[test]
600 fn builder_timeout_secs_shorthand() {
601 let b = Pool::builder().timeout_secs(3);
602 assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(3))));
603 }
604
605 #[test]
608 fn builder_defaults_no_replica() {
609 let b = Pool::builder();
610 assert!(b.replica_url.is_none());
611 assert!(b.replica_max_size.is_none());
612 }
613
614 #[test]
615 fn builder_replica_url() {
616 let b = Pool::builder().replica_url("postgres://replica:5432/db");
617 assert_eq!(b.replica_url.as_deref(), Some("postgres://replica:5432/db"));
618 }
619
620 #[test]
621 fn builder_replica_max_size() {
622 let b = Pool::builder().replica_max_size(20);
623 assert_eq!(b.replica_max_size, Some(20));
624 }
625
626 #[test]
627 fn pool_connect_has_no_replica() {
628 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
629 assert!(!pool.has_replica());
630 }
631
632 #[test]
635 fn pool_is_uds_false_for_tcp() {
636 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
637 assert!(!pool.is_uds());
638 }
639
640 #[cfg(unix)]
641 #[test]
642 fn pool_is_uds_true_for_unix_socket() {
643 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
644 assert!(pool.is_uds());
645 }
646
647 #[test]
648 fn pool_is_uds_false_for_ip() {
649 let pool = Pool::connect("postgres://user:pass@127.0.0.1/db").unwrap();
650 assert!(!pool.is_uds());
651 }
652
653 #[test]
656 fn pool_status_display() {
657 let status = PoolStatus {
658 idle: 3,
659 active: 2,
660 open: 5,
661 max_size: 10,
662 };
663 assert_eq!(status.to_string(), "idle=3, active=2, open=5, max=10");
664 }
665
666 #[test]
667 fn pool_status_display_zeros() {
668 let status = PoolStatus {
669 idle: 0,
670 active: 0,
671 open: 0,
672 max_size: 0,
673 };
674 assert_eq!(status.to_string(), "idle=0, active=0, open=0, max=0");
675 }
676
677 #[test]
680 fn pool_connection_debug() {
681 let dbg_str = "PoolConnection";
683 assert!(!dbg_str.is_empty());
684 fn _assert_debug<T: std::fmt::Debug>() {}
687 _assert_debug::<PoolConnection>();
688 }
689
690 #[test]
693 fn pool_debug() {
694 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
695 let dbg = format!("{pool:?}");
696 assert!(dbg.contains("Pool"), "Debug should show Pool: {dbg}");
697 }
698
699 #[test]
702 fn pool_clone_is_cheap() {
703 let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
704 let pool2 = pool.clone();
705 assert_eq!(pool.status().max_size, pool2.status().max_size);
706 assert!(!pool.has_replica());
707 assert!(!pool2.has_replica());
708 }
709
710 fn _assert_send<T: Send>() {}
713 fn _assert_sync<T: Sync>() {}
714
715 #[test]
716 fn pool_is_send_and_sync() {
717 _assert_send::<Pool>();
718 _assert_sync::<Pool>();
719 }
720
721 #[test]
722 fn pool_connection_is_send_and_sync() {
723 _assert_send::<PoolConnection>();
724 _assert_sync::<PoolConnection>();
725 }
726
727 #[test]
728 fn pool_status_is_send_and_sync() {
729 _assert_send::<PoolStatus>();
730 _assert_sync::<PoolStatus>();
731 }
732
733 #[test]
736 fn builder_build_without_url_errors() {
737 let result = Pool::builder().build();
738 assert!(result.is_err());
739 let err = result.unwrap_err().to_string();
740 assert!(err.contains("URL"), "error should mention URL: {err}");
741 }
742
743 #[test]
746 fn builder_chaining() {
747 let b = Pool::builder()
748 .url("postgres://u@localhost/db")
749 .max_size(20)
750 .lifetime_secs(600)
751 .timeout_secs(3)
752 .min_idle(2)
753 .replica_url("postgres://u@replica/db")
754 .replica_max_size(10);
755 assert_eq!(b.max_size, 20);
756 assert_eq!(b.min_idle, Some(2));
757 assert_eq!(b.replica_max_size, Some(10));
758 }
759
760 #[test]
763 fn raw_row_get() {
764 let row = RawRow(vec![Some("hello".into()), None, Some("42".into())]);
765 assert_eq!(row.get(0), Some("hello"));
766 assert_eq!(row.get(1), None);
767 assert_eq!(row.get(2), Some("42"));
768 assert_eq!(row.get(99), None);
769 assert_eq!(row.len(), 3);
770 }
771
772 #[test]
773 fn raw_row_is_empty() {
774 let empty = RawRow(vec![]);
775 assert!(empty.is_empty());
776 assert_eq!(empty.len(), 0);
777
778 let non_empty = RawRow(vec![Some("x".into())]);
779 assert!(!non_empty.is_empty());
780 }
781
782 #[test]
783 fn raw_row_iter() {
784 let row = RawRow(vec![Some("a".into()), None, Some("b".into())]);
785 let vals: Vec<_> = row.iter().collect();
786 assert_eq!(vals, vec![Some("a"), None, Some("b")]);
787 }
788
789 #[test]
790 fn raw_row_clone() {
791 let row = RawRow(vec![Some("hello".into()), None]);
792 let cloned = row.clone();
793 assert_eq!(cloned.get(0), Some("hello"));
794 assert_eq!(cloned.get(1), None);
795 assert_eq!(cloned.len(), 2);
796 }
797
798 #[test]
799 fn raw_row_debug() {
800 let row = RawRow(vec![Some("x".into())]);
801 let dbg = format!("{row:?}");
802 assert!(dbg.contains("RawRow"), "Debug should show RawRow: {dbg}");
803 }
804
805 #[test]
808 fn raw_row_all_null_values() {
809 let row = RawRow(vec![None, None, None]);
810 assert_eq!(row.len(), 3);
811 assert!(!row.is_empty());
812 assert_eq!(row.get(0), None);
813 assert_eq!(row.get(1), None);
814 assert_eq!(row.get(2), None);
815 let vals: Vec<_> = row.iter().collect();
817 assert_eq!(vals, vec![None, None, None]);
818 }
819
820 #[test]
821 fn raw_row_empty_string_values() {
822 let row = RawRow(vec![Some(String::new()), Some("".into())]);
823 assert_eq!(row.len(), 2);
824 assert_eq!(row.get(0), Some(""));
826 assert_eq!(row.get(1), Some(""));
827 }
828
829 #[test]
830 fn raw_row_get_out_of_bounds() {
831 let row = RawRow(vec![Some("only".into())]);
832 assert_eq!(row.get(0), Some("only"));
833 assert_eq!(row.get(1), None);
834 assert_eq!(row.get(100), None);
835 assert_eq!(row.get(usize::MAX), None);
836 }
837
838 #[test]
839 fn raw_row_iter_empty() {
840 let row = RawRow(vec![]);
841 let vals: Vec<_> = row.iter().collect();
842 assert!(vals.is_empty());
843 }
844
845 #[test]
846 fn raw_row_iter_mixed() {
847 let row = RawRow(vec![
848 Some("hello".into()),
849 None,
850 Some("world".into()),
851 None,
852 Some("".into()),
853 ]);
854 let vals: Vec<_> = row.iter().collect();
855 assert_eq!(
856 vals,
857 vec![Some("hello"), None, Some("world"), None, Some("")]
858 );
859 }
860
861 #[test]
862 fn raw_row_single_null() {
863 let row = RawRow(vec![None]);
864 assert_eq!(row.len(), 1);
865 assert!(!row.is_empty());
866 assert_eq!(row.get(0), None);
867 }
868}