Skip to main content

bsql_core/
pool.rs

1//! Connection pool with fail-fast semantics, PgBouncer detection,
2//! singleflight query coalescing, and read/write splitting.
3//!
4//! The pool wraps `deadpool-postgres` with key behaviors:
5//! - **Fail-fast**: `acquire()` returns `PoolExhausted` immediately when no
6//!   connections are available. It does not wait. See CREDO principle #17.
7//! - **PgBouncer detection**: on pool creation, bsql detects whether the
8//!   connection goes through PgBouncer and adjusts prepared statement strategy.
9//! - **Singleflight** (v0.7): identical concurrent SELECT queries are coalesced
10//!   into a single PG round-trip. The result is shared via `Arc<Vec<Row>>`.
11//! - **Read/write splitting** (v0.7): when replicas are configured, SELECT
12//!   queries are routed to replicas. Writes always go to the primary.
13
14use std::sync::Arc;
15
16use deadpool_postgres::{Config, ManagerConfig, RecyclingMethod, Runtime};
17use tokio_postgres::NoTls;
18use tokio_postgres::types::ToSql;
19
20use crate::error::{BsqlError, BsqlResult, ConnectError};
21use crate::singleflight::{FlightStatus, Singleflight, sql_key};
22use crate::stream::QueryStream;
23use crate::transaction::Transaction;
24
25/// A PostgreSQL connection pool.
26///
27/// Wraps `deadpool-postgres` with fail-fast acquire semantics, singleflight
28/// query coalescing, and optional read/write splitting.
29pub struct Pool {
30    primary: deadpool_postgres::Pool,
31    /// Replica pools for read-only queries. Round-robin selection.
32    /// Empty when no replicas are configured.
33    replicas: Vec<deadpool_postgres::Pool>,
34    /// Atomic counter for round-robin replica selection.
35    replica_idx: std::sync::atomic::AtomicUsize,
36    pgbouncer: PgBouncerInfo,
37    singleflight: Singleflight,
38}
39
40/// PgBouncer detection result.
41#[derive(Debug, Clone, Copy)]
42pub(crate) struct PgBouncerInfo {
43    /// True if PgBouncer was detected between the client and PostgreSQL.
44    detected: bool,
45    /// True if PgBouncer supports server-side prepared statement tracking
46    /// (PgBouncer 1.21+ with `prepared_statements=yes`).
47    supports_named_stmts: bool,
48}
49
50impl PgBouncerInfo {
51    const DIRECT: Self = Self {
52        detected: false,
53        supports_named_stmts: true,
54    };
55}
56
57/// Builder for configuring a connection pool.
58pub struct PoolBuilder {
59    host: Option<String>,
60    port: Option<u16>,
61    dbname: Option<String>,
62    user: Option<String>,
63    password: Option<String>,
64    max_size: usize,
65    connect_timeout_secs: u64,
66    replica_urls: Vec<String>,
67}
68
69impl PoolBuilder {
70    pub fn host(mut self, host: &str) -> Self {
71        self.host = Some(host.into());
72        self
73    }
74
75    pub fn port(mut self, port: u16) -> Self {
76        self.port = Some(port);
77        self
78    }
79
80    pub fn dbname(mut self, dbname: &str) -> Self {
81        self.dbname = Some(dbname.into());
82        self
83    }
84
85    pub fn user(mut self, user: &str) -> Self {
86        self.user = Some(user.into());
87        self
88    }
89
90    pub fn password(mut self, password: &str) -> Self {
91        self.password = Some(password.into());
92        self
93    }
94
95    pub fn max_size(mut self, size: usize) -> Self {
96        self.max_size = size;
97        self
98    }
99
100    /// TCP connect timeout in seconds. This is the ONLY timeout in bsql --
101    /// it exists because TCP itself will wait forever on a dead network.
102    pub fn connect_timeout(mut self, secs: u64) -> Self {
103        self.connect_timeout_secs = secs;
104        self
105    }
106
107    /// Add a read replica. SELECT queries will be routed to replicas when
108    /// the executor uses `query_raw_readonly` (generated for SELECT queries).
109    ///
110    /// Multiple replicas are selected round-robin. If a replica is unavailable,
111    /// the query falls back to the primary.
112    ///
113    /// Format: `postgres://user:password@host:port/dbname`
114    pub fn replica(mut self, url: &str) -> Self {
115        self.replica_urls.push(url.into());
116        self
117    }
118
119    pub async fn build(self) -> BsqlResult<Pool> {
120        let mut cfg = Config::new();
121        cfg.host = self.host;
122        cfg.port = self.port;
123        cfg.dbname = self.dbname;
124        cfg.user = self.user;
125        cfg.password = self.password;
126        cfg.connect_timeout = Some(std::time::Duration::from_secs(self.connect_timeout_secs));
127        cfg.manager = Some(ManagerConfig {
128            recycling_method: RecyclingMethod::Fast,
129        });
130        // FIX 2: fail-fast -- zero wait timeout means acquire() never blocks
131        cfg.pool = Some(deadpool_postgres::PoolConfig {
132            max_size: self.max_size,
133            timeouts: deadpool_postgres::Timeouts {
134                wait: Some(std::time::Duration::ZERO),
135                create: None,
136                recycle: None,
137            },
138            ..Default::default()
139        });
140
141        let pool = cfg
142            .create_pool(Some(Runtime::Tokio1), NoTls)
143            .map_err(|e| ConnectError::create(e.to_string()))?;
144
145        // FIX 11: detect PgBouncer -- propagate connection failure
146        let pgbouncer = detect_pgbouncer(&pool).await?;
147
148        // Build replica pools
149        let mut replicas = Vec::with_capacity(self.replica_urls.len());
150        for url in &self.replica_urls {
151            let replica_pool = create_pool_from_url(url, self.max_size).await?;
152            replicas.push(replica_pool);
153        }
154
155        Ok(Pool {
156            primary: pool,
157            replicas,
158            replica_idx: std::sync::atomic::AtomicUsize::new(0),
159            pgbouncer,
160            singleflight: Singleflight::new(),
161        })
162    }
163}
164
165impl Pool {
166    /// Connect to PostgreSQL using a connection URL.
167    ///
168    /// Format: `postgres://user:password@host:port/dbname`
169    pub async fn connect(url: &str) -> BsqlResult<Self> {
170        let config: tokio_postgres::Config = url
171            .parse()
172            .map_err(|e: tokio_postgres::Error| ConnectError::create(e.to_string()))?;
173
174        let mut cfg = Config::new();
175        cfg.host = config.get_hosts().first().map(|h| match h {
176            tokio_postgres::config::Host::Tcp(s) => s.clone(),
177            #[cfg(unix)]
178            tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().into_owned(),
179        });
180        cfg.port = config.get_ports().first().copied();
181        cfg.dbname = config.get_dbname().map(String::from);
182        cfg.user = config.get_user().map(String::from);
183        cfg.password =
184            match config.get_password() {
185                Some(p) => Some(String::from_utf8(p.to_vec()).map_err(|_| {
186                    ConnectError::create("database password contains invalid UTF-8")
187                })?),
188                None => None,
189            };
190        cfg.connect_timeout = Some(std::time::Duration::from_secs(5));
191        cfg.manager = Some(ManagerConfig {
192            recycling_method: RecyclingMethod::Fast,
193        });
194        // FIX 2: fail-fast -- zero wait timeout means acquire() never blocks
195        cfg.pool = Some(deadpool_postgres::PoolConfig {
196            max_size: 16,
197            timeouts: deadpool_postgres::Timeouts {
198                wait: Some(std::time::Duration::ZERO),
199                create: None,
200                recycle: None,
201            },
202            ..Default::default()
203        });
204
205        let pool = cfg
206            .create_pool(Some(Runtime::Tokio1), NoTls)
207            .map_err(|e| ConnectError::create(e.to_string()))?;
208
209        // FIX 11: detect PgBouncer -- propagate connection failure
210        let pgbouncer = detect_pgbouncer(&pool).await?;
211
212        Ok(Pool {
213            primary: pool,
214            replicas: Vec::new(),
215            replica_idx: std::sync::atomic::AtomicUsize::new(0),
216            pgbouncer,
217            singleflight: Singleflight::new(),
218        })
219    }
220
221    /// Create a pool builder for fine-grained configuration.
222    pub fn builder() -> PoolBuilder {
223        PoolBuilder {
224            host: None,
225            port: None,
226            dbname: None,
227            user: None,
228            password: None,
229            max_size: 16,
230            connect_timeout_secs: 5,
231            replica_urls: Vec::new(),
232        }
233    }
234
235    /// Acquire a connection from the primary pool.
236    ///
237    /// **Fail-fast**: returns `BsqlError::Pool` immediately if no connections
238    /// are available. Does not wait. Does not timeout. See CREDO principle #17.
239    pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
240        let conn = self.primary.get().await.map_err(BsqlError::from)?;
241
242        Ok(PoolConnection {
243            inner: conn,
244            pgbouncer: self.pgbouncer,
245        })
246    }
247
248    /// Whether PgBouncer was detected between the client and PostgreSQL.
249    pub fn is_pgbouncer(&self) -> bool {
250        self.pgbouncer.detected
251    }
252
253    /// Whether named prepared statements can be used.
254    ///
255    /// False when PgBouncer is detected without `prepared_statements=yes`.
256    pub fn supports_named_statements(&self) -> bool {
257        self.pgbouncer.supports_named_stmts
258    }
259
260    /// Whether read replicas are configured.
261    pub fn has_replicas(&self) -> bool {
262        !self.replicas.is_empty()
263    }
264
265    /// Begin a new transaction.
266    ///
267    /// Acquires a connection from the primary pool and sends `BEGIN`. The
268    /// connection is held for the lifetime of the returned [`Transaction`].
269    ///
270    /// **Fail-fast**: returns `BsqlError::Pool` immediately if no connections
271    /// are available. See CREDO principle #17.
272    pub async fn begin(&self) -> BsqlResult<Transaction> {
273        let conn = self.acquire().await?;
274        conn.inner
275            .batch_execute("BEGIN")
276            .await
277            .map_err(BsqlError::from)?;
278        Ok(Transaction::new(conn))
279    }
280
281    /// Execute a query and return a stream of rows.
282    ///
283    /// Acquires a connection from the pool and returns a [`QueryStream`]
284    /// that holds the connection alive until the stream is consumed or
285    /// dropped. Rows arrive one at a time, avoiding buffering the
286    /// entire result set in memory.
287    ///
288    /// **Fail-fast**: returns `BsqlError::Pool` immediately if no connections
289    /// are available. See CREDO principle #17.
290    ///
291    /// This method is only available on `Pool` (not `PoolConnection` or
292    /// `Transaction`) because the stream must own the connection for its
293    /// entire lifetime.
294    pub async fn query_stream(
295        &self,
296        sql: &str,
297        params: &[&(dyn ToSql + Sync)],
298    ) -> BsqlResult<QueryStream> {
299        let conn = self.acquire().await?;
300        let stmt = conn
301            .inner
302            .prepare_cached(sql)
303            .await
304            .map_err(BsqlError::from)?;
305
306        let row_stream = conn
307            .inner
308            .query_raw(&stmt, params.iter().copied())
309            .await
310            .map_err(BsqlError::from)?;
311
312        Ok(QueryStream::new(conn, row_stream))
313    }
314
315    /// Current pool status: available and total connections.
316    pub fn status(&self) -> PoolStatus {
317        let status = self.primary.status();
318        PoolStatus {
319            available: status.available,
320            size: status.size,
321            max_size: status.max_size,
322        }
323    }
324
325    // -- Internal singleflight + routing methods --
326
327    /// Execute a query on the primary with singleflight coalescing.
328    pub(crate) async fn query_raw_primary(
329        &self,
330        sql: &str,
331        params: &[&(dyn ToSql + Sync)],
332    ) -> BsqlResult<Arc<Vec<tokio_postgres::Row>>> {
333        // Singleflight ONLY for parameterless queries.
334        // Parameterized queries with different param values have the same SQL
335        // text, so keying by SQL alone would return wrong results.
336        if params.is_empty() {
337            let key = sql_key(sql);
338            self.query_with_singleflight(key, sql, params, false).await
339        } else {
340            self.execute_on_pool(sql, params, false).await
341        }
342    }
343
344    /// Execute a read-only query. Routes to a replica if available,
345    /// falls back to primary. Singleflight only for parameterless queries.
346    pub(crate) async fn query_raw_read(
347        &self,
348        sql: &str,
349        params: &[&(dyn ToSql + Sync)],
350    ) -> BsqlResult<Arc<Vec<tokio_postgres::Row>>> {
351        if self.replicas.is_empty() {
352            return self.query_raw_primary(sql, params).await;
353        }
354
355        if params.is_empty() {
356            let key = sql_key(sql);
357            // Try replica with singleflight
358            match self.query_with_singleflight(key, sql, params, true).await {
359                Ok(rows) => Ok(rows),
360                Err(_) => self.query_with_singleflight(key, sql, params, false).await,
361            }
362        } else {
363            // Parameterized — no singleflight, try replica with fallback
364            match self.execute_on_pool(sql, params, true).await {
365                Ok(rows) => Ok(rows),
366                Err(_) => self.execute_on_pool(sql, params, false).await,
367            }
368        }
369    }
370
371    /// Core singleflight execution. Acquires from primary or replica pool.
372    async fn query_with_singleflight(
373        &self,
374        key: u64,
375        sql: &str,
376        params: &[&(dyn ToSql + Sync)],
377        use_replica: bool,
378    ) -> BsqlResult<Arc<Vec<tokio_postgres::Row>>> {
379        match self.singleflight.try_join(key) {
380            FlightStatus::Follower(mut rx) => {
381                // Wait for the leader to complete
382                match rx.recv().await {
383                    Ok(rows) => Ok(rows),
384                    Err(_) => {
385                        // Leader failed or channel closed -- execute ourselves
386                        self.execute_on_pool(sql, params, use_replica).await
387                    }
388                }
389            }
390            FlightStatus::Leader => match self.execute_on_pool(sql, params, use_replica).await {
391                Ok(rows) => {
392                    self.singleflight.complete(key, Arc::clone(&rows));
393                    Ok(rows)
394                }
395                Err(e) => {
396                    self.singleflight.abandon(key);
397                    Err(e)
398                }
399            },
400        }
401    }
402
403    /// Execute a query on the appropriate pool (primary or replica).
404    async fn execute_on_pool(
405        &self,
406        sql: &str,
407        params: &[&(dyn ToSql + Sync)],
408        use_replica: bool,
409    ) -> BsqlResult<Arc<Vec<tokio_postgres::Row>>> {
410        let raw_conn = if use_replica && !self.replicas.is_empty() {
411            let idx = self
412                .replica_idx
413                .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
414                % self.replicas.len();
415            self.replicas[idx].get().await.map_err(BsqlError::from)?
416        } else {
417            self.primary.get().await.map_err(BsqlError::from)?
418        };
419
420        let stmt = raw_conn
421            .prepare_cached(sql)
422            .await
423            .map_err(BsqlError::from)?;
424
425        let rows = raw_conn
426            .query(&stmt, params)
427            .await
428            .map_err(BsqlError::from)?;
429
430        Ok(Arc::new(rows))
431    }
432}
433
434/// A connection borrowed from the pool.
435///
436/// Returned to the pool when dropped.
437pub struct PoolConnection {
438    pub(crate) inner: deadpool_postgres::Object,
439    pub(crate) pgbouncer: PgBouncerInfo,
440}
441
442impl PoolConnection {
443    /// Whether named prepared statements can be used on this connection.
444    pub fn supports_named_statements(&self) -> bool {
445        self.pgbouncer.supports_named_stmts
446    }
447}
448
449/// Snapshot of pool utilization.
450#[derive(Debug, Clone, Copy)]
451pub struct PoolStatus {
452    pub available: usize,
453    pub size: usize,
454    pub max_size: usize,
455}
456
457/// Create a deadpool-postgres pool from a connection URL.
458///
459/// Used internally for both primary and replica pools.
460async fn create_pool_from_url(url: &str, max_size: usize) -> BsqlResult<deadpool_postgres::Pool> {
461    let config: tokio_postgres::Config = url
462        .parse()
463        .map_err(|e: tokio_postgres::Error| ConnectError::create(e.to_string()))?;
464
465    let mut cfg = Config::new();
466    cfg.host = config.get_hosts().first().map(|h| match h {
467        tokio_postgres::config::Host::Tcp(s) => s.clone(),
468        #[cfg(unix)]
469        tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().into_owned(),
470    });
471    cfg.port = config.get_ports().first().copied();
472    cfg.dbname = config.get_dbname().map(String::from);
473    cfg.user = config.get_user().map(String::from);
474    cfg.password = match config.get_password() {
475        Some(p) => Some(
476            String::from_utf8(p.to_vec())
477                .map_err(|_| ConnectError::create("database password contains invalid UTF-8"))?,
478        ),
479        None => None,
480    };
481    cfg.connect_timeout = Some(std::time::Duration::from_secs(5));
482    cfg.manager = Some(ManagerConfig {
483        recycling_method: RecyclingMethod::Fast,
484    });
485    cfg.pool = Some(deadpool_postgres::PoolConfig {
486        max_size,
487        timeouts: deadpool_postgres::Timeouts {
488            wait: Some(std::time::Duration::ZERO),
489            create: None,
490            recycle: None,
491        },
492        ..Default::default()
493    });
494
495    let pool = cfg
496        .create_pool(Some(Runtime::Tokio1), NoTls)
497        .map_err(|e| ConnectError::create(e.to_string()))?;
498
499    // Verify connectivity
500    let _conn = pool
501        .get()
502        .await
503        .map_err(|e| ConnectError::with_source(format!("failed to connect to replica: {e}"), e))?;
504
505    Ok(pool)
506}
507
508/// Detect PgBouncer on the first connection from the pool.
509///
510/// Strategy: try `SHOW POOLS` -- only PgBouncer responds to this.
511/// If PgBouncer is detected, check `SHOW CONFIG` for `prepared_statements`.
512///
513/// FIX 11: returns `Err` if the initial connection fails, instead of silently
514/// returning `DIRECT`. A pool that can't connect on creation is broken.
515async fn detect_pgbouncer(pool: &deadpool_postgres::Pool) -> BsqlResult<PgBouncerInfo> {
516    let conn = pool.get().await.map_err(|e| {
517        ConnectError::with_source(format!("failed to establish initial connection: {e}"), e)
518    })?;
519
520    // PgBouncer responds to `SHOW POOLS`; PostgreSQL does not.
521    let is_pgbouncer = conn.simple_query("SHOW POOLS").await.is_ok();
522
523    if !is_pgbouncer {
524        return Ok(PgBouncerInfo::DIRECT);
525    }
526
527    // Check if PgBouncer supports named prepared statements (1.21+)
528    let supports_named = match conn.simple_query("SHOW CONFIG").await {
529        Ok(messages) => messages.iter().any(|msg| {
530            if let tokio_postgres::SimpleQueryMessage::Row(row) = msg {
531                row.get(0) == Some("prepared_statements") && row.get(1) == Some("yes")
532            } else {
533                false
534            }
535        }),
536        Err(_) => false,
537    };
538
539    Ok(PgBouncerInfo {
540        detected: true,
541        supports_named_stmts: supports_named,
542    })
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548
549    #[test]
550    fn builder_defaults() {
551        let b = Pool::builder();
552        assert_eq!(b.max_size, 16);
553        assert_eq!(b.connect_timeout_secs, 5);
554        assert!(b.replica_urls.is_empty());
555    }
556
557    #[test]
558    fn builder_config() {
559        let b = Pool::builder()
560            .host("localhost")
561            .port(5432)
562            .dbname("test")
563            .user("app")
564            .password("secret")
565            .max_size(8)
566            .connect_timeout(10);
567
568        assert_eq!(b.host.as_deref(), Some("localhost"));
569        assert_eq!(b.port, Some(5432));
570        assert_eq!(b.dbname.as_deref(), Some("test"));
571        assert_eq!(b.user.as_deref(), Some("app"));
572        assert_eq!(b.password.as_deref(), Some("secret"));
573        assert_eq!(b.max_size, 8);
574        assert_eq!(b.connect_timeout_secs, 10);
575    }
576
577    #[test]
578    fn builder_replicas() {
579        let b = Pool::builder()
580            .replica("postgres://replica1:5432/db")
581            .replica("postgres://replica2:5432/db");
582        assert_eq!(b.replica_urls.len(), 2);
583    }
584
585    #[test]
586    fn pgbouncer_direct_defaults() {
587        let info = PgBouncerInfo::DIRECT;
588        assert!(!info.detected);
589        assert!(info.supports_named_stmts);
590    }
591
592    #[test]
593    fn pool_status_type_is_copy() {
594        fn assert_copy<T: Copy>() {}
595        assert_copy::<PoolStatus>();
596    }
597}