zero4rs 2.0.0

zero4rs is a powerful, pragmatic, and extremely fast web framework for Rust
Documentation
use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;

use sqlx::MySql;
use sqlx::Pool;

use tera::{Result as TeraResult, Tera, Value};

thread_local! {
    static CURRENT_BIND_CTX: RefCell<Option<SqlBindContext>> = const { RefCell::new(None) };
}

#[derive(Clone)]
pub struct MysqlClient {
    pub pool: Pool<MySql>,
    pub tera: Tera,
}

#[derive(Clone, Default)]
pub struct OwnedSqlQuery {
    pub sql: String,
    pub binds: Vec<serde_json::Value>,
}

#[derive(Clone, Default)]
pub struct SqlBindContext {
    binds: Arc<Mutex<Vec<Value>>>,
}

impl SqlBindContext {
    pub fn add_bind(&self, val: Value) {
        self.binds.lock().unwrap().push(val);
    }

    pub fn take_binds(&self) -> Vec<Value> {
        self.binds.lock().unwrap().drain(..).collect()
    }
}

impl MysqlClient {
    pub fn register_sql_bind_filter(&mut self, bind_ctx: SqlBindContext) {
        self.tera.register_filter(
            "sql_bind",
            move |value: &Value, _args: &HashMap<String, Value>| -> TeraResult<Value> {
                // 每次遇到 sql_bind 就收集值
                bind_ctx.add_bind(value.clone());

                // 模板中输出 ? 占位符
                Ok(Value::String("?".to_string()))
            },
        );
    }

    pub fn init_sql_bind_filter(&mut self) {
        self.tera.register_filter(
            "sql_bind",
            |_value: &Value, _args: &HashMap<String, Value>| -> TeraResult<Value> {
                CURRENT_BIND_CTX.with(|ctx| {
                    if let Some(ref ctx) = *ctx.borrow() {
                        ctx.add_bind(_value.clone());
                    }
                });

                Ok(Value::String("?".to_string()))
            },
        );
    }

    /// 用于调试的函数:将 SQL 中的 `?` 占位符替换成绑定值的可读形式。
    pub fn debug_sql_with_binds(sql: &str, binds: &[serde_json::Value]) -> String {
        let mut final_sql = String::new();
        let mut bind_iter = binds.iter();

        let chars = sql.chars();

        for c in chars {
            if c == '?' {
                if let Some(v) = bind_iter.next() {
                    let formatted = if v.is_string() {
                        format!("'{}'", v.as_str().unwrap().replace("'", "''"))
                    } else {
                        v.to_string()
                    };
                    final_sql.push_str(&formatted);
                } else {
                    final_sql.push('?');
                }
            } else {
                final_sql.push(c);
            }
        }

        final_sql
    }

    pub fn bind_query(
        &self,
        template_name: &str,
        params: Option<&std::collections::HashMap<String, serde_json::Value>>,
    ) -> anyhow::Result<OwnedSqlQuery> {
        let mut tera_ctx = tera::Context::new();

        if let Some(_params) = params {
            for (k, v) in _params {
                tera_ctx.insert(k, v);
            }
        }

        let bind_ctx = SqlBindContext::default();

        // 设置当前线程绑定上下文
        CURRENT_BIND_CTX.with(|cell| {
            *cell.borrow_mut() = Some(bind_ctx.clone());
        });

        // 渲染 SQL
        let sql = self.tera.render(template_name, &tera_ctx)?;
        let binds = bind_ctx.take_binds();

        // 清理上下文
        CURRENT_BIND_CTX.with(|cell| {
            *cell.borrow_mut() = None;
        });

        log::debug!(
            "sql_name={}, statement={}",
            template_name,
            Self::collapse_spaces(&Self::debug_sql_with_binds(&sql, &binds).replace("\n", " "))
        );

        Ok(OwnedSqlQuery { sql, binds })
    }

    // 将一个或多个空格替换成一个空格
    fn collapse_spaces(s: &str) -> String {
        s.split_whitespace().collect::<Vec<_>>().join(" ")
    }

    pub fn bind_query_args<'a>(
        mut q: sqlx::query::Query<'a, sqlx::mysql::MySql, sqlx::mysql::MySqlArguments>,
        binds: &'a [serde_json::Value],
    ) -> sqlx::query::Query<'a, sqlx::mysql::MySql, sqlx::mysql::MySqlArguments> {
        for val in binds {
            if let Some(s) = val.as_str() {
                q = q.bind(s);
            } else if let Some(n) = val.as_i64() {
                q = q.bind(n);
            } else if let Some(f) = val.as_f64() {
                q = q.bind(f);
            } else if val.is_boolean() {
                q = q.bind(val.as_bool().unwrap());
            } else {
                q = q.bind(format!("{}", val));
            }
        }

        q
    }

    pub async fn insert(
        &self,
        template_name: &str,
        params: Option<&std::collections::HashMap<String, serde_json::Value>>,
    ) -> anyhow::Result<u64> {
        let bq = self.bind_query(template_name, params)?;
        let q = Self::bind_query_args(sqlx::query(&bq.sql), &bq.binds);

        let transaction = self.pool.begin().await.map_err(|e| {
            log::error!("execute-failed: error={:?}", e);
            anyhow::anyhow!(e)
        })?;

        let last_insert_id = q
            .execute(&self.pool)
            .await
            .map_err(|e| anyhow::anyhow!(e))?
            .last_insert_id();

        // if commit or rollback have not been called before the Transaction object goes out of scope (i.e. Drop is invoked),
        // a rollback command is queued to be executed as soon as an opportunity arises.
        transaction.commit().await.map_err(|e| {
            log::error!("execute-failed: error={:?}", e);
            anyhow::anyhow!(e)
        })?;

        Ok(last_insert_id)
    }

    pub async fn execute(
        &self,
        template_name: &str,
        params: Option<&std::collections::HashMap<String, serde_json::Value>>,
    ) -> anyhow::Result<u64> {
        let bq = self.bind_query(template_name, params)?;
        let q = Self::bind_query_args(sqlx::query(&bq.sql), &bq.binds);

        let transaction = self.pool.begin().await.map_err(|e| {
            log::error!("execute-failed: error={:?}", e);
            anyhow::anyhow!(e)
        })?;

        let update_result_int = q
            .execute(&self.pool)
            .await
            .map_err(|e| anyhow::anyhow!(e))?
            .rows_affected();

        // if commit or rollback have not been called before the Transaction object goes out of scope (i.e. Drop is invoked),
        // a rollback command is queued to be executed as soon as an opportunity arises.
        transaction.commit().await.map_err(|e| {
            log::error!("execute-failed: error={:?}", e);
            anyhow::anyhow!(e)
        })?;

        Ok(update_result_int)
    }

    pub async fn query_with_rows(
        &self,
        template_name: &str,
        params: Option<&std::collections::HashMap<String, serde_json::Value>>,
    ) -> anyhow::Result<Vec<sqlx::mysql::MySqlRow>> {
        let bq = self.bind_query(template_name, params)?;
        let q = Self::bind_query_args(sqlx::query(&bq.sql), &bq.binds);

        let rows = q.fetch_all(&self.pool).await?;

        Ok(rows)
    }

    pub async fn query_one_row(
        &self,
        template_name: &str,
        params: Option<&std::collections::HashMap<String, serde_json::Value>>,
    ) -> anyhow::Result<Option<sqlx::mysql::MySqlRow>> {
        let bq = self.bind_query(template_name, params)?;
        let q = Self::bind_query_args(sqlx::query(&bq.sql), &bq.binds);

        let row = q.fetch_optional(&self.pool).await?;

        Ok(row)
    }
}