Skip to main content

bsql_core/
pool.rs

1//! Connection pool with fail-fast semantics, PgBouncer detection,
2//! singleflight query coalescing, and read/write splitting.
3//!
4//! The pool wraps `deadpool-postgres` with key behaviors:
5//! - **Fail-fast**: `acquire()` returns `PoolExhausted` immediately when no
6//!   connections are available. It does not wait. See CREDO principle #17.
7//! - **PgBouncer detection**: on pool creation, bsql detects whether the
8//!   connection goes through PgBouncer and adjusts prepared statement strategy.
9//! - **Singleflight** (v0.7): identical concurrent SELECT queries are coalesced
10//!   into a single PG round-trip. The result is shared via `Arc<[Row]>`.
11//! - **Read/write splitting** (v0.7): when replicas are configured, SELECT
12//!   queries are routed to replicas. Writes always go to the primary.
13
14use std::sync::Arc;
15
16use deadpool_postgres::{Config, ManagerConfig, RecyclingMethod, Runtime};
17#[cfg(not(feature = "tls"))]
18use tokio_postgres::NoTls;
19use tokio_postgres::types::ToSql;
20
21use crate::error::{BsqlError, BsqlResult, ConnectError};
22use crate::singleflight::{FlightStatus, Singleflight, sql_key};
23use crate::stream::QueryStream;
24use crate::transaction::Transaction;
25
26/// A PostgreSQL connection pool.
27///
28/// Wraps `deadpool-postgres` with fail-fast acquire semantics, singleflight
29/// query coalescing, and optional read/write splitting.
30pub struct Pool {
31    primary: deadpool_postgres::Pool,
32    /// Replica pools for read-only queries. Round-robin selection.
33    /// Empty when no replicas are configured.
34    replicas: Vec<deadpool_postgres::Pool>,
35    /// Atomic counter for round-robin replica selection.
36    replica_idx: std::sync::atomic::AtomicUsize,
37    /// True if PgBouncer was detected between the client and PostgreSQL.
38    pgbouncer: bool,
39    singleflight: Singleflight,
40}
41
42/// Builder for configuring a connection pool.
43pub struct PoolBuilder {
44    host: Option<String>,
45    port: Option<u16>,
46    dbname: Option<String>,
47    user: Option<String>,
48    password: Option<String>,
49    max_size: usize,
50    connect_timeout_secs: u64,
51    replica_urls: Vec<String>,
52}
53
54impl PoolBuilder {
55    /// Configure the pool from a PostgreSQL connection URL.
56    ///
57    /// Parses `postgres://user:password@host:port/dbname` and fills in
58    /// host, port, dbname, user, and password. Other builder settings
59    /// (max_size, connect_timeout, replicas) are preserved.
60    ///
61    /// Returns an error if the URL cannot be parsed.
62    pub fn url(mut self, url: &str) -> Result<Self, BsqlError> {
63        let parsed = parse_pg_url(url)?;
64        self.host = parsed.host;
65        self.port = parsed.port;
66        self.dbname = parsed.dbname;
67        self.user = parsed.user;
68        self.password = parsed.password;
69        Ok(self)
70    }
71
72    pub fn host(mut self, host: &str) -> Self {
73        self.host = Some(host.into());
74        self
75    }
76
77    pub fn port(mut self, port: u16) -> Self {
78        self.port = Some(port);
79        self
80    }
81
82    pub fn dbname(mut self, dbname: &str) -> Self {
83        self.dbname = Some(dbname.into());
84        self
85    }
86
87    pub fn user(mut self, user: &str) -> Self {
88        self.user = Some(user.into());
89        self
90    }
91
92    pub fn password(mut self, password: &str) -> Self {
93        self.password = Some(password.into());
94        self
95    }
96
97    pub fn max_size(mut self, size: usize) -> Self {
98        self.max_size = size;
99        self
100    }
101
102    /// TCP connect timeout in seconds. This is the ONLY timeout in bsql --
103    /// it exists because TCP itself will wait forever on a dead network.
104    pub fn connect_timeout(mut self, secs: u64) -> Self {
105        self.connect_timeout_secs = secs;
106        self
107    }
108
109    /// Add a read replica. SELECT queries will be routed to replicas when
110    /// the executor uses `query_raw_readonly` (generated for SELECT queries).
111    ///
112    /// Multiple replicas are selected round-robin. If a replica is unavailable,
113    /// the query falls back to the primary.
114    ///
115    /// Format: `postgres://user:password@host:port/dbname`
116    pub fn replica(mut self, url: &str) -> Self {
117        self.replica_urls.push(url.into());
118        self
119    }
120
121    pub async fn build(self) -> BsqlResult<Pool> {
122        let mut cfg = Config::new();
123        cfg.host = self.host;
124        cfg.port = self.port;
125        cfg.dbname = self.dbname;
126        cfg.user = self.user;
127        cfg.password = self.password;
128        cfg.connect_timeout = Some(std::time::Duration::from_secs(self.connect_timeout_secs));
129        cfg.manager = Some(ManagerConfig {
130            recycling_method: RecyclingMethod::Fast,
131        });
132        // FIX 2: fail-fast -- zero wait timeout means acquire() never blocks
133        cfg.pool = Some(deadpool_postgres::PoolConfig {
134            max_size: self.max_size,
135            timeouts: deadpool_postgres::Timeouts {
136                wait: Some(std::time::Duration::ZERO),
137                create: None,
138                recycle: None,
139            },
140            ..Default::default()
141        });
142
143        let pool = create_deadpool(cfg)?;
144
145        // FIX 11: detect PgBouncer -- propagate connection failure
146        let mut pgbouncer = detect_pgbouncer(&pool).await?;
147
148        // Build replica pools, detecting PgBouncer on each.
149        // If ANY pool (primary or replica) is behind PgBouncer, we disable
150        // named prepared statements globally. This is conservative but safe:
151        // a mixed topology is unusual, and the performance cost of unnamed
152        // statements is negligible compared to a hard failure.
153        let mut replicas = Vec::with_capacity(self.replica_urls.len());
154        for url in &self.replica_urls {
155            let replica_pool =
156                create_pool_from_url(url, self.max_size, self.connect_timeout_secs).await?;
157            pgbouncer |= detect_pgbouncer(&replica_pool).await?;
158            replicas.push(replica_pool);
159        }
160
161        Ok(Pool {
162            primary: pool,
163            replicas,
164            replica_idx: std::sync::atomic::AtomicUsize::new(0),
165            pgbouncer,
166            singleflight: Singleflight::new(),
167        })
168    }
169}
170
171impl Pool {
172    /// Connect to PostgreSQL using a connection URL.
173    ///
174    /// Format: `postgres://user:password@host:port/dbname`
175    pub async fn connect(url: &str) -> BsqlResult<Self> {
176        Pool::builder().url(url)?.build().await
177    }
178
179    /// Create a pool builder for fine-grained configuration.
180    pub fn builder() -> PoolBuilder {
181        PoolBuilder {
182            host: None,
183            port: None,
184            dbname: None,
185            user: None,
186            password: None,
187            max_size: 16,
188            connect_timeout_secs: 5,
189            replica_urls: Vec::new(),
190        }
191    }
192
193    /// Acquire a connection from the primary pool.
194    ///
195    /// **Fail-fast**: returns `BsqlError::Pool` immediately if no connections
196    /// are available. Does not wait. Does not timeout. See CREDO principle #17.
197    pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
198        let conn = self.primary.get().await.map_err(BsqlError::from)?;
199
200        Ok(PoolConnection { inner: conn })
201    }
202
203    /// Whether PgBouncer was detected between the client and PostgreSQL.
204    pub fn is_pgbouncer(&self) -> bool {
205        self.pgbouncer
206    }
207
208    /// Whether read replicas are configured.
209    pub fn has_replicas(&self) -> bool {
210        !self.replicas.is_empty()
211    }
212
213    /// Begin a new transaction.
214    ///
215    /// Acquires a connection from the primary pool. `BEGIN` is sent lazily
216    /// on the first query inside the transaction (see
217    /// [`Transaction::ensure_begun`]). This eliminates one PG round-trip
218    /// when the transaction is created.
219    ///
220    /// If the transaction is committed or dropped without executing any
221    /// queries, no `BEGIN`/`COMMIT`/`ROLLBACK` is sent at all and the
222    /// connection returns to the pool cleanly.
223    ///
224    /// **Fail-fast**: returns `BsqlError::Pool` immediately if no connections
225    /// are available. See CREDO principle #17.
226    pub async fn begin(&self) -> BsqlResult<Transaction> {
227        let conn = self.acquire().await?;
228        Ok(Transaction::new(conn))
229    }
230
231    /// Execute a query and return a stream of rows.
232    ///
233    /// Acquires a connection from the pool and returns a [`QueryStream`]
234    /// that holds the connection alive until the stream is consumed or
235    /// dropped. Rows arrive one at a time, avoiding buffering the
236    /// entire result set in memory.
237    ///
238    /// **Fail-fast**: returns `BsqlError::Pool` immediately if no connections
239    /// are available. See CREDO principle #17.
240    ///
241    /// This method is only available on `Pool` (not `PoolConnection` or
242    /// `Transaction`) because the stream must own the connection for its
243    /// entire lifetime.
244    pub async fn query_stream(
245        &self,
246        sql: &str,
247        params: &[&(dyn ToSql + Sync)],
248    ) -> BsqlResult<QueryStream> {
249        let conn = self.acquire().await?;
250        let stmt = conn
251            .inner
252            .prepare_cached(sql)
253            .await
254            .map_err(BsqlError::from)?;
255
256        let row_stream = conn
257            .inner
258            .query_raw(&stmt, params.iter().copied())
259            .await
260            .map_err(BsqlError::from)?;
261
262        Ok(QueryStream::new(conn, row_stream))
263    }
264
265    /// Current pool status: available and total connections.
266    pub fn status(&self) -> PoolStatus {
267        let status = self.primary.status();
268        PoolStatus {
269            available: status.available,
270            size: status.size,
271            max_size: status.max_size,
272        }
273    }
274
275    // -- Internal singleflight + routing methods --
276
277    /// Execute a query on the primary with singleflight coalescing.
278    pub(crate) async fn query_raw_primary(
279        &self,
280        sql: &str,
281        params: &[&(dyn ToSql + Sync)],
282    ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
283        // Singleflight ONLY for parameterless queries.
284        // Parameterized queries with different param values have the same SQL
285        // text, so keying by SQL alone would return wrong results.
286        if params.is_empty() {
287            let key = sql_key(sql);
288            self.query_with_singleflight(key, sql, params, false).await
289        } else {
290            self.execute_on_pool(sql, params, false).await
291        }
292    }
293
294    /// Execute a read-only query. Routes to a replica if available,
295    /// falls back to primary. Singleflight only for parameterless queries.
296    pub(crate) async fn query_raw_read(
297        &self,
298        sql: &str,
299        params: &[&(dyn ToSql + Sync)],
300    ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
301        if self.replicas.is_empty() {
302            return self.query_raw_primary(sql, params).await;
303        }
304
305        if params.is_empty() {
306            let key = sql_key(sql);
307            // Try replica with singleflight
308            match self.query_with_singleflight(key, sql, params, true).await {
309                Ok(rows) => Ok(rows),
310                Err(_) => self.query_with_singleflight(key, sql, params, false).await,
311            }
312        } else {
313            // Parameterized — no singleflight, try replica with fallback
314            match self.execute_on_pool(sql, params, true).await {
315                Ok(rows) => Ok(rows),
316                Err(_) => self.execute_on_pool(sql, params, false).await,
317            }
318        }
319    }
320
321    /// Core singleflight execution. Acquires from primary or replica pool.
322    async fn query_with_singleflight(
323        &self,
324        key: u64,
325        sql: &str,
326        params: &[&(dyn ToSql + Sync)],
327        use_replica: bool,
328    ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
329        match self.singleflight.try_join(key) {
330            FlightStatus::Follower(mut rx) => {
331                // Wait for the leader to complete
332                match rx.recv().await {
333                    Ok(rows) => Ok(rows),
334                    Err(_) => {
335                        // Leader failed or channel closed -- execute ourselves
336                        self.execute_on_pool(sql, params, use_replica).await
337                    }
338                }
339            }
340            FlightStatus::Leader => match self.execute_on_pool(sql, params, use_replica).await {
341                Ok(rows) => {
342                    self.singleflight.complete(key, Arc::clone(&rows));
343                    Ok(rows)
344                }
345                Err(e) => {
346                    self.singleflight.abandon(key);
347                    Err(e)
348                }
349            },
350        }
351    }
352
353    /// Execute a query on the appropriate pool (primary or replica).
354    async fn execute_on_pool(
355        &self,
356        sql: &str,
357        params: &[&(dyn ToSql + Sync)],
358        use_replica: bool,
359    ) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
360        let raw_conn = if use_replica && !self.replicas.is_empty() {
361            let idx = self
362                .replica_idx
363                .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
364                % self.replicas.len();
365            self.replicas[idx].get().await.map_err(BsqlError::from)?
366        } else {
367            self.primary.get().await.map_err(BsqlError::from)?
368        };
369
370        let stmt = raw_conn
371            .prepare_cached(sql)
372            .await
373            .map_err(BsqlError::from)?;
374
375        let rows = raw_conn
376            .query(&stmt, params)
377            .await
378            .map_err(BsqlError::from)?;
379
380        Ok(Arc::from(rows))
381    }
382}
383
384impl std::fmt::Debug for Pool {
385    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386        f.debug_struct("Pool")
387            .field("status", &self.status())
388            .field("is_pgbouncer", &self.pgbouncer)
389            .field("replicas", &self.replicas.len())
390            .finish()
391    }
392}
393
394/// A connection borrowed from the pool.
395///
396/// Returned to the pool when dropped.
397pub struct PoolConnection {
398    pub(crate) inner: deadpool_postgres::Object,
399}
400
401/// Snapshot of pool utilization.
402#[derive(Debug, Clone, Copy)]
403pub struct PoolStatus {
404    pub available: usize,
405    pub size: usize,
406    pub max_size: usize,
407}
408
409/// Parsed fields from a PostgreSQL connection URL.
410struct ParsedUrl {
411    host: Option<String>,
412    port: Option<u16>,
413    dbname: Option<String>,
414    user: Option<String>,
415    password: Option<String>,
416}
417
418/// Parse a PostgreSQL connection URL into its component fields.
419fn parse_pg_url(url: &str) -> BsqlResult<ParsedUrl> {
420    let config: tokio_postgres::Config = url
421        .parse()
422        .map_err(|e: tokio_postgres::Error| ConnectError::create(e.to_string()))?;
423
424    let host = config.get_hosts().first().map(|h| match h {
425        tokio_postgres::config::Host::Tcp(s) => s.clone(),
426        #[cfg(unix)]
427        tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().into_owned(),
428    });
429    let port = config.get_ports().first().copied();
430    let dbname = config.get_dbname().map(String::from);
431    let user = config.get_user().map(String::from);
432    let password = match config.get_password() {
433        Some(p) => Some(
434            String::from_utf8(p.to_vec())
435                .map_err(|_| ConnectError::create("database password contains invalid UTF-8"))?,
436        ),
437        None => None,
438    };
439    Ok(ParsedUrl {
440        host,
441        port,
442        dbname,
443        user,
444        password,
445    })
446}
447
448/// Create a deadpool-postgres pool from a connection URL.
449///
450/// Used internally for both primary and replica pools.
451async fn create_pool_from_url(
452    url: &str,
453    max_size: usize,
454    connect_timeout_secs: u64,
455) -> BsqlResult<deadpool_postgres::Pool> {
456    let parsed = parse_pg_url(url)?;
457
458    let mut cfg = Config::new();
459    cfg.host = parsed.host;
460    cfg.port = parsed.port;
461    cfg.dbname = parsed.dbname;
462    cfg.user = parsed.user;
463    cfg.password = parsed.password;
464    cfg.connect_timeout = Some(std::time::Duration::from_secs(connect_timeout_secs));
465    cfg.manager = Some(ManagerConfig {
466        recycling_method: RecyclingMethod::Fast,
467    });
468    cfg.pool = Some(deadpool_postgres::PoolConfig {
469        max_size,
470        timeouts: deadpool_postgres::Timeouts {
471            wait: Some(std::time::Duration::ZERO),
472            create: None,
473            recycle: None,
474        },
475        ..Default::default()
476    });
477
478    let pool = create_deadpool(cfg)?;
479
480    // Verify connectivity
481    let _conn = pool
482        .get()
483        .await
484        .map_err(|e| ConnectError::with_source(format!("failed to connect to replica: {e}"), e))?;
485
486    Ok(pool)
487}
488
489/// Create a `deadpool_postgres::Pool` from a `Config`.
490///
491/// When the `tls` feature is enabled, connections use rustls with Mozilla's
492/// bundled root certificates. When disabled, connections use `NoTls`.
493fn create_deadpool(cfg: Config) -> BsqlResult<deadpool_postgres::Pool> {
494    #[cfg(feature = "tls")]
495    {
496        let tls = make_rustls_connect();
497        cfg.create_pool(Some(Runtime::Tokio1), tls)
498            .map_err(|e| ConnectError::create(e.to_string()))
499    }
500    #[cfg(not(feature = "tls"))]
501    {
502        cfg.create_pool(Some(Runtime::Tokio1), NoTls)
503            .map_err(|e| ConnectError::create(e.to_string()))
504    }
505}
506
507/// Build a `MakeRustlsConnect` using Mozilla's bundled root certificates.
508///
509/// This avoids a runtime dependency on the OS certificate store.
510/// `pub(crate)` so `listener.rs` can reuse it.
511#[cfg(feature = "tls")]
512pub(crate) fn make_rustls_connect() -> tokio_postgres_rustls::MakeRustlsConnect {
513    let mut roots = rustls::RootCertStore::empty();
514    roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
515    let config = rustls::ClientConfig::builder()
516        .with_root_certificates(roots)
517        .with_no_client_auth();
518    tokio_postgres_rustls::MakeRustlsConnect::new(config)
519}
520
521/// Detect PgBouncer on the first connection from the pool.
522///
523/// Strategy: try `SHOW POOLS` -- only PgBouncer responds to this.
524///
525/// Returns `Err` if the initial connection fails, instead of silently
526/// assuming direct. A pool that can't connect on creation is broken.
527async fn detect_pgbouncer(pool: &deadpool_postgres::Pool) -> BsqlResult<bool> {
528    let conn = pool.get().await.map_err(|e| {
529        ConnectError::with_source(format!("failed to establish initial connection: {e}"), e)
530    })?;
531
532    // PgBouncer responds to `SHOW POOLS`; PostgreSQL does not.
533    Ok(conn.simple_query("SHOW POOLS").await.is_ok())
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    #[test]
541    fn builder_defaults() {
542        let b = Pool::builder();
543        assert_eq!(b.max_size, 16);
544        assert_eq!(b.connect_timeout_secs, 5);
545        assert!(b.replica_urls.is_empty());
546    }
547
548    #[test]
549    fn builder_config() {
550        let b = Pool::builder()
551            .host("localhost")
552            .port(5432)
553            .dbname("test")
554            .user("app")
555            .password("secret")
556            .max_size(8)
557            .connect_timeout(10);
558
559        assert_eq!(b.host.as_deref(), Some("localhost"));
560        assert_eq!(b.port, Some(5432));
561        assert_eq!(b.dbname.as_deref(), Some("test"));
562        assert_eq!(b.user.as_deref(), Some("app"));
563        assert_eq!(b.password.as_deref(), Some("secret"));
564        assert_eq!(b.max_size, 8);
565        assert_eq!(b.connect_timeout_secs, 10);
566    }
567
568    #[test]
569    fn builder_replicas() {
570        let b = Pool::builder()
571            .replica("postgres://replica1:5432/db")
572            .replica("postgres://replica2:5432/db");
573        assert_eq!(b.replica_urls.len(), 2);
574    }
575}