use std::fmt::Write;
use async_trait::async_trait;
use fraiseql_error::{FraiseQLError, Result};
use sqlx::{
Column, Row, TypeInfo,
mysql::{MySqlPool, MySqlPoolOptions, MySqlRow},
};
use super::where_generator::MySqlWhereGenerator;
use crate::{
dialect::MySqlDialect,
identifier::quote_mysql_identifier,
order_by::append_order_by,
traits::{
CursorValue, DatabaseAdapter, RelayDatabaseAdapter, RelayPageResult, SupportsMutations,
},
types::{
DatabaseType, JsonbValue, PoolMetrics,
sql_hints::{OrderByClause, OrderDirection},
},
where_clause::WhereClause,
};
#[derive(Clone)]
pub struct MySqlAdapter {
pool: MySqlPool,
}
impl MySqlAdapter {
pub async fn new(connection_string: &str) -> Result<Self> {
Self::with_pool_size(connection_string, 10).await
}
pub async fn with_pool_config(
connection_string: &str,
min_size: u32,
max_size: u32,
) -> Result<Self> {
let pool = MySqlPoolOptions::new()
.min_connections(min_size)
.max_connections(max_size)
.connect(connection_string)
.await
.map_err(|e| FraiseQLError::ConnectionPool {
message: format!("Failed to create MySQL connection pool: {e}"),
})?;
Ok(Self { pool })
}
pub async fn with_pool_size(connection_string: &str, max_size: u32) -> Result<Self> {
let pool = MySqlPoolOptions::new()
.max_connections(max_size)
.connect(connection_string)
.await
.map_err(|e| FraiseQLError::ConnectionPool {
message: format!("Failed to create MySQL connection pool: {e}"),
})?;
sqlx::query("SELECT 1")
.fetch_one(&pool)
.await
.map_err(|e| FraiseQLError::Database {
message: format!("Failed to connect to MySQL database: {e}"),
sql_state: None,
})?;
Ok(Self { pool })
}
async fn execute_raw(
&self,
sql: &str,
params: Vec<serde_json::Value>,
) -> Result<Vec<JsonbValue>> {
let mut query = sqlx::query(sql);
for param in ¶ms {
query = match param {
serde_json::Value::String(s) => query.bind(s.clone()),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
query.bind(i)
} else if let Some(f) = n.as_f64() {
query.bind(f)
} else {
query.bind(n.to_string())
}
},
serde_json::Value::Bool(b) => query.bind(*b),
serde_json::Value::Null => query.bind(Option::<String>::None),
serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
query.bind(param.to_string())
},
};
}
let rows: Vec<MySqlRow> = query.fetch_all(&self.pool).await.map_err(|e| {
let sql_state = if let sqlx::Error::Database(ref db_err) = e {
db_err.code().and_then(|c| c.parse::<u16>().ok()).and_then(map_mysql_error_code)
} else {
None
};
FraiseQLError::Database {
message: format!("MySQL query execution failed: {e}"),
sql_state,
}
})?;
let results = rows
.into_iter()
.map(|row| {
let data: serde_json::Value =
row.try_get("data").unwrap_or(serde_json::Value::Null);
JsonbValue::new(data)
})
.collect();
Ok(results)
}
}
fn mysql_escape_json_value(v: &serde_json::Value) -> String {
match v {
serde_json::Value::String(s) => format!("'{}'", s.replace('\'', "''")),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::Bool(b) => if *b { "1" } else { "0" }.to_string(),
serde_json::Value::Null => "NULL".to_string(),
serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
format!("'{}'", v.to_string().replace('\'', "''"))
},
}
}
#[async_trait]
impl DatabaseAdapter for MySqlAdapter {
async fn execute_with_projection(
&self,
view: &str,
projection: Option<&crate::types::SqlProjectionHint>,
where_clause: Option<&WhereClause>,
limit: Option<u32>,
offset: Option<u32>,
order_by: Option<&[OrderByClause]>,
) -> Result<Vec<JsonbValue>> {
if projection.is_none() {
return self.execute_where_query(view, where_clause, limit, offset, order_by).await;
}
let projection = projection.expect("projection is Some; None case returned above");
let mut sql = format!(
"SELECT {} FROM {}",
projection.projection_template,
quote_mysql_identifier(view)
);
let params: Vec<serde_json::Value> = if let Some(clause) = where_clause {
let generator = super::where_generator::MySqlWhereGenerator::new(MySqlDialect);
let (where_sql, where_params) = generator.generate(clause)?;
sql.push_str(" WHERE ");
sql.push_str(&where_sql);
where_params
} else {
Vec::new()
};
append_order_by(&mut sql, order_by, DatabaseType::MySQL)?;
match (limit, offset) {
(Some(lim), Some(off)) => {
write!(sql, " LIMIT {lim} OFFSET {off}").expect("write to String");
},
(Some(lim), None) => {
write!(sql, " LIMIT {lim}").expect("write to String");
},
(None, Some(off)) => {
write!(sql, " LIMIT 18446744073709551615 OFFSET {off}").expect("write to String");
},
(None, None) => {},
}
self.execute_raw(&sql, params).await
}
async fn execute_where_query(
&self,
view: &str,
where_clause: Option<&WhereClause>,
limit: Option<u32>,
offset: Option<u32>,
order_by: Option<&[OrderByClause]>,
) -> Result<Vec<JsonbValue>> {
let mut sql = format!("SELECT data FROM {}", quote_mysql_identifier(view));
let mut params: Vec<serde_json::Value> = if let Some(clause) = where_clause {
let generator = MySqlWhereGenerator::new(MySqlDialect);
let (where_sql, where_params) = generator.generate(clause)?;
sql.push_str(" WHERE ");
sql.push_str(&where_sql);
where_params
} else {
Vec::new()
};
append_order_by(&mut sql, order_by, DatabaseType::MySQL)?;
match (limit, offset) {
(Some(lim), Some(off)) => {
sql.push_str(" LIMIT ? OFFSET ?");
params.push(serde_json::Value::Number(lim.into()));
params.push(serde_json::Value::Number(off.into()));
},
(Some(lim), None) => {
sql.push_str(" LIMIT ?");
params.push(serde_json::Value::Number(lim.into()));
},
(None, Some(off)) => {
sql.push_str(" LIMIT 18446744073709551615 OFFSET ?");
params.push(serde_json::Value::Number(off.into()));
},
(None, None) => {},
}
self.execute_raw(&sql, params).await
}
fn database_type(&self) -> DatabaseType {
DatabaseType::MySQL
}
async fn health_check(&self) -> Result<()> {
sqlx::query("SELECT 1").fetch_one(&self.pool).await.map_err(|e| {
FraiseQLError::Database {
message: format!("MySQL health check failed: {e}"),
sql_state: None,
}
})?;
Ok(())
}
#[allow(clippy::cast_possible_truncation)] fn pool_metrics(&self) -> PoolMetrics {
let size = self.pool.size();
let idle = self.pool.num_idle();
PoolMetrics {
total_connections: size,
idle_connections: idle as u32,
active_connections: size - idle as u32,
waiting_requests: 0, }
}
async fn execute_raw_query(
&self,
sql: &str,
) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
let rows: Vec<MySqlRow> = sqlx::query(sql).fetch_all(&self.pool).await.map_err(|e| {
let sql_state = if let sqlx::Error::Database(ref db_err) = e {
db_err.code().map(|c| c.into_owned())
} else {
None
};
FraiseQLError::Database {
message: format!("MySQL query execution failed: {e}"),
sql_state,
}
})?;
let results: Vec<std::collections::HashMap<String, serde_json::Value>> = rows
.into_iter()
.map(|row| {
let mut map = std::collections::HashMap::new();
for column in row.columns() {
let col = column.name().to_string();
let type_name = column.type_info().name();
let value = match type_name {
"BOOLEAN" | "BIT" => row
.try_get::<bool, _>(col.as_str())
.map(|v| serde_json::json!(v))
.unwrap_or(serde_json::Value::Null),
"TINYINT(1)" => row
.try_get::<bool, _>(col.as_str())
.map(|v| serde_json::json!(v))
.unwrap_or(serde_json::Value::Null),
"BIGINT UNSIGNED" => row
.try_get::<u64, _>(col.as_str())
.map(|v| serde_json::json!(v))
.unwrap_or(serde_json::Value::Null),
"BIGINT" | "INT" | "INT UNSIGNED" | "MEDIUMINT" | "MEDIUMINT UNSIGNED"
| "SMALLINT" | "SMALLINT UNSIGNED" | "TINYINT" | "TINYINT UNSIGNED" => row
.try_get::<i64, _>(col.as_str())
.map(|v| serde_json::json!(v))
.unwrap_or(serde_json::Value::Null),
"DOUBLE" | "FLOAT" => row
.try_get::<f64, _>(col.as_str())
.map(|v| serde_json::json!(v))
.unwrap_or(serde_json::Value::Null),
"NEWDECIMAL" | "DECIMAL" => row
.try_get::<String, _>(col.as_str())
.map(|v| {
serde_json::from_str(&v).unwrap_or_else(|_| serde_json::json!(v))
})
.unwrap_or(serde_json::Value::Null),
"JSON" => row
.try_get::<serde_json::Value, _>(col.as_str())
.unwrap_or(serde_json::Value::Null),
_ => row
.try_get::<String, _>(col.as_str())
.map(|v| {
serde_json::from_str(&v).unwrap_or_else(|_| serde_json::json!(v))
})
.unwrap_or(serde_json::Value::Null),
};
map.insert(col, value);
}
map
})
.collect();
Ok(results)
}
async fn execute_parameterized_aggregate(
&self,
sql: &str,
params: &[serde_json::Value],
) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
let mut query = sqlx::query(sql);
for param in params {
query = match param {
serde_json::Value::String(s) => query.bind(s.clone()),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
query.bind(i)
} else if let Some(f) = n.as_f64() {
query.bind(f)
} else {
query.bind(n.to_string())
}
},
serde_json::Value::Bool(b) => query.bind(*b),
serde_json::Value::Null => query.bind(Option::<String>::None),
serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
query.bind(param.to_string())
},
};
}
let rows: Vec<MySqlRow> = query.fetch_all(&self.pool).await.map_err(|e| {
let sql_state = if let sqlx::Error::Database(ref db_err) = e {
db_err.code().map(|c| c.into_owned())
} else {
None
};
FraiseQLError::Database {
message: format!("MySQL parameterized aggregate query failed: {e}"),
sql_state,
}
})?;
let results = rows
.into_iter()
.map(|row| {
let mut map = std::collections::HashMap::new();
for column in row.columns() {
let col = column.name().to_string();
let type_name = column.type_info().name();
let value = match type_name {
"BOOLEAN" | "BIT" | "TINYINT(1)" => row
.try_get::<bool, _>(col.as_str())
.map(|v| serde_json::json!(v))
.unwrap_or(serde_json::Value::Null),
"BIGINT UNSIGNED" => row
.try_get::<u64, _>(col.as_str())
.map(|v| serde_json::json!(v))
.unwrap_or(serde_json::Value::Null),
"BIGINT" | "INT" | "INT UNSIGNED" | "MEDIUMINT" | "MEDIUMINT UNSIGNED"
| "SMALLINT" | "SMALLINT UNSIGNED" | "TINYINT" | "TINYINT UNSIGNED" => row
.try_get::<i64, _>(col.as_str())
.map(|v| serde_json::json!(v))
.unwrap_or(serde_json::Value::Null),
"DOUBLE" | "FLOAT" => row
.try_get::<f64, _>(col.as_str())
.map(|v| serde_json::json!(v))
.unwrap_or(serde_json::Value::Null),
"NEWDECIMAL" | "DECIMAL" => row
.try_get::<String, _>(col.as_str())
.map(|v| {
serde_json::from_str(&v).unwrap_or_else(|_| serde_json::json!(v))
})
.unwrap_or(serde_json::Value::Null),
"JSON" => row
.try_get::<serde_json::Value, _>(col.as_str())
.unwrap_or(serde_json::Value::Null),
_ => row
.try_get::<String, _>(col.as_str())
.map(|v| {
serde_json::from_str(&v).unwrap_or_else(|_| serde_json::json!(v))
})
.unwrap_or(serde_json::Value::Null),
};
map.insert(col, value);
}
map
})
.collect();
Ok(results)
}
async fn execute_function_call(
&self,
function_name: &str,
args: &[serde_json::Value],
) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
let escaped: Vec<String> = args.iter().map(mysql_escape_json_value).collect();
let call_sql =
format!("CALL {}({})", quote_mysql_identifier(function_name), escaped.join(", "));
let rows: Vec<MySqlRow> =
sqlx::raw_sql(&call_sql).fetch_all(&self.pool).await.map_err(|e| {
let sql_state = if let sqlx::Error::Database(ref db_err) = e {
db_err.code().map(|c| c.into_owned())
} else {
None
};
FraiseQLError::Database {
message: format!("MySQL stored procedure call failed ({function_name}): {e}"),
sql_state,
}
})?;
let results = rows
.into_iter()
.map(|row| {
let mut map = std::collections::HashMap::new();
for column in row.columns() {
let col = column.name().to_string();
let type_name = column.type_info().name();
let value = match type_name {
"BOOLEAN" | "BIT" | "TINYINT(1)" => row
.try_get::<bool, _>(col.as_str())
.map(|v| serde_json::json!(v))
.unwrap_or(serde_json::Value::Null),
"BIGINT UNSIGNED" => row
.try_get::<u64, _>(col.as_str())
.map(|v| serde_json::json!(v))
.unwrap_or(serde_json::Value::Null),
"BIGINT" | "INT" | "INT UNSIGNED" | "MEDIUMINT" | "MEDIUMINT UNSIGNED"
| "SMALLINT" | "SMALLINT UNSIGNED" | "TINYINT" | "TINYINT UNSIGNED" => row
.try_get::<i64, _>(col.as_str())
.map(|v| serde_json::json!(v))
.unwrap_or(serde_json::Value::Null),
"DOUBLE" | "FLOAT" => row
.try_get::<f64, _>(col.as_str())
.map(|v| serde_json::json!(v))
.unwrap_or(serde_json::Value::Null),
"NEWDECIMAL" | "DECIMAL" => row
.try_get::<String, _>(col.as_str())
.map(|v| {
serde_json::from_str(&v).unwrap_or_else(|_| serde_json::json!(v))
})
.unwrap_or(serde_json::Value::Null),
"JSON" => row
.try_get::<serde_json::Value, _>(col.as_str())
.unwrap_or(serde_json::Value::Null),
_ => row
.try_get::<String, _>(col.as_str())
.map(|v| {
serde_json::from_str(&v).unwrap_or_else(|_| serde_json::json!(v))
})
.unwrap_or(serde_json::Value::Null),
};
map.insert(col, value);
}
map
})
.collect();
Ok(results)
}
async fn explain_query(
&self,
sql: &str,
_params: &[serde_json::Value],
) -> Result<serde_json::Value> {
use sqlx::Row as _;
if sql.contains(';') {
return Err(FraiseQLError::Validation {
message: "EXPLAIN SQL must be a single statement".into(),
path: None,
});
}
let explain_sql = format!("EXPLAIN FORMAT=JSON {sql}");
let row: sqlx::mysql::MySqlRow = sqlx::query(&explain_sql)
.fetch_one(&self.pool)
.await
.map_err(|e| FraiseQLError::Database {
message: format!("MySQL EXPLAIN failed: {e}"),
sql_state: None,
})?;
let raw: String = row.try_get(0).map_err(|e| FraiseQLError::Database {
message: format!("Failed to read MySQL EXPLAIN output: {e}"),
sql_state: None,
})?;
serde_json::from_str(&raw).map_err(|e| FraiseQLError::Database {
message: format!("Failed to parse MySQL EXPLAIN JSON: {e}"),
sql_state: None,
})
}
async fn query_stats(&self, limit: u32) -> Result<Vec<crate::types::QueryStatEntry>> {
let probe: std::result::Result<MySqlRow, _> = sqlx::query(
"SELECT 1 FROM performance_schema.events_statements_summary_by_digest LIMIT 0",
)
.fetch_one(&self.pool)
.await;
if probe.is_err() {
return Ok(vec![]);
}
let rows: Vec<MySqlRow> = sqlx::query(
"SELECT \
DIGEST AS query_id, \
DIGEST_TEXT AS query_text, \
COUNT_STAR AS calls, \
SUM_TIMER_WAIT / 1000000000 AS total_exec_time_ms, \
AVG_TIMER_WAIT / 1000000000 AS mean_exec_time_ms, \
MIN_TIMER_WAIT / 1000000000 AS min_exec_time_ms, \
MAX_TIMER_WAIT / 1000000000 AS max_exec_time_ms, \
SUM_ROWS_SENT AS rows_returned, \
SUM_ROWS_EXAMINED, \
SUM_NO_INDEX_USED, \
SUM_NO_GOOD_INDEX_USED \
FROM performance_schema.events_statements_summary_by_digest \
WHERE DIGEST IS NOT NULL \
ORDER BY SUM_TIMER_WAIT DESC \
LIMIT ?",
)
.bind(limit)
.fetch_all(&self.pool)
.await
.map_err(|e| FraiseQLError::Database {
message: format!("Failed to query performance_schema: {e}"),
sql_state: None,
})?;
rows.iter()
.map(|row| {
let rows_examined: i64 = row.try_get("SUM_ROWS_EXAMINED").unwrap_or(0);
let no_index: i64 = row.try_get("SUM_NO_INDEX_USED").unwrap_or(0);
let no_good_index: i64 = row.try_get("SUM_NO_GOOD_INDEX_USED").unwrap_or(0);
Ok(crate::types::QueryStatEntry {
query_id: row.try_get::<String, _>("query_id").unwrap_or_default(),
query_text: row.try_get::<String, _>("query_text").unwrap_or_default(),
calls: row.try_get::<i64, _>("calls").unwrap_or(0).unsigned_abs(),
total_exec_time_ms: row.try_get("total_exec_time_ms").unwrap_or(0.0),
mean_exec_time_ms: row.try_get("mean_exec_time_ms").unwrap_or(0.0),
min_exec_time_ms: row.try_get("min_exec_time_ms").unwrap_or(0.0),
max_exec_time_ms: row.try_get("max_exec_time_ms").unwrap_or(0.0),
rows_returned: row
.try_get::<i64, _>("rows_returned")
.unwrap_or(0)
.unsigned_abs(),
cache_hit_ratio: None,
database_specific: serde_json::json!({
"sum_rows_examined": rows_examined,
"sum_no_index_used": no_index,
"sum_no_good_index_used": no_good_index,
}),
})
})
.collect()
}
}
pub(super) fn map_mysql_error_code(code: u16) -> Option<String> {
let sqlstate = match code {
1062 | 1169 => "23505",
1048 => "23502",
1451 | 1452 => "23503",
1205 => "40001",
1213 => "40001",
_ => return None,
};
Some(sqlstate.to_string())
}
fn build_mysql_relay_order_sql(
quoted_col: &str,
order_by: Option<&[OrderByClause]>,
forward: bool,
) -> String {
let mut parts: Vec<String> = Vec::new();
if let Some(clauses) = order_by {
for c in clauses {
let dir = match (c.direction, forward) {
(OrderDirection::Asc, true) | (OrderDirection::Desc, false) => "ASC",
(OrderDirection::Desc, true) | (OrderDirection::Asc, false) => "DESC",
};
let escaped = c.field.replace('\'', "''");
parts.push(format!("JSON_UNQUOTE(JSON_EXTRACT(data, '$.{escaped}')) {dir}"));
}
}
let cursor_dir = if forward { "ASC" } else { "DESC" };
parts.push(format!("{quoted_col} {cursor_dir}"));
format!(" ORDER BY {}", parts.join(", "))
}
fn build_mysql_relay_where(cursor_sql: Option<&str>, user_sql: Option<&str>) -> String {
match (cursor_sql, user_sql) {
(None, None) => String::new(),
(Some(c), None) => format!(" WHERE {c}"),
(None, Some(u)) => format!(" WHERE ({u})"),
(Some(c), Some(u)) => format!(" WHERE {c} AND ({u})"),
}
}
impl MySqlAdapter {
async fn execute_count_query(&self, sql: &str, params: Vec<serde_json::Value>) -> Result<u64> {
let mut query = sqlx::query(sql);
for param in ¶ms {
query = match param {
serde_json::Value::String(s) => query.bind(s.clone()),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
query.bind(i)
} else if let Some(f) = n.as_f64() {
query.bind(f)
} else {
query.bind(n.to_string())
}
},
serde_json::Value::Bool(b) => query.bind(*b),
serde_json::Value::Null => query.bind(Option::<String>::None),
serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
query.bind(param.to_string())
},
};
}
let row: MySqlRow =
query.fetch_one(&self.pool).await.map_err(|e| FraiseQLError::Database {
message: format!("MySQL COUNT query failed: {e}"),
sql_state: None,
})?;
let cnt: u64 = if let Ok(v) = row.try_get::<i64, _>(0) {
v.cast_unsigned()
} else {
row.try_get::<u64, _>(0).unwrap_or_default()
};
Ok(cnt)
}
}
impl SupportsMutations for MySqlAdapter {}
impl RelayDatabaseAdapter for MySqlAdapter {
async fn execute_relay_page(
&self,
view: &str,
cursor_column: &str,
after: Option<CursorValue>,
before: Option<CursorValue>,
limit: u32,
forward: bool,
where_clause: Option<&WhereClause>,
order_by: Option<&[OrderByClause]>,
include_total_count: bool,
) -> Result<RelayPageResult> {
let quoted_view = quote_mysql_identifier(view);
let quoted_col = quote_mysql_identifier(cursor_column);
let active_cursor = if forward { after } else { before };
let (cursor_where_sql, cursor_param): (Option<String>, Option<serde_json::Value>) =
match active_cursor {
None => (None, None),
Some(CursorValue::Int64(pk)) => {
let op = if forward { ">" } else { "<" };
(
Some(format!("{quoted_col} {op} ?")),
Some(serde_json::Value::Number(pk.into())),
)
},
Some(CursorValue::Uuid(uuid)) => {
let op = if forward { ">" } else { "<" };
(Some(format!("{quoted_col} {op} ?")), Some(serde_json::Value::String(uuid)))
},
};
let (user_where_sql, user_where_params): (Option<String>, Vec<serde_json::Value>) =
if let Some(clause) = where_clause {
let generator = MySqlWhereGenerator::new(MySqlDialect);
let (sql, params) = generator.generate(clause)?;
(Some(sql), params)
} else {
(None, Vec::new())
};
let order_sql = build_mysql_relay_order_sql("ed_col, order_by, forward);
let page_where_sql =
build_mysql_relay_where(cursor_where_sql.as_deref(), user_where_sql.as_deref());
let mut page_params: Vec<serde_json::Value> = Vec::new();
if let Some(cp) = cursor_param {
page_params.push(cp);
}
page_params.extend(user_where_params.iter().cloned());
page_params.push(serde_json::Value::Number(limit.into()));
let page_sql = if forward {
format!("SELECT data FROM {quoted_view}{page_where_sql}{order_sql} LIMIT ?")
} else {
let inner = format!(
"SELECT data, {quoted_col} AS _relay_cursor \
FROM {quoted_view}{page_where_sql}{order_sql} LIMIT ?"
);
format!("SELECT data FROM ({inner}) _relay_page ORDER BY _relay_cursor ASC")
};
let rows = self.execute_raw(&page_sql, page_params).await?;
let total_count = if include_total_count {
let (count_sql, count_params) = if let Some(u_sql) = &user_where_sql {
(
format!("SELECT COUNT(*) FROM {quoted_view} WHERE ({u_sql})"),
user_where_params.clone(),
)
} else {
(format!("SELECT COUNT(*) FROM {quoted_view}"), vec![])
};
Some(self.execute_count_query(&count_sql, count_params).await?)
} else {
None
};
Ok(RelayPageResult::new(rows, total_count))
}
}
#[cfg(test)]
mod tests;