database_mcp_postgres/
connection.rs1use database_mcp_backend::error::AppError;
4use database_mcp_backend::identifier::validate_identifier;
5use database_mcp_config::DatabaseConfig;
6use moka::future::Cache;
7use sqlx::PgPool;
8use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgSslMode};
9use tracing::info;
10
11const POOL_CACHE_CAPACITY: u64 = 6;
13
14#[derive(Clone)]
19pub struct PostgresBackend {
20 config: DatabaseConfig,
21 default_db: String,
22 pools: Cache<String, PgPool>,
23 pub read_only: bool,
24}
25
26impl std::fmt::Debug for PostgresBackend {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("PostgresBackend")
29 .field("read_only", &self.read_only)
30 .field("default_db", &self.default_db)
31 .finish_non_exhaustive()
32 }
33}
34
35impl PostgresBackend {
36 pub async fn new(config: &DatabaseConfig) -> Result<Self, AppError> {
46 let pool = PgPoolOptions::new()
47 .max_connections(config.max_pool_size)
48 .connect_with(connect_options(config))
49 .await
50 .map_err(|e| AppError::Connection(format!("Failed to connect to PostgreSQL: {e}")))?;
51
52 info!(
53 "PostgreSQL connection pool initialized (max size: {})",
54 config.max_pool_size
55 );
56
57 let default_db = config
59 .name
60 .as_deref()
61 .filter(|n| !n.is_empty())
62 .map_or_else(|| config.user.clone(), String::from);
63
64 let pools = Cache::builder()
65 .max_capacity(POOL_CACHE_CAPACITY)
66 .eviction_listener(|_key, pool: PgPool, _cause| {
67 tokio::spawn(async move {
68 pool.close().await;
69 });
70 })
71 .build();
72
73 pools.insert(default_db.clone(), pool).await;
74
75 Ok(Self {
76 config: config.clone(),
77 default_db,
78 pools,
79 read_only: config.read_only,
80 })
81 }
82
83 pub(crate) fn quote_identifier(name: &str) -> String {
85 database_mcp_backend::identifier::quote_identifier(name, '"')
86 }
87
88 pub(crate) async fn get_pool(&self, database: Option<&str>) -> Result<PgPool, AppError> {
99 let db_key = match database {
100 Some(name) if !name.is_empty() => name,
101 _ => &self.default_db,
102 };
103
104 if let Some(pool) = self.pools.get(db_key).await {
105 return Ok(pool);
106 }
107
108 validate_identifier(db_key)?;
110
111 let config = self.config.clone();
112 let db_key_owned = db_key.to_owned();
113
114 let pool = self
115 .pools
116 .try_get_with(db_key_owned, async {
117 let mut cfg = config;
118 cfg.name = Some(db_key.to_owned());
119 PgPoolOptions::new()
120 .max_connections(cfg.max_pool_size)
121 .connect_with(connect_options(&cfg))
122 .await
123 .map_err(|e| {
124 AppError::Connection(format!("Failed to connect to PostgreSQL database '{db_key}': {e}"))
125 })
126 })
127 .await
128 .map_err(|e| match e.as_ref() {
129 AppError::Connection(msg) => AppError::Connection(msg.clone()),
130 other => AppError::Connection(other.to_string()),
131 })?;
132
133 Ok(pool)
134 }
135}
136
137fn connect_options(config: &DatabaseConfig) -> PgConnectOptions {
143 let mut opts = PgConnectOptions::new_without_pgpass()
144 .host(&config.host)
145 .port(config.port)
146 .username(&config.user);
147
148 if let Some(ref password) = config.password {
149 opts = opts.password(password);
150 }
151 if let Some(ref name) = config.name
152 && !name.is_empty()
153 {
154 opts = opts.database(name);
155 }
156
157 if config.ssl {
158 opts = if config.ssl_verify_cert {
159 opts.ssl_mode(PgSslMode::VerifyCa)
160 } else {
161 opts.ssl_mode(PgSslMode::Require)
162 };
163 if let Some(ref ca) = config.ssl_ca {
164 opts = opts.ssl_root_cert(ca);
165 }
166 if let Some(ref cert) = config.ssl_cert {
167 opts = opts.ssl_client_cert(cert);
168 }
169 if let Some(ref key) = config.ssl_key {
170 opts = opts.ssl_client_key(key);
171 }
172 }
173
174 opts
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use database_mcp_config::DatabaseBackend;
181
182 fn base_config() -> DatabaseConfig {
183 DatabaseConfig {
184 backend: DatabaseBackend::Postgres,
185 host: "pg.example.com".into(),
186 port: 5433,
187 user: "pgadmin".into(),
188 password: Some("pgpass".into()),
189 name: Some("mydb".into()),
190 ..DatabaseConfig::default()
191 }
192 }
193
194 #[test]
195 fn try_from_basic_config() {
196 let config = base_config();
197 let opts = connect_options(&config);
198
199 assert_eq!(opts.get_host(), "pg.example.com");
200 assert_eq!(opts.get_port(), 5433);
201 assert_eq!(opts.get_username(), "pgadmin");
202 assert_eq!(opts.get_database(), Some("mydb"));
203 }
204
205 #[test]
206 fn try_from_with_ssl_require() {
207 let config = DatabaseConfig {
208 ssl: true,
209 ssl_verify_cert: false,
210 ..base_config()
211 };
212 let opts = connect_options(&config);
213
214 assert!(
215 matches!(opts.get_ssl_mode(), PgSslMode::Require),
216 "expected Require, got {:?}",
217 opts.get_ssl_mode()
218 );
219 }
220
221 #[test]
222 fn try_from_with_ssl_verify_ca() {
223 let config = DatabaseConfig {
224 ssl: true,
225 ssl_verify_cert: true,
226 ..base_config()
227 };
228 let opts = connect_options(&config);
229
230 assert!(
231 matches!(opts.get_ssl_mode(), PgSslMode::VerifyCa),
232 "expected VerifyCa, got {:?}",
233 opts.get_ssl_mode()
234 );
235 }
236
237 #[test]
238 fn try_from_without_database_name() {
239 let config = DatabaseConfig {
240 name: None,
241 ..base_config()
242 };
243 let opts = connect_options(&config);
244
245 assert_eq!(opts.get_database(), None);
246 }
247
248 #[test]
249 fn try_from_without_password() {
250 let config = DatabaseConfig {
251 password: None,
252 ..base_config()
253 };
254 let opts = connect_options(&config);
255
256 assert_eq!(opts.get_host(), "pg.example.com");
257 }
258}