use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use mysql_async::consts::ColumnType;
use mysql_async::prelude::Queryable;
use mysql_async::{Conn, Opts, OptsBuilder, Params, Pool, PoolConstraints, PoolOpts};
use mysql_async::{Row as MyRow, Value as MyValue};
use tokio::sync::Mutex;
use crate::driver::ExecuteResult;
use crate::error::OrmError;
use crate::row::Row;
use crate::value::Value;
const BINARY_CHARSET: u16 = 63;
#[derive(Clone)]
pub struct MysqlPool {
pool: Pool,
statements: Arc<AtomicU64>,
}
impl MysqlPool {
pub fn new(url: &str, max_connections: u32) -> crate::Result<Self> {
if max_connections == 0 {
return Err(OrmError::configuration("max_connections must be at least 1"));
}
let opts = Opts::from_url(url)
.map_err(|e| OrmError::configuration("invalid MySQL url").with_source(e))?;
let constraints = PoolConstraints::new(0, max_connections as usize)
.ok_or_else(|| OrmError::configuration("invalid MySQL pool size"))?;
let opts = OptsBuilder::from_opts(opts)
.pool_opts(PoolOpts::default().with_constraints(constraints));
Ok(Self { pool: Pool::new(opts), statements: Arc::new(AtomicU64::new(0)) })
}
pub async fn fetch_all(&self, sql: String, params: Vec<Value>) -> crate::Result<Vec<Row>> {
self.statements.fetch_add(1, Ordering::Relaxed);
let mut conn = self.get().await?;
let rows = query(&mut conn, &sql, params).await;
rows
}
pub async fn execute(&self, sql: String, params: Vec<Value>) -> crate::Result<ExecuteResult> {
self.statements.fetch_add(1, Ordering::Relaxed);
let mut conn = self.get().await?;
execute(&mut conn, &sql, params).await
}
pub async fn execute_batch(&self, sql: String) -> crate::Result<()> {
self.statements.fetch_add(1, Ordering::Relaxed);
let mut conn = self.get().await?;
execute_batch(&mut conn, &sql).await
}
pub fn statement_count(&self) -> u64 {
self.statements.load(Ordering::Relaxed)
}
pub async fn close(&self) {
let _ = self.pool.clone().disconnect().await;
}
pub(crate) async fn acquire_pinned(&self) -> crate::Result<PinnedMysql> {
let conn = self.get().await?;
Ok(PinnedMysql {
conn: Mutex::new(Some(conn)),
statements: Arc::clone(&self.statements),
})
}
async fn get(&self) -> crate::Result<Conn> {
self.pool
.get_conn()
.await
.map_err(|e| OrmError::connection("cannot acquire a MySQL connection").with_source(e))
}
}
async fn query(conn: &mut Conn, sql: &str, params: Vec<Value>) -> crate::Result<Vec<Row>> {
let bound = Params::Positional(params.into_iter().map(to_my_value).collect());
let rows: Vec<MyRow> = conn
.exec(sql, bound)
.await
.map_err(|e| OrmError::query("query execution failed").with_source(e))?;
rows.into_iter().map(read_row).collect()
}
async fn execute(conn: &mut Conn, sql: &str, params: Vec<Value>) -> crate::Result<ExecuteResult> {
if params.is_empty() {
conn.query_drop(sql)
.await
.map_err(|e| OrmError::query("statement execution failed").with_source(e))?;
} else {
let bound = Params::Positional(params.into_iter().map(to_my_value).collect());
conn.exec_drop(sql, bound)
.await
.map_err(|e| OrmError::query("statement execution failed").with_source(e))?;
}
Ok(ExecuteResult {
rows_affected: conn.affected_rows(),
last_insert_rowid: conn.last_insert_id().unwrap_or(0) as i64,
})
}
async fn execute_batch(conn: &mut Conn, sql: &str) -> crate::Result<()> {
for statement in sql.split(';') {
let statement = statement.trim();
if statement.is_empty() {
continue;
}
conn.query_drop(statement)
.await
.map_err(|e| OrmError::query("statement batch failed").with_source(e))?;
}
Ok(())
}
fn to_my_value(value: Value) -> MyValue {
match value {
Value::Null => MyValue::NULL,
Value::Bool(b) => MyValue::Int(i64::from(b)),
Value::Int(i) => MyValue::Int(i),
Value::Real(r) => MyValue::Double(r),
Value::Text(s) => MyValue::Bytes(s.into_bytes()),
Value::Blob(b) => MyValue::Bytes(b),
Value::Timestamp(ts) => {
let utc = ts.to_offset(time::UtcOffset::UTC);
MyValue::Date(
utc.year() as u16,
u8::from(utc.month()),
utc.day(),
utc.hour(),
utc.minute(),
utc.second(),
utc.microsecond(),
)
}
Value::Json(json) => MyValue::Bytes(json.to_string().into_bytes()),
Value::Uuid(uuid) => MyValue::Bytes(uuid.to_string().into_bytes()),
Value::Array(items) => MyValue::Bytes(format!("{items:?}").into_bytes()),
}
}
fn read_row(row: MyRow) -> crate::Result<Row> {
let columns = row.columns();
let names: Arc<[String]> = columns
.iter()
.map(|column| column.name_str().into_owned())
.collect::<Vec<_>>()
.into();
let mut values = Vec::with_capacity(columns.len());
for (index, column) in columns.iter().enumerate() {
let raw = row.as_ref(index).cloned().unwrap_or(MyValue::NULL);
values.push(read_value(&raw, column.column_type(), column.character_set())?);
}
Ok(Row::with_columns(names, values))
}
fn read_value(value: &MyValue, column_type: ColumnType, charset: u16) -> crate::Result<Value> {
Ok(match value {
MyValue::NULL => Value::Null,
MyValue::Int(i) => Value::Int(*i),
MyValue::UInt(u) => Value::Int(*u as i64),
MyValue::Float(f) => Value::Real(f64::from(*f)),
MyValue::Double(f) => Value::Real(*f),
MyValue::Date(year, month, day, hour, min, sec, micro) => {
let date = time::Date::from_calendar_date(
i32::from(*year),
time::Month::try_from(*month)
.map_err(|_| OrmError::conversion("invalid month in DATETIME"))?,
*day,
)
.map_err(|_| OrmError::conversion("invalid DATETIME date"))?;
let clock = time::Time::from_hms_micro(*hour, *min, *sec, *micro)
.map_err(|_| OrmError::conversion("invalid DATETIME time"))?;
Value::Timestamp(time::PrimitiveDateTime::new(date, clock).assume_utc())
}
MyValue::Time(neg, days, hours, mins, secs, micros) => Value::Text(format!(
"{}{}:{:02}:{:02}:{:02}.{:06}",
if *neg { "-" } else { "" },
days,
hours,
mins,
secs,
micros
)),
MyValue::Bytes(bytes) => {
if column_type == ColumnType::MYSQL_TYPE_JSON {
serde_json::from_slice(bytes)
.map(Value::Json)
.unwrap_or_else(|_| Value::Text(String::from_utf8_lossy(bytes).into_owned()))
} else if is_binary_blob(column_type, charset) {
Value::Blob(bytes.clone())
} else {
Value::Text(String::from_utf8_lossy(bytes).into_owned())
}
}
})
}
fn is_binary_blob(column_type: ColumnType, charset: u16) -> bool {
matches!(
column_type,
ColumnType::MYSQL_TYPE_BLOB
| ColumnType::MYSQL_TYPE_TINY_BLOB
| ColumnType::MYSQL_TYPE_MEDIUM_BLOB
| ColumnType::MYSQL_TYPE_LONG_BLOB
| ColumnType::MYSQL_TYPE_GEOMETRY
) && charset == BINARY_CHARSET
}
pub(crate) struct PinnedMysql {
conn: Mutex<Option<Conn>>,
statements: Arc<AtomicU64>,
}
impl PinnedMysql {
pub(crate) async fn fetch_all(
&self,
sql: String,
params: Vec<Value>,
) -> crate::Result<Vec<Row>> {
self.statements.fetch_add(1, Ordering::Relaxed);
let mut guard = self.conn.lock().await;
let conn = guard
.as_mut()
.ok_or_else(|| OrmError::query("pinned connection is unavailable"))?;
query(conn, &sql, params).await
}
pub(crate) async fn execute(
&self,
sql: String,
params: Vec<Value>,
) -> crate::Result<ExecuteResult> {
self.statements.fetch_add(1, Ordering::Relaxed);
let mut guard = self.conn.lock().await;
let conn = guard
.as_mut()
.ok_or_else(|| OrmError::query("pinned connection is unavailable"))?;
execute(conn, &sql, params).await
}
pub(crate) async fn execute_batch(&self, sql: String) -> crate::Result<()> {
self.statements.fetch_add(1, Ordering::Relaxed);
let mut guard = self.conn.lock().await;
let conn = guard
.as_mut()
.ok_or_else(|| OrmError::query("pinned connection is unavailable"))?;
execute_batch(conn, &sql).await
}
pub(crate) fn rollback_now(&self) {
let Ok(mut guard) = self.conn.try_lock() else { return };
let Some(mut conn) = guard.take() else { return };
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn(async move {
let _ = conn.query_drop("ROLLBACK").await;
});
}
}
}