use crate::{recordset::ColumnInfo, DatabaseError, DatabaseResult, Recordset};
use mysql_async::{prelude::*, Conn, OptsBuilder, Pool, Row, Value as MySQLValue};
use serde_json::Value as JsonValue;
#[derive(Debug)]
pub struct MySQLConnection {
pool: Option<Pool>,
conn: Option<Conn>,
}
impl MySQLConnection {
pub async fn connect(connection_string: &str) -> DatabaseResult<Self> {
let opts = OptsBuilder::from_opts(connection_string).map_err(|e| {
DatabaseError::connection_error(format!("Invalid MySQL connection string: {}", e))
})?;
let pool = Pool::new(opts);
let mut conn = pool.get_conn().await.map_err(|e| {
DatabaseError::connection_error(format!("Failed to connect to MySQL: {}", e))
})?;
conn.query_drop("SELECT 1").await.map_err(|e| {
DatabaseError::connection_error(format!("MySQL connection test failed: {}", e))
})?;
Ok(Self {
pool: Some(pool),
conn: Some(conn),
})
}
pub async fn execute(&self, query: &str) -> DatabaseResult<Recordset> {
let conn = self.conn.as_ref().ok_or(DatabaseError::NotConnected)?;
let result: Vec<Row> = conn
.clone()
.query(query)
.await
.map_err(|e| DatabaseError::query_error(format!("MySQL query failed: {}", e)))?;
if result.is_empty() {
return Ok(Recordset::empty());
}
let first_row = &result[0];
let columns: Vec<ColumnInfo> = first_row
.columns()
.iter()
.enumerate()
.map(|(i, col)| ColumnInfo {
name: col.name_str().to_string(),
data_type: format!("{:?}", col.column_type()),
ordinal: i,
})
.collect();
let mut data_rows = Vec::new();
for row in result {
let mut row_data = Vec::new();
for col in row.columns() {
let col_name = col.name_str();
let value = mysql_value_to_json(&row, col_name)?;
row_data.push(value);
}
data_rows.push(row_data);
}
Ok(Recordset::new(columns, data_rows))
}
pub async fn execute_command(&self, command: &str) -> DatabaseResult<u64> {
let conn = self.conn.as_ref().ok_or(DatabaseError::NotConnected)?;
let result = conn
.clone()
.query_drop(command)
.await
.map_err(|e| DatabaseError::query_error(format!("MySQL command failed: {}", e)))?;
let mut conn_mut = self.conn.as_ref().unwrap().clone();
let affected = conn_mut
.exec_drop(command, ())
.await
.map(|_| {
0u64
})
.unwrap_or(0);
Ok(affected)
}
pub async fn close(&mut self) -> DatabaseResult<()> {
if let Some(conn) = self.conn.take() {
drop(conn);
}
if let Some(pool) = self.pool.take() {
pool.disconnect().await.map_err(|e| {
DatabaseError::query_error(format!("Failed to disconnect MySQL pool: {}", e))
})?;
}
Ok(())
}
pub async fn is_connected(&self) -> bool {
if let Some(conn) = &self.conn {
conn.clone().query_drop("SELECT 1").await.is_ok()
} else {
false
}
}
}
fn mysql_value_to_json(row: &Row, col_name: &str) -> DatabaseResult<JsonValue> {
let value: MySQLValue = row
.get(col_name)
.ok_or(DatabaseError::conversion_error(format!(
"Column '{}' not found",
col_name
)))?;
let json_value = match value {
MySQLValue::NULL => JsonValue::Null,
MySQLValue::Bytes(bytes) => {
match String::from_utf8(bytes) {
Ok(s) => {
if let Ok(num) = s.parse::<i64>() {
JsonValue::from(num)
} else if let Ok(num) = s.parse::<f64>() {
JsonValue::from(num)
} else {
JsonValue::from(s)
}
}
Err(_) => JsonValue::from("(binary data)"),
}
}
MySQLValue::Int(i) => JsonValue::from(i),
MySQLValue::UInt(u) => JsonValue::from(u),
MySQLValue::Float(f) => JsonValue::from(f as f64),
MySQLValue::Double(d) => JsonValue::from(d),
MySQLValue::Date(year, month, day, hour, minute, second, _micro) => {
let datetime_str = format!(
"{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
year, month, day, hour, minute, second
);
JsonValue::from(datetime_str)
}
MySQLValue::Time(neg, days, hours, minutes, seconds, _micros) => {
let sign = if neg { "-" } else { "" };
let time_str = format!(
"{}{} {:02}:{:02}:{:02}",
sign, days, hours, minutes, seconds
);
JsonValue::from(time_str)
}
};
Ok(json_value)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mysql_connection_string_parsing() {
let test_strings = vec![
"mysql://root:password@localhost:3306/testdb",
"mysql://user@localhost/mydb",
"mysql://user:pass@192.168.1.100:3307/data",
];
for conn_str in test_strings {
let result = OptsBuilder::from_opts(conn_str);
assert!(result.is_ok(), "Failed to parse: {}", conn_str);
}
}
#[tokio::test]
async fn test_value_conversion() {
let json = mysql_value_to_json_test(MySQLValue::NULL);
assert_eq!(json, JsonValue::Null);
let json = mysql_value_to_json_test(MySQLValue::Int(42));
assert_eq!(json, JsonValue::from(42));
let json = mysql_value_to_json_test(MySQLValue::Bytes(b"hello".to_vec()));
assert_eq!(json, JsonValue::from("hello"));
}
fn mysql_value_to_json_test(value: MySQLValue) -> JsonValue {
match value {
MySQLValue::NULL => JsonValue::Null,
MySQLValue::Int(i) => JsonValue::from(i),
MySQLValue::Bytes(bytes) => String::from_utf8(bytes)
.map(JsonValue::from)
.unwrap_or(JsonValue::from("(binary data)")),
_ => JsonValue::Null,
}
}
}