Skip to main content

bsql_core/
pool.rs

1//! Connection pool with fail-fast semantics and PgBouncer detection.
2//!
3//! The pool wraps `deadpool-postgres` with two key behaviors:
4//! - **Fail-fast**: `acquire()` returns `PoolExhausted` immediately when no
5//!   connections are available. It does not wait. See CREDO principle #17.
6//! - **PgBouncer detection**: on pool creation, bsql detects whether the
7//!   connection goes through PgBouncer and adjusts prepared statement strategy.
8
9use deadpool_postgres::{Config, ManagerConfig, RecyclingMethod, Runtime};
10use tokio_postgres::NoTls;
11use tokio_postgres::types::ToSql;
12
13use crate::error::{BsqlError, BsqlResult, ConnectError};
14use crate::stream::QueryStream;
15use crate::transaction::Transaction;
16
17/// A PostgreSQL connection pool.
18///
19/// Wraps `deadpool-postgres` with fail-fast acquire semantics.
20/// All connections are returned to the pool when `PoolConnection` is dropped.
21pub struct Pool {
22    inner: deadpool_postgres::Pool,
23    pgbouncer: PgBouncerInfo,
24}
25
26/// PgBouncer detection result.
27#[derive(Debug, Clone, Copy)]
28pub(crate) struct PgBouncerInfo {
29    /// True if PgBouncer was detected between the client and PostgreSQL.
30    detected: bool,
31    /// True if PgBouncer supports server-side prepared statement tracking
32    /// (PgBouncer 1.21+ with `prepared_statements=yes`).
33    supports_named_stmts: bool,
34}
35
36impl PgBouncerInfo {
37    const DIRECT: Self = Self {
38        detected: false,
39        supports_named_stmts: true,
40    };
41}
42
43/// Builder for configuring a connection pool.
44pub struct PoolBuilder {
45    host: Option<String>,
46    port: Option<u16>,
47    dbname: Option<String>,
48    user: Option<String>,
49    password: Option<String>,
50    max_size: usize,
51    connect_timeout_secs: u64,
52}
53
54impl PoolBuilder {
55    pub fn host(mut self, host: &str) -> Self {
56        self.host = Some(host.into());
57        self
58    }
59
60    pub fn port(mut self, port: u16) -> Self {
61        self.port = Some(port);
62        self
63    }
64
65    pub fn dbname(mut self, dbname: &str) -> Self {
66        self.dbname = Some(dbname.into());
67        self
68    }
69
70    pub fn user(mut self, user: &str) -> Self {
71        self.user = Some(user.into());
72        self
73    }
74
75    pub fn password(mut self, password: &str) -> Self {
76        self.password = Some(password.into());
77        self
78    }
79
80    pub fn max_size(mut self, size: usize) -> Self {
81        self.max_size = size;
82        self
83    }
84
85    /// TCP connect timeout in seconds. This is the ONLY timeout in bsql —
86    /// it exists because TCP itself will wait forever on a dead network.
87    pub fn connect_timeout(mut self, secs: u64) -> Self {
88        self.connect_timeout_secs = secs;
89        self
90    }
91
92    pub async fn build(self) -> BsqlResult<Pool> {
93        let mut cfg = Config::new();
94        cfg.host = self.host;
95        cfg.port = self.port;
96        cfg.dbname = self.dbname;
97        cfg.user = self.user;
98        cfg.password = self.password;
99        cfg.connect_timeout = Some(std::time::Duration::from_secs(self.connect_timeout_secs));
100        cfg.manager = Some(ManagerConfig {
101            recycling_method: RecyclingMethod::Fast,
102        });
103        // FIX 2: fail-fast — zero wait timeout means acquire() never blocks
104        cfg.pool = Some(deadpool_postgres::PoolConfig {
105            max_size: self.max_size,
106            timeouts: deadpool_postgres::Timeouts {
107                wait: Some(std::time::Duration::ZERO),
108                create: None,
109                recycle: None,
110            },
111            ..Default::default()
112        });
113
114        let pool = cfg
115            .create_pool(Some(Runtime::Tokio1), NoTls)
116            .map_err(|e| ConnectError::create(e.to_string()))?;
117
118        // FIX 11: detect PgBouncer — propagate connection failure
119        let pgbouncer = detect_pgbouncer(&pool).await?;
120
121        Ok(Pool {
122            inner: pool,
123            pgbouncer,
124        })
125    }
126}
127
128impl Pool {
129    /// Connect to PostgreSQL using a connection URL.
130    ///
131    /// Format: `postgres://user:password@host:port/dbname`
132    pub async fn connect(url: &str) -> BsqlResult<Self> {
133        let config: tokio_postgres::Config = url
134            .parse()
135            .map_err(|e: tokio_postgres::Error| ConnectError::create(e.to_string()))?;
136
137        let mut cfg = Config::new();
138        cfg.host = config.get_hosts().first().map(|h| match h {
139            tokio_postgres::config::Host::Tcp(s) => s.clone(),
140            #[cfg(unix)]
141            tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().into_owned(),
142        });
143        cfg.port = config.get_ports().first().copied();
144        cfg.dbname = config.get_dbname().map(String::from);
145        cfg.user = config.get_user().map(String::from);
146        cfg.password = config
147            .get_password()
148            .map(|p| String::from_utf8_lossy(p).into_owned());
149        cfg.connect_timeout = Some(std::time::Duration::from_secs(5));
150        cfg.manager = Some(ManagerConfig {
151            recycling_method: RecyclingMethod::Fast,
152        });
153        // FIX 2: fail-fast — zero wait timeout means acquire() never blocks
154        cfg.pool = Some(deadpool_postgres::PoolConfig {
155            max_size: 16,
156            timeouts: deadpool_postgres::Timeouts {
157                wait: Some(std::time::Duration::ZERO),
158                create: None,
159                recycle: None,
160            },
161            ..Default::default()
162        });
163
164        let pool = cfg
165            .create_pool(Some(Runtime::Tokio1), NoTls)
166            .map_err(|e| ConnectError::create(e.to_string()))?;
167
168        // FIX 11: detect PgBouncer — propagate connection failure
169        let pgbouncer = detect_pgbouncer(&pool).await?;
170
171        Ok(Pool {
172            inner: pool,
173            pgbouncer,
174        })
175    }
176
177    /// Create a pool builder for fine-grained configuration.
178    pub fn builder() -> PoolBuilder {
179        PoolBuilder {
180            host: None,
181            port: None,
182            dbname: None,
183            user: None,
184            password: None,
185            max_size: 16,
186            connect_timeout_secs: 5,
187        }
188    }
189
190    /// Acquire a connection from the pool.
191    ///
192    /// **Fail-fast**: returns `BsqlError::Pool` immediately if no connections
193    /// are available. Does not wait. Does not timeout. See CREDO principle #17.
194    pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
195        let conn = self.inner.get().await.map_err(BsqlError::from)?;
196
197        Ok(PoolConnection {
198            inner: conn,
199            pgbouncer: self.pgbouncer,
200        })
201    }
202
203    /// Whether PgBouncer was detected between the client and PostgreSQL.
204    pub fn is_pgbouncer(&self) -> bool {
205        self.pgbouncer.detected
206    }
207
208    /// Whether named prepared statements can be used.
209    ///
210    /// False when PgBouncer is detected without `prepared_statements=yes`.
211    pub fn supports_named_statements(&self) -> bool {
212        self.pgbouncer.supports_named_stmts
213    }
214
215    /// Begin a new transaction.
216    ///
217    /// Acquires a connection from the pool and sends `BEGIN`. The connection
218    /// is held for the lifetime of the returned [`Transaction`].
219    ///
220    /// **Fail-fast**: returns `BsqlError::Pool` immediately if no connections
221    /// are available. See CREDO principle #17.
222    pub async fn begin(&self) -> BsqlResult<Transaction> {
223        let conn = self.acquire().await?;
224        conn.inner
225            .batch_execute("BEGIN")
226            .await
227            .map_err(BsqlError::from)?;
228        Ok(Transaction::new(conn))
229    }
230
231    /// Execute a query and return a stream of rows.
232    ///
233    /// Acquires a connection from the pool and returns a [`QueryStream`]
234    /// that holds the connection alive until the stream is consumed or
235    /// dropped. Rows arrive one at a time, avoiding buffering the
236    /// entire result set in memory.
237    ///
238    /// **Fail-fast**: returns `BsqlError::Pool` immediately if no connections
239    /// are available. See CREDO principle #17.
240    ///
241    /// This method is only available on `Pool` (not `PoolConnection` or
242    /// `Transaction`) because the stream must own the connection for its
243    /// entire lifetime.
244    pub async fn query_stream(
245        &self,
246        sql: &str,
247        params: &[&(dyn ToSql + Sync)],
248    ) -> BsqlResult<QueryStream> {
249        let conn = self.acquire().await?;
250        let stmt = conn
251            .inner
252            .prepare_cached(sql)
253            .await
254            .map_err(BsqlError::from)?;
255
256        let row_stream = conn
257            .inner
258            .query_raw(&stmt, params.iter().copied())
259            .await
260            .map_err(BsqlError::from)?;
261
262        Ok(QueryStream::new(conn, row_stream))
263    }
264
265    /// Current pool status: available and total connections.
266    pub fn status(&self) -> PoolStatus {
267        let status = self.inner.status();
268        PoolStatus {
269            available: status.available,
270            size: status.size,
271            max_size: status.max_size,
272        }
273    }
274}
275
276/// A connection borrowed from the pool.
277///
278/// Returned to the pool when dropped.
279pub struct PoolConnection {
280    pub(crate) inner: deadpool_postgres::Object,
281    pub(crate) pgbouncer: PgBouncerInfo,
282}
283
284impl PoolConnection {
285    /// Whether named prepared statements can be used on this connection.
286    pub fn supports_named_statements(&self) -> bool {
287        self.pgbouncer.supports_named_stmts
288    }
289}
290
291/// Snapshot of pool utilization.
292#[derive(Debug, Clone, Copy)]
293pub struct PoolStatus {
294    pub available: usize,
295    pub size: usize,
296    pub max_size: usize,
297}
298
299/// Detect PgBouncer on the first connection from the pool.
300///
301/// Strategy: try `SHOW POOLS` — only PgBouncer responds to this.
302/// If PgBouncer is detected, check `SHOW CONFIG` for `prepared_statements`.
303///
304/// FIX 11: returns `Err` if the initial connection fails, instead of silently
305/// returning `DIRECT`. A pool that can't connect on creation is broken.
306async fn detect_pgbouncer(pool: &deadpool_postgres::Pool) -> BsqlResult<PgBouncerInfo> {
307    let conn = pool.get().await.map_err(|e| {
308        ConnectError::with_source(format!("failed to establish initial connection: {e}"), e)
309    })?;
310
311    // PgBouncer responds to `SHOW POOLS`; PostgreSQL does not.
312    let is_pgbouncer = conn.simple_query("SHOW POOLS").await.is_ok();
313
314    if !is_pgbouncer {
315        return Ok(PgBouncerInfo::DIRECT);
316    }
317
318    // Check if PgBouncer supports named prepared statements (1.21+)
319    let supports_named = match conn.simple_query("SHOW CONFIG").await {
320        Ok(messages) => messages.iter().any(|msg| {
321            if let tokio_postgres::SimpleQueryMessage::Row(row) = msg {
322                row.get(0) == Some("prepared_statements") && row.get(1) == Some("yes")
323            } else {
324                false
325            }
326        }),
327        Err(_) => false,
328    };
329
330    Ok(PgBouncerInfo {
331        detected: true,
332        supports_named_stmts: supports_named,
333    })
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn builder_defaults() {
342        let b = Pool::builder();
343        assert_eq!(b.max_size, 16);
344        assert_eq!(b.connect_timeout_secs, 5);
345    }
346
347    #[test]
348    fn builder_config() {
349        let b = Pool::builder()
350            .host("localhost")
351            .port(5432)
352            .dbname("test")
353            .user("app")
354            .password("secret")
355            .max_size(8)
356            .connect_timeout(10);
357
358        assert_eq!(b.host.as_deref(), Some("localhost"));
359        assert_eq!(b.port, Some(5432));
360        assert_eq!(b.dbname.as_deref(), Some("test"));
361        assert_eq!(b.user.as_deref(), Some("app"));
362        assert_eq!(b.password.as_deref(), Some("secret"));
363        assert_eq!(b.max_size, 8);
364        assert_eq!(b.connect_timeout_secs, 10);
365    }
366
367    #[test]
368    fn pgbouncer_direct_defaults() {
369        let info = PgBouncerInfo::DIRECT;
370        assert!(!info.detected);
371        assert!(info.supports_named_stmts);
372    }
373
374    #[test]
375    fn pool_status_type_is_copy() {
376        fn assert_copy<T: Copy>() {}
377        assert_copy::<PoolStatus>();
378    }
379}