use std::future::Future;
use std::sync::Arc;
use crate::error::DBError;
use mysql_async::prelude::Queryable;
use mysql_async::{Opts, Pool};
use sqlw::{FromRow, Query, RowCell, RowError, RowLike, Value};
pub struct MySqlRowRef<'a> {
row: &'a mysql_async::Row,
}
impl<'a> MySqlRowRef<'a> {
pub fn new(row: &'a mysql_async::Row) -> Self {
Self { row }
}
fn column_index(&self, name: &str) -> Result<usize, RowError> {
self.row
.columns()
.iter()
.position(|col| col.name_str() == name)
.ok_or_else(|| RowError::ColumnNotFound {
name: name.to_string(),
})
}
fn get_value_by_index(&self, index: usize) -> Result<Value, RowError> {
let columns = self.row.columns();
let column = columns
.get(index)
.ok_or_else(|| RowError::Any(format!("Column index {} not found", index)))?;
use mysql_async::consts::ColumnType;
match column.column_type() {
ColumnType::MYSQL_TYPE_VAR_STRING
| ColumnType::MYSQL_TYPE_STRING
| ColumnType::MYSQL_TYPE_VARCHAR
| ColumnType::MYSQL_TYPE_BLOB => {
let val: Option<String> = self.row.get(index);
match val {
Some(s) => Ok(Value::Text(s)),
None => Ok(Value::Null),
}
}
ColumnType::MYSQL_TYPE_TINY
| ColumnType::MYSQL_TYPE_SHORT
| ColumnType::MYSQL_TYPE_INT24
| ColumnType::MYSQL_TYPE_LONG => {
let val: Option<i32> = self.row.get(index);
match val {
Some(v) => Ok(Value::Int(v as i64)),
None => Ok(Value::Null),
}
}
ColumnType::MYSQL_TYPE_LONGLONG => {
let val: Option<i64> = self.row.get(index);
match val {
Some(v) => Ok(Value::Int(v)),
None => Ok(Value::Null),
}
}
ColumnType::MYSQL_TYPE_FLOAT | ColumnType::MYSQL_TYPE_DOUBLE => {
let val: Option<f64> = self.row.get(index);
match val {
Some(v) => Ok(Value::Float(v)),
None => Ok(Value::Null),
}
}
_ => {
let val: Option<String> = self.row.get(index);
match val {
Some(s) => Ok(Value::Text(s)),
None => Ok(Value::Null),
}
}
}
}
}
impl<'a> RowLike for MySqlRowRef<'a> {
fn cell<'b>(&'b self, name: &str) -> Result<RowCell<'b>, RowError> {
let index = self.column_index(name)?;
Ok(RowCell::Owned(self.get_value_by_index(index)?))
}
}
#[derive(Debug, Clone)]
pub struct MySqlExecutor {
pool: Arc<Pool>,
}
impl MySqlExecutor {
pub async fn new<F, Fut>(connector: F) -> Result<Self, DBError>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<Pool, mysql_async::Error>> + Send + 'static,
{
let pool = connector()
.await
.map_err(|e| DBError::Connection(e.into()))?;
Ok(MySqlExecutor {
pool: Arc::new(pool),
})
}
pub fn pool(&self) -> &Pool {
&self.pool
}
pub async fn from_url(url: &str) -> Result<Self, DBError> {
let opts = Opts::from_url(url).map_err(|e| DBError::Connection(e.into()))?;
let pool = Pool::new(opts);
Ok(Self {
pool: Arc::new(pool),
})
}
}
impl sqlw::QueryExecutor for MySqlExecutor {
type Error = DBError;
fn query_void(&self, query: Query) -> impl Future<Output = Result<(), DBError>> {
let pool = Arc::clone(&self.pool);
async move {
let (sql, args) = query.split();
let mysql_params: Vec<mysql_async::Value> =
args.into_iter().map(sqlw_to_mysql_value).collect();
let mut conn = pool
.get_conn()
.await
.map_err(|e| DBError::Execution(e.into()))?;
conn.exec_drop(&sql, mysql_params)
.await
.map_err(|e| DBError::Execution(e.into()))
}
}
fn query_one<T: FromRow + Send + 'static>(
&self,
query: Query,
) -> impl Future<Output = Result<Option<T>, DBError>> {
let pool = Arc::clone(&self.pool);
async move {
let (sql, args) = query.split();
let mysql_params: Vec<mysql_async::Value> =
args.into_iter().map(sqlw_to_mysql_value).collect();
let mut conn = pool
.get_conn()
.await
.map_err(|e| DBError::Execution(e.into()))?;
let result: Option<mysql_async::Row> = conn
.exec_first(&sql, mysql_params)
.await
.map_err(|e| DBError::Execution(e.into()))?;
match result {
Some(row) => {
let row_ref = MySqlRowRef::new(&row);
T::from_row(&row_ref).map(Some).map_err(DBError::Mapping)
}
None => Ok(None),
}
}
}
fn query_list<T: FromRow + Send + 'static>(
&self,
query: Query,
) -> impl Future<Output = Result<Vec<T>, DBError>> {
let pool = Arc::clone(&self.pool);
async move {
let (sql, args) = query.split();
let mysql_params: Vec<mysql_async::Value> =
args.into_iter().map(sqlw_to_mysql_value).collect();
let mut conn = pool
.get_conn()
.await
.map_err(|e| DBError::Execution(e.into()))?;
let rows: Vec<mysql_async::Row> = conn
.exec(&sql, mysql_params)
.await
.map_err(|e| DBError::Execution(e.into()))?;
let mut results = Vec::new();
for row in rows {
let row_ref = MySqlRowRef::new(&row);
results.push(T::from_row(&row_ref)?);
}
Ok(results)
}
}
}
fn sqlw_to_mysql_value(value: Value) -> mysql_async::Value {
match value {
Value::Text(s) => mysql_async::Value::Bytes(s.into_bytes()),
Value::Int(i) => mysql_async::Value::Int(i),
Value::Float(f) => mysql_async::Value::Double(f),
Value::Bool(b) => mysql_async::Value::Int(i64::from(b)),
Value::Blob(b) => mysql_async::Value::Bytes(b),
Value::Null => mysql_async::Value::NULL,
}
}