Skip to main content

database_mcp_mysql/
adapter.rs

1//! MySQL/MariaDB adapter definition and connection configuration.
2//!
3//! Builds [`MySqlConnectOptions`] from a [`DatabaseConfig`] and checks
4//! for dangerous server privileges on startup.
5
6use database_mcp_config::DatabaseConfig;
7use database_mcp_server::AppError;
8use sqlx::MySqlPool;
9use sqlx::mysql::{MySqlConnectOptions, MySqlPoolOptions, MySqlSslMode};
10use tracing::{error, info};
11
12/// MySQL/MariaDB database adapter.
13#[derive(Clone)]
14pub struct MysqlAdapter {
15    pub(crate) config: DatabaseConfig,
16    pub(crate) pool: MySqlPool,
17}
18
19impl std::fmt::Debug for MysqlAdapter {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        f.debug_struct("MysqlAdapter")
22            .field("read_only", &self.config.read_only)
23            .finish_non_exhaustive()
24    }
25}
26
27impl MysqlAdapter {
28    /// Creates a new `MySQL` adapter from configuration.
29    ///
30    /// # Errors
31    ///
32    /// Returns [`AppError::Connection`] if the connection fails.
33    pub async fn new(config: &DatabaseConfig) -> Result<Self, AppError> {
34        let pool = MySqlPoolOptions::new()
35            .max_connections(config.max_pool_size)
36            .connect_with(connect_options(config))
37            .await
38            .map_err(|e| AppError::Connection(format!("Failed to connect to MySQL: {e}")))?;
39
40        info!("MySQL connection pool initialized (max size: {})", config.max_pool_size);
41
42        let backend = Self {
43            config: config.clone(),
44            pool,
45        };
46
47        if backend.config.read_only {
48            backend.warn_if_file_privilege().await;
49        }
50
51        Ok(backend)
52    }
53
54    /// Wraps `name` in backticks for safe use in `MySQL` SQL statements.
55    pub(crate) fn quote_identifier(name: &str) -> String {
56        database_mcp_sql::identifier::quote_identifier(name, '`')
57    }
58
59    /// Wraps a value in single quotes for use as a SQL string literal.
60    ///
61    /// Escapes internal single quotes by doubling them.
62    pub(crate) fn quote_string(value: &str) -> String {
63        let escaped = value.replace('\'', "''");
64        format!("'{escaped}'")
65    }
66
67    async fn warn_if_file_privilege(&self) {
68        let result: Result<(), AppError> = async {
69            let current_user: Option<String> = sqlx::query_scalar("SELECT CURRENT_USER()")
70                .fetch_optional(&self.pool)
71                .await
72                .map_err(|e| AppError::Query(e.to_string()))?;
73
74            let Some(current_user) = current_user else {
75                return Ok(());
76            };
77
78            let quoted_user = if let Some((user, host)) = current_user.split_once('@') {
79                format!("'{user}'@'{host}'")
80            } else {
81                format!("'{current_user}'")
82            };
83
84            let grants: Vec<String> = sqlx::query_scalar(&format!("SHOW GRANTS FOR {quoted_user}"))
85                .fetch_all(&self.pool)
86                .await
87                .map_err(|e| AppError::Query(e.to_string()))?;
88
89            let has_file_priv = grants.iter().any(|grant| {
90                let upper = grant.to_uppercase();
91                upper.contains("FILE") && upper.contains("ON *.*")
92            });
93
94            if has_file_priv {
95                error!(
96                    "Connected database user has the global FILE privilege. \
97                     Revoke FILE for the database user you are connecting as."
98                );
99            }
100
101            Ok(())
102        }
103        .await;
104
105        if let Err(e) = result {
106            tracing::debug!("Unable to determine whether FILE privilege is enabled: {e}");
107        }
108    }
109}
110
111/// Builds [`MySqlConnectOptions`] from a [`DatabaseConfig`].
112fn connect_options(config: &DatabaseConfig) -> MySqlConnectOptions {
113    let mut opts = MySqlConnectOptions::new()
114        .host(&config.host)
115        .port(config.port)
116        .username(&config.user);
117
118    if let Some(ref password) = config.password {
119        opts = opts.password(password);
120    }
121    if let Some(ref name) = config.name
122        && !name.is_empty()
123    {
124        opts = opts.database(name);
125    }
126    if let Some(ref charset) = config.charset {
127        opts = opts.charset(charset);
128    }
129
130    if config.ssl {
131        opts = if config.ssl_verify_cert {
132            opts.ssl_mode(MySqlSslMode::VerifyCa)
133        } else {
134            opts.ssl_mode(MySqlSslMode::Required)
135        };
136        if let Some(ref ca) = config.ssl_ca {
137            opts = opts.ssl_ca(ca);
138        }
139        if let Some(ref cert) = config.ssl_cert {
140            opts = opts.ssl_client_cert(cert);
141        }
142        if let Some(ref key) = config.ssl_key {
143            opts = opts.ssl_client_key(key);
144        }
145    }
146
147    opts
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use database_mcp_config::DatabaseBackend;
154
155    fn base_config() -> DatabaseConfig {
156        DatabaseConfig {
157            backend: DatabaseBackend::Mysql,
158            host: "db.example.com".into(),
159            port: 3307,
160            user: "admin".into(),
161            password: Some("s3cret".into()),
162            name: Some("mydb".into()),
163            ..DatabaseConfig::default()
164        }
165    }
166
167    #[test]
168    fn try_from_basic_config() {
169        let config = base_config();
170        let opts = connect_options(&config);
171
172        assert_eq!(opts.get_host(), "db.example.com");
173        assert_eq!(opts.get_port(), 3307);
174        assert_eq!(opts.get_username(), "admin");
175        assert_eq!(opts.get_database(), Some("mydb"));
176    }
177
178    #[test]
179    fn try_from_with_charset() {
180        let config = DatabaseConfig {
181            charset: Some("utf8mb4".into()),
182            ..base_config()
183        };
184        let opts = connect_options(&config);
185
186        assert_eq!(opts.get_charset(), "utf8mb4");
187    }
188
189    #[test]
190    fn try_from_with_ssl_required() {
191        let config = DatabaseConfig {
192            ssl: true,
193            ssl_verify_cert: false,
194            ..base_config()
195        };
196        let opts = connect_options(&config);
197
198        assert!(
199            matches!(opts.get_ssl_mode(), MySqlSslMode::Required),
200            "expected Required, got {:?}",
201            opts.get_ssl_mode()
202        );
203    }
204
205    #[test]
206    fn try_from_with_ssl_verify_ca() {
207        let config = DatabaseConfig {
208            ssl: true,
209            ssl_verify_cert: true,
210            ..base_config()
211        };
212        let opts = connect_options(&config);
213
214        assert!(
215            matches!(opts.get_ssl_mode(), MySqlSslMode::VerifyCa),
216            "expected VerifyCa, got {:?}",
217            opts.get_ssl_mode()
218        );
219    }
220
221    #[test]
222    fn try_from_without_password() {
223        let config = DatabaseConfig {
224            password: None,
225            ..base_config()
226        };
227        let opts = connect_options(&config);
228
229        // Should not panic — password is simply omitted
230        assert_eq!(opts.get_host(), "db.example.com");
231    }
232
233    #[test]
234    fn try_from_without_database_name() {
235        let config = DatabaseConfig {
236            name: None,
237            ..base_config()
238        };
239        let opts = connect_options(&config);
240
241        assert_eq!(opts.get_database(), None);
242    }
243}