1use std::sync::Arc;
15
16use deadpool_postgres::{Config, ManagerConfig, RecyclingMethod, Runtime};
17#[cfg(not(feature = "tls"))]
18use tokio_postgres::NoTls;
19use tokio_postgres::types::ToSql;
20
21use crate::error::{BsqlError, BsqlResult, ConnectError};
22use crate::singleflight::{FlightStatus, Singleflight, sql_key};
23use crate::stream::QueryStream;
24use crate::transaction::Transaction;
25
26pub struct Pool {
31 primary: deadpool_postgres::Pool,
32 replicas: Vec<deadpool_postgres::Pool>,
35 replica_idx: std::sync::atomic::AtomicUsize,
37 pgbouncer: bool,
39 singleflight: Singleflight,
40}
41
42pub struct PoolBuilder {
44 host: Option<String>,
45 port: Option<u16>,
46 dbname: Option<String>,
47 user: Option<String>,
48 password: Option<String>,
49 max_size: usize,
50 connect_timeout_secs: u64,
51 replica_urls: Vec<String>,
52}
53
54impl PoolBuilder {
55 pub fn url(mut self, url: &str) -> Result<Self, BsqlError> {
63 let parsed = parse_pg_url(url)?;
64 self.host = parsed.host;
65 self.port = parsed.port;
66 self.dbname = parsed.dbname;
67 self.user = parsed.user;
68 self.password = parsed.password;
69 Ok(self)
70 }
71
72 pub fn host(mut self, host: &str) -> Self {
73 self.host = Some(host.into());
74 self
75 }
76
77 pub fn port(mut self, port: u16) -> Self {
78 self.port = Some(port);
79 self
80 }
81
82 pub fn dbname(mut self, dbname: &str) -> Self {
83 self.dbname = Some(dbname.into());
84 self
85 }
86
87 pub fn user(mut self, user: &str) -> Self {
88 self.user = Some(user.into());
89 self
90 }
91
92 pub fn password(mut self, password: &str) -> Self {
93 self.password = Some(password.into());
94 self
95 }
96
97 pub fn max_size(mut self, size: usize) -> Self {
98 self.max_size = size;
99 self
100 }
101
102 pub fn connect_timeout(mut self, secs: u64) -> Self {
105 self.connect_timeout_secs = secs;
106 self
107 }
108
109 pub fn replica(mut self, url: &str) -> Self {
117 self.replica_urls.push(url.into());
118 self
119 }
120
121 pub async fn build(self) -> BsqlResult<Pool> {
122 let mut cfg = Config::new();
123 cfg.host = self.host;
124 cfg.port = self.port;
125 cfg.dbname = self.dbname;
126 cfg.user = self.user;
127 cfg.password = self.password;
128 cfg.connect_timeout = Some(std::time::Duration::from_secs(self.connect_timeout_secs));
129 cfg.manager = Some(ManagerConfig {
130 recycling_method: RecyclingMethod::Fast,
131 });
132 cfg.pool = Some(deadpool_postgres::PoolConfig {
134 max_size: self.max_size,
135 timeouts: deadpool_postgres::Timeouts {
136 wait: Some(std::time::Duration::ZERO),
137 create: None,
138 recycle: None,
139 },
140 ..Default::default()
141 });
142
143 let pool = create_deadpool(cfg)?;
144
145 let mut pgbouncer = detect_pgbouncer(&pool).await?;
147
148 let mut replicas = Vec::with_capacity(self.replica_urls.len());
154 for url in &self.replica_urls {
155 let replica_pool =
156 create_pool_from_url(url, self.max_size, self.connect_timeout_secs).await?;
157 pgbouncer |= detect_pgbouncer(&replica_pool).await?;
158 replicas.push(replica_pool);
159 }
160
161 Ok(Pool {
162 primary: pool,
163 replicas,
164 replica_idx: std::sync::atomic::AtomicUsize::new(0),
165 pgbouncer,
166 singleflight: Singleflight::new(),
167 })
168 }
169}
170
171impl Pool {
172 pub async fn connect(url: &str) -> BsqlResult<Self> {
176 Pool::builder().url(url)?.build().await
177 }
178
179 pub fn builder() -> PoolBuilder {
181 PoolBuilder {
182 host: None,
183 port: None,
184 dbname: None,
185 user: None,
186 password: None,
187 max_size: 16,
188 connect_timeout_secs: 5,
189 replica_urls: Vec::new(),
190 }
191 }
192
193 pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
198 let conn = self.primary.get().await.map_err(BsqlError::from)?;
199
200 Ok(PoolConnection { inner: conn })
201 }
202
203 pub fn is_pgbouncer(&self) -> bool {
205 self.pgbouncer
206 }
207
208 pub fn has_replicas(&self) -> bool {
210 !self.replicas.is_empty()
211 }
212
213 pub async fn begin(&self) -> BsqlResult<Transaction> {
227 let conn = self.acquire().await?;
228 Ok(Transaction::new(conn))
229 }
230
231 pub async fn query_stream(
245 &self,
246 sql: &str,
247 params: &[&(dyn ToSql + Sync)],
248 ) -> BsqlResult<QueryStream> {
249 let conn = self.acquire().await?;
250 let stmt = conn
251 .inner
252 .prepare_cached(sql)
253 .await
254 .map_err(BsqlError::from)?;
255
256 let row_stream = conn
257 .inner
258 .query_raw(&stmt, params.iter().copied())
259 .await
260 .map_err(BsqlError::from)?;
261
262 Ok(QueryStream::new(conn, row_stream))
263 }
264
265 pub async fn warmup(&self, sqls: &[&str]) -> BsqlResult<()> {
287 if sqls.is_empty() {
288 return Ok(());
289 }
290 let conn = self.acquire().await?;
291 for sql in sqls {
292 conn.inner
293 .prepare_cached(sql)
294 .await
295 .map_err(BsqlError::from)?;
296 }
297 Ok(())
298 }
299
300 pub fn status(&self) -> PoolStatus {
302 let status = self.primary.status();
303 PoolStatus {
304 available: status.available,
305 size: status.size,
306 max_size: status.max_size,
307 }
308 }
309
310 pub(crate) async fn query_raw_primary(
314 &self,
315 sql: &str,
316 params: &[&(dyn ToSql + Sync)],
317 ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
318 if params.is_empty() {
322 let key = sql_key(sql);
323 self.query_with_singleflight(key, sql, params, false).await
324 } else {
325 self.execute_on_pool(sql, params, false).await
326 }
327 }
328
329 pub(crate) async fn query_raw_read(
332 &self,
333 sql: &str,
334 params: &[&(dyn ToSql + Sync)],
335 ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
336 if self.replicas.is_empty() {
337 return self.query_raw_primary(sql, params).await;
338 }
339
340 if params.is_empty() {
341 let key = sql_key(sql);
342 match self.query_with_singleflight(key, sql, params, true).await {
344 Ok(rows) => Ok(rows),
345 Err(_) => self.query_with_singleflight(key, sql, params, false).await,
346 }
347 } else {
348 match self.execute_on_pool(sql, params, true).await {
350 Ok(rows) => Ok(rows),
351 Err(_) => self.execute_on_pool(sql, params, false).await,
352 }
353 }
354 }
355
356 async fn query_with_singleflight(
358 &self,
359 key: u64,
360 sql: &str,
361 params: &[&(dyn ToSql + Sync)],
362 use_replica: bool,
363 ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
364 match self.singleflight.try_join(key) {
365 FlightStatus::Follower(mut rx) => {
366 match rx.recv().await {
368 Ok(rows) => Ok(rows),
369 Err(_) => {
370 self.execute_on_pool(sql, params, use_replica).await
372 }
373 }
374 }
375 FlightStatus::Leader => match self.execute_on_pool(sql, params, use_replica).await {
376 Ok(rows) => {
377 self.singleflight.complete(key, Arc::clone(&rows));
378 Ok(rows)
379 }
380 Err(e) => {
381 self.singleflight.abandon(key);
382 Err(e)
383 }
384 },
385 }
386 }
387
388 async fn execute_on_pool(
390 &self,
391 sql: &str,
392 params: &[&(dyn ToSql + Sync)],
393 use_replica: bool,
394 ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
395 let raw_conn = if use_replica && !self.replicas.is_empty() {
396 let idx = self
397 .replica_idx
398 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
399 % self.replicas.len();
400 self.replicas[idx].get().await.map_err(BsqlError::from)?
401 } else {
402 self.primary.get().await.map_err(BsqlError::from)?
403 };
404
405 let stmt = raw_conn
406 .prepare_cached(sql)
407 .await
408 .map_err(BsqlError::from)?;
409
410 let rows = raw_conn
411 .query(&stmt, params)
412 .await
413 .map_err(BsqlError::from)?;
414
415 Ok(Arc::from(rows))
416 }
417}
418
419impl std::fmt::Debug for Pool {
420 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
421 f.debug_struct("Pool")
422 .field("status", &self.status())
423 .field("is_pgbouncer", &self.pgbouncer)
424 .field("replicas", &self.replicas.len())
425 .finish()
426 }
427}
428
429pub struct PoolConnection {
433 pub(crate) inner: deadpool_postgres::Object,
434}
435
436#[derive(Debug, Clone, Copy)]
438pub struct PoolStatus {
439 pub available: usize,
440 pub size: usize,
441 pub max_size: usize,
442}
443
444struct ParsedUrl {
446 host: Option<String>,
447 port: Option<u16>,
448 dbname: Option<String>,
449 user: Option<String>,
450 password: Option<String>,
451}
452
453fn parse_pg_url(url: &str) -> BsqlResult<ParsedUrl> {
455 let config: tokio_postgres::Config = url
456 .parse()
457 .map_err(|e: tokio_postgres::Error| ConnectError::create(e.to_string()))?;
458
459 let host = config.get_hosts().first().map(|h| match h {
460 tokio_postgres::config::Host::Tcp(s) => s.clone(),
461 #[cfg(unix)]
462 tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().into_owned(),
463 });
464 let port = config.get_ports().first().copied();
465 let dbname = config.get_dbname().map(String::from);
466 let user = config.get_user().map(String::from);
467 let password = match config.get_password() {
468 Some(p) => Some(
469 String::from_utf8(p.to_vec())
470 .map_err(|_| ConnectError::create("database password contains invalid UTF-8"))?,
471 ),
472 None => None,
473 };
474 Ok(ParsedUrl {
475 host,
476 port,
477 dbname,
478 user,
479 password,
480 })
481}
482
483async fn create_pool_from_url(
487 url: &str,
488 max_size: usize,
489 connect_timeout_secs: u64,
490) -> BsqlResult<deadpool_postgres::Pool> {
491 let parsed = parse_pg_url(url)?;
492
493 let mut cfg = Config::new();
494 cfg.host = parsed.host;
495 cfg.port = parsed.port;
496 cfg.dbname = parsed.dbname;
497 cfg.user = parsed.user;
498 cfg.password = parsed.password;
499 cfg.connect_timeout = Some(std::time::Duration::from_secs(connect_timeout_secs));
500 cfg.manager = Some(ManagerConfig {
501 recycling_method: RecyclingMethod::Fast,
502 });
503 cfg.pool = Some(deadpool_postgres::PoolConfig {
504 max_size,
505 timeouts: deadpool_postgres::Timeouts {
506 wait: Some(std::time::Duration::ZERO),
507 create: None,
508 recycle: None,
509 },
510 ..Default::default()
511 });
512
513 let pool = create_deadpool(cfg)?;
514
515 let _conn = pool
517 .get()
518 .await
519 .map_err(|e| ConnectError::with_source(format!("failed to connect to replica: {e}"), e))?;
520
521 Ok(pool)
522}
523
524fn create_deadpool(cfg: Config) -> BsqlResult<deadpool_postgres::Pool> {
529 #[cfg(feature = "tls")]
530 {
531 let tls = make_rustls_connect();
532 cfg.create_pool(Some(Runtime::Tokio1), tls)
533 .map_err(|e| ConnectError::create(e.to_string()))
534 }
535 #[cfg(not(feature = "tls"))]
536 {
537 cfg.create_pool(Some(Runtime::Tokio1), NoTls)
538 .map_err(|e| ConnectError::create(e.to_string()))
539 }
540}
541
542#[cfg(feature = "tls")]
547pub(crate) fn make_rustls_connect() -> tokio_postgres_rustls::MakeRustlsConnect {
548 let mut roots = rustls::RootCertStore::empty();
549 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
550 let config = rustls::ClientConfig::builder()
551 .with_root_certificates(roots)
552 .with_no_client_auth();
553 tokio_postgres_rustls::MakeRustlsConnect::new(config)
554}
555
556async fn detect_pgbouncer(pool: &deadpool_postgres::Pool) -> BsqlResult<bool> {
563 let conn = pool.get().await.map_err(|e| {
564 ConnectError::with_source(format!("failed to establish initial connection: {e}"), e)
565 })?;
566
567 Ok(conn.simple_query("SHOW POOLS").await.is_ok())
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574
575 #[test]
576 fn builder_defaults() {
577 let b = Pool::builder();
578 assert_eq!(b.max_size, 16);
579 assert_eq!(b.connect_timeout_secs, 5);
580 assert!(b.replica_urls.is_empty());
581 }
582
583 #[test]
584 fn builder_config() {
585 let b = Pool::builder()
586 .host("localhost")
587 .port(5432)
588 .dbname("test")
589 .user("app")
590 .password("secret")
591 .max_size(8)
592 .connect_timeout(10);
593
594 assert_eq!(b.host.as_deref(), Some("localhost"));
595 assert_eq!(b.port, Some(5432));
596 assert_eq!(b.dbname.as_deref(), Some("test"));
597 assert_eq!(b.user.as_deref(), Some("app"));
598 assert_eq!(b.password.as_deref(), Some("secret"));
599 assert_eq!(b.max_size, 8);
600 assert_eq!(b.connect_timeout_secs, 10);
601 }
602
603 #[test]
604 fn builder_replicas() {
605 let b = Pool::builder()
606 .replica("postgres://replica1:5432/db")
607 .replica("postgres://replica2:5432/db");
608 assert_eq!(b.replica_urls.len(), 2);
609 }
610}