Skip to main content

prax_postgres/
pool.rs

1//! Connection pool for PostgreSQL.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod, Runtime};
7use tokio_postgres::NoTls;
8use tracing::{debug, info};
9
10use crate::config::PgConfig;
11use crate::connection::PgConnection;
12use crate::error::{PgError, PgResult};
13use crate::statement::PreparedStatementCache;
14
15/// A connection pool for PostgreSQL.
16#[derive(Clone)]
17pub struct PgPool {
18    inner: Pool,
19    config: Arc<PgConfig>,
20    statement_cache: Arc<PreparedStatementCache>,
21}
22
23impl PgPool {
24    /// Create a new connection pool from configuration.
25    pub async fn new(config: PgConfig) -> PgResult<Self> {
26        Self::with_pool_config(config, PoolConfig::default()).await
27    }
28
29    /// Create a new connection pool with custom pool configuration.
30    pub async fn with_pool_config(config: PgConfig, pool_config: PoolConfig) -> PgResult<Self> {
31        let pg_config = config.to_pg_config();
32
33        let mgr_config = ManagerConfig {
34            recycling_method: RecyclingMethod::Fast,
35        };
36
37        let mgr = Manager::from_config(pg_config, NoTls, mgr_config);
38
39        // Build pool - set runtime to tokio for timeout support
40        let mut builder = Pool::builder(mgr).max_size(pool_config.max_connections);
41
42        // Only set timeouts if they are configured
43        if let Some(timeout) = pool_config.connection_timeout {
44            builder = builder
45                .wait_timeout(Some(timeout))
46                .create_timeout(Some(timeout));
47        }
48        if let Some(timeout) = pool_config.idle_timeout {
49            builder = builder.recycle_timeout(Some(timeout));
50        }
51
52        // Set runtime for timeout support
53        builder = builder.runtime(Runtime::Tokio1);
54
55        let pool = builder
56            .build()
57            .map_err(|e| PgError::config(format!("failed to create pool: {}", e)))?;
58
59        info!(
60            host = %config.host,
61            port = %config.port,
62            database = %config.database,
63            max_connections = %pool_config.max_connections,
64            "PostgreSQL connection pool created"
65        );
66
67        Ok(Self {
68            inner: pool,
69            config: Arc::new(config),
70            statement_cache: Arc::new(PreparedStatementCache::new(
71                pool_config.statement_cache_size,
72            )),
73        })
74    }
75
76    /// Get a connection from the pool.
77    pub async fn get(&self) -> PgResult<PgConnection> {
78        debug!("Acquiring connection from pool");
79        let client = self.inner.get().await?;
80        Ok(PgConnection::new(client, self.statement_cache.clone()))
81    }
82
83    /// Get the current pool status.
84    pub fn status(&self) -> PoolStatus {
85        let status = self.inner.status();
86        PoolStatus {
87            available: status.available,
88            size: status.size,
89            max_size: status.max_size,
90            waiting: status.waiting,
91        }
92    }
93
94    /// Get the pool configuration.
95    pub fn config(&self) -> &PgConfig {
96        &self.config
97    }
98
99    /// Check if the pool is healthy by attempting to get a connection.
100    pub async fn is_healthy(&self) -> bool {
101        match self.inner.get().await {
102            Ok(client) => {
103                // Try a simple query to verify the connection is actually working
104                client.query_one("SELECT 1", &[]).await.is_ok()
105            }
106            Err(_) => false,
107        }
108    }
109
110    /// Close the pool and all connections.
111    pub fn close(&self) {
112        self.inner.close();
113        info!("PostgreSQL connection pool closed");
114    }
115
116    /// Create a builder for configuring the pool.
117    pub fn builder() -> PgPoolBuilder {
118        PgPoolBuilder::new()
119    }
120
121    /// Warm up the connection pool by pre-establishing connections.
122    ///
123    /// This eliminates the latency of establishing connections on the first queries.
124    /// The `count` parameter specifies how many connections to pre-establish.
125    ///
126    /// # Example
127    ///
128    /// ```rust,ignore
129    /// let pool = PgPool::builder()
130    ///     .url("postgresql://localhost/db")
131    ///     .max_connections(10)
132    ///     .build()
133    ///     .await?;
134    ///
135    /// // Pre-establish 5 connections
136    /// pool.warmup(5).await?;
137    /// ```
138    pub async fn warmup(&self, count: usize) -> PgResult<()> {
139        info!(count = count, "Warming up connection pool");
140
141        let count = count.min(self.inner.status().max_size);
142        let mut connections = Vec::with_capacity(count);
143
144        // Acquire connections to force establishment
145        for i in 0..count {
146            match self.inner.get().await {
147                Ok(conn) => {
148                    // Validate the connection with a simple query
149                    if let Err(e) = conn.query_one("SELECT 1", &[]).await {
150                        debug!(error = %e, "Warmup connection {} failed validation", i);
151                    } else {
152                        debug!("Warmup connection {} established", i);
153                        connections.push(conn);
154                    }
155                }
156                Err(e) => {
157                    debug!(error = %e, "Failed to establish warmup connection {}", i);
158                }
159            }
160        }
161
162        // Connections are returned to pool when dropped
163        let established = connections.len();
164        drop(connections);
165
166        info!(
167            established = established,
168            requested = count,
169            "Connection pool warmup complete"
170        );
171
172        Ok(())
173    }
174
175    /// Warm up with common prepared statements.
176    ///
177    /// This pre-prepares common SQL statements on warmed connections,
178    /// eliminating the prepare latency on first use.
179    pub async fn warmup_with_statements(&self, count: usize, statements: &[&str]) -> PgResult<()> {
180        info!(
181            count = count,
182            statements = statements.len(),
183            "Warming up connection pool with prepared statements"
184        );
185
186        let count = count.min(self.inner.status().max_size);
187        let mut connections = Vec::with_capacity(count);
188
189        for i in 0..count {
190            match self.inner.get().await {
191                Ok(conn) => {
192                    // Pre-prepare all statements
193                    for sql in statements {
194                        if let Err(e) = conn.prepare_cached(sql).await {
195                            debug!(error = %e, sql = %sql, "Failed to prepare statement");
196                        }
197                    }
198                    debug!(
199                        connection = i,
200                        statements = statements.len(),
201                        "Prepared statements on connection"
202                    );
203                    connections.push(conn);
204                }
205                Err(e) => {
206                    debug!(error = %e, "Failed to establish warmup connection {}", i);
207                }
208            }
209        }
210
211        let established = connections.len();
212        drop(connections);
213
214        info!(
215            established = established,
216            "Connection pool warmup with statements complete"
217        );
218
219        Ok(())
220    }
221}
222
223/// Pool status information.
224#[derive(Debug, Clone)]
225pub struct PoolStatus {
226    /// Number of available (idle) connections.
227    pub available: usize,
228    /// Current total size of the pool.
229    pub size: usize,
230    /// Maximum size of the pool.
231    pub max_size: usize,
232    /// Number of tasks waiting for a connection.
233    pub waiting: usize,
234}
235
236/// Configuration for the connection pool.
237#[derive(Debug, Clone)]
238pub struct PoolConfig {
239    /// Maximum number of connections in the pool.
240    pub max_connections: usize,
241    /// Minimum number of connections to keep alive.
242    pub min_connections: usize,
243    /// Maximum time to wait for a connection.
244    pub connection_timeout: Option<Duration>,
245    /// Maximum idle time before a connection is closed.
246    pub idle_timeout: Option<Duration>,
247    /// Maximum lifetime of a connection.
248    pub max_lifetime: Option<Duration>,
249    /// Size of the prepared statement cache per connection.
250    pub statement_cache_size: usize,
251}
252
253impl Default for PoolConfig {
254    fn default() -> Self {
255        Self {
256            max_connections: 10,
257            min_connections: 1,
258            connection_timeout: Some(Duration::from_secs(30)),
259            idle_timeout: Some(Duration::from_secs(600)), // 10 minutes
260            max_lifetime: Some(Duration::from_secs(1800)), // 30 minutes
261            statement_cache_size: 100,
262        }
263    }
264}
265
266/// Builder for creating a connection pool.
267#[derive(Debug, Default)]
268pub struct PgPoolBuilder {
269    config: Option<PgConfig>,
270    url: Option<String>,
271    pool_config: PoolConfig,
272}
273
274impl PgPoolBuilder {
275    /// Create a new pool builder.
276    pub fn new() -> Self {
277        Self {
278            config: None,
279            url: None,
280            pool_config: PoolConfig::default(),
281        }
282    }
283
284    /// Set the database URL.
285    pub fn url(mut self, url: impl Into<String>) -> Self {
286        self.url = Some(url.into());
287        self
288    }
289
290    /// Set the configuration.
291    pub fn config(mut self, config: PgConfig) -> Self {
292        self.config = Some(config);
293        self
294    }
295
296    /// Set the maximum number of connections.
297    pub fn max_connections(mut self, n: usize) -> Self {
298        self.pool_config.max_connections = n;
299        self
300    }
301
302    /// Set the minimum number of connections.
303    pub fn min_connections(mut self, n: usize) -> Self {
304        self.pool_config.min_connections = n;
305        self
306    }
307
308    /// Set the connection timeout.
309    pub fn connection_timeout(mut self, timeout: Duration) -> Self {
310        self.pool_config.connection_timeout = Some(timeout);
311        self
312    }
313
314    /// Set the idle timeout.
315    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
316        self.pool_config.idle_timeout = Some(timeout);
317        self
318    }
319
320    /// Set the maximum connection lifetime.
321    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
322        self.pool_config.max_lifetime = Some(lifetime);
323        self
324    }
325
326    /// Set the prepared statement cache size.
327    pub fn statement_cache_size(mut self, size: usize) -> Self {
328        self.pool_config.statement_cache_size = size;
329        self
330    }
331
332    /// Build the connection pool.
333    pub async fn build(self) -> PgResult<PgPool> {
334        let config = if let Some(config) = self.config {
335            config
336        } else if let Some(url) = self.url {
337            PgConfig::from_url(url)?
338        } else {
339            return Err(PgError::config("no database URL or config provided"));
340        };
341
342        PgPool::with_pool_config(config, self.pool_config).await
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn test_pool_config_default() {
352        let config = PoolConfig::default();
353        assert_eq!(config.max_connections, 10);
354        assert_eq!(config.min_connections, 1);
355        assert_eq!(config.statement_cache_size, 100);
356    }
357
358    #[test]
359    fn test_pool_builder() {
360        let builder = PgPoolBuilder::new()
361            .url("postgresql://localhost/test")
362            .max_connections(20)
363            .statement_cache_size(200);
364
365        assert!(builder.url.is_some());
366        assert_eq!(builder.pool_config.max_connections, 20);
367        assert_eq!(builder.pool_config.statement_cache_size, 200);
368    }
369}