use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use sqlx::MySql;
use sqlx::Pool;
use tera::{Result as TeraResult, Tera, Value};
thread_local! {
static CURRENT_BIND_CTX: RefCell<Option<SqlBindContext>> = const { RefCell::new(None) };
}
#[derive(Clone)]
pub struct MysqlClient {
pub pool: Pool<MySql>,
pub tera: Tera,
}
#[derive(Clone, Default)]
pub struct OwnedSqlQuery {
pub sql: String,
pub binds: Vec<serde_json::Value>,
}
#[derive(Clone, Default)]
pub struct SqlBindContext {
binds: Arc<Mutex<Vec<Value>>>,
}
impl SqlBindContext {
pub fn add_bind(&self, val: Value) {
self.binds.lock().unwrap().push(val);
}
pub fn take_binds(&self) -> Vec<Value> {
self.binds.lock().unwrap().drain(..).collect()
}
}
impl MysqlClient {
pub fn register_sql_bind_filter(&mut self, bind_ctx: SqlBindContext) {
self.tera.register_filter(
"sql_bind",
move |value: &Value, _args: &HashMap<String, Value>| -> TeraResult<Value> {
bind_ctx.add_bind(value.clone());
Ok(Value::String("?".to_string()))
},
);
}
pub fn init_sql_bind_filter(&mut self) {
self.tera.register_filter(
"sql_bind",
|_value: &Value, _args: &HashMap<String, Value>| -> TeraResult<Value> {
CURRENT_BIND_CTX.with(|ctx| {
if let Some(ref ctx) = *ctx.borrow() {
ctx.add_bind(_value.clone());
}
});
Ok(Value::String("?".to_string()))
},
);
}
pub fn debug_sql_with_binds(sql: &str, binds: &[serde_json::Value]) -> String {
let mut final_sql = String::new();
let mut bind_iter = binds.iter();
let chars = sql.chars();
for c in chars {
if c == '?' {
if let Some(v) = bind_iter.next() {
let formatted = if v.is_string() {
format!("'{}'", v.as_str().unwrap().replace("'", "''"))
} else {
v.to_string()
};
final_sql.push_str(&formatted);
} else {
final_sql.push('?');
}
} else {
final_sql.push(c);
}
}
final_sql
}
pub fn bind_query(
&self,
template_name: &str,
params: Option<&std::collections::HashMap<String, serde_json::Value>>,
) -> anyhow::Result<OwnedSqlQuery> {
let mut tera_ctx = tera::Context::new();
if let Some(_params) = params {
for (k, v) in _params {
tera_ctx.insert(k, v);
}
}
let bind_ctx = SqlBindContext::default();
CURRENT_BIND_CTX.with(|cell| {
*cell.borrow_mut() = Some(bind_ctx.clone());
});
let sql = self.tera.render(template_name, &tera_ctx)?;
let binds = bind_ctx.take_binds();
CURRENT_BIND_CTX.with(|cell| {
*cell.borrow_mut() = None;
});
log::debug!(
"sql_name={}, statement={}",
template_name,
Self::collapse_spaces(&Self::debug_sql_with_binds(&sql, &binds).replace("\n", " "))
);
Ok(OwnedSqlQuery { sql, binds })
}
fn collapse_spaces(s: &str) -> String {
s.split_whitespace().collect::<Vec<_>>().join(" ")
}
pub fn bind_query_args<'a>(
mut q: sqlx::query::Query<'a, sqlx::mysql::MySql, sqlx::mysql::MySqlArguments>,
binds: &'a [serde_json::Value],
) -> sqlx::query::Query<'a, sqlx::mysql::MySql, sqlx::mysql::MySqlArguments> {
for val in binds {
if let Some(s) = val.as_str() {
q = q.bind(s);
} else if let Some(n) = val.as_i64() {
q = q.bind(n);
} else if let Some(f) = val.as_f64() {
q = q.bind(f);
} else if val.is_boolean() {
q = q.bind(val.as_bool().unwrap());
} else {
q = q.bind(format!("{}", val));
}
}
q
}
pub async fn insert(
&self,
template_name: &str,
params: Option<&std::collections::HashMap<String, serde_json::Value>>,
) -> anyhow::Result<u64> {
let bq = self.bind_query(template_name, params)?;
let q = Self::bind_query_args(sqlx::query(&bq.sql), &bq.binds);
let transaction = self.pool.begin().await.map_err(|e| {
log::error!("execute-failed: error={:?}", e);
anyhow::anyhow!(e)
})?;
let last_insert_id = q
.execute(&self.pool)
.await
.map_err(|e| anyhow::anyhow!(e))?
.last_insert_id();
transaction.commit().await.map_err(|e| {
log::error!("execute-failed: error={:?}", e);
anyhow::anyhow!(e)
})?;
Ok(last_insert_id)
}
pub async fn execute(
&self,
template_name: &str,
params: Option<&std::collections::HashMap<String, serde_json::Value>>,
) -> anyhow::Result<u64> {
let bq = self.bind_query(template_name, params)?;
let q = Self::bind_query_args(sqlx::query(&bq.sql), &bq.binds);
let transaction = self.pool.begin().await.map_err(|e| {
log::error!("execute-failed: error={:?}", e);
anyhow::anyhow!(e)
})?;
let update_result_int = q
.execute(&self.pool)
.await
.map_err(|e| anyhow::anyhow!(e))?
.rows_affected();
transaction.commit().await.map_err(|e| {
log::error!("execute-failed: error={:?}", e);
anyhow::anyhow!(e)
})?;
Ok(update_result_int)
}
pub async fn query_with_rows(
&self,
template_name: &str,
params: Option<&std::collections::HashMap<String, serde_json::Value>>,
) -> anyhow::Result<Vec<sqlx::mysql::MySqlRow>> {
let bq = self.bind_query(template_name, params)?;
let q = Self::bind_query_args(sqlx::query(&bq.sql), &bq.binds);
let rows = q.fetch_all(&self.pool).await?;
Ok(rows)
}
pub async fn query_one_row(
&self,
template_name: &str,
params: Option<&std::collections::HashMap<String, serde_json::Value>>,
) -> anyhow::Result<Option<sqlx::mysql::MySqlRow>> {
let bq = self.bind_query(template_name, params)?;
let q = Self::bind_query_args(sqlx::query(&bq.sql), &bq.binds);
let row = q.fetch_optional(&self.pool).await?;
Ok(row)
}
}