Skip to main content

bsql_core/
pool.rs

1//! Connection pool with fail-fast semantics, PgBouncer detection,
2//! singleflight query coalescing, and read/write splitting.
3//!
4//! The pool wraps `deadpool-postgres` with key behaviors:
5//! - **Fail-fast**: `acquire()` returns `PoolExhausted` immediately when no
6//!   connections are available. It does not wait. See CREDO principle #17.
7//! - **PgBouncer detection**: on pool creation, bsql detects whether the
8//!   connection goes through PgBouncer and adjusts prepared statement strategy.
9//! - **Singleflight** (v0.7): identical concurrent SELECT queries are coalesced
10//!   into a single PG round-trip. The result is shared via `Arc<[Row]>`.
11//! - **Read/write splitting** (v0.7): when replicas are configured, SELECT
12//!   queries are routed to replicas. Writes always go to the primary.
13
14use std::sync::Arc;
15
16use deadpool_postgres::{Config, ManagerConfig, RecyclingMethod, Runtime};
17#[cfg(not(feature = "tls"))]
18use tokio_postgres::NoTls;
19use tokio_postgres::types::ToSql;
20
21use crate::error::{BsqlError, BsqlResult, ConnectError};
22use crate::singleflight::{FlightStatus, Singleflight, sql_key};
23use crate::stream::QueryStream;
24use crate::transaction::Transaction;
25
26/// A PostgreSQL connection pool.
27///
28/// Wraps `deadpool-postgres` with fail-fast acquire semantics, singleflight
29/// query coalescing, and optional read/write splitting.
30pub struct Pool {
31    primary: deadpool_postgres::Pool,
32    /// Replica pools for read-only queries. Round-robin selection.
33    /// Empty when no replicas are configured.
34    replicas: Vec<deadpool_postgres::Pool>,
35    /// Atomic counter for round-robin replica selection.
36    replica_idx: std::sync::atomic::AtomicUsize,
37    /// True if PgBouncer was detected between the client and PostgreSQL.
38    pgbouncer: bool,
39    singleflight: Singleflight,
40}
41
42/// Builder for configuring a connection pool.
43pub struct PoolBuilder {
44    host: Option<String>,
45    port: Option<u16>,
46    dbname: Option<String>,
47    user: Option<String>,
48    password: Option<String>,
49    max_size: usize,
50    connect_timeout_secs: u64,
51    replica_urls: Vec<String>,
52}
53
54impl PoolBuilder {
55    /// Configure the pool from a PostgreSQL connection URL.
56    ///
57    /// Parses `postgres://user:password@host:port/dbname` and fills in
58    /// host, port, dbname, user, and password. Other builder settings
59    /// (max_size, connect_timeout, replicas) are preserved.
60    ///
61    /// Returns an error if the URL cannot be parsed.
62    pub fn url(mut self, url: &str) -> Result<Self, BsqlError> {
63        let parsed = parse_pg_url(url)?;
64        self.host = parsed.host;
65        self.port = parsed.port;
66        self.dbname = parsed.dbname;
67        self.user = parsed.user;
68        self.password = parsed.password;
69        Ok(self)
70    }
71
72    pub fn host(mut self, host: &str) -> Self {
73        self.host = Some(host.into());
74        self
75    }
76
77    pub fn port(mut self, port: u16) -> Self {
78        self.port = Some(port);
79        self
80    }
81
82    pub fn dbname(mut self, dbname: &str) -> Self {
83        self.dbname = Some(dbname.into());
84        self
85    }
86
87    pub fn user(mut self, user: &str) -> Self {
88        self.user = Some(user.into());
89        self
90    }
91
92    pub fn password(mut self, password: &str) -> Self {
93        self.password = Some(password.into());
94        self
95    }
96
97    pub fn max_size(mut self, size: usize) -> Self {
98        self.max_size = size;
99        self
100    }
101
102    /// TCP connect timeout in seconds. This is the ONLY timeout in bsql --
103    /// it exists because TCP itself will wait forever on a dead network.
104    pub fn connect_timeout(mut self, secs: u64) -> Self {
105        self.connect_timeout_secs = secs;
106        self
107    }
108
109    /// Add a read replica. SELECT queries will be routed to replicas when
110    /// the executor uses `query_raw_readonly` (generated for SELECT queries).
111    ///
112    /// Multiple replicas are selected round-robin. If a replica is unavailable,
113    /// the query falls back to the primary.
114    ///
115    /// Format: `postgres://user:password@host:port/dbname`
116    pub fn replica(mut self, url: &str) -> Self {
117        self.replica_urls.push(url.into());
118        self
119    }
120
121    pub async fn build(self) -> BsqlResult<Pool> {
122        let mut cfg = Config::new();
123        cfg.host = self.host;
124        cfg.port = self.port;
125        cfg.dbname = self.dbname;
126        cfg.user = self.user;
127        cfg.password = self.password;
128        cfg.connect_timeout = Some(std::time::Duration::from_secs(self.connect_timeout_secs));
129        cfg.manager = Some(ManagerConfig {
130            recycling_method: RecyclingMethod::Fast,
131        });
132        // FIX 2: fail-fast -- zero wait timeout means acquire() never blocks
133        cfg.pool = Some(deadpool_postgres::PoolConfig {
134            max_size: self.max_size,
135            timeouts: deadpool_postgres::Timeouts {
136                wait: Some(std::time::Duration::ZERO),
137                create: None,
138                recycle: None,
139            },
140            ..Default::default()
141        });
142
143        let pool = create_deadpool(cfg)?;
144
145        // FIX 11: detect PgBouncer -- propagate connection failure
146        let mut pgbouncer = detect_pgbouncer(&pool).await?;
147
148        // Build replica pools, detecting PgBouncer on each.
149        // If ANY pool (primary or replica) is behind PgBouncer, we disable
150        // named prepared statements globally. This is conservative but safe:
151        // a mixed topology is unusual, and the performance cost of unnamed
152        // statements is negligible compared to a hard failure.
153        let mut replicas = Vec::with_capacity(self.replica_urls.len());
154        for url in &self.replica_urls {
155            let replica_pool =
156                create_pool_from_url(url, self.max_size, self.connect_timeout_secs).await?;
157            pgbouncer |= detect_pgbouncer(&replica_pool).await?;
158            replicas.push(replica_pool);
159        }
160
161        Ok(Pool {
162            primary: pool,
163            replicas,
164            replica_idx: std::sync::atomic::AtomicUsize::new(0),
165            pgbouncer,
166            singleflight: Singleflight::new(),
167        })
168    }
169}
170
171impl Pool {
172    /// Connect to PostgreSQL using a connection URL.
173    ///
174    /// Format: `postgres://user:password@host:port/dbname`
175    pub async fn connect(url: &str) -> BsqlResult<Self> {
176        Pool::builder().url(url)?.build().await
177    }
178
179    /// Create a pool builder for fine-grained configuration.
180    pub fn builder() -> PoolBuilder {
181        PoolBuilder {
182            host: None,
183            port: None,
184            dbname: None,
185            user: None,
186            password: None,
187            max_size: 16,
188            connect_timeout_secs: 5,
189            replica_urls: Vec::new(),
190        }
191    }
192
193    /// Acquire a connection from the primary pool.
194    ///
195    /// **Fail-fast**: returns `BsqlError::Pool` immediately if no connections
196    /// are available. Does not wait. Does not timeout. See CREDO principle #17.
197    pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
198        let conn = self.primary.get().await.map_err(BsqlError::from)?;
199
200        Ok(PoolConnection { inner: conn })
201    }
202
203    /// Whether PgBouncer was detected between the client and PostgreSQL.
204    pub fn is_pgbouncer(&self) -> bool {
205        self.pgbouncer
206    }
207
208    /// Whether read replicas are configured.
209    pub fn has_replicas(&self) -> bool {
210        !self.replicas.is_empty()
211    }
212
213    /// Begin a new transaction.
214    ///
215    /// Acquires a connection from the primary pool. `BEGIN` is sent lazily
216    /// on the first query inside the transaction (see
217    /// [`Transaction::ensure_begun`]). This eliminates one PG round-trip
218    /// when the transaction is created.
219    ///
220    /// If the transaction is committed or dropped without executing any
221    /// queries, no `BEGIN`/`COMMIT`/`ROLLBACK` is sent at all and the
222    /// connection returns to the pool cleanly.
223    ///
224    /// **Fail-fast**: returns `BsqlError::Pool` immediately if no connections
225    /// are available. See CREDO principle #17.
226    pub async fn begin(&self) -> BsqlResult<Transaction> {
227        let conn = self.acquire().await?;
228        Ok(Transaction::new(conn))
229    }
230
231    /// Execute a query and return a stream of rows.
232    ///
233    /// Acquires a connection from the pool and returns a [`QueryStream`]
234    /// that holds the connection alive until the stream is consumed or
235    /// dropped. Rows arrive one at a time, avoiding buffering the
236    /// entire result set in memory.
237    ///
238    /// **Fail-fast**: returns `BsqlError::Pool` immediately if no connections
239    /// are available. See CREDO principle #17.
240    ///
241    /// This method is only available on `Pool` (not `PoolConnection` or
242    /// `Transaction`) because the stream must own the connection for its
243    /// entire lifetime.
244    pub async fn query_stream(
245        &self,
246        sql: &str,
247        params: &[&(dyn ToSql + Sync)],
248    ) -> BsqlResult<QueryStream> {
249        let conn = self.acquire().await?;
250        let stmt = conn
251            .inner
252            .prepare_cached(sql)
253            .await
254            .map_err(BsqlError::from)?;
255
256        let row_stream = conn
257            .inner
258            .query_raw(&stmt, params.iter().copied())
259            .await
260            .map_err(BsqlError::from)?;
261
262        Ok(QueryStream::new(conn, row_stream))
263    }
264
265    /// Pre-prepare SQL statements on a connection from the pool.
266    ///
267    /// Acquires one connection and calls `prepare_cached` for each SQL
268    /// string. This populates the connection's statement cache so the
269    /// first real query on that connection pays zero PREPARE overhead.
270    ///
271    /// Call this once after creating the pool with the SQL strings your
272    /// application uses. The easiest source is your `query!()` calls:
273    ///
274    /// ```rust,ignore
275    /// pool.warmup(&[
276    ///     "SELECT id, login FROM users WHERE id = $1",
277    ///     "UPDATE tickets SET status = $1 WHERE id = $2",
278    /// ]).await?;
279    /// ```
280    ///
281    /// **Best-effort**: if the pool is exhausted or a PREPARE fails
282    /// (e.g. table was dropped), the error is returned but partial
283    /// warmup of earlier statements still takes effect.
284    ///
285    /// Calling warmup with an empty slice is a no-op.
286    pub async fn warmup(&self, sqls: &[&str]) -> BsqlResult<()> {
287        if sqls.is_empty() {
288            return Ok(());
289        }
290        let conn = self.acquire().await?;
291        for sql in sqls {
292            conn.inner
293                .prepare_cached(sql)
294                .await
295                .map_err(BsqlError::from)?;
296        }
297        Ok(())
298    }
299
300    /// Current pool status: available and total connections.
301    pub fn status(&self) -> PoolStatus {
302        let status = self.primary.status();
303        PoolStatus {
304            available: status.available,
305            size: status.size,
306            max_size: status.max_size,
307        }
308    }
309
310    // -- Internal singleflight + routing methods --
311
312    /// Execute a query on the primary with singleflight coalescing.
313    pub(crate) async fn query_raw_primary(
314        &self,
315        sql: &str,
316        params: &[&(dyn ToSql + Sync)],
317    ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
318        // Singleflight ONLY for parameterless queries.
319        // Parameterized queries with different param values have the same SQL
320        // text, so keying by SQL alone would return wrong results.
321        if params.is_empty() {
322            let key = sql_key(sql);
323            self.query_with_singleflight(key, sql, params, false).await
324        } else {
325            self.execute_on_pool(sql, params, false).await
326        }
327    }
328
329    /// Execute a read-only query. Routes to a replica if available,
330    /// falls back to primary. Singleflight only for parameterless queries.
331    pub(crate) async fn query_raw_read(
332        &self,
333        sql: &str,
334        params: &[&(dyn ToSql + Sync)],
335    ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
336        if self.replicas.is_empty() {
337            return self.query_raw_primary(sql, params).await;
338        }
339
340        if params.is_empty() {
341            let key = sql_key(sql);
342            // Try replica with singleflight
343            match self.query_with_singleflight(key, sql, params, true).await {
344                Ok(rows) => Ok(rows),
345                Err(_) => self.query_with_singleflight(key, sql, params, false).await,
346            }
347        } else {
348            // Parameterized — no singleflight, try replica with fallback
349            match self.execute_on_pool(sql, params, true).await {
350                Ok(rows) => Ok(rows),
351                Err(_) => self.execute_on_pool(sql, params, false).await,
352            }
353        }
354    }
355
356    /// Core singleflight execution. Acquires from primary or replica pool.
357    async fn query_with_singleflight(
358        &self,
359        key: u64,
360        sql: &str,
361        params: &[&(dyn ToSql + Sync)],
362        use_replica: bool,
363    ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
364        match self.singleflight.try_join(key) {
365            FlightStatus::Follower(mut rx) => {
366                // Wait for the leader to complete
367                match rx.recv().await {
368                    Ok(rows) => Ok(rows),
369                    Err(_) => {
370                        // Leader failed or channel closed -- execute ourselves
371                        self.execute_on_pool(sql, params, use_replica).await
372                    }
373                }
374            }
375            FlightStatus::Leader => match self.execute_on_pool(sql, params, use_replica).await {
376                Ok(rows) => {
377                    self.singleflight.complete(key, Arc::clone(&rows));
378                    Ok(rows)
379                }
380                Err(e) => {
381                    self.singleflight.abandon(key);
382                    Err(e)
383                }
384            },
385        }
386    }
387
388    /// Execute a query on the appropriate pool (primary or replica).
389    async fn execute_on_pool(
390        &self,
391        sql: &str,
392        params: &[&(dyn ToSql + Sync)],
393        use_replica: bool,
394    ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
395        let raw_conn = if use_replica && !self.replicas.is_empty() {
396            let idx = self
397                .replica_idx
398                .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
399                % self.replicas.len();
400            self.replicas[idx].get().await.map_err(BsqlError::from)?
401        } else {
402            self.primary.get().await.map_err(BsqlError::from)?
403        };
404
405        let stmt = raw_conn
406            .prepare_cached(sql)
407            .await
408            .map_err(BsqlError::from)?;
409
410        let rows = raw_conn
411            .query(&stmt, params)
412            .await
413            .map_err(BsqlError::from)?;
414
415        Ok(Arc::from(rows))
416    }
417}
418
419impl std::fmt::Debug for Pool {
420    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
421        f.debug_struct("Pool")
422            .field("status", &self.status())
423            .field("is_pgbouncer", &self.pgbouncer)
424            .field("replicas", &self.replicas.len())
425            .finish()
426    }
427}
428
429/// A connection borrowed from the pool.
430///
431/// Returned to the pool when dropped.
432pub struct PoolConnection {
433    pub(crate) inner: deadpool_postgres::Object,
434}
435
436/// Snapshot of pool utilization.
437#[derive(Debug, Clone, Copy)]
438pub struct PoolStatus {
439    pub available: usize,
440    pub size: usize,
441    pub max_size: usize,
442}
443
444/// Parsed fields from a PostgreSQL connection URL.
445struct ParsedUrl {
446    host: Option<String>,
447    port: Option<u16>,
448    dbname: Option<String>,
449    user: Option<String>,
450    password: Option<String>,
451}
452
453/// Parse a PostgreSQL connection URL into its component fields.
454fn parse_pg_url(url: &str) -> BsqlResult<ParsedUrl> {
455    let config: tokio_postgres::Config = url
456        .parse()
457        .map_err(|e: tokio_postgres::Error| ConnectError::create(e.to_string()))?;
458
459    let host = config.get_hosts().first().map(|h| match h {
460        tokio_postgres::config::Host::Tcp(s) => s.clone(),
461        #[cfg(unix)]
462        tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().into_owned(),
463    });
464    let port = config.get_ports().first().copied();
465    let dbname = config.get_dbname().map(String::from);
466    let user = config.get_user().map(String::from);
467    let password = match config.get_password() {
468        Some(p) => Some(
469            String::from_utf8(p.to_vec())
470                .map_err(|_| ConnectError::create("database password contains invalid UTF-8"))?,
471        ),
472        None => None,
473    };
474    Ok(ParsedUrl {
475        host,
476        port,
477        dbname,
478        user,
479        password,
480    })
481}
482
483/// Create a deadpool-postgres pool from a connection URL.
484///
485/// Used internally for both primary and replica pools.
486async fn create_pool_from_url(
487    url: &str,
488    max_size: usize,
489    connect_timeout_secs: u64,
490) -> BsqlResult<deadpool_postgres::Pool> {
491    let parsed = parse_pg_url(url)?;
492
493    let mut cfg = Config::new();
494    cfg.host = parsed.host;
495    cfg.port = parsed.port;
496    cfg.dbname = parsed.dbname;
497    cfg.user = parsed.user;
498    cfg.password = parsed.password;
499    cfg.connect_timeout = Some(std::time::Duration::from_secs(connect_timeout_secs));
500    cfg.manager = Some(ManagerConfig {
501        recycling_method: RecyclingMethod::Fast,
502    });
503    cfg.pool = Some(deadpool_postgres::PoolConfig {
504        max_size,
505        timeouts: deadpool_postgres::Timeouts {
506            wait: Some(std::time::Duration::ZERO),
507            create: None,
508            recycle: None,
509        },
510        ..Default::default()
511    });
512
513    let pool = create_deadpool(cfg)?;
514
515    // Verify connectivity
516    let _conn = pool
517        .get()
518        .await
519        .map_err(|e| ConnectError::with_source(format!("failed to connect to replica: {e}"), e))?;
520
521    Ok(pool)
522}
523
524/// Create a `deadpool_postgres::Pool` from a `Config`.
525///
526/// When the `tls` feature is enabled, connections use rustls with Mozilla's
527/// bundled root certificates. When disabled, connections use `NoTls`.
528fn create_deadpool(cfg: Config) -> BsqlResult<deadpool_postgres::Pool> {
529    #[cfg(feature = "tls")]
530    {
531        let tls = make_rustls_connect();
532        cfg.create_pool(Some(Runtime::Tokio1), tls)
533            .map_err(|e| ConnectError::create(e.to_string()))
534    }
535    #[cfg(not(feature = "tls"))]
536    {
537        cfg.create_pool(Some(Runtime::Tokio1), NoTls)
538            .map_err(|e| ConnectError::create(e.to_string()))
539    }
540}
541
542/// Build a `MakeRustlsConnect` using Mozilla's bundled root certificates.
543///
544/// This avoids a runtime dependency on the OS certificate store.
545/// `pub(crate)` so `listener.rs` can reuse it.
546#[cfg(feature = "tls")]
547pub(crate) fn make_rustls_connect() -> tokio_postgres_rustls::MakeRustlsConnect {
548    let mut roots = rustls::RootCertStore::empty();
549    roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
550    let config = rustls::ClientConfig::builder()
551        .with_root_certificates(roots)
552        .with_no_client_auth();
553    tokio_postgres_rustls::MakeRustlsConnect::new(config)
554}
555
556/// Detect PgBouncer on the first connection from the pool.
557///
558/// Strategy: try `SHOW POOLS` -- only PgBouncer responds to this.
559///
560/// Returns `Err` if the initial connection fails, instead of silently
561/// assuming direct. A pool that can't connect on creation is broken.
562async fn detect_pgbouncer(pool: &deadpool_postgres::Pool) -> BsqlResult<bool> {
563    let conn = pool.get().await.map_err(|e| {
564        ConnectError::with_source(format!("failed to establish initial connection: {e}"), e)
565    })?;
566
567    // PgBouncer responds to `SHOW POOLS`; PostgreSQL does not.
568    Ok(conn.simple_query("SHOW POOLS").await.is_ok())
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574
575    #[test]
576    fn builder_defaults() {
577        let b = Pool::builder();
578        assert_eq!(b.max_size, 16);
579        assert_eq!(b.connect_timeout_secs, 5);
580        assert!(b.replica_urls.is_empty());
581    }
582
583    #[test]
584    fn builder_config() {
585        let b = Pool::builder()
586            .host("localhost")
587            .port(5432)
588            .dbname("test")
589            .user("app")
590            .password("secret")
591            .max_size(8)
592            .connect_timeout(10);
593
594        assert_eq!(b.host.as_deref(), Some("localhost"));
595        assert_eq!(b.port, Some(5432));
596        assert_eq!(b.dbname.as_deref(), Some("test"));
597        assert_eq!(b.user.as_deref(), Some("app"));
598        assert_eq!(b.password.as_deref(), Some("secret"));
599        assert_eq!(b.max_size, 8);
600        assert_eq!(b.connect_timeout_secs, 10);
601    }
602
603    #[test]
604    fn builder_replicas() {
605        let b = Pool::builder()
606            .replica("postgres://replica1:5432/db")
607            .replica("postgres://replica2:5432/db");
608        assert_eq!(b.replica_urls.len(), 2);
609    }
610}