database_mcp_postgres/
adapter.rs1use database_mcp_config::DatabaseConfig;
4use database_mcp_server::AppError;
5use database_mcp_sql::identifier::validate_identifier;
6use moka::future::Cache;
7use sqlx::PgPool;
8use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgSslMode};
9use tracing::info;
10
11const POOL_CACHE_CAPACITY: u64 = 6;
13
14#[derive(Clone)]
19pub struct PostgresAdapter {
20 pub(crate) config: DatabaseConfig,
21 default_db: String,
22 pools: Cache<String, PgPool>,
23}
24
25impl std::fmt::Debug for PostgresAdapter {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.debug_struct("PostgresAdapter")
28 .field("read_only", &self.config.read_only)
29 .field("default_db", &self.default_db)
30 .finish_non_exhaustive()
31 }
32}
33
34impl PostgresAdapter {
35 pub async fn new(config: &DatabaseConfig) -> Result<Self, AppError> {
45 let pool = PgPoolOptions::new()
46 .max_connections(config.max_pool_size)
47 .connect_with(connect_options(config))
48 .await
49 .map_err(|e| AppError::Connection(format!("Failed to connect to PostgreSQL: {e}")))?;
50
51 info!(
52 "PostgreSQL connection pool initialized (max size: {})",
53 config.max_pool_size
54 );
55
56 let default_db = config
58 .name
59 .as_deref()
60 .filter(|n| !n.is_empty())
61 .map_or_else(|| config.user.clone(), String::from);
62
63 let pools = Cache::builder()
64 .max_capacity(POOL_CACHE_CAPACITY)
65 .eviction_listener(|_key, pool: PgPool, _cause| {
66 tokio::spawn(async move {
67 pool.close().await;
68 });
69 })
70 .build();
71
72 pools.insert(default_db.clone(), pool).await;
73
74 Ok(Self {
75 config: config.clone(),
76 default_db,
77 pools,
78 })
79 }
80
81 pub(crate) fn quote_identifier(name: &str) -> String {
83 database_mcp_sql::identifier::quote_identifier(name, '"')
84 }
85
86 pub(crate) async fn get_pool(&self, database: Option<&str>) -> Result<PgPool, AppError> {
97 let db_key = match database {
98 Some(name) if !name.is_empty() => name,
99 _ => &self.default_db,
100 };
101
102 if let Some(pool) = self.pools.get(db_key).await {
103 return Ok(pool);
104 }
105
106 validate_identifier(db_key)?;
108
109 let config = self.config.clone();
110 let db_key_owned = db_key.to_owned();
111
112 let pool = self
113 .pools
114 .try_get_with(db_key_owned, async {
115 let mut cfg = config;
116 cfg.name = Some(db_key.to_owned());
117 PgPoolOptions::new()
118 .max_connections(cfg.max_pool_size)
119 .connect_with(connect_options(&cfg))
120 .await
121 .map_err(|e| {
122 AppError::Connection(format!("Failed to connect to PostgreSQL database '{db_key}': {e}"))
123 })
124 })
125 .await
126 .map_err(|e| match e.as_ref() {
127 AppError::Connection(msg) => AppError::Connection(msg.clone()),
128 other => AppError::Connection(other.to_string()),
129 })?;
130
131 Ok(pool)
132 }
133}
134
135fn connect_options(config: &DatabaseConfig) -> PgConnectOptions {
141 let mut opts = PgConnectOptions::new_without_pgpass()
142 .host(&config.host)
143 .port(config.port)
144 .username(&config.user);
145
146 if let Some(ref password) = config.password {
147 opts = opts.password(password);
148 }
149 if let Some(ref name) = config.name
150 && !name.is_empty()
151 {
152 opts = opts.database(name);
153 }
154
155 if config.ssl {
156 opts = if config.ssl_verify_cert {
157 opts.ssl_mode(PgSslMode::VerifyCa)
158 } else {
159 opts.ssl_mode(PgSslMode::Require)
160 };
161 if let Some(ref ca) = config.ssl_ca {
162 opts = opts.ssl_root_cert(ca);
163 }
164 if let Some(ref cert) = config.ssl_cert {
165 opts = opts.ssl_client_cert(cert);
166 }
167 if let Some(ref key) = config.ssl_key {
168 opts = opts.ssl_client_key(key);
169 }
170 }
171
172 opts
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use database_mcp_config::DatabaseBackend;
179
180 fn base_config() -> DatabaseConfig {
181 DatabaseConfig {
182 backend: DatabaseBackend::Postgres,
183 host: "pg.example.com".into(),
184 port: 5433,
185 user: "pgadmin".into(),
186 password: Some("pgpass".into()),
187 name: Some("mydb".into()),
188 ..DatabaseConfig::default()
189 }
190 }
191
192 #[test]
193 fn try_from_basic_config() {
194 let config = base_config();
195 let opts = connect_options(&config);
196
197 assert_eq!(opts.get_host(), "pg.example.com");
198 assert_eq!(opts.get_port(), 5433);
199 assert_eq!(opts.get_username(), "pgadmin");
200 assert_eq!(opts.get_database(), Some("mydb"));
201 }
202
203 #[test]
204 fn try_from_with_ssl_require() {
205 let config = DatabaseConfig {
206 ssl: true,
207 ssl_verify_cert: false,
208 ..base_config()
209 };
210 let opts = connect_options(&config);
211
212 assert!(
213 matches!(opts.get_ssl_mode(), PgSslMode::Require),
214 "expected Require, got {:?}",
215 opts.get_ssl_mode()
216 );
217 }
218
219 #[test]
220 fn try_from_with_ssl_verify_ca() {
221 let config = DatabaseConfig {
222 ssl: true,
223 ssl_verify_cert: true,
224 ..base_config()
225 };
226 let opts = connect_options(&config);
227
228 assert!(
229 matches!(opts.get_ssl_mode(), PgSslMode::VerifyCa),
230 "expected VerifyCa, got {:?}",
231 opts.get_ssl_mode()
232 );
233 }
234
235 #[test]
236 fn try_from_without_database_name() {
237 let config = DatabaseConfig {
238 name: None,
239 ..base_config()
240 };
241 let opts = connect_options(&config);
242
243 assert_eq!(opts.get_database(), None);
244 }
245
246 #[test]
247 fn try_from_without_password() {
248 let config = DatabaseConfig {
249 password: None,
250 ..base_config()
251 };
252 let opts = connect_options(&config);
253
254 assert_eq!(opts.get_host(), "pg.example.com");
255 }
256}