database_mcp_mysql/
adapter.rs1use database_mcp_config::DatabaseConfig;
7use database_mcp_server::AppError;
8use sqlx::MySqlPool;
9use sqlx::mysql::{MySqlConnectOptions, MySqlPoolOptions, MySqlSslMode};
10use tracing::{error, info};
11
12#[derive(Clone)]
14pub struct MysqlAdapter {
15 pub(crate) config: DatabaseConfig,
16 pub(crate) pool: MySqlPool,
17}
18
19impl std::fmt::Debug for MysqlAdapter {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 f.debug_struct("MysqlAdapter")
22 .field("read_only", &self.config.read_only)
23 .finish_non_exhaustive()
24 }
25}
26
27impl MysqlAdapter {
28 pub async fn new(config: &DatabaseConfig) -> Result<Self, AppError> {
34 let pool = MySqlPoolOptions::new()
35 .max_connections(config.max_pool_size)
36 .connect_with(connect_options(config))
37 .await
38 .map_err(|e| AppError::Connection(format!("Failed to connect to MySQL: {e}")))?;
39
40 info!("MySQL connection pool initialized (max size: {})", config.max_pool_size);
41
42 let backend = Self {
43 config: config.clone(),
44 pool,
45 };
46
47 if backend.config.read_only {
48 backend.warn_if_file_privilege().await;
49 }
50
51 Ok(backend)
52 }
53
54 pub(crate) fn quote_identifier(name: &str) -> String {
56 database_mcp_sql::identifier::quote_identifier(name, '`')
57 }
58
59 pub(crate) fn quote_string(value: &str) -> String {
63 let escaped = value.replace('\'', "''");
64 format!("'{escaped}'")
65 }
66
67 async fn warn_if_file_privilege(&self) {
68 let result: Result<(), AppError> = async {
69 let current_user: Option<String> = sqlx::query_scalar("SELECT CURRENT_USER()")
70 .fetch_optional(&self.pool)
71 .await
72 .map_err(|e| AppError::Query(e.to_string()))?;
73
74 let Some(current_user) = current_user else {
75 return Ok(());
76 };
77
78 let quoted_user = if let Some((user, host)) = current_user.split_once('@') {
79 format!("'{user}'@'{host}'")
80 } else {
81 format!("'{current_user}'")
82 };
83
84 let grants: Vec<String> = sqlx::query_scalar(&format!("SHOW GRANTS FOR {quoted_user}"))
85 .fetch_all(&self.pool)
86 .await
87 .map_err(|e| AppError::Query(e.to_string()))?;
88
89 let has_file_priv = grants.iter().any(|grant| {
90 let upper = grant.to_uppercase();
91 upper.contains("FILE") && upper.contains("ON *.*")
92 });
93
94 if has_file_priv {
95 error!(
96 "Connected database user has the global FILE privilege. \
97 Revoke FILE for the database user you are connecting as."
98 );
99 }
100
101 Ok(())
102 }
103 .await;
104
105 if let Err(e) = result {
106 tracing::debug!("Unable to determine whether FILE privilege is enabled: {e}");
107 }
108 }
109}
110
111fn connect_options(config: &DatabaseConfig) -> MySqlConnectOptions {
113 let mut opts = MySqlConnectOptions::new()
114 .host(&config.host)
115 .port(config.port)
116 .username(&config.user);
117
118 if let Some(ref password) = config.password {
119 opts = opts.password(password);
120 }
121 if let Some(ref name) = config.name
122 && !name.is_empty()
123 {
124 opts = opts.database(name);
125 }
126 if let Some(ref charset) = config.charset {
127 opts = opts.charset(charset);
128 }
129
130 if config.ssl {
131 opts = if config.ssl_verify_cert {
132 opts.ssl_mode(MySqlSslMode::VerifyCa)
133 } else {
134 opts.ssl_mode(MySqlSslMode::Required)
135 };
136 if let Some(ref ca) = config.ssl_ca {
137 opts = opts.ssl_ca(ca);
138 }
139 if let Some(ref cert) = config.ssl_cert {
140 opts = opts.ssl_client_cert(cert);
141 }
142 if let Some(ref key) = config.ssl_key {
143 opts = opts.ssl_client_key(key);
144 }
145 }
146
147 opts
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use database_mcp_config::DatabaseBackend;
154
155 fn base_config() -> DatabaseConfig {
156 DatabaseConfig {
157 backend: DatabaseBackend::Mysql,
158 host: "db.example.com".into(),
159 port: 3307,
160 user: "admin".into(),
161 password: Some("s3cret".into()),
162 name: Some("mydb".into()),
163 ..DatabaseConfig::default()
164 }
165 }
166
167 #[test]
168 fn try_from_basic_config() {
169 let config = base_config();
170 let opts = connect_options(&config);
171
172 assert_eq!(opts.get_host(), "db.example.com");
173 assert_eq!(opts.get_port(), 3307);
174 assert_eq!(opts.get_username(), "admin");
175 assert_eq!(opts.get_database(), Some("mydb"));
176 }
177
178 #[test]
179 fn try_from_with_charset() {
180 let config = DatabaseConfig {
181 charset: Some("utf8mb4".into()),
182 ..base_config()
183 };
184 let opts = connect_options(&config);
185
186 assert_eq!(opts.get_charset(), "utf8mb4");
187 }
188
189 #[test]
190 fn try_from_with_ssl_required() {
191 let config = DatabaseConfig {
192 ssl: true,
193 ssl_verify_cert: false,
194 ..base_config()
195 };
196 let opts = connect_options(&config);
197
198 assert!(
199 matches!(opts.get_ssl_mode(), MySqlSslMode::Required),
200 "expected Required, got {:?}",
201 opts.get_ssl_mode()
202 );
203 }
204
205 #[test]
206 fn try_from_with_ssl_verify_ca() {
207 let config = DatabaseConfig {
208 ssl: true,
209 ssl_verify_cert: true,
210 ..base_config()
211 };
212 let opts = connect_options(&config);
213
214 assert!(
215 matches!(opts.get_ssl_mode(), MySqlSslMode::VerifyCa),
216 "expected VerifyCa, got {:?}",
217 opts.get_ssl_mode()
218 );
219 }
220
221 #[test]
222 fn try_from_without_password() {
223 let config = DatabaseConfig {
224 password: None,
225 ..base_config()
226 };
227 let opts = connect_options(&config);
228
229 assert_eq!(opts.get_host(), "db.example.com");
231 }
232
233 #[test]
234 fn try_from_without_database_name() {
235 let config = DatabaseConfig {
236 name: None,
237 ..base_config()
238 };
239 let opts = connect_options(&config);
240
241 assert_eq!(opts.get_database(), None);
242 }
243}