Skip to main content

hyperdb_api/
pool.rs

1// Copyright (c) 2026, Salesforce, Inc. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Async connection pool for Hyper database.
5//!
6//! This module provides connection pooling via [`deadpool`] for efficient
7//! connection reuse in async applications.
8//!
9//! # Example
10//!
11//! ```no_run
12//! use hyperdb_api::pool::{create_pool, PoolConfig};
13//! use hyperdb_api::CreateMode;
14//!
15//! #[tokio::main]
16//! async fn main() -> hyperdb_api::Result<()> {
17//!     // Create a pool configuration
18//!     let config = PoolConfig::new("localhost:7483", "example.hyper")
19//!         .create_mode(CreateMode::CreateIfNotExists)
20//!         .max_size(16);
21//!
22//!     // Build the pool
23//!     let pool = create_pool(config)?;
24//!
25//!     // Get a connection from the pool
26//!     let conn = pool.get().await.map_err(|e| hyperdb_api::Error::new(e.to_string()))?;
27//!
28//!     // Use the connection
29//!     conn.execute_command("SELECT 1").await?;
30//!
31//!     // Connection is returned to pool when dropped
32//!     Ok(())
33//! }
34//! ```
35//!
36//! # Lifecycle hooks
37//!
38//! `PoolConfig` supports two async lifecycle hooks for users who need to
39//! customize per-connection or per-checkout behavior:
40//!
41//! - `after_connect` runs once on every newly-opened connection (useful for
42//!   `SET search_path`, prepared-statement warmup, etc.)
43//! - `before_acquire` runs every time a connection is checked out (useful
44//!   for session reset, telemetry, custom health checks)
45//!
46//! `health_check(bool)` toggles the default per-checkout `SELECT 1` probe —
47//! disable it on hot paths where the roundtrip cost outweighs the value of
48//! catching a half-dead connection at acquire time.
49//!
50//! ```no_run
51//! use hyperdb_api::pool::{create_pool, PoolConfig};
52//! use hyperdb_api::CreateMode;
53//!
54//! # #[tokio::main]
55//! # async fn main() -> hyperdb_api::Result<()> {
56//! let config = PoolConfig::new("localhost:7483", "example.hyper")
57//!     .create_mode(CreateMode::CreateIfNotExists)
58//!     .max_size(16)
59//!     .health_check(false) // skip per-checkout SELECT 1
60//!     .after_connect(|conn| Box::pin(async move {
61//!         conn.execute_command("SET search_path TO public").await?;
62//!         Ok(())
63//!     }));
64//! let _pool = create_pool(config)?;
65//! # Ok(())
66//! # }
67//! ```
68
69use std::pin::Pin;
70use std::sync::Arc;
71
72use deadpool::managed::{self, Manager, Metrics, RecycleError, RecycleResult};
73use tokio::sync::Mutex as AsyncMutex;
74
75use crate::async_connection::AsyncConnection;
76use crate::error::{Error, Result};
77use crate::CreateMode;
78
79/// Future returned by pool lifecycle hooks.
80pub type HookFuture<'a> = Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>>;
81
82/// A hook that runs once on every newly-opened connection (after authentication
83/// and any database-creation handshake). Use it to set session variables, install
84/// statement caches, warm prepared statements, etc.
85///
86/// Returning `Err` from the hook causes pool creation to fail and the connection
87/// to be dropped.
88pub type AfterConnectHook = Arc<dyn Fn(&AsyncConnection) -> HookFuture<'_> + Send + Sync + 'static>;
89
90/// A hook that runs every time a connection is checked out of the pool, before
91/// it is handed to the caller. Use it for per-acquire health checks, session
92/// resets, or telemetry.
93///
94/// Returning `Err` from the hook causes the connection to be evicted (the pool
95/// retries with another connection or builds a new one).
96pub type BeforeAcquireHook =
97    Arc<dyn Fn(&AsyncConnection) -> HookFuture<'_> + Send + Sync + 'static>;
98
99/// Configuration for the connection pool.
100#[derive(Clone)]
101pub struct PoolConfig {
102    /// Server endpoint (e.g., "localhost:7483" or "<http://localhost:7484>")
103    pub endpoint: String,
104    /// Database path
105    pub database: String,
106    /// Database creation mode (only used for first connection)
107    pub create_mode: CreateMode,
108    /// Optional username for authentication
109    pub user: Option<String>,
110    /// Optional password for authentication
111    pub password: Option<String>,
112    /// Maximum number of connections in the pool
113    pub max_size: usize,
114    /// If `false`, skip the per-checkout `SELECT 1` health probe in `recycle()`.
115    /// Defaults to `true`. Disable for hot paths where the roundtrip cost matters
116    /// more than detecting a half-dead connection at acquire time. The pool
117    /// still reaps connections via [`AsyncConnection::is_alive`].
118    pub health_check: bool,
119    /// Optional hook run on every newly-opened connection (see [`AfterConnectHook`]).
120    pub after_connect: Option<AfterConnectHook>,
121    /// Optional hook run on every checkout (see [`BeforeAcquireHook`]).
122    pub before_acquire: Option<BeforeAcquireHook>,
123}
124
125impl std::fmt::Debug for PoolConfig {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        f.debug_struct("PoolConfig")
128            .field("endpoint", &self.endpoint)
129            .field("database", &self.database)
130            .field("create_mode", &self.create_mode)
131            .field("user", &self.user)
132            .field("password", &self.password.as_ref().map(|_| "<redacted>"))
133            .field("max_size", &self.max_size)
134            .field("health_check", &self.health_check)
135            .field(
136                "after_connect",
137                &self.after_connect.as_ref().map(|_| "<fn>"),
138            )
139            .field(
140                "before_acquire",
141                &self.before_acquire.as_ref().map(|_| "<fn>"),
142            )
143            .finish()
144    }
145}
146
147impl PoolConfig {
148    /// Creates a new pool configuration.
149    pub fn new(endpoint: impl Into<String>, database: impl Into<String>) -> Self {
150        Self {
151            endpoint: endpoint.into(),
152            database: database.into(),
153            create_mode: CreateMode::DoNotCreate,
154            user: None,
155            password: None,
156            max_size: 16,
157            health_check: true,
158            after_connect: None,
159            before_acquire: None,
160        }
161    }
162
163    /// Sets the database creation mode.
164    #[must_use]
165    pub fn create_mode(mut self, mode: CreateMode) -> Self {
166        self.create_mode = mode;
167        self
168    }
169
170    #[must_use]
171    /// Sets authentication credentials.
172    pub fn auth(mut self, user: impl Into<String>, password: impl Into<String>) -> Self {
173        self.user = Some(user.into());
174        self.password = Some(password.into());
175        self
176    }
177
178    /// Sets the maximum pool size.
179    #[must_use]
180    pub fn max_size(mut self, size: usize) -> Self {
181        self.max_size = size;
182        self
183    }
184
185    /// Enables or disables the per-checkout `SELECT 1` health probe.
186    /// Defaults to enabled. Disable on hot paths where the roundtrip cost
187    /// outweighs the value of catching a dead connection at acquire time.
188    #[must_use]
189    pub fn health_check(mut self, enabled: bool) -> Self {
190        self.health_check = enabled;
191        self
192    }
193
194    /// Installs a hook that runs on every newly-opened connection.
195    ///
196    /// Use this to apply session-level setup (e.g. `SET search_path`, install
197    /// prepared statements). The hook is called once per physical connection,
198    /// not per checkout.
199    #[must_use]
200    pub fn after_connect<F>(mut self, hook: F) -> Self
201    where
202        F: Fn(&AsyncConnection) -> HookFuture<'_> + Send + Sync + 'static,
203    {
204        self.after_connect = Some(Arc::new(hook));
205        self
206    }
207
208    /// Installs a hook that runs on every connection checkout, before the
209    /// connection is handed to the caller.
210    ///
211    /// Returning `Err` from the hook evicts the connection from the pool;
212    /// the caller's `pool.get()` then retries with another connection or
213    /// builds a new one. Use this for per-acquire health checks beyond the
214    /// default `SELECT 1` probe (e.g. validating session state).
215    #[must_use]
216    pub fn before_acquire<F>(mut self, hook: F) -> Self
217    where
218        F: Fn(&AsyncConnection) -> HookFuture<'_> + Send + Sync + 'static,
219    {
220        self.before_acquire = Some(Arc::new(hook));
221        self
222    }
223}
224
225/// Connection pool manager for `AsyncConnection`.
226///
227/// The first call to [`Manager::create`] holds an async mutex while attempting
228/// to open a connection with the configured [`CreateMode`]. Concurrent callers
229/// wait for that attempt to finish, then use `CreateMode::DoNotCreate`. If the
230/// first attempt fails, the next caller retries with the original create_mode
231/// (for idempotent modes only — `Create` is not retried because a sibling
232/// connection may have already created the database).
233#[derive(Debug)]
234pub struct ConnectionManager {
235    config: Arc<PoolConfig>,
236    /// Synchronizes the first-connection attempt across concurrent callers.
237    /// `Some(())` after the first successful attempt; held while a first
238    /// attempt is in progress to serialize concurrent races. The value is the
239    /// outcome of the first call (the database is now known to exist).
240    init_lock: Arc<AsyncMutex<bool>>,
241}
242
243impl ConnectionManager {
244    /// Creates a new connection manager.
245    #[must_use]
246    pub fn new(config: PoolConfig) -> Self {
247        Self {
248            config: Arc::new(config),
249            init_lock: Arc::new(AsyncMutex::new(false)),
250        }
251    }
252
253    async fn open(&self, mode: CreateMode) -> Result<AsyncConnection> {
254        if let (Some(user), Some(password)) = (&self.config.user, &self.config.password) {
255            AsyncConnection::connect_with_auth(
256                &self.config.endpoint,
257                &self.config.database,
258                mode,
259                user,
260                password,
261            )
262            .await
263        } else {
264            AsyncConnection::connect(&self.config.endpoint, &self.config.database, mode).await
265        }
266    }
267}
268
269impl Manager for ConnectionManager {
270    type Type = AsyncConnection;
271    type Error = Error;
272
273    async fn create(&self) -> Result<AsyncConnection> {
274        // Fast path: if the first connection already succeeded, just open with
275        // DoNotCreate. We hold the lock briefly to read the flag.
276        // (Lock is uncontended after the first connection — fast path is cheap.)
277        let conn = {
278            let initialized = self.init_lock.lock().await;
279            if *initialized {
280                drop(initialized);
281                self.open(CreateMode::DoNotCreate).await?
282            } else {
283                drop(initialized);
284                // Slow path: first creation. Acquire the lock and re-check (in
285                // case another waiter raced us), then attempt with the
286                // configured mode.
287                let mut initialized = self.init_lock.lock().await;
288                if *initialized {
289                    drop(initialized);
290                    self.open(CreateMode::DoNotCreate).await?
291                } else {
292                    let result = self.open(self.config.create_mode).await;
293                    if result.is_ok() {
294                        *initialized = true;
295                    }
296                    // On failure leave `initialized = false` so the next caller
297                    // retries with the original create_mode.
298                    result?
299                }
300            }
301        };
302
303        // Run the after_connect hook (if any) before handing the connection
304        // to the pool. Hook errors propagate as connection-creation errors.
305        if let Some(hook) = self.config.after_connect.as_ref() {
306            hook(&conn).await?;
307        }
308        Ok(conn)
309    }
310
311    async fn recycle(
312        &self,
313        conn: &mut AsyncConnection,
314        _metrics: &Metrics,
315    ) -> RecycleResult<Self::Error> {
316        // Optional `SELECT 1` health probe. Off by default if the user
317        // disables it via PoolConfig::health_check(false).
318        if self.config.health_check {
319            conn.execute_command("SELECT 1")
320                .await
321                .map_err(RecycleError::Backend)?;
322        }
323        // Per-checkout user hook (e.g. session reset, telemetry).
324        if let Some(hook) = self.config.before_acquire.as_ref() {
325            hook(conn).await.map_err(RecycleError::Backend)?;
326        }
327        Ok(())
328    }
329}
330
331/// A pool of async connections to a Hyper database.
332///
333/// This pool manages a set of reusable connections, automatically creating
334/// new connections when needed and recycling them after use.
335pub type Pool = managed::Pool<ConnectionManager>;
336
337/// A pooled connection wrapper.
338pub type PooledConnection = managed::Object<ConnectionManager>;
339
340/// Creates a new connection pool from configuration.
341///
342/// # Errors
343///
344/// Returns [`Error::Other`] wrapping the `deadpool` builder failure if
345/// the pool cannot be constructed (e.g. invalid `max_size`). Connections
346/// themselves are opened lazily on first use, so endpoint/auth errors
347/// surface from [`Pool::get`](managed::Pool::get), not here.
348pub fn create_pool(config: PoolConfig) -> Result<Pool> {
349    let max_size = config.max_size;
350    let manager = ConnectionManager::new(config);
351    Pool::builder(manager)
352        .max_size(max_size)
353        .build()
354        .map_err(|e| Error::new(format!("Failed to create pool: {e}")))
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_pool_config_builder() {
363        let config = PoolConfig::new("localhost:7483", "test.hyper")
364            .create_mode(CreateMode::CreateIfNotExists)
365            .auth("user", "pass")
366            .max_size(32);
367
368        assert_eq!(config.endpoint, "localhost:7483");
369        assert_eq!(config.database, "test.hyper");
370        assert_eq!(config.create_mode, CreateMode::CreateIfNotExists);
371        assert_eq!(config.user, Some("user".to_string()));
372        assert_eq!(config.password, Some("pass".to_string()));
373        assert_eq!(config.max_size, 32);
374    }
375}