1use crate::config::DatabaseConfig;
7use crate::db::backend::DatabaseBackend;
8use crate::db::identifier::validate_identifier;
9use crate::error::AppError;
10use serde_json::{Value, json};
11use sqlx::mysql::{MySqlConnectOptions, MySqlPoolOptions, MySqlRow, MySqlSslMode};
12use sqlx::{Executor, MySqlPool, Row};
13use sqlx_to_json::RowExt;
14use std::collections::HashMap;
15use tracing::{error, info};
16
17impl From<&DatabaseConfig> for MySqlConnectOptions {
19 fn from(config: &DatabaseConfig) -> Self {
20 let mut opts = MySqlConnectOptions::new()
21 .host(&config.host)
22 .port(config.port)
23 .username(&config.user);
24
25 if let Some(ref password) = config.password {
26 opts = opts.password(password);
27 }
28 if let Some(ref name) = config.name
29 && !name.is_empty()
30 {
31 opts = opts.database(name);
32 }
33 if let Some(ref charset) = config.charset {
34 opts = opts.charset(charset);
35 }
36
37 if config.ssl {
38 opts = if config.ssl_verify_cert {
39 opts.ssl_mode(MySqlSslMode::VerifyCa)
40 } else {
41 opts.ssl_mode(MySqlSslMode::Required)
42 };
43 if let Some(ref ca) = config.ssl_ca {
44 opts = opts.ssl_ca(ca);
45 }
46 if let Some(ref cert) = config.ssl_cert {
47 opts = opts.ssl_client_cert(cert);
48 }
49 if let Some(ref key) = config.ssl_key {
50 opts = opts.ssl_client_key(key);
51 }
52 }
53
54 opts
55 }
56}
57
58#[derive(Clone)]
60pub struct MysqlBackend {
61 pool: MySqlPool,
62 pub read_only: bool,
63}
64
65impl std::fmt::Debug for MysqlBackend {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("MysqlBackend")
68 .field("read_only", &self.read_only)
69 .finish_non_exhaustive()
70 }
71}
72
73impl MysqlBackend {
74 pub async fn new(config: &DatabaseConfig) -> Result<Self, AppError> {
80 let pool = MySqlPoolOptions::new()
81 .max_connections(config.max_pool_size)
82 .connect_with(config.into())
83 .await
84 .map_err(|e| AppError::Connection(format!("Failed to connect to MySQL: {e}")))?;
85
86 info!("MySQL connection pool initialized (max size: {})", config.max_pool_size);
87
88 let backend = Self {
89 pool,
90 read_only: config.read_only,
91 };
92
93 if config.read_only {
94 backend.warn_if_file_privilege().await;
95 }
96
97 Ok(backend)
98 }
99
100 async fn warn_if_file_privilege(&self) {
101 let result: Result<(), AppError> = async {
102 let current_user: Option<String> = sqlx::query_scalar("SELECT CURRENT_USER()")
103 .fetch_optional(&self.pool)
104 .await
105 .map_err(|e| AppError::Query(e.to_string()))?;
106
107 let Some(current_user) = current_user else {
108 return Ok(());
109 };
110
111 let quoted_user = if let Some((user, host)) = current_user.split_once('@') {
112 format!("'{user}'@'{host}'")
113 } else {
114 format!("'{current_user}'")
115 };
116
117 let grants: Vec<String> = sqlx::query_scalar(&format!("SHOW GRANTS FOR {quoted_user}"))
118 .fetch_all(&self.pool)
119 .await
120 .map_err(|e| AppError::Query(e.to_string()))?;
121
122 let has_file_priv = grants.iter().any(|grant| {
123 let upper = grant.to_uppercase();
124 upper.contains("FILE") && upper.contains("ON *.*")
125 });
126
127 if has_file_priv {
128 error!(
129 "Connected database user has the global FILE privilege. \
130 Revoke FILE for the database user you are connecting as."
131 );
132 }
133
134 Ok(())
135 }
136 .await;
137
138 if let Err(e) = result {
139 tracing::debug!("Unable to determine whether FILE privilege is enabled: {e}");
140 }
141 }
142
143 fn quote_identifier(name: &str) -> String {
147 let escaped = name.replace('`', "``");
148 format!("`{escaped}`")
149 }
150
151 fn quote_string(value: &str) -> String {
155 let escaped = value.replace('\'', "''");
156 format!("'{escaped}'")
157 }
158
159 async fn query_to_json(&self, sql: &str, database: Option<&str>) -> Result<Value, AppError> {
165 let mut conn = self
167 .pool
168 .acquire()
169 .await
170 .map_err(|e| AppError::Connection(e.to_string()))?;
171
172 if let Some(db) = database {
174 validate_identifier(db)?;
175 let use_sql = format!("USE {}", Self::quote_identifier(db));
176 conn.execute(use_sql.as_str())
177 .await
178 .map_err(|e| AppError::Query(e.to_string()))?;
179 }
180
181 let rows: Vec<MySqlRow> = conn.fetch_all(sql).await.map_err(|e| AppError::Query(e.to_string()))?;
182 Ok(Value::Array(rows.iter().map(RowExt::to_json).collect()))
183 }
184}
185
186impl DatabaseBackend for MysqlBackend {
187 async fn list_databases(&self) -> Result<Vec<String>, AppError> {
188 let results = self
189 .query_to_json(
190 "SELECT SCHEMA_NAME AS name FROM information_schema.SCHEMATA ORDER BY SCHEMA_NAME",
191 None,
192 )
193 .await?;
194 let rows = results.as_array().map_or([].as_slice(), Vec::as_slice);
195 Ok(rows
196 .iter()
197 .filter_map(|row| row.get("name").and_then(|v| v.as_str().map(String::from)))
198 .collect())
199 }
200
201 async fn list_tables(&self, database: &str) -> Result<Vec<String>, AppError> {
202 validate_identifier(database)?;
203 let sql = format!(
204 "SELECT TABLE_NAME AS name FROM information_schema.TABLES WHERE TABLE_SCHEMA = {} ORDER BY TABLE_NAME",
205 Self::quote_string(database)
206 );
207 let results = self.query_to_json(&sql, None).await?;
208 let rows = results.as_array().map_or([].as_slice(), Vec::as_slice);
209 Ok(rows
210 .iter()
211 .filter_map(|row| row.get("name").and_then(|v| v.as_str().map(String::from)))
212 .collect())
213 }
214
215 async fn get_table_schema(&self, database: &str, table: &str) -> Result<Value, AppError> {
216 validate_identifier(database)?;
217 validate_identifier(table)?;
218
219 let sql = format!(
220 "DESCRIBE {}.{}",
221 Self::quote_identifier(database),
222 Self::quote_identifier(table)
223 );
224 let results = self.query_to_json(&sql, None).await?;
225 let rows = results.as_array().map_or([].as_slice(), Vec::as_slice);
226
227 if rows.is_empty() {
228 return Err(AppError::TableNotFound(format!("{database}.{table}")));
229 }
230
231 let mut schema: HashMap<String, Value> = HashMap::new();
232 for row in rows {
233 if let Some(col_name) = row.get("Field").and_then(|v| v.as_str()) {
234 schema.insert(
235 col_name.to_string(),
236 json!({
237 "type": row.get("Type").unwrap_or(&Value::Null),
238 "nullable": row.get("Null").and_then(|v| v.as_str()).is_some_and(|s| s.to_uppercase() == "YES"),
239 "key": row.get("Key").unwrap_or(&Value::Null),
240 "default": row.get("Default").unwrap_or(&Value::Null),
241 "extra": row.get("Extra").unwrap_or(&Value::Null),
242 }),
243 );
244 }
245 }
246
247 Ok(json!(schema))
248 }
249
250 async fn get_table_schema_with_relations(&self, database: &str, table: &str) -> Result<Value, AppError> {
251 validate_identifier(database)?;
252 validate_identifier(table)?;
253
254 let describe_sql = format!(
256 "DESCRIBE {}.{}",
257 Self::quote_identifier(database),
258 Self::quote_identifier(table)
259 );
260 let schema_results = self.query_to_json(&describe_sql, None).await?;
261 let schema_rows = schema_results.as_array().map_or([].as_slice(), Vec::as_slice);
262
263 if schema_rows.is_empty() {
264 return Err(AppError::TableNotFound(format!("{database}.{table}")));
265 }
266
267 let mut columns: HashMap<String, Value> = HashMap::new();
268 for row in schema_rows {
269 if let Some(col_name) = row.get("Field").and_then(|v| v.as_str()) {
270 columns.insert(
271 col_name.to_string(),
272 json!({
273 "type": row.get("Type").unwrap_or(&Value::Null),
274 "nullable": row.get("Null").and_then(|v| v.as_str()).is_some_and(|s| s.to_uppercase() == "YES"),
275 "key": row.get("Key").unwrap_or(&Value::Null),
276 "default": row.get("Default").unwrap_or(&Value::Null),
277 "extra": row.get("Extra").unwrap_or(&Value::Null),
278 "foreign_key": null,
279 }),
280 );
281 }
282 }
283
284 let fk_sql = r"
286 SELECT
287 kcu.COLUMN_NAME as column_name,
288 kcu.CONSTRAINT_NAME as constraint_name,
289 kcu.REFERENCED_TABLE_NAME as referenced_table,
290 kcu.REFERENCED_COLUMN_NAME as referenced_column,
291 rc.UPDATE_RULE as on_update,
292 rc.DELETE_RULE as on_delete
293 FROM information_schema.KEY_COLUMN_USAGE kcu
294 INNER JOIN information_schema.REFERENTIAL_CONSTRAINTS rc
295 ON kcu.CONSTRAINT_NAME = rc.CONSTRAINT_NAME
296 AND kcu.CONSTRAINT_SCHEMA = rc.CONSTRAINT_SCHEMA
297 WHERE kcu.TABLE_SCHEMA = ?
298 AND kcu.TABLE_NAME = ?
299 AND kcu.REFERENCED_TABLE_NAME IS NOT NULL
300 ORDER BY kcu.CONSTRAINT_NAME, kcu.ORDINAL_POSITION
301 ";
302
303 let fk_rows: Vec<MySqlRow> = sqlx::query(fk_sql)
304 .bind(database)
305 .bind(table)
306 .fetch_all(&self.pool)
307 .await
308 .map_err(|e| AppError::Query(e.to_string()))?;
309
310 for fk_row in &fk_rows {
311 let col_name: Option<String> = fk_row.try_get("column_name").ok();
312 if let Some(col_name) = col_name
313 && let Some(col_info) = columns.get_mut(&col_name)
314 && let Some(obj) = col_info.as_object_mut()
315 {
316 let constraint_name: Option<String> = fk_row.try_get("constraint_name").ok();
317 let referenced_table: Option<String> = fk_row.try_get("referenced_table").ok();
318 let referenced_column: Option<String> = fk_row.try_get("referenced_column").ok();
319 let on_update: Option<String> = fk_row.try_get("on_update").ok();
320 let on_delete: Option<String> = fk_row.try_get("on_delete").ok();
321 obj.insert(
322 "foreign_key".to_string(),
323 json!({
324 "constraint_name": constraint_name,
325 "referenced_table": referenced_table,
326 "referenced_column": referenced_column,
327 "on_update": on_update,
328 "on_delete": on_delete,
329 }),
330 );
331 }
332 }
333
334 Ok(json!({
335 "table_name": table,
336 "columns": columns,
337 }))
338 }
339
340 async fn execute_query(&self, sql: &str, database: Option<&str>) -> Result<Value, AppError> {
341 self.query_to_json(sql, database).await
342 }
343
344 async fn create_database(&self, name: &str) -> Result<Value, AppError> {
345 if self.read_only {
346 return Err(AppError::ReadOnlyViolation);
347 }
348 validate_identifier(name)?;
349
350 let exists: Option<Vec<u8>> =
352 sqlx::query_scalar("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?")
353 .bind(name)
354 .fetch_optional(&self.pool)
355 .await
356 .map_err(|e| AppError::Query(e.to_string()))?;
357
358 if exists.is_some() {
359 return Ok(json!({
360 "status": "exists",
361 "message": format!("Database '{name}' already exists."),
362 "database_name": name,
363 }));
364 }
365
366 sqlx::query(&format!(
367 "CREATE DATABASE IF NOT EXISTS {}",
368 Self::quote_identifier(name)
369 ))
370 .execute(&self.pool)
371 .await
372 .map_err(|e| AppError::Query(e.to_string()))?;
373
374 Ok(json!({
375 "status": "success",
376 "message": format!("Database '{name}' created successfully."),
377 "database_name": name,
378 }))
379 }
380
381 fn dialect(&self) -> Box<dyn sqlparser::dialect::Dialect> {
382 Box::new(sqlparser::dialect::MySqlDialect {})
383 }
384
385 fn read_only(&self) -> bool {
386 self.read_only
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::config::DatabaseBackend;
394
395 fn base_config() -> DatabaseConfig {
396 DatabaseConfig {
397 backend: DatabaseBackend::Mysql,
398 host: "db.example.com".into(),
399 port: 3307,
400 user: "admin".into(),
401 password: Some("s3cret".into()),
402 name: Some("mydb".into()),
403 ..DatabaseConfig::default()
404 }
405 }
406
407 #[test]
408 fn quote_identifier_wraps_in_backticks() {
409 assert_eq!(MysqlBackend::quote_identifier("users"), "`users`");
410 assert_eq!(MysqlBackend::quote_identifier("eu-docker"), "`eu-docker`");
411 }
412
413 #[test]
414 fn quote_identifier_escapes_backticks() {
415 assert_eq!(MysqlBackend::quote_identifier("test`db"), "`test``db`");
416 assert_eq!(MysqlBackend::quote_identifier("a`b`c"), "`a``b``c`");
417 }
418
419 #[test]
420 fn try_from_basic_config() {
421 let config = base_config();
422 let opts = MySqlConnectOptions::from(&config);
423
424 assert_eq!(opts.get_host(), "db.example.com");
425 assert_eq!(opts.get_port(), 3307);
426 assert_eq!(opts.get_username(), "admin");
427 assert_eq!(opts.get_database(), Some("mydb"));
428 }
429
430 #[test]
431 fn try_from_with_charset() {
432 let config = DatabaseConfig {
433 charset: Some("utf8mb4".into()),
434 ..base_config()
435 };
436 let opts = MySqlConnectOptions::from(&config);
437
438 assert_eq!(opts.get_charset(), "utf8mb4");
439 }
440
441 #[test]
442 fn try_from_with_ssl_required() {
443 let config = DatabaseConfig {
444 ssl: true,
445 ssl_verify_cert: false,
446 ..base_config()
447 };
448 let opts = MySqlConnectOptions::from(&config);
449
450 assert!(
451 matches!(opts.get_ssl_mode(), MySqlSslMode::Required),
452 "expected Required, got {:?}",
453 opts.get_ssl_mode()
454 );
455 }
456
457 #[test]
458 fn try_from_with_ssl_verify_ca() {
459 let config = DatabaseConfig {
460 ssl: true,
461 ssl_verify_cert: true,
462 ..base_config()
463 };
464 let opts = MySqlConnectOptions::from(&config);
465
466 assert!(
467 matches!(opts.get_ssl_mode(), MySqlSslMode::VerifyCa),
468 "expected VerifyCa, got {:?}",
469 opts.get_ssl_mode()
470 );
471 }
472
473 #[test]
474 fn try_from_without_password() {
475 let config = DatabaseConfig {
476 password: None,
477 ..base_config()
478 };
479 let opts = MySqlConnectOptions::from(&config);
480
481 assert_eq!(opts.get_host(), "db.example.com");
483 }
484
485 #[test]
486 fn try_from_without_database_name() {
487 let config = DatabaseConfig {
488 name: None,
489 ..base_config()
490 };
491 let opts = MySqlConnectOptions::from(&config);
492
493 assert_eq!(opts.get_database(), None);
494 }
495}