1use std::sync::Arc;
15
16use deadpool_postgres::{Config, ManagerConfig, RecyclingMethod, Runtime};
17#[cfg(not(feature = "tls"))]
18use tokio_postgres::NoTls;
19use tokio_postgres::types::ToSql;
20
21use crate::error::{BsqlError, BsqlResult, ConnectError};
22use crate::singleflight::{FlightStatus, Singleflight, sql_key};
23use crate::stream::QueryStream;
24use crate::transaction::Transaction;
25
26pub struct Pool {
31 primary: deadpool_postgres::Pool,
32 replicas: Vec<deadpool_postgres::Pool>,
35 replica_idx: std::sync::atomic::AtomicUsize,
37 pgbouncer: bool,
39 singleflight: Singleflight,
40}
41
42pub struct PoolBuilder {
44 host: Option<String>,
45 port: Option<u16>,
46 dbname: Option<String>,
47 user: Option<String>,
48 password: Option<String>,
49 max_size: usize,
50 connect_timeout_secs: u64,
51 replica_urls: Vec<String>,
52}
53
54impl PoolBuilder {
55 pub fn url(mut self, url: &str) -> Result<Self, BsqlError> {
63 let parsed = parse_pg_url(url)?;
64 self.host = parsed.host;
65 self.port = parsed.port;
66 self.dbname = parsed.dbname;
67 self.user = parsed.user;
68 self.password = parsed.password;
69 Ok(self)
70 }
71
72 pub fn host(mut self, host: &str) -> Self {
73 self.host = Some(host.into());
74 self
75 }
76
77 pub fn port(mut self, port: u16) -> Self {
78 self.port = Some(port);
79 self
80 }
81
82 pub fn dbname(mut self, dbname: &str) -> Self {
83 self.dbname = Some(dbname.into());
84 self
85 }
86
87 pub fn user(mut self, user: &str) -> Self {
88 self.user = Some(user.into());
89 self
90 }
91
92 pub fn password(mut self, password: &str) -> Self {
93 self.password = Some(password.into());
94 self
95 }
96
97 pub fn max_size(mut self, size: usize) -> Self {
98 self.max_size = size;
99 self
100 }
101
102 pub fn connect_timeout(mut self, secs: u64) -> Self {
105 self.connect_timeout_secs = secs;
106 self
107 }
108
109 pub fn replica(mut self, url: &str) -> Self {
117 self.replica_urls.push(url.into());
118 self
119 }
120
121 pub async fn build(self) -> BsqlResult<Pool> {
122 let mut cfg = Config::new();
123 cfg.host = self.host;
124 cfg.port = self.port;
125 cfg.dbname = self.dbname;
126 cfg.user = self.user;
127 cfg.password = self.password;
128 cfg.connect_timeout = Some(std::time::Duration::from_secs(self.connect_timeout_secs));
129 cfg.manager = Some(ManagerConfig {
130 recycling_method: RecyclingMethod::Fast,
131 });
132 cfg.pool = Some(deadpool_postgres::PoolConfig {
134 max_size: self.max_size,
135 timeouts: deadpool_postgres::Timeouts {
136 wait: Some(std::time::Duration::ZERO),
137 create: None,
138 recycle: None,
139 },
140 ..Default::default()
141 });
142
143 let pool = create_deadpool(cfg)?;
144
145 let mut pgbouncer = detect_pgbouncer(&pool).await?;
147
148 let mut replicas = Vec::with_capacity(self.replica_urls.len());
154 for url in &self.replica_urls {
155 let replica_pool =
156 create_pool_from_url(url, self.max_size, self.connect_timeout_secs).await?;
157 pgbouncer |= detect_pgbouncer(&replica_pool).await?;
158 replicas.push(replica_pool);
159 }
160
161 Ok(Pool {
162 primary: pool,
163 replicas,
164 replica_idx: std::sync::atomic::AtomicUsize::new(0),
165 pgbouncer,
166 singleflight: Singleflight::new(),
167 })
168 }
169}
170
171impl Pool {
172 pub async fn connect(url: &str) -> BsqlResult<Self> {
176 Pool::builder().url(url)?.build().await
177 }
178
179 pub fn builder() -> PoolBuilder {
181 PoolBuilder {
182 host: None,
183 port: None,
184 dbname: None,
185 user: None,
186 password: None,
187 max_size: 16,
188 connect_timeout_secs: 5,
189 replica_urls: Vec::new(),
190 }
191 }
192
193 pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
198 let conn = self.primary.get().await.map_err(BsqlError::from)?;
199
200 Ok(PoolConnection { inner: conn })
201 }
202
203 pub fn is_pgbouncer(&self) -> bool {
205 self.pgbouncer
206 }
207
208 pub fn has_replicas(&self) -> bool {
210 !self.replicas.is_empty()
211 }
212
213 pub async fn begin(&self) -> BsqlResult<Transaction> {
227 let conn = self.acquire().await?;
228 Ok(Transaction::new(conn))
229 }
230
231 pub async fn query_stream(
245 &self,
246 sql: &str,
247 params: &[&(dyn ToSql + Sync)],
248 ) -> BsqlResult<QueryStream> {
249 let conn = self.acquire().await?;
250 let stmt = conn
251 .inner
252 .prepare_cached(sql)
253 .await
254 .map_err(BsqlError::from)?;
255
256 let row_stream = conn
257 .inner
258 .query_raw(&stmt, params.iter().copied())
259 .await
260 .map_err(BsqlError::from)?;
261
262 Ok(QueryStream::new(conn, row_stream))
263 }
264
265 pub fn status(&self) -> PoolStatus {
267 let status = self.primary.status();
268 PoolStatus {
269 available: status.available,
270 size: status.size,
271 max_size: status.max_size,
272 }
273 }
274
275 pub(crate) async fn query_raw_primary(
279 &self,
280 sql: &str,
281 params: &[&(dyn ToSql + Sync)],
282 ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
283 if params.is_empty() {
287 let key = sql_key(sql);
288 self.query_with_singleflight(key, sql, params, false).await
289 } else {
290 self.execute_on_pool(sql, params, false).await
291 }
292 }
293
294 pub(crate) async fn query_raw_read(
297 &self,
298 sql: &str,
299 params: &[&(dyn ToSql + Sync)],
300 ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
301 if self.replicas.is_empty() {
302 return self.query_raw_primary(sql, params).await;
303 }
304
305 if params.is_empty() {
306 let key = sql_key(sql);
307 match self.query_with_singleflight(key, sql, params, true).await {
309 Ok(rows) => Ok(rows),
310 Err(_) => self.query_with_singleflight(key, sql, params, false).await,
311 }
312 } else {
313 match self.execute_on_pool(sql, params, true).await {
315 Ok(rows) => Ok(rows),
316 Err(_) => self.execute_on_pool(sql, params, false).await,
317 }
318 }
319 }
320
321 async fn query_with_singleflight(
323 &self,
324 key: u64,
325 sql: &str,
326 params: &[&(dyn ToSql + Sync)],
327 use_replica: bool,
328 ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
329 match self.singleflight.try_join(key) {
330 FlightStatus::Follower(mut rx) => {
331 match rx.recv().await {
333 Ok(rows) => Ok(rows),
334 Err(_) => {
335 self.execute_on_pool(sql, params, use_replica).await
337 }
338 }
339 }
340 FlightStatus::Leader => match self.execute_on_pool(sql, params, use_replica).await {
341 Ok(rows) => {
342 self.singleflight.complete(key, Arc::clone(&rows));
343 Ok(rows)
344 }
345 Err(e) => {
346 self.singleflight.abandon(key);
347 Err(e)
348 }
349 },
350 }
351 }
352
353 async fn execute_on_pool(
355 &self,
356 sql: &str,
357 params: &[&(dyn ToSql + Sync)],
358 use_replica: bool,
359 ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
360 let raw_conn = if use_replica && !self.replicas.is_empty() {
361 let idx = self
362 .replica_idx
363 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
364 % self.replicas.len();
365 self.replicas[idx].get().await.map_err(BsqlError::from)?
366 } else {
367 self.primary.get().await.map_err(BsqlError::from)?
368 };
369
370 let stmt = raw_conn
371 .prepare_cached(sql)
372 .await
373 .map_err(BsqlError::from)?;
374
375 let rows = raw_conn
376 .query(&stmt, params)
377 .await
378 .map_err(BsqlError::from)?;
379
380 Ok(Arc::from(rows))
381 }
382}
383
384impl std::fmt::Debug for Pool {
385 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 f.debug_struct("Pool")
387 .field("status", &self.status())
388 .field("is_pgbouncer", &self.pgbouncer)
389 .field("replicas", &self.replicas.len())
390 .finish()
391 }
392}
393
394pub struct PoolConnection {
398 pub(crate) inner: deadpool_postgres::Object,
399}
400
401#[derive(Debug, Clone, Copy)]
403pub struct PoolStatus {
404 pub available: usize,
405 pub size: usize,
406 pub max_size: usize,
407}
408
409struct ParsedUrl {
411 host: Option<String>,
412 port: Option<u16>,
413 dbname: Option<String>,
414 user: Option<String>,
415 password: Option<String>,
416}
417
418fn parse_pg_url(url: &str) -> BsqlResult<ParsedUrl> {
420 let config: tokio_postgres::Config = url
421 .parse()
422 .map_err(|e: tokio_postgres::Error| ConnectError::create(e.to_string()))?;
423
424 let host = config.get_hosts().first().map(|h| match h {
425 tokio_postgres::config::Host::Tcp(s) => s.clone(),
426 #[cfg(unix)]
427 tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().into_owned(),
428 });
429 let port = config.get_ports().first().copied();
430 let dbname = config.get_dbname().map(String::from);
431 let user = config.get_user().map(String::from);
432 let password = match config.get_password() {
433 Some(p) => Some(
434 String::from_utf8(p.to_vec())
435 .map_err(|_| ConnectError::create("database password contains invalid UTF-8"))?,
436 ),
437 None => None,
438 };
439 Ok(ParsedUrl {
440 host,
441 port,
442 dbname,
443 user,
444 password,
445 })
446}
447
448async fn create_pool_from_url(
452 url: &str,
453 max_size: usize,
454 connect_timeout_secs: u64,
455) -> BsqlResult<deadpool_postgres::Pool> {
456 let parsed = parse_pg_url(url)?;
457
458 let mut cfg = Config::new();
459 cfg.host = parsed.host;
460 cfg.port = parsed.port;
461 cfg.dbname = parsed.dbname;
462 cfg.user = parsed.user;
463 cfg.password = parsed.password;
464 cfg.connect_timeout = Some(std::time::Duration::from_secs(connect_timeout_secs));
465 cfg.manager = Some(ManagerConfig {
466 recycling_method: RecyclingMethod::Fast,
467 });
468 cfg.pool = Some(deadpool_postgres::PoolConfig {
469 max_size,
470 timeouts: deadpool_postgres::Timeouts {
471 wait: Some(std::time::Duration::ZERO),
472 create: None,
473 recycle: None,
474 },
475 ..Default::default()
476 });
477
478 let pool = create_deadpool(cfg)?;
479
480 let _conn = pool
482 .get()
483 .await
484 .map_err(|e| ConnectError::with_source(format!("failed to connect to replica: {e}"), e))?;
485
486 Ok(pool)
487}
488
489fn create_deadpool(cfg: Config) -> BsqlResult<deadpool_postgres::Pool> {
494 #[cfg(feature = "tls")]
495 {
496 let tls = make_rustls_connect();
497 cfg.create_pool(Some(Runtime::Tokio1), tls)
498 .map_err(|e| ConnectError::create(e.to_string()))
499 }
500 #[cfg(not(feature = "tls"))]
501 {
502 cfg.create_pool(Some(Runtime::Tokio1), NoTls)
503 .map_err(|e| ConnectError::create(e.to_string()))
504 }
505}
506
507#[cfg(feature = "tls")]
512pub(crate) fn make_rustls_connect() -> tokio_postgres_rustls::MakeRustlsConnect {
513 let mut roots = rustls::RootCertStore::empty();
514 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
515 let config = rustls::ClientConfig::builder()
516 .with_root_certificates(roots)
517 .with_no_client_auth();
518 tokio_postgres_rustls::MakeRustlsConnect::new(config)
519}
520
521async fn detect_pgbouncer(pool: &deadpool_postgres::Pool) -> BsqlResult<bool> {
528 let conn = pool.get().await.map_err(|e| {
529 ConnectError::with_source(format!("failed to establish initial connection: {e}"), e)
530 })?;
531
532 Ok(conn.simple_query("SHOW POOLS").await.is_ok())
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 #[test]
541 fn builder_defaults() {
542 let b = Pool::builder();
543 assert_eq!(b.max_size, 16);
544 assert_eq!(b.connect_timeout_secs, 5);
545 assert!(b.replica_urls.is_empty());
546 }
547
548 #[test]
549 fn builder_config() {
550 let b = Pool::builder()
551 .host("localhost")
552 .port(5432)
553 .dbname("test")
554 .user("app")
555 .password("secret")
556 .max_size(8)
557 .connect_timeout(10);
558
559 assert_eq!(b.host.as_deref(), Some("localhost"));
560 assert_eq!(b.port, Some(5432));
561 assert_eq!(b.dbname.as_deref(), Some("test"));
562 assert_eq!(b.user.as_deref(), Some("app"));
563 assert_eq!(b.password.as_deref(), Some("secret"));
564 assert_eq!(b.max_size, 8);
565 assert_eq!(b.connect_timeout_secs, 10);
566 }
567
568 #[test]
569 fn builder_replicas() {
570 let b = Pool::builder()
571 .replica("postgres://replica1:5432/db")
572 .replica("postgres://replica2:5432/db");
573 assert_eq!(b.replica_urls.len(), 2);
574 }
575}