1use std::time::Duration;
7
8use bsql_driver_postgres::arena::acquire_arena;
9use bsql_driver_postgres::codec::Encode;
10use tokio::sync::Mutex;
11
12use crate::error::{BsqlError, BsqlResult};
13use crate::stream::QueryStream;
14use crate::transaction::Transaction;
15
16pub struct Pool {
38 pub(crate) inner: bsql_driver_postgres::Pool,
39 pub(crate) read_pool: Option<bsql_driver_postgres::Pool>,
41}
42
43pub struct PoolBuilder {
60 url: Option<String>,
61 max_size: usize,
62 max_lifetime: Option<Option<Duration>>,
63 acquire_timeout: Option<Option<Duration>>,
64 min_idle: Option<usize>,
65 replica_url: Option<String>,
68 replica_max_size: Option<usize>,
70}
71
72impl PoolBuilder {
73 pub fn url(mut self, url: &str) -> Self {
77 self.url = Some(url.into());
78 self
79 }
80
81 pub fn max_size(mut self, size: usize) -> Self {
82 self.max_size = size;
83 self
84 }
85
86 pub fn max_lifetime(mut self, d: Option<Duration>) -> Self {
91 self.max_lifetime = Some(d);
92 self
93 }
94
95 pub fn max_lifetime_secs(self, secs: u64) -> Self {
98 self.max_lifetime(Some(Duration::from_secs(secs)))
99 }
100
101 pub fn lifetime_secs(self, secs: u64) -> Self {
103 self.max_lifetime_secs(secs)
104 }
105
106 pub fn acquire_timeout(mut self, d: Option<Duration>) -> Self {
111 self.acquire_timeout = Some(d);
112 self
113 }
114
115 pub fn acquire_timeout_secs(self, secs: u64) -> Self {
118 self.acquire_timeout(Some(Duration::from_secs(secs)))
119 }
120
121 pub fn timeout_secs(self, secs: u64) -> Self {
123 self.acquire_timeout_secs(secs)
124 }
125
126 pub fn min_idle(mut self, n: usize) -> Self {
131 self.min_idle = Some(n);
132 self
133 }
134
135 pub fn replica_url(mut self, url: &str) -> Self {
141 self.replica_url = Some(url.into());
142 self
143 }
144
145 pub fn replica_max_size(mut self, size: usize) -> Self {
148 self.replica_max_size = Some(size);
149 self
150 }
151
152 pub async fn build(self) -> BsqlResult<Pool> {
153 let url = self.url.ok_or_else(|| {
154 BsqlError::from(bsql_driver_postgres::DriverError::Pool(
155 "pool builder requires a URL".into(),
156 ))
157 })?;
158 let mut builder = bsql_driver_postgres::Pool::builder()
159 .url(&url)
160 .max_size(self.max_size);
161
162 if let Some(lt) = self.max_lifetime {
163 builder = builder.max_lifetime(lt);
164 }
165 if let Some(at) = self.acquire_timeout {
166 builder = builder.acquire_timeout(at);
167 }
168 if let Some(mi) = self.min_idle {
169 builder = builder.min_idle(mi);
170 }
171
172 let inner = builder.build().await.map_err(BsqlError::from)?;
173
174 let read_pool = if let Some(replica_url) = &self.replica_url {
176 let replica_size = self.replica_max_size.unwrap_or(self.max_size);
177 let mut rbuilder = bsql_driver_postgres::Pool::builder()
178 .url(replica_url)
179 .max_size(replica_size);
180 if let Some(lt) = self.max_lifetime {
181 rbuilder = rbuilder.max_lifetime(lt);
182 }
183 if let Some(at) = self.acquire_timeout {
184 rbuilder = rbuilder.acquire_timeout(at);
185 }
186 Some(rbuilder.build().await.map_err(BsqlError::from)?)
187 } else {
188 None
189 };
190
191 Ok(Pool { inner, read_pool })
192 }
193}
194
195impl Pool {
196 pub async fn connect(url: &str) -> BsqlResult<Self> {
200 let inner = bsql_driver_postgres::Pool::connect(url)
201 .await
202 .map_err(BsqlError::from)?;
203 Ok(Pool {
204 inner,
205 read_pool: None,
206 })
207 }
208
209 pub fn builder() -> PoolBuilder {
211 PoolBuilder {
212 url: None,
213 max_size: 10,
214 max_lifetime: None,
215 acquire_timeout: None,
216 min_idle: None,
217 replica_url: None,
218 replica_max_size: None,
219 }
220 }
221
222 pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
227 let guard = self.inner.acquire().await.map_err(BsqlError::from)?;
228 Ok(PoolConnection {
229 inner: Mutex::new(guard),
230 })
231 }
232
233 pub async fn begin(&self) -> BsqlResult<Transaction> {
237 let tx = self.inner.begin().await.map_err(BsqlError::from)?;
238 Ok(Transaction::from_driver(tx))
239 }
240
241 pub async fn query_stream(
250 &self,
251 sql: &str,
252 sql_hash: u64,
253 params: &[&(dyn Encode + Sync)],
254 ) -> BsqlResult<QueryStream> {
255 let mut guard = self.inner.acquire().await.map_err(BsqlError::from)?;
256 let mut arena = acquire_arena();
257
258 const CHUNK_SIZE: i32 = 64;
260
261 let (columns, _) = guard
262 .query_streaming_start(sql, sql_hash, params, CHUNK_SIZE)
263 .await
264 .map_err(BsqlError::from)?;
265
266 let num_cols = columns.len();
267 let mut all_col_offsets: Vec<(usize, i32)> =
268 Vec::with_capacity(num_cols * CHUNK_SIZE as usize);
269
270 let more = guard
271 .streaming_next_chunk(&mut arena, &mut all_col_offsets)
272 .await
273 .map_err(BsqlError::from)?;
274
275 let first_result = bsql_driver_postgres::QueryResult::from_parts(
276 all_col_offsets,
277 num_cols,
278 columns.clone(),
279 0,
280 );
281
282 Ok(QueryStream::new(guard, arena, first_result, columns, !more))
283 }
284
285 pub fn set_warmup_sqls(&self, sqls: &[&str]) {
293 self.inner.set_warmup_sqls(sqls);
294 }
295
296 pub fn status(&self) -> PoolStatus {
300 let driver_status = self.inner.status();
301 PoolStatus {
302 idle: driver_status.idle,
303 active: driver_status.active,
304 open: driver_status.open,
305 max_size: driver_status.max_size,
306 }
307 }
308
309 pub async fn close(&self) {
315 self.inner.close().await;
316 if let Some(ref rp) = self.read_pool {
317 rp.close().await;
318 }
319 }
320
321 pub fn is_closed(&self) -> bool {
323 self.inner.is_closed()
324 }
325
326 pub fn has_replica(&self) -> bool {
328 self.read_pool.is_some()
329 }
330
331 pub fn is_uds(&self) -> bool {
337 self.inner.is_uds()
338 }
339
340 pub async fn for_each_raw<F>(
349 &self,
350 sql: &str,
351 sql_hash: u64,
352 params: &[&(dyn Encode + Sync)],
353 readonly: bool,
354 mut f: F,
355 ) -> BsqlResult<()>
356 where
357 F: FnMut(bsql_driver_postgres::PgDataRow<'_>) -> BsqlResult<()>,
358 {
359 let pool = if readonly {
360 self.read_pool.as_ref().unwrap_or(&self.inner)
361 } else {
362 &self.inner
363 };
364 let mut guard = pool.acquire().await.map_err(BsqlError::from)?;
365 let mut user_err: Option<BsqlError> = None;
369 let driver_result = guard
370 .for_each(sql, sql_hash, params, |row| match f(row) {
371 Ok(()) => Ok(()),
372 Err(e) => {
373 user_err = Some(e);
374 Err(bsql_driver_postgres::DriverError::Protocol(
375 "for_each closure error".into(),
376 ))
377 }
378 })
379 .await;
380 if let Some(e) = user_err {
382 return Err(e);
383 }
384 driver_result.map_err(BsqlError::from_driver_query)
385 }
386
387 #[doc(hidden)]
394 pub async fn __for_each_raw_bytes<F>(
395 &self,
396 sql: &str,
397 sql_hash: u64,
398 params: &[&(dyn Encode + Sync)],
399 readonly: bool,
400 mut f: F,
401 ) -> BsqlResult<()>
402 where
403 F: FnMut(&[u8]) -> BsqlResult<()>,
404 {
405 let pool = if readonly {
406 self.read_pool.as_ref().unwrap_or(&self.inner)
407 } else {
408 &self.inner
409 };
410 let mut guard = pool.acquire().await.map_err(BsqlError::from)?;
411 let mut user_err: Option<BsqlError> = None;
412 let driver_result = guard
413 .for_each_raw(sql, sql_hash, params, |data| match f(data) {
414 Ok(()) => Ok(()),
415 Err(e) => {
416 user_err = Some(e);
417 Err(bsql_driver_postgres::DriverError::Protocol(
418 "for_each closure error".into(),
419 ))
420 }
421 })
422 .await;
423 if let Some(e) = user_err {
424 return Err(e);
425 }
426 driver_result.map_err(BsqlError::from_driver_query)
427 }
428}
429
430impl Clone for Pool {
431 fn clone(&self) -> Self {
432 Pool {
433 inner: self.inner.clone(),
434 read_pool: self.read_pool.clone(),
435 }
436 }
437}
438
439impl std::fmt::Debug for Pool {
440 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
441 f.debug_struct("Pool")
442 .field("status", &self.status())
443 .finish()
444 }
445}
446
447pub struct PoolConnection {
458 pub(crate) inner: Mutex<bsql_driver_postgres::PoolGuard>,
459}
460
461impl std::fmt::Debug for PoolConnection {
462 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
463 f.debug_struct("PoolConnection").finish()
464 }
465}
466
467#[derive(Debug, Clone, Copy)]
469pub struct PoolStatus {
470 pub idle: usize,
472 pub active: usize,
474 pub open: usize,
476 pub max_size: usize,
478}
479
480impl std::fmt::Display for PoolStatus {
481 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
482 write!(
483 f,
484 "idle={}, active={}, open={}, max={}",
485 self.idle, self.active, self.open, self.max_size
486 )
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493
494 #[test]
495 fn builder_defaults() {
496 let b = Pool::builder();
497 assert_eq!(b.max_size, 10);
498 assert!(b.max_lifetime.is_none());
499 assert!(b.acquire_timeout.is_none());
500 assert!(b.min_idle.is_none());
501 }
502
503 #[test]
504 fn builder_max_lifetime() {
505 let b = Pool::builder().max_lifetime(Some(Duration::from_secs(60)));
506 assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(60))));
507 }
508
509 #[test]
510 fn builder_max_lifetime_none_disables() {
511 let b = Pool::builder().max_lifetime(None);
512 assert_eq!(b.max_lifetime, Some(None));
513 }
514
515 #[test]
516 fn builder_acquire_timeout() {
517 let b = Pool::builder().acquire_timeout(Some(Duration::from_secs(3)));
518 assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(3))));
519 }
520
521 #[test]
522 fn builder_acquire_timeout_none_disables() {
523 let b = Pool::builder().acquire_timeout(None);
524 assert_eq!(b.acquire_timeout, Some(None));
525 }
526
527 #[test]
528 fn builder_min_idle() {
529 let b = Pool::builder().min_idle(5);
530 assert_eq!(b.min_idle, Some(5));
531 }
532
533 #[test]
536 fn builder_max_lifetime_secs() {
537 let b = Pool::builder().max_lifetime_secs(1800);
538 assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(1800))));
539 }
540
541 #[test]
542 fn builder_acquire_timeout_secs() {
543 let b = Pool::builder().acquire_timeout_secs(5);
544 assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(5))));
545 }
546
547 #[test]
550 fn builder_lifetime_secs_shorthand() {
551 let b = Pool::builder().lifetime_secs(900);
552 assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(900))));
553 }
554
555 #[test]
556 fn builder_timeout_secs_shorthand() {
557 let b = Pool::builder().timeout_secs(3);
558 assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(3))));
559 }
560
561 #[test]
564 fn builder_defaults_no_replica() {
565 let b = Pool::builder();
566 assert!(b.replica_url.is_none());
567 assert!(b.replica_max_size.is_none());
568 }
569
570 #[test]
571 fn builder_replica_url() {
572 let b = Pool::builder().replica_url("postgres://replica:5432/db");
573 assert_eq!(b.replica_url.as_deref(), Some("postgres://replica:5432/db"));
574 }
575
576 #[test]
577 fn builder_replica_max_size() {
578 let b = Pool::builder().replica_max_size(20);
579 assert_eq!(b.replica_max_size, Some(20));
580 }
581
582 #[tokio::test]
583 async fn pool_connect_has_no_replica() {
584 let pool = Pool::connect("postgres://user:pass@localhost/db")
585 .await
586 .unwrap();
587 assert!(!pool.has_replica());
588 }
589
590 #[tokio::test]
593 async fn pool_is_uds_false_for_tcp() {
594 let pool = Pool::connect("postgres://user:pass@localhost/db")
595 .await
596 .unwrap();
597 assert!(!pool.is_uds());
598 }
599
600 #[cfg(unix)]
601 #[tokio::test]
602 async fn pool_is_uds_true_for_unix_socket() {
603 let pool = Pool::connect("postgres://user@localhost/db?host=/tmp")
604 .await
605 .unwrap();
606 assert!(pool.is_uds());
607 }
608
609 #[tokio::test]
610 async fn pool_is_uds_false_for_ip() {
611 let pool = Pool::connect("postgres://user:pass@127.0.0.1/db")
612 .await
613 .unwrap();
614 assert!(!pool.is_uds());
615 }
616
617 #[test]
620 fn pool_status_display() {
621 let status = PoolStatus {
622 idle: 3,
623 active: 2,
624 open: 5,
625 max_size: 10,
626 };
627 assert_eq!(status.to_string(), "idle=3, active=2, open=5, max=10");
628 }
629
630 #[test]
631 fn pool_status_display_zeros() {
632 let status = PoolStatus {
633 idle: 0,
634 active: 0,
635 open: 0,
636 max_size: 0,
637 };
638 assert_eq!(status.to_string(), "idle=0, active=0, open=0, max=0");
639 }
640
641 #[tokio::test]
644 async fn pool_connection_debug() {
645 let dbg_str = "PoolConnection";
647 assert!(!dbg_str.is_empty());
648 fn _assert_debug<T: std::fmt::Debug>() {}
651 _assert_debug::<PoolConnection>();
652 }
653
654 #[tokio::test]
657 async fn pool_debug() {
658 let pool = Pool::connect("postgres://user:pass@localhost/db")
659 .await
660 .unwrap();
661 let dbg = format!("{pool:?}");
662 assert!(dbg.contains("Pool"), "Debug should show Pool: {dbg}");
663 }
664
665 #[tokio::test]
668 async fn pool_clone_is_cheap() {
669 let pool = Pool::connect("postgres://user:pass@localhost/db")
670 .await
671 .unwrap();
672 let pool2 = pool.clone();
673 assert_eq!(pool.status().max_size, pool2.status().max_size);
674 assert!(!pool.has_replica());
675 assert!(!pool2.has_replica());
676 }
677
678 fn _assert_send<T: Send>() {}
681 fn _assert_sync<T: Sync>() {}
682
683 #[test]
684 fn pool_is_send_and_sync() {
685 _assert_send::<Pool>();
686 _assert_sync::<Pool>();
687 }
688
689 #[test]
690 fn pool_connection_is_send_and_sync() {
691 _assert_send::<PoolConnection>();
692 _assert_sync::<PoolConnection>();
693 }
694
695 #[test]
696 fn pool_status_is_send_and_sync() {
697 _assert_send::<PoolStatus>();
698 _assert_sync::<PoolStatus>();
699 }
700
701 #[tokio::test]
704 async fn builder_build_without_url_errors() {
705 let result = Pool::builder().build().await;
706 assert!(result.is_err());
707 let err = result.unwrap_err().to_string();
708 assert!(err.contains("URL"), "error should mention URL: {err}");
709 }
710
711 #[test]
714 fn builder_chaining() {
715 let b = Pool::builder()
716 .url("postgres://u@localhost/db")
717 .max_size(20)
718 .lifetime_secs(600)
719 .timeout_secs(3)
720 .min_idle(2)
721 .replica_url("postgres://u@replica/db")
722 .replica_max_size(10);
723 assert_eq!(b.max_size, 20);
724 assert_eq!(b.min_idle, Some(2));
725 assert_eq!(b.replica_max_size, Some(10));
726 }
727}