Skip to main content

mcp_postgres/
pool.rs

1//! PostgreSQL connection pool — lock-free implementation.
2//!
3//! Uses `LockFreePool<tokio_postgres::Client>` internally.  No mutexes,
4//! no semaphores, no kernel transitions on the hot path — only CAS loops
5//! on `crossbeam::queue::ArrayQueue` and atomic size tracking.
6//!
7//! The `acquire()` method returns a `PooledConnection` which auto-returns
8//! to the pool on `Drop`.  There is no explicit `release()` needed.
9
10use std::time::Duration;
11use tokio_postgres::{Client, NoTls};
12use tracing::debug;
13
14use crate::config::PoolConfig;
15use crate::errors::{MCPError, Result as MCPResult};
16use crate::lockfree_pool::{
17    BoxFuture, CreateFn, LockFreePool, PoolConfig as LFPoolConfig, PoolError, PooledConnection,
18    ValidateFn,
19};
20
21/// Wrapper around the lock-free connection pool.
22pub struct ConnectionPool {
23    inner: LockFreePool<Client>,
24    max_size: u32,
25}
26
27impl ConnectionPool {
28    pub async fn new(connection_string: &str, config: PoolConfig) -> anyhow::Result<Self> {
29        Self::with_session_setup(connection_string, config, Duration::ZERO, false).await
30    }
31
32    /// Create a pool whose connections enforce a server-side `statement_timeout`.
33    ///
34    /// A non-zero `statement_timeout` caps how long any single query may run,
35    /// preventing a slow/runaway query from pinning a pooled connection
36    /// indefinitely. A value of `Duration::ZERO` leaves PostgreSQL's default
37    /// (unlimited) in place.
38    pub async fn with_statement_timeout(
39        connection_string: &str,
40        config: PoolConfig,
41        statement_timeout: Duration,
42    ) -> anyhow::Result<Self> {
43        Self::with_session_setup(connection_string, config, statement_timeout, false).await
44    }
45
46    /// Create a pool, optionally enforcing `statement_timeout` and a read-only
47    /// default transaction mode on every connection.
48    ///
49    /// When `read_only` is true, each connection runs
50    /// `SET default_transaction_read_only = on`, so every statement (including
51    /// volatile functions invoked from a SELECT) is rejected at the database if
52    /// it attempts a write — a stronger guarantee than tool-name filtering.
53    pub async fn with_session_setup(
54        connection_string: &str,
55        config: PoolConfig,
56        statement_timeout: Duration,
57        read_only: bool,
58    ) -> anyhow::Result<Self> {
59        debug!(
60            "Creating lock-free connection pool: max_size={}, statement_timeout={:?}, read_only={}",
61            config.max_size, statement_timeout, read_only
62        );
63
64        let conn_string = connection_string.to_string();
65        let create_timeout = Duration::from_secs(5);
66        let stmt_timeout_ms = statement_timeout.as_millis();
67
68        // Build a TLS connector once if the connection string opts into it.
69        let tls_connector = if crate::tls::wants_tls(&conn_string) {
70            Some(crate::tls::make_connector()?)
71        } else {
72            None
73        };
74
75        let create = {
76            let cs = conn_string.clone();
77            Box::new(move || {
78                let cs = cs.clone();
79                let tls = tls_connector.clone();
80                Box::pin(async move {
81                    let client = match tls {
82                        Some(tls) => {
83                            let (client, connection) = tokio_postgres::connect(&cs, tls)
84                                .await
85                                .map_err(|e| e.to_string())?;
86                            tokio::spawn(connection);
87                            client
88                        }
89                        None => {
90                            let (client, connection) = tokio_postgres::connect(&cs, NoTls)
91                                .await
92                                .map_err(|e| e.to_string())?;
93                            tokio::spawn(connection);
94                            client
95                        }
96                    };
97                    // Apply a per-connection statement_timeout so no single query
98                    // can hold a pooled connection forever. Session-level (not LOCAL)
99                    // so it persists for every query on this connection.
100                    if stmt_timeout_ms > 0 {
101                        client
102                            .batch_execute(&format!("SET statement_timeout TO '{stmt_timeout_ms}'"))
103                            .await
104                            .map_err(|e| e.to_string())?;
105                    }
106                    // Restricted (read-only) mode: enforce at the database so a
107                    // write attempted from any statement is rejected.
108                    if read_only {
109                        client
110                            .batch_execute("SET default_transaction_read_only = on")
111                            .await
112                            .map_err(|e| e.to_string())?;
113                    }
114                    Ok(client)
115                }) as BoxFuture<'static, Result<Client, String>>
116            }) as CreateFn<Client>
117        };
118
119        let validate = Box::new(|client: &Client| !client.is_closed()) as ValidateFn<Client>;
120
121        let lf_config = LFPoolConfig {
122            max_size: config.max_size,
123            create_timeout,
124            wait_timeout: config.queue_timeout,
125        };
126
127        let pool = LockFreePool::new(create, validate, &lf_config);
128
129        // Test the pool by acquiring a connection
130        let test_conn = pool
131            .acquire()
132            .await
133            .map_err(|e| anyhow::anyhow!("Failed to establish database connection: {e}"))?;
134        drop(test_conn);
135
136        Ok(Self {
137            inner: pool,
138            max_size: config.max_size,
139        })
140    }
141
142    /// Acquire a connection from the pool.
143    ///
144    /// Returns a `PooledConnection<Client>` which implements `Deref<Target = Client>`
145    /// and automatically returns to the pool when dropped.
146    pub async fn acquire(&self) -> MCPResult<PooledConnection<Client>> {
147        self.inner.acquire().await.map_err(|e| match e {
148            PoolError::Timeout => {
149                MCPError::PoolError("Connection pool timeout: no connection available".into())
150            }
151            PoolError::Closed => MCPError::PoolError("Connection pool is closed".into()),
152            PoolError::CreateFailed(msg) => {
153                MCPError::PoolError(format!("Failed to create connection: {msg}"))
154            }
155        })
156    }
157
158    /// Release a connection back to the pool.
159    ///
160    /// With `PooledConnection`, this is automatic on `Drop`.  This method
161    /// exists for backward compatibility with existing callers.
162    pub fn release(&self, _conn: PooledConnection<Client>) {
163        // Connection auto-returns to pool on Drop
164    }
165
166    pub fn active_count(&self) -> u32 {
167        self.inner.status().size
168    }
169
170    pub const fn max_size(&self) -> u32 {
171        self.max_size
172    }
173
174    pub fn is_closed(&self) -> bool {
175        self.inner.is_closed()
176    }
177
178    /// Close the pool, dropping all idle connections.
179    pub fn close(&self) {
180        self.inner.close();
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use tokio::time::sleep;
188
189    #[test]
190    fn test_config() {
191        let cfg = PoolConfig {
192            min_size: 2,
193            max_size: 10,
194            queue_timeout: Duration::from_secs(10),
195        };
196        assert!(cfg.max_size >= cfg.min_size);
197    }
198
199    #[tokio::test]
200    async fn test_pool_create_and_acquire() {
201        // This test requires a real PostgreSQL instance.
202        // It's a no-op if DATABASE_URL is not set.
203        if std::env::var("DATABASE_URL").is_err() && std::env::var("PGHOST").is_err() {
204            eprintln!("Skipping: no database available");
205            return;
206        }
207        let url = std::env::var("DATABASE_URL")
208            .unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/postgres".to_string());
209        let config = PoolConfig {
210            min_size: 1,
211            max_size: 5,
212            queue_timeout: Duration::from_secs(5),
213        };
214        let pool = ConnectionPool::new(&url, config).await.unwrap();
215        assert_eq!(pool.max_size(), 5);
216        let conn = pool.acquire().await.unwrap();
217        assert!(!conn.is_closed());
218        pool.release(conn);
219        sleep(Duration::from_millis(50)).await;
220        assert!(pool.active_count() > 0);
221    }
222}