1use crate::config::DatabaseConfig;
8use crate::db::backend::DatabaseBackend;
9use crate::db::identifier::validate_identifier;
10use crate::error::AppError;
11use moka::future::Cache;
12use serde_json::{Value, json};
13use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgRow, PgSslMode};
14use sqlx::{PgPool, Row};
15use sqlx_to_json::RowExt;
16use std::collections::HashMap;
17use tracing::info;
18
19const POOL_CACHE_CAPACITY: u64 = 6;
21
22impl From<&DatabaseConfig> for PgConnectOptions {
28 fn from(config: &DatabaseConfig) -> Self {
29 let mut opts = PgConnectOptions::new_without_pgpass()
30 .host(&config.host)
31 .port(config.port)
32 .username(&config.user);
33
34 if let Some(ref password) = config.password {
35 opts = opts.password(password);
36 }
37 if let Some(ref name) = config.name
38 && !name.is_empty()
39 {
40 opts = opts.database(name);
41 }
42
43 if config.ssl {
44 opts = if config.ssl_verify_cert {
45 opts.ssl_mode(PgSslMode::VerifyCa)
46 } else {
47 opts.ssl_mode(PgSslMode::Require)
48 };
49 if let Some(ref ca) = config.ssl_ca {
50 opts = opts.ssl_root_cert(ca);
51 }
52 if let Some(ref cert) = config.ssl_cert {
53 opts = opts.ssl_client_cert(cert);
54 }
55 if let Some(ref key) = config.ssl_key {
56 opts = opts.ssl_client_key(key);
57 }
58 }
59
60 opts
61 }
62}
63
64#[derive(Clone)]
69pub struct PostgresBackend {
70 config: DatabaseConfig,
71 default_db: String,
72 pools: Cache<String, PgPool>,
73 pub read_only: bool,
74}
75
76impl std::fmt::Debug for PostgresBackend {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 f.debug_struct("PostgresBackend")
79 .field("read_only", &self.read_only)
80 .field("default_db", &self.default_db)
81 .finish_non_exhaustive()
82 }
83}
84
85impl PostgresBackend {
86 pub async fn new(config: &DatabaseConfig) -> Result<Self, AppError> {
96 let pool = PgPoolOptions::new()
97 .max_connections(config.max_pool_size)
98 .connect_with(config.into())
99 .await
100 .map_err(|e| AppError::Connection(format!("Failed to connect to PostgreSQL: {e}")))?;
101
102 info!(
103 "PostgreSQL connection pool initialized (max size: {})",
104 config.max_pool_size
105 );
106
107 let default_db = config
109 .name
110 .as_deref()
111 .filter(|n| !n.is_empty())
112 .map_or_else(|| config.user.clone(), String::from);
113
114 let pools = Cache::builder()
115 .max_capacity(POOL_CACHE_CAPACITY)
116 .eviction_listener(|_key, pool: PgPool, _cause| {
117 tokio::spawn(async move {
118 pool.close().await;
119 });
120 })
121 .build();
122
123 pools.insert(default_db.clone(), pool).await;
124
125 Ok(Self {
126 config: config.clone(),
127 default_db,
128 pools,
129 read_only: config.read_only,
130 })
131 }
132}
133
134impl PostgresBackend {
135 fn quote_identifier(name: &str) -> String {
139 let escaped = name.replace('"', "\"\"");
140 format!("\"{escaped}\"")
141 }
142
143 async fn get_pool(&self, database: Option<&str>) -> Result<PgPool, AppError> {
154 let db_key = match database {
155 Some(name) if !name.is_empty() => name,
156 _ => &self.default_db,
157 };
158
159 if let Some(pool) = self.pools.get(db_key).await {
160 return Ok(pool);
161 }
162
163 validate_identifier(db_key)?;
165
166 let config = self.config.clone();
167 let db_key_owned = db_key.to_owned();
168
169 let pool = self
170 .pools
171 .try_get_with(db_key_owned, async {
172 let mut cfg = config;
173 cfg.name = Some(db_key.to_owned());
174 PgPoolOptions::new()
175 .max_connections(cfg.max_pool_size)
176 .connect_with((&cfg).into())
177 .await
178 .map_err(|e| {
179 AppError::Connection(format!("Failed to connect to PostgreSQL database '{db_key}': {e}"))
180 })
181 })
182 .await
183 .map_err(|e| match e.as_ref() {
184 AppError::Connection(msg) => AppError::Connection(msg.clone()),
185 other => AppError::Connection(other.to_string()),
186 })?;
187
188 Ok(pool)
189 }
190}
191
192impl DatabaseBackend for PostgresBackend {
193 async fn list_databases(&self) -> Result<Vec<String>, AppError> {
197 let pool = self.get_pool(None).await?;
198 let rows: Vec<(String,)> =
199 sqlx::query_as("SELECT datname FROM pg_database WHERE datistemplate = false ORDER BY datname")
200 .fetch_all(&pool)
201 .await
202 .map_err(|e| AppError::Query(e.to_string()))?;
203 Ok(rows.into_iter().map(|r| r.0).collect())
204 }
205
206 async fn list_tables(&self, database: &str) -> Result<Vec<String>, AppError> {
207 let db = if database.is_empty() { None } else { Some(database) };
208 let pool = self.get_pool(db).await?;
209 let rows: Vec<(String,)> =
210 sqlx::query_as("SELECT tablename FROM pg_tables WHERE schemaname = 'public' ORDER BY tablename")
211 .fetch_all(&pool)
212 .await
213 .map_err(|e| AppError::Query(e.to_string()))?;
214 Ok(rows.into_iter().map(|r| r.0).collect())
215 }
216
217 async fn get_table_schema(&self, database: &str, table: &str) -> Result<Value, AppError> {
218 validate_identifier(table)?;
219 let db = if database.is_empty() { None } else { Some(database) };
220 let pool = self.get_pool(db).await?;
221 let rows: Vec<PgRow> = sqlx::query(
222 r"SELECT column_name, data_type, is_nullable, column_default,
223 character_maximum_length
224 FROM information_schema.columns
225 WHERE table_schema = 'public' AND table_name = $1
226 ORDER BY ordinal_position",
227 )
228 .bind(table)
229 .fetch_all(&pool)
230 .await
231 .map_err(|e| AppError::Query(e.to_string()))?;
232
233 if rows.is_empty() {
234 return Err(AppError::TableNotFound(table.to_string()));
235 }
236
237 let mut schema: HashMap<String, Value> = HashMap::new();
238 for row in &rows {
239 let col_name: String = row.try_get("column_name").unwrap_or_default();
240 let data_type: String = row.try_get("data_type").unwrap_or_default();
241 let nullable: String = row.try_get("is_nullable").unwrap_or_default();
242 let default: Option<String> = row.try_get("column_default").ok();
243 schema.insert(
244 col_name,
245 json!({
246 "type": data_type,
247 "nullable": nullable.to_uppercase() == "YES",
248 "key": Value::Null,
249 "default": default,
250 "extra": Value::Null,
251 }),
252 );
253 }
254 Ok(json!(schema))
255 }
256
257 async fn get_table_schema_with_relations(&self, database: &str, table: &str) -> Result<Value, AppError> {
258 let schema = self.get_table_schema(database, table).await?;
259 let mut columns: HashMap<String, Value> = serde_json::from_value(schema).unwrap_or_default();
260
261 for col in columns.values_mut() {
263 if let Some(obj) = col.as_object_mut() {
264 obj.entry("foreign_key".to_string()).or_insert(Value::Null);
265 }
266 }
267
268 let db = if database.is_empty() { None } else { Some(database) };
270 let pool = self.get_pool(db).await?;
271 let fk_rows: Vec<PgRow> = sqlx::query(
272 r"SELECT
273 kcu.column_name,
274 tc.constraint_name,
275 ccu.table_name AS referenced_table,
276 ccu.column_name AS referenced_column,
277 rc.update_rule AS on_update,
278 rc.delete_rule AS on_delete
279 FROM information_schema.table_constraints tc
280 JOIN information_schema.key_column_usage kcu
281 ON tc.constraint_name = kcu.constraint_name
282 AND tc.table_schema = kcu.table_schema
283 JOIN information_schema.constraint_column_usage ccu
284 ON ccu.constraint_name = tc.constraint_name
285 AND ccu.table_schema = tc.table_schema
286 JOIN information_schema.referential_constraints rc
287 ON rc.constraint_name = tc.constraint_name
288 AND rc.constraint_schema = tc.table_schema
289 WHERE tc.constraint_type = 'FOREIGN KEY'
290 AND tc.table_name = $1
291 AND tc.table_schema = 'public'",
292 )
293 .bind(table)
294 .fetch_all(&pool)
295 .await
296 .map_err(|e| AppError::Query(e.to_string()))?;
297
298 for fk_row in &fk_rows {
299 let col_name: String = fk_row.try_get("column_name").unwrap_or_default();
300 if let Some(col_info) = columns.get_mut(&col_name)
301 && let Some(obj) = col_info.as_object_mut()
302 {
303 obj.insert(
304 "foreign_key".to_string(),
305 json!({
306 "constraint_name": fk_row.try_get::<String, _>("constraint_name").ok(),
307 "referenced_table": fk_row.try_get::<String, _>("referenced_table").ok(),
308 "referenced_column": fk_row.try_get::<String, _>("referenced_column").ok(),
309 "on_update": fk_row.try_get::<String, _>("on_update").ok(),
310 "on_delete": fk_row.try_get::<String, _>("on_delete").ok(),
311 }),
312 );
313 }
314 }
315
316 Ok(json!({
317 "table_name": table,
318 "columns": columns,
319 }))
320 }
321
322 async fn execute_query(&self, sql: &str, database: Option<&str>) -> Result<Value, AppError> {
323 let pool = self.get_pool(database).await?;
324 let rows: Vec<PgRow> = sqlx::query(sql)
325 .fetch_all(&pool)
326 .await
327 .map_err(|e| AppError::Query(e.to_string()))?;
328 Ok(Value::Array(rows.iter().map(RowExt::to_json).collect()))
329 }
330
331 async fn create_database(&self, name: &str) -> Result<Value, AppError> {
332 if self.read_only {
333 return Err(AppError::ReadOnlyViolation);
334 }
335 validate_identifier(name)?;
336
337 let pool = self.get_pool(None).await?;
338
339 sqlx::query(&format!("CREATE DATABASE {}", Self::quote_identifier(name)))
341 .execute(&pool)
342 .await
343 .map_err(|e| {
344 let msg = e.to_string();
345 if msg.contains("already exists") {
346 return AppError::Query(format!("Database '{name}' already exists."));
347 }
348 AppError::Query(msg)
349 })?;
350
351 Ok(json!({
352 "status": "success",
353 "message": format!("Database '{name}' created successfully."),
354 "database_name": name,
355 }))
356 }
357
358 fn dialect(&self) -> Box<dyn sqlparser::dialect::Dialect> {
359 Box::new(sqlparser::dialect::PostgreSqlDialect {})
360 }
361
362 fn read_only(&self) -> bool {
363 self.read_only
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::config::DatabaseBackend;
371
372 fn base_config() -> DatabaseConfig {
373 DatabaseConfig {
374 backend: DatabaseBackend::Postgres,
375 host: "pg.example.com".into(),
376 port: 5433,
377 user: "pgadmin".into(),
378 password: Some("pgpass".into()),
379 name: Some("mydb".into()),
380 ..DatabaseConfig::default()
381 }
382 }
383
384 #[test]
385 fn quote_identifier_wraps_in_double_quotes() {
386 assert_eq!(PostgresBackend::quote_identifier("users"), "\"users\"");
387 assert_eq!(PostgresBackend::quote_identifier("eu-docker"), "\"eu-docker\"");
388 }
389
390 #[test]
391 fn quote_identifier_escapes_double_quotes() {
392 assert_eq!(PostgresBackend::quote_identifier("test\"db"), "\"test\"\"db\"");
393 assert_eq!(PostgresBackend::quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
394 }
395
396 #[test]
397 fn try_from_basic_config() {
398 let config = base_config();
399 let opts = PgConnectOptions::from(&config);
400
401 assert_eq!(opts.get_host(), "pg.example.com");
402 assert_eq!(opts.get_port(), 5433);
403 assert_eq!(opts.get_username(), "pgadmin");
404 assert_eq!(opts.get_database(), Some("mydb"));
405 }
406
407 #[test]
408 fn try_from_with_ssl_require() {
409 let config = DatabaseConfig {
410 ssl: true,
411 ssl_verify_cert: false,
412 ..base_config()
413 };
414 let opts = PgConnectOptions::from(&config);
415
416 assert!(
417 matches!(opts.get_ssl_mode(), PgSslMode::Require),
418 "expected Require, got {:?}",
419 opts.get_ssl_mode()
420 );
421 }
422
423 #[test]
424 fn try_from_with_ssl_verify_ca() {
425 let config = DatabaseConfig {
426 ssl: true,
427 ssl_verify_cert: true,
428 ..base_config()
429 };
430 let opts = PgConnectOptions::from(&config);
431
432 assert!(
433 matches!(opts.get_ssl_mode(), PgSslMode::VerifyCa),
434 "expected VerifyCa, got {:?}",
435 opts.get_ssl_mode()
436 );
437 }
438
439 #[test]
440 fn try_from_without_database_name() {
441 let config = DatabaseConfig {
442 name: None,
443 ..base_config()
444 };
445 let opts = PgConnectOptions::from(&config);
446
447 assert_eq!(opts.get_database(), None);
448 }
449
450 #[test]
451 fn try_from_without_password() {
452 let config = DatabaseConfig {
453 password: None,
454 ..base_config()
455 };
456 let opts = PgConnectOptions::from(&config);
457
458 assert_eq!(opts.get_host(), "pg.example.com");
459 }
460}