#![allow(clippy::doc_markdown)]
use async_trait::async_trait;
use serde_json::{json, Map, Value};
use sqlx::mysql::{MySqlArguments, MySqlPool, MySqlRow};
use sqlx::query::Query;
use sqlx::{Column, MySql, Row, TypeInfo};
use pmcp_server_toolkit::sql::{
translate_placeholders, ConnectorError, Dialect, SqlConnector, TranslatedSql,
};
pub mod dev_mock;
pub struct MysqlConnector {
pool: MySqlPool,
database: String,
}
fn sanitize_url(url: &str) -> String {
let Some(scheme_end) = url.find("://") else {
return url.to_string();
};
let authority_start = scheme_end + 3;
let rest = &url[authority_start..];
let authority_end = rest.find('/').map_or(rest.len(), |i| i);
let authority = &rest[..authority_end];
let Some(at) = authority.find('@') else {
return url.to_string();
};
let userinfo = &authority[..at];
let Some(colon) = userinfo.find(':') else {
return url.to_string();
};
let user = &userinfo[..colon];
let after_at = &authority[at..];
format!(
"{}{}:***{}{}",
&url[..authority_start],
user,
after_at,
&rest[authority_end..]
)
}
fn parse_database_from_url(url: &str) -> Option<String> {
let scheme_end = url.find("://")?;
let rest = &url[scheme_end + 3..];
let slash = rest.find('/')?;
let after = &rest[slash + 1..];
let db = after.split(['?', '/']).next().unwrap_or("");
if db.is_empty() {
None
} else {
Some(db.to_string())
}
}
impl MysqlConnector {
pub async fn connect(url: &str) -> Result<Self, ConnectorError> {
let database = parse_database_from_url(url).unwrap_or_default();
let pool = MySqlPool::connect_lazy(url).map_err(|e| {
ConnectorError::Connection(format!("mysql url ({}): {e}", sanitize_url(url)))
})?;
Ok(Self { pool, database })
}
}
fn bind_one<'q>(
q: Query<'q, MySql, MySqlArguments>,
v: &Value,
) -> Query<'q, MySql, MySqlArguments> {
match v {
Value::Null => q.bind(None::<&str>),
Value::Bool(b) => q.bind(*b),
Value::Number(n) if n.is_i64() => q.bind(n.as_i64().unwrap_or(0)),
Value::Number(n) => q.bind(n.as_f64().unwrap_or(0.0)),
Value::String(s) => q.bind(s.clone()),
arr_or_obj => q.bind(serde_json::to_string(arr_or_obj).unwrap_or_default()),
}
}
fn column_to_value(row: &MySqlRow, idx: usize, type_name: &str) -> Value {
match type_name {
"BIGINT" | "INT" | "MEDIUMINT" | "SMALLINT" | "TINYINT" => row
.try_get::<Option<i64>, _>(idx)
.ok()
.flatten()
.map_or(Value::Null, |i| json!(i)),
"DOUBLE" | "FLOAT" => row
.try_get::<Option<f64>, _>(idx)
.ok()
.flatten()
.map_or(Value::Null, |f| json!(f)),
"BOOLEAN" | "BOOL" => row
.try_get::<Option<bool>, _>(idx)
.ok()
.flatten()
.map_or(Value::Null, |b| json!(b)),
_ => row
.try_get::<Option<String>, _>(idx)
.ok()
.flatten()
.map_or(Value::Null, |s| json!(s)),
}
}
fn row_to_value(row: &MySqlRow) -> Value {
let mut obj = Map::new();
for (idx, col) in row.columns().iter().enumerate() {
obj.insert(
col.name().to_string(),
column_to_value(row, idx, col.type_info().name()),
);
}
Value::Object(obj)
}
fn schema_col(row: &MySqlRow, name: &str) -> String {
row.try_get::<String, _>(name).unwrap_or_default()
}
fn format_information_schema_as_ddl(rows: &[MySqlRow]) -> String {
let mut out = String::new();
let mut current_table: Option<String> = None;
for row in rows {
let table = schema_col(row, "table_name");
let column = schema_col(row, "column_name");
let data_type = schema_col(row, "data_type");
let is_nullable = schema_col(row, "is_nullable");
if current_table.as_deref() != Some(table.as_str()) {
if current_table.is_some() {
out.push_str(") ENGINE=InnoDB;\n");
}
out.push_str(&format!("CREATE TABLE `{table}` (\n"));
current_table = Some(table);
}
let not_null = if is_nullable == "NO" { " NOT NULL" } else { "" };
out.push_str(&format!(" `{column}` {data_type}{not_null}\n"));
}
if current_table.is_some() {
out.push_str(") ENGINE=InnoDB;\n");
}
out
}
#[async_trait]
impl SqlConnector for MysqlConnector {
fn dialect(&self) -> Dialect {
Dialect::MySql
}
async fn execute(
&self,
sql: &str,
params: &[(String, Value)],
) -> Result<Vec<Value>, ConnectorError> {
let TranslatedSql {
sql: translated,
ordered_params,
} = translate_placeholders(sql, Dialect::MySql);
let mut q = sqlx::query(&translated);
for name in &ordered_params {
let v = params
.iter()
.find(|(k, _)| k == name)
.map_or(Value::Null, |(_, v)| v.clone());
q = bind_one(q, &v);
}
let rows = q
.fetch_all(&self.pool)
.await
.map_err(|e| ConnectorError::Query(e.to_string()))?;
Ok(rows.iter().map(row_to_value).collect())
}
async fn schema_text(&self) -> Result<String, ConnectorError> {
let rows = sqlx::query(
"SELECT table_name, column_name, data_type, is_nullable \
FROM information_schema.columns WHERE table_schema = ? \
ORDER BY table_name, ordinal_position",
)
.bind(&self.database)
.fetch_all(&self.pool)
.await
.map_err(|e| ConnectorError::Schema(e.to_string()))?;
Ok(format_information_schema_as_ddl(&rows))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sanitize_url_redacts_password() {
assert_eq!(
sanitize_url("mysql://user:secret@host/db"),
"mysql://user:***@host/db"
);
}
#[test]
fn test_sanitize_url_without_password_unchanged() {
assert_eq!(
sanitize_url("mysql://host/db"),
"mysql://host/db",
"no userinfo → unchanged"
);
assert_eq!(
sanitize_url("mysql://user@host/db"),
"mysql://user@host/db",
"user without password → unchanged"
);
}
#[test]
fn test_parse_database_from_url() {
assert_eq!(
parse_database_from_url("mysql://localhost/mydb"),
Some("mydb".to_string())
);
assert_eq!(
parse_database_from_url("mysql://user:pw@host:3306/shop?ssl=true"),
Some("shop".to_string()),
"query string and port are stripped"
);
assert_eq!(
parse_database_from_url("mysql://localhost/"),
None,
"empty database segment → None"
);
assert_eq!(
parse_database_from_url("not a url"),
None,
"no scheme → None"
);
}
#[test]
fn test_bind_one_dispatch() {
for v in [
Value::Null,
json!(true),
json!(42_i64),
json!(2.5_f64),
json!("hello"),
json!([1, 2, 3]),
json!({"k": "v"}),
] {
let _ = bind_one(sqlx::query("SELECT ?"), &v);
}
}
#[tokio::test]
async fn test_connect_lazy_returns_ok_without_network() {
let result = MysqlConnector::connect("mysql://localhost/db").await;
assert!(
result.is_ok(),
"connect_lazy must return Ok without a reachable server"
);
}
#[tokio::test]
async fn test_connect_invalid_url_returns_err() {
match MysqlConnector::connect("not a url").await {
Err(ConnectorError::Connection(msg)) => {
assert!(
!msg.contains("password"),
"error text must not contain the literal 'password'; got: {msg:?}"
);
},
Err(other) => panic!("expected ConnectorError::Connection, got {other:?}"),
Ok(_) => panic!("malformed URL must error"),
}
}
#[tokio::test]
async fn test_connect_invalid_url_does_not_echo_password() {
match MysqlConnector::connect("mysql://u:hunter2@@@bad url/db").await {
Err(ConnectorError::Connection(msg)) => {
assert!(
!msg.contains("hunter2"),
"error text must not echo the password; got: {msg:?}"
);
},
Err(other) => panic!("expected ConnectorError::Connection, got {other:?}"),
Ok(_) => {},
}
}
}