Skip to main content

database_mcp/db/
mysql.rs

1//! MySQL/MariaDB backend implementation via sqlx.
2//!
3//! Implements [`DatabaseBackend`] for `MySQL` and `MariaDB` databases
4//! using sqlx's `MySqlPool`.
5
6use crate::config::DatabaseConfig;
7use crate::db::backend::DatabaseBackend;
8use crate::db::identifier::validate_identifier;
9use crate::error::AppError;
10use serde_json::{Value, json};
11use sqlx::mysql::{MySqlConnectOptions, MySqlPoolOptions, MySqlRow, MySqlSslMode};
12use sqlx::{Executor, MySqlPool, Row};
13use sqlx_to_json::RowExt;
14use std::collections::HashMap;
15use tracing::{error, info};
16
17/// Converts [`DatabaseConfig`] into [`MySqlConnectOptions`].
18impl From<&DatabaseConfig> for MySqlConnectOptions {
19    fn from(config: &DatabaseConfig) -> Self {
20        let mut opts = MySqlConnectOptions::new()
21            .host(&config.host)
22            .port(config.port)
23            .username(&config.user);
24
25        if let Some(ref password) = config.password {
26            opts = opts.password(password);
27        }
28        if let Some(ref name) = config.name
29            && !name.is_empty()
30        {
31            opts = opts.database(name);
32        }
33        if let Some(ref charset) = config.charset {
34            opts = opts.charset(charset);
35        }
36
37        if config.ssl {
38            opts = if config.ssl_verify_cert {
39                opts.ssl_mode(MySqlSslMode::VerifyCa)
40            } else {
41                opts.ssl_mode(MySqlSslMode::Required)
42            };
43            if let Some(ref ca) = config.ssl_ca {
44                opts = opts.ssl_ca(ca);
45            }
46            if let Some(ref cert) = config.ssl_cert {
47                opts = opts.ssl_client_cert(cert);
48            }
49            if let Some(ref key) = config.ssl_key {
50                opts = opts.ssl_client_key(key);
51            }
52        }
53
54        opts
55    }
56}
57
58/// MySQL/MariaDB database backend.
59#[derive(Clone)]
60pub struct MysqlBackend {
61    pool: MySqlPool,
62    pub read_only: bool,
63}
64
65impl std::fmt::Debug for MysqlBackend {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        f.debug_struct("MysqlBackend")
68            .field("read_only", &self.read_only)
69            .finish_non_exhaustive()
70    }
71}
72
73impl MysqlBackend {
74    /// Creates a new `MySQL` backend from configuration.
75    ///
76    /// # Errors
77    ///
78    /// Returns [`AppError::Connection`] if the connection fails.
79    pub async fn new(config: &DatabaseConfig) -> Result<Self, AppError> {
80        let pool = MySqlPoolOptions::new()
81            .max_connections(config.max_pool_size)
82            .connect_with(config.into())
83            .await
84            .map_err(|e| AppError::Connection(format!("Failed to connect to MySQL: {e}")))?;
85
86        info!("MySQL connection pool initialized (max size: {})", config.max_pool_size);
87
88        let backend = Self {
89            pool,
90            read_only: config.read_only,
91        };
92
93        if config.read_only {
94            backend.warn_if_file_privilege().await;
95        }
96
97        Ok(backend)
98    }
99
100    async fn warn_if_file_privilege(&self) {
101        let result: Result<(), AppError> = async {
102            let current_user: Option<String> = sqlx::query_scalar("SELECT CURRENT_USER()")
103                .fetch_optional(&self.pool)
104                .await
105                .map_err(|e| AppError::Query(e.to_string()))?;
106
107            let Some(current_user) = current_user else {
108                return Ok(());
109            };
110
111            let quoted_user = if let Some((user, host)) = current_user.split_once('@') {
112                format!("'{user}'@'{host}'")
113            } else {
114                format!("'{current_user}'")
115            };
116
117            let grants: Vec<String> = sqlx::query_scalar(&format!("SHOW GRANTS FOR {quoted_user}"))
118                .fetch_all(&self.pool)
119                .await
120                .map_err(|e| AppError::Query(e.to_string()))?;
121
122            let has_file_priv = grants.iter().any(|grant| {
123                let upper = grant.to_uppercase();
124                upper.contains("FILE") && upper.contains("ON *.*")
125            });
126
127            if has_file_priv {
128                error!(
129                    "Connected database user has the global FILE privilege. \
130                     Revoke FILE for the database user you are connecting as."
131                );
132            }
133
134            Ok(())
135        }
136        .await;
137
138        if let Err(e) = result {
139            tracing::debug!("Unable to determine whether FILE privilege is enabled: {e}");
140        }
141    }
142
143    /// Wraps `name` in backticks for safe use in `MySQL` SQL statements.
144    ///
145    /// Escapes internal backticks by doubling them.
146    fn quote_identifier(name: &str) -> String {
147        let escaped = name.replace('`', "``");
148        format!("`{escaped}`")
149    }
150
151    /// Wraps a value in single quotes for use as a SQL string literal.
152    ///
153    /// Escapes internal single quotes by doubling them.
154    fn quote_string(value: &str) -> String {
155        let escaped = value.replace('\'', "''");
156        format!("'{escaped}'")
157    }
158
159    /// Executes raw SQL and converts rows to JSON maps.
160    ///
161    /// Uses the text protocol via `Executor::fetch_all(&str)` instead of prepared
162    /// statements, because `MySQL` 9+ doesn't support SHOW commands as prepared
163    /// statements, and the text protocol returns all values as strings.
164    async fn query_to_json(&self, sql: &str, database: Option<&str>) -> Result<Value, AppError> {
165        // Acquire a single connection so USE and the query run on the same session
166        let mut conn = self
167            .pool
168            .acquire()
169            .await
170            .map_err(|e| AppError::Connection(e.to_string()))?;
171
172        // Switch database if needed
173        if let Some(db) = database {
174            validate_identifier(db)?;
175            let use_sql = format!("USE {}", Self::quote_identifier(db));
176            conn.execute(use_sql.as_str())
177                .await
178                .map_err(|e| AppError::Query(e.to_string()))?;
179        }
180
181        let rows: Vec<MySqlRow> = conn.fetch_all(sql).await.map_err(|e| AppError::Query(e.to_string()))?;
182        Ok(Value::Array(rows.iter().map(RowExt::to_json).collect()))
183    }
184}
185
186impl DatabaseBackend for MysqlBackend {
187    async fn list_databases(&self) -> Result<Vec<String>, AppError> {
188        let results = self
189            .query_to_json(
190                "SELECT SCHEMA_NAME AS name FROM information_schema.SCHEMATA ORDER BY SCHEMA_NAME",
191                None,
192            )
193            .await?;
194        let rows = results.as_array().map_or([].as_slice(), Vec::as_slice);
195        Ok(rows
196            .iter()
197            .filter_map(|row| row.get("name").and_then(|v| v.as_str().map(String::from)))
198            .collect())
199    }
200
201    async fn list_tables(&self, database: &str) -> Result<Vec<String>, AppError> {
202        validate_identifier(database)?;
203        let sql = format!(
204            "SELECT TABLE_NAME AS name FROM information_schema.TABLES WHERE TABLE_SCHEMA = {} ORDER BY TABLE_NAME",
205            Self::quote_string(database)
206        );
207        let results = self.query_to_json(&sql, None).await?;
208        let rows = results.as_array().map_or([].as_slice(), Vec::as_slice);
209        Ok(rows
210            .iter()
211            .filter_map(|row| row.get("name").and_then(|v| v.as_str().map(String::from)))
212            .collect())
213    }
214
215    async fn get_table_schema(&self, database: &str, table: &str) -> Result<Value, AppError> {
216        validate_identifier(database)?;
217        validate_identifier(table)?;
218
219        let sql = format!(
220            "DESCRIBE {}.{}",
221            Self::quote_identifier(database),
222            Self::quote_identifier(table)
223        );
224        let results = self.query_to_json(&sql, None).await?;
225        let rows = results.as_array().map_or([].as_slice(), Vec::as_slice);
226
227        if rows.is_empty() {
228            return Err(AppError::TableNotFound(format!("{database}.{table}")));
229        }
230
231        let mut schema: HashMap<String, Value> = HashMap::new();
232        for row in rows {
233            if let Some(col_name) = row.get("Field").and_then(|v| v.as_str()) {
234                schema.insert(
235                    col_name.to_string(),
236                    json!({
237                        "type": row.get("Type").unwrap_or(&Value::Null),
238                        "nullable": row.get("Null").and_then(|v| v.as_str()).is_some_and(|s| s.to_uppercase() == "YES"),
239                        "key": row.get("Key").unwrap_or(&Value::Null),
240                        "default": row.get("Default").unwrap_or(&Value::Null),
241                        "extra": row.get("Extra").unwrap_or(&Value::Null),
242                    }),
243                );
244            }
245        }
246
247        Ok(json!(schema))
248    }
249
250    async fn get_table_schema_with_relations(&self, database: &str, table: &str) -> Result<Value, AppError> {
251        validate_identifier(database)?;
252        validate_identifier(table)?;
253
254        // 1. Get basic schema
255        let describe_sql = format!(
256            "DESCRIBE {}.{}",
257            Self::quote_identifier(database),
258            Self::quote_identifier(table)
259        );
260        let schema_results = self.query_to_json(&describe_sql, None).await?;
261        let schema_rows = schema_results.as_array().map_or([].as_slice(), Vec::as_slice);
262
263        if schema_rows.is_empty() {
264            return Err(AppError::TableNotFound(format!("{database}.{table}")));
265        }
266
267        let mut columns: HashMap<String, Value> = HashMap::new();
268        for row in schema_rows {
269            if let Some(col_name) = row.get("Field").and_then(|v| v.as_str()) {
270                columns.insert(
271                    col_name.to_string(),
272                    json!({
273                        "type": row.get("Type").unwrap_or(&Value::Null),
274                        "nullable": row.get("Null").and_then(|v| v.as_str()).is_some_and(|s| s.to_uppercase() == "YES"),
275                        "key": row.get("Key").unwrap_or(&Value::Null),
276                        "default": row.get("Default").unwrap_or(&Value::Null),
277                        "extra": row.get("Extra").unwrap_or(&Value::Null),
278                        "foreign_key": null,
279                    }),
280                );
281            }
282        }
283
284        // 2. Get FK relationships
285        let fk_sql = r"
286            SELECT
287                kcu.COLUMN_NAME as column_name,
288                kcu.CONSTRAINT_NAME as constraint_name,
289                kcu.REFERENCED_TABLE_NAME as referenced_table,
290                kcu.REFERENCED_COLUMN_NAME as referenced_column,
291                rc.UPDATE_RULE as on_update,
292                rc.DELETE_RULE as on_delete
293            FROM information_schema.KEY_COLUMN_USAGE kcu
294            INNER JOIN information_schema.REFERENTIAL_CONSTRAINTS rc
295                ON kcu.CONSTRAINT_NAME = rc.CONSTRAINT_NAME
296                AND kcu.CONSTRAINT_SCHEMA = rc.CONSTRAINT_SCHEMA
297            WHERE kcu.TABLE_SCHEMA = ?
298              AND kcu.TABLE_NAME = ?
299              AND kcu.REFERENCED_TABLE_NAME IS NOT NULL
300            ORDER BY kcu.CONSTRAINT_NAME, kcu.ORDINAL_POSITION
301        ";
302
303        let fk_rows: Vec<MySqlRow> = sqlx::query(fk_sql)
304            .bind(database)
305            .bind(table)
306            .fetch_all(&self.pool)
307            .await
308            .map_err(|e| AppError::Query(e.to_string()))?;
309
310        for fk_row in &fk_rows {
311            let col_name: Option<String> = fk_row.try_get("column_name").ok();
312            if let Some(col_name) = col_name
313                && let Some(col_info) = columns.get_mut(&col_name)
314                && let Some(obj) = col_info.as_object_mut()
315            {
316                let constraint_name: Option<String> = fk_row.try_get("constraint_name").ok();
317                let referenced_table: Option<String> = fk_row.try_get("referenced_table").ok();
318                let referenced_column: Option<String> = fk_row.try_get("referenced_column").ok();
319                let on_update: Option<String> = fk_row.try_get("on_update").ok();
320                let on_delete: Option<String> = fk_row.try_get("on_delete").ok();
321                obj.insert(
322                    "foreign_key".to_string(),
323                    json!({
324                        "constraint_name": constraint_name,
325                        "referenced_table": referenced_table,
326                        "referenced_column": referenced_column,
327                        "on_update": on_update,
328                        "on_delete": on_delete,
329                    }),
330                );
331            }
332        }
333
334        Ok(json!({
335            "table_name": table,
336            "columns": columns,
337        }))
338    }
339
340    async fn execute_query(&self, sql: &str, database: Option<&str>) -> Result<Value, AppError> {
341        self.query_to_json(sql, database).await
342    }
343
344    async fn create_database(&self, name: &str) -> Result<Value, AppError> {
345        if self.read_only {
346            return Err(AppError::ReadOnlyViolation);
347        }
348        validate_identifier(name)?;
349
350        // Check existence — use Vec<u8> because MySQL 9 returns BINARY columns
351        let exists: Option<Vec<u8>> =
352            sqlx::query_scalar("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?")
353                .bind(name)
354                .fetch_optional(&self.pool)
355                .await
356                .map_err(|e| AppError::Query(e.to_string()))?;
357
358        if exists.is_some() {
359            return Ok(json!({
360                "status": "exists",
361                "message": format!("Database '{name}' already exists."),
362                "database_name": name,
363            }));
364        }
365
366        sqlx::query(&format!(
367            "CREATE DATABASE IF NOT EXISTS {}",
368            Self::quote_identifier(name)
369        ))
370        .execute(&self.pool)
371        .await
372        .map_err(|e| AppError::Query(e.to_string()))?;
373
374        Ok(json!({
375            "status": "success",
376            "message": format!("Database '{name}' created successfully."),
377            "database_name": name,
378        }))
379    }
380
381    fn dialect(&self) -> Box<dyn sqlparser::dialect::Dialect> {
382        Box::new(sqlparser::dialect::MySqlDialect {})
383    }
384
385    fn read_only(&self) -> bool {
386        self.read_only
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use crate::config::DatabaseBackend;
394
395    fn base_config() -> DatabaseConfig {
396        DatabaseConfig {
397            backend: DatabaseBackend::Mysql,
398            host: "db.example.com".into(),
399            port: 3307,
400            user: "admin".into(),
401            password: Some("s3cret".into()),
402            name: Some("mydb".into()),
403            ..DatabaseConfig::default()
404        }
405    }
406
407    #[test]
408    fn quote_identifier_wraps_in_backticks() {
409        assert_eq!(MysqlBackend::quote_identifier("users"), "`users`");
410        assert_eq!(MysqlBackend::quote_identifier("eu-docker"), "`eu-docker`");
411    }
412
413    #[test]
414    fn quote_identifier_escapes_backticks() {
415        assert_eq!(MysqlBackend::quote_identifier("test`db"), "`test``db`");
416        assert_eq!(MysqlBackend::quote_identifier("a`b`c"), "`a``b``c`");
417    }
418
419    #[test]
420    fn try_from_basic_config() {
421        let config = base_config();
422        let opts = MySqlConnectOptions::from(&config);
423
424        assert_eq!(opts.get_host(), "db.example.com");
425        assert_eq!(opts.get_port(), 3307);
426        assert_eq!(opts.get_username(), "admin");
427        assert_eq!(opts.get_database(), Some("mydb"));
428    }
429
430    #[test]
431    fn try_from_with_charset() {
432        let config = DatabaseConfig {
433            charset: Some("utf8mb4".into()),
434            ..base_config()
435        };
436        let opts = MySqlConnectOptions::from(&config);
437
438        assert_eq!(opts.get_charset(), "utf8mb4");
439    }
440
441    #[test]
442    fn try_from_with_ssl_required() {
443        let config = DatabaseConfig {
444            ssl: true,
445            ssl_verify_cert: false,
446            ..base_config()
447        };
448        let opts = MySqlConnectOptions::from(&config);
449
450        assert!(
451            matches!(opts.get_ssl_mode(), MySqlSslMode::Required),
452            "expected Required, got {:?}",
453            opts.get_ssl_mode()
454        );
455    }
456
457    #[test]
458    fn try_from_with_ssl_verify_ca() {
459        let config = DatabaseConfig {
460            ssl: true,
461            ssl_verify_cert: true,
462            ..base_config()
463        };
464        let opts = MySqlConnectOptions::from(&config);
465
466        assert!(
467            matches!(opts.get_ssl_mode(), MySqlSslMode::VerifyCa),
468            "expected VerifyCa, got {:?}",
469            opts.get_ssl_mode()
470        );
471    }
472
473    #[test]
474    fn try_from_without_password() {
475        let config = DatabaseConfig {
476            password: None,
477            ..base_config()
478        };
479        let opts = MySqlConnectOptions::from(&config);
480
481        // Should not panic — password is simply omitted
482        assert_eq!(opts.get_host(), "db.example.com");
483    }
484
485    #[test]
486    fn try_from_without_database_name() {
487        let config = DatabaseConfig {
488            name: None,
489            ..base_config()
490        };
491        let opts = MySqlConnectOptions::from(&config);
492
493        assert_eq!(opts.get_database(), None);
494    }
495}