1use std::time::Duration;
11use tokio_postgres::{Client, NoTls};
12use tracing::debug;
13
14use crate::config::PoolConfig;
15use crate::errors::{MCPError, Result as MCPResult};
16use crate::lockfree_pool::{
17 BoxFuture, CreateFn, LockFreePool, PoolConfig as LFPoolConfig, PoolError, PooledConnection,
18 ValidateFn,
19};
20
21pub struct ConnectionPool {
23 inner: LockFreePool<Client>,
24 max_size: u32,
25}
26
27impl ConnectionPool {
28 pub async fn new(connection_string: &str, config: PoolConfig) -> anyhow::Result<Self> {
29 Self::with_session_setup(connection_string, config, Duration::ZERO, false).await
30 }
31
32 pub async fn with_statement_timeout(
39 connection_string: &str,
40 config: PoolConfig,
41 statement_timeout: Duration,
42 ) -> anyhow::Result<Self> {
43 Self::with_session_setup(connection_string, config, statement_timeout, false).await
44 }
45
46 pub async fn with_session_setup(
54 connection_string: &str,
55 config: PoolConfig,
56 statement_timeout: Duration,
57 read_only: bool,
58 ) -> anyhow::Result<Self> {
59 debug!(
60 "Creating lock-free connection pool: max_size={}, statement_timeout={:?}, read_only={}",
61 config.max_size, statement_timeout, read_only
62 );
63
64 let conn_string = connection_string.to_string();
65 let create_timeout = Duration::from_secs(5);
66 let stmt_timeout_ms = statement_timeout.as_millis();
67
68 let tls_connector = if crate::tls::wants_tls(&conn_string) {
70 Some(crate::tls::make_connector()?)
71 } else {
72 None
73 };
74
75 let create = {
76 let cs = conn_string.clone();
77 Box::new(move || {
78 let cs = cs.clone();
79 let tls = tls_connector.clone();
80 Box::pin(async move {
81 let client = match tls {
82 Some(tls) => {
83 let (client, connection) = tokio_postgres::connect(&cs, tls)
84 .await
85 .map_err(|e| e.to_string())?;
86 tokio::spawn(connection);
87 client
88 }
89 None => {
90 let (client, connection) = tokio_postgres::connect(&cs, NoTls)
91 .await
92 .map_err(|e| e.to_string())?;
93 tokio::spawn(connection);
94 client
95 }
96 };
97 if stmt_timeout_ms > 0 {
101 client
102 .batch_execute(&format!("SET statement_timeout TO '{stmt_timeout_ms}'"))
103 .await
104 .map_err(|e| e.to_string())?;
105 }
106 if read_only {
109 client
110 .batch_execute("SET default_transaction_read_only = on")
111 .await
112 .map_err(|e| e.to_string())?;
113 }
114 Ok(client)
115 }) as BoxFuture<'static, Result<Client, String>>
116 }) as CreateFn<Client>
117 };
118
119 let validate = Box::new(|client: &Client| !client.is_closed()) as ValidateFn<Client>;
120
121 let lf_config = LFPoolConfig {
122 max_size: config.max_size,
123 create_timeout,
124 wait_timeout: config.queue_timeout,
125 };
126
127 let pool = LockFreePool::new(create, validate, &lf_config);
128
129 let test_conn = pool
131 .acquire()
132 .await
133 .map_err(|e| anyhow::anyhow!("Failed to establish database connection: {e}"))?;
134 drop(test_conn);
135
136 Ok(Self {
137 inner: pool,
138 max_size: config.max_size,
139 })
140 }
141
142 pub async fn acquire(&self) -> MCPResult<PooledConnection<Client>> {
147 self.inner.acquire().await.map_err(|e| match e {
148 PoolError::Timeout => {
149 MCPError::PoolError("Connection pool timeout: no connection available".into())
150 }
151 PoolError::Closed => MCPError::PoolError("Connection pool is closed".into()),
152 PoolError::CreateFailed(msg) => {
153 MCPError::PoolError(format!("Failed to create connection: {msg}"))
154 }
155 })
156 }
157
158 pub fn release(&self, _conn: PooledConnection<Client>) {
163 }
165
166 pub fn active_count(&self) -> u32 {
167 self.inner.status().size
168 }
169
170 pub const fn max_size(&self) -> u32 {
171 self.max_size
172 }
173
174 pub fn is_closed(&self) -> bool {
175 self.inner.is_closed()
176 }
177
178 pub fn close(&self) {
180 self.inner.close();
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use tokio::time::sleep;
188
189 #[test]
190 fn test_config() {
191 let cfg = PoolConfig {
192 min_size: 2,
193 max_size: 10,
194 queue_timeout: Duration::from_secs(10),
195 };
196 assert!(cfg.max_size >= cfg.min_size);
197 }
198
199 #[tokio::test]
200 async fn test_pool_create_and_acquire() {
201 if std::env::var("DATABASE_URL").is_err() && std::env::var("PGHOST").is_err() {
204 eprintln!("Skipping: no database available");
205 return;
206 }
207 let url = std::env::var("DATABASE_URL")
208 .unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/postgres".to_string());
209 let config = PoolConfig {
210 min_size: 1,
211 max_size: 5,
212 queue_timeout: Duration::from_secs(5),
213 };
214 let pool = ConnectionPool::new(&url, config).await.unwrap();
215 assert_eq!(pool.max_size(), 5);
216 let conn = pool.acquire().await.unwrap();
217 assert!(!conn.is_closed());
218 pool.release(conn);
219 sleep(Duration::from_millis(50)).await;
220 assert!(pool.active_count() > 0);
221 }
222}