use std::borrow::BorrowMut;
use std::cell::Cell;
use std::collections::HashMap;
use std::time::Duration;
use once_cell::sync::OnceCell;
use serde::de::DeserializeOwned;
use serde::ser::Serialize;
use serde_json::Number;
use uuid::Uuid;
use rbatis_core::db::DBConnectOption;
use crate::core::db::{DBExecResult, DBPool, DBPoolConn, DBPoolOptions, DBQuery, DBTx, DriverType};
use crate::core::sync::sync_map::SyncMap;
use crate::core::Error;
use crate::crud::CRUDTable;
use crate::plugin::intercept::SqlIntercept;
use crate::plugin::log::{LogPlugin, RbatisLog};
use crate::plugin::logic_delete::{LogicDelete, RbatisLogicDeletePlugin};
use crate::plugin::page::{IPage, IPageRequest, Page, PagePlugin, RbatisPagePlugin};
use crate::plugin::version_lock::{RbatisVersionLockPlugin, VersionLockPlugin};
use crate::sql::upper::SqlUpperCase;
use crate::sql::PageLimit;
use crate::tx::{TxGuard, TxManager, TxState};
use crate::utils::error_util::ToResult;
use crate::utils::string_util;
use crate::wrapper::Wrapper;
use py_sql::node::proxy_node::NodeFactory;
use py_sql::py_sql::PyRuntime;
use rexpr::runtime::RExprRuntime;
use std::sync::Arc;
#[derive(Debug)]
pub struct Rbatis {
pub pool: OnceCell<DBPool>,
pub runtime_expr: RExprRuntime,
pub runtime_py: PyRuntime,
pub tx_manager: Arc<TxManager>,
pub page_plugin: Box<dyn PagePlugin>,
pub sql_intercepts: Vec<Box<dyn SqlIntercept>>,
pub logic_plugin: Option<Box<dyn LogicDelete>>,
pub log_plugin: Arc<Box<dyn LogPlugin>>,
pub version_lock_plugin: Option<Box<dyn VersionLockPlugin>>,
}
impl Default for Rbatis {
fn default() -> Rbatis {
Rbatis::new()
}
}
impl Drop for Rbatis {
fn drop(&mut self) {
crate::core::runtime::task::block_on(async {
self.tx_manager.close().await;
match self.pool.get_mut() {
Some(p) => {
p.close().await;
}
_ => {}
}
});
}
}
pub struct RbatisOption {
pub tx_lock_wait_timeout: Duration,
pub tx_check_interval: Duration,
pub generate: Vec<Box<dyn NodeFactory>>,
pub page_plugin: Box<dyn PagePlugin>,
pub sql_intercepts: Vec<Box<dyn SqlIntercept>>,
pub logic_plugin: Option<Box<dyn LogicDelete>>,
pub log_plugin: Arc<Box<dyn LogPlugin>>,
pub tx_prefix: String,
pub version_lock_plugin: Option<Box<dyn VersionLockPlugin>>,
}
impl Default for RbatisOption {
fn default() -> Self {
Self {
tx_lock_wait_timeout: Duration::from_secs(60),
tx_check_interval: Duration::from_secs(1),
generate: vec![],
page_plugin: Box::new(RbatisPagePlugin::new()),
sql_intercepts: vec![],
logic_plugin: None,
log_plugin: Arc::new(Box::new(RbatisLog::default()) as Box<dyn LogPlugin>),
tx_prefix: "tx:".to_string(),
version_lock_plugin: None,
}
}
}
impl Rbatis {
pub fn new() -> Self {
return Self::new_with_opt(RbatisOption::default());
}
pub fn new_with_opt(option: RbatisOption) -> Self {
return Self {
pool: OnceCell::new(),
runtime_expr: RExprRuntime::new(),
tx_manager: TxManager::new_arc(
&option.tx_prefix,
option.log_plugin.clone(),
option.tx_lock_wait_timeout,
option.tx_check_interval,
),
page_plugin: option.page_plugin,
sql_intercepts: option.sql_intercepts,
logic_plugin: option.logic_plugin,
log_plugin: option.log_plugin,
runtime_py: PyRuntime {
cache: Default::default(),
generate: option.generate,
},
version_lock_plugin: None,
};
}
pub fn new_wrapper(&self) -> Wrapper {
let driver = self.driver_type();
if driver.as_ref().unwrap().eq(&DriverType::None) {
panic!("[rbatis] .new_wrapper() method must be call .link(url) to init first!");
}
Wrapper::new(&driver.unwrap_or_else(|_| {
panic!("[rbatis] .new_wrapper() method must be call .link(url) to init first!");
}))
}
pub fn new_wrapper_table<T>(&self) -> Wrapper
where
T: CRUDTable,
{
let mut w = self.new_wrapper();
w = w.set_formats(T::formats(&self.driver_type().unwrap()));
return w;
}
pub async fn link(&self, driver_url: &str) -> Result<(), Error> {
return Ok(self.link_opt(driver_url, &DBPoolOptions::default()).await?);
}
pub async fn link_opt(
&self,
driver_url: &str,
pool_options: &DBPoolOptions,
) -> Result<(), Error> {
if driver_url.is_empty() {
return Err(Error::from("[rbatis] link url is empty!"));
}
if self.pool.get().is_none() {
let pool = DBPool::new_opt_str(driver_url, pool_options).await?;
self.pool.get_or_init(|| {
return pool;
});
}
return Ok(());
}
pub async fn link_cfg(
&self,
connect_option: &DBConnectOption,
pool_options: &DBPoolOptions,
) -> Result<(), Error> {
if self.pool.get().is_none() {
let pool = DBPool::new_opt(connect_option, pool_options).await?;
self.pool.get_or_init(|| {
return pool;
});
}
return Ok(());
}
pub fn set_log_plugin<T>(&mut self, arg: T)
where
T: LogPlugin + 'static,
{
self.log_plugin = Arc::new(Box::new(arg));
}
pub fn set_logic_plugin<T>(&mut self, arg: Option<T>)
where
T: LogicDelete + 'static,
{
match arg {
Some(v) => {
self.logic_plugin = Some(Box::new(v));
}
None => {
self.logic_plugin = None;
}
}
}
pub fn set_page_plugin<T>(&mut self, arg: T)
where
T: PagePlugin + 'static,
{
self.page_plugin = Box::new(arg);
}
pub fn add_sql_intercept<T>(&mut self, arg: T)
where
T: SqlIntercept + 'static,
{
self.sql_intercepts.push(Box::new(arg));
}
pub fn set_sql_intercepts<T>(&mut self, arg: Vec<Box<dyn SqlIntercept>>) {
self.sql_intercepts = arg;
}
pub fn get_pool(&self) -> Result<&DBPool, Error> {
let p = self.pool.get();
if p.is_none() {
return Err(Error::from("[rbatis] rbatis pool not inited!"));
}
return Ok(p.unwrap());
}
pub fn driver_type(&self) -> Result<DriverType, Error> {
let pool = self.get_pool()?;
Ok(pool.driver_type)
}
pub async fn begin_tx_defer(&self, when_drop_commit: bool) -> Result<TxGuard, Error> {
let tx_id = self.begin_tx().await?;
let guard = TxGuard::new(&tx_id, when_drop_commit, self.tx_manager.clone());
return Ok(guard);
}
pub async fn begin_tx(&self) -> Result<String, Error> {
let new_context_id = format!(
"{}{}",
&self.tx_manager.tx_prefix,
Uuid::new_v4().to_string()
);
return Ok(self.begin(&new_context_id).await?);
}
pub async fn begin_defer(
&self,
context_id: &str,
when_drop_commit: bool,
) -> Result<TxGuard, Error> {
let tx_id = self.begin(context_id).await?;
let guard = TxGuard::new(&tx_id, when_drop_commit, self.tx_manager.clone());
return Ok(guard);
}
pub async fn begin(&self, context_id: &str) -> Result<String, Error> {
if context_id.is_empty() {
return Err(Error::from("[rbatis] context_id can not be empty"));
}
if !self.tx_manager.is_tx_prifix_id(context_id) {
return Err(Error::from(format!(
"[rbatis] context_id: {} must be start with '{}', for example: {}{}",
&self.tx_manager.tx_prefix, &self.tx_manager.tx_prefix, context_id, context_id
)));
}
let result = self.tx_manager.begin(context_id, self.get_pool()?).await?;
return Ok(result);
}
pub async fn commit(&self, context_id: &str) -> Result<String, Error> {
if context_id.is_empty() {
return Err(Error::from("[rbatis] context_id can not be empty"));
}
if !self.tx_manager.is_tx_prifix_id(context_id) {
return Err(Error::from(format!(
"[rbatis] context_id: {} must be start with '{}', for example: {}{}",
&self.tx_manager.tx_prefix, &self.tx_manager.tx_prefix, context_id, context_id
)));
}
let result = self.tx_manager.commit(context_id).await?;
return Ok(result);
}
pub async fn rollback(&self, context_id: &str) -> Result<String, Error> {
if context_id.is_empty() {
return Err(Error::from("[rbatis] context_id can not be empty"));
}
if !self.tx_manager.is_tx_prifix_id(context_id) {
return Err(Error::from(format!(
"[rbatis] context_id: {} must be start with '{}', for example: {}{}",
&self.tx_manager.tx_prefix, &self.tx_manager.tx_prefix, context_id, context_id
)));
}
let result = self.tx_manager.rollback(context_id).await?;
return Ok(result);
}
pub async fn fetch<T>(&self, context_id: &str, sql: &str) -> Result<T, Error>
where
T: DeserializeOwned,
{
let mut sql = sql.to_string();
for item in &self.sql_intercepts {
item.do_intercept(self, &mut sql, &mut vec![], false);
}
if self.log_plugin.is_enable() {
self.log_plugin.do_log(&format!(
"[rbatis] [{}] Query ==> {}",
context_id,
sql.as_str()
));
}
let result;
let mut fetch_num = 0;
if self.tx_manager.is_tx_prifix_id(context_id) {
let conn = self.tx_manager.get_mut(context_id).await;
if conn.is_none() {
return Err(Error::from(format!(
"[rbatis] transaction:{} not exist!",
context_id
)));
}
let mut conn = conn.unwrap();
let (data, num) = conn.value_mut().0.fetch(sql.as_str()).await?;
result = data;
fetch_num = num;
} else {
let mut conn = self.get_pool()?.acquire().await?;
let (data, num) = conn.fetch(sql.as_str()).await?;
result = data;
fetch_num = num;
}
if self.log_plugin.is_enable() {
self.log_plugin.do_log(&format!(
"[rbatis] [{}] ReturnRows <== {}",
context_id, fetch_num
));
}
return Ok(result);
}
pub async fn exec(&self, context_id: &str, sql: &str) -> Result<DBExecResult, Error> {
let mut sql = sql.to_string();
for item in &self.sql_intercepts {
item.do_intercept(self, &mut sql, &mut vec![], false);
}
if self.log_plugin.is_enable() {
self.log_plugin
.do_log(&format!("[rbatis] [{}] Exec ==> {}", context_id, &sql));
}
let result;
if self.tx_manager.is_tx_prifix_id(context_id) {
let conn = self.tx_manager.get_mut(context_id).await;
if conn.is_none() {
return Err(Error::from(format!(
"[rbatis] transaction:{} not exist!",
context_id
)));
}
let mut conn = conn.unwrap();
result = conn.value_mut().0.execute(&sql).await;
} else {
let mut conn = self.get_pool()?.acquire().await?;
result = conn.execute(&sql).await;
}
if self.log_plugin.is_enable() {
if result.is_ok() {
self.log_plugin.do_log(&format!(
"[rbatis] [{}] RowsAffected <== {}",
context_id,
result.as_ref().unwrap().rows_affected
));
} else {
self.log_plugin
.do_log(&format!("[rbatis] [{}] RowsAffected <== {}", context_id, 0));
}
}
return result;
}
fn bind_arg<'arg>(
&self,
sql: &'arg str,
arg: &Vec<serde_json::Value>,
) -> Result<DBQuery<'arg>, Error> {
let mut q: DBQuery = self.get_pool()?.make_query(sql)?;
for x in arg {
q.bind_value(x);
}
return Ok(q);
}
pub async fn fetch_prepare<T>(
&self,
context_id: &str,
sql: &str,
args: &Vec<serde_json::Value>,
) -> Result<T, Error>
where
T: DeserializeOwned,
{
let mut sql = sql.to_string();
let mut args = args.clone();
for item in &self.sql_intercepts {
item.do_intercept(self, &mut sql, &mut args, true);
}
if self.log_plugin.is_enable() {
self.log_plugin.do_log(&format!(
"[rbatis] [{}] Query ==> {}\n{}[rbatis] [{}] Args ==> {}",
context_id,
&sql,
string_util::LOG_SPACE,
context_id,
serde_json::Value::Array(args.clone()).to_string()
));
}
let result_data;
let mut return_num = 0;
if self.tx_manager.is_tx_prifix_id(context_id) {
let q: DBQuery = self.bind_arg(&sql, &args)?;
let conn = self.tx_manager.get_mut(context_id).await;
if conn.is_none() {
return Err(Error::from(format!(
"[rbatis] transaction:{} not exist!",
context_id
)));
}
let mut conn = conn.unwrap();
let (result, num) = conn.value_mut().0.fetch_parperd(q).await?;
result_data = result;
return_num = num;
} else {
let mut conn = self.get_pool()?.acquire().await?;
let q: DBQuery = self.bind_arg(&sql, &args)?;
let (result, num) = conn.fetch_parperd(q).await?;
result_data = result;
return_num = num;
}
if self.log_plugin.is_enable() {
self.log_plugin.do_log(&format!(
"[rbatis] [{}] ReturnRows <== {}",
context_id, return_num
));
}
return Ok(result_data);
}
pub async fn exec_prepare(
&self,
context_id: &str,
sql: &str,
args: &Vec<serde_json::Value>,
) -> Result<DBExecResult, Error> {
let mut sql = sql.to_string();
let mut args = args.clone();
for item in &self.sql_intercepts {
item.do_intercept(self, &mut sql, &mut args, true);
}
if self.log_plugin.is_enable() {
self.log_plugin.do_log(&format!(
"[rbatis] [{}] Exec ==> {}\n{}[rbatis] [{}] Args ==> {}",
context_id,
&sql,
string_util::LOG_SPACE,
context_id,
serde_json::Value::Array(args.clone()).to_string()
));
}
let result;
if self.tx_manager.is_tx_prifix_id(context_id) {
let q: DBQuery = self.bind_arg(&sql, &args)?;
let conn = self.tx_manager.get_mut(context_id).await;
if conn.is_none() {
return Err(Error::from(format!(
"[rbatis] transaction:{} not exist!",
context_id
)));
}
let mut conn = conn.unwrap();
result = conn.value_mut().0.exec_prepare(q).await;
} else {
let q: DBQuery = self.bind_arg(&sql, &args)?;
let mut conn = self.get_pool()?.acquire().await?;
result = conn.exec_prepare(q).await;
}
if self.log_plugin.is_enable() {
if result.is_ok() {
self.log_plugin.do_log(&format!(
"[rbatis] [{}] RowsAffected <== {}",
context_id,
result.as_ref().unwrap().rows_affected
));
} else {
self.log_plugin
.do_log(&format!("[rbatis] [{}] RowsAffected <== {}", context_id, 0));
}
}
return result;
}
fn py_to_sql<Arg>(&self, py_sql: &str, arg: &Arg) -> Result<(String, Vec<serde_json::Value>), Error>
where
Arg: Serialize + Send + Sync,
{
let mut arg = json!(arg);
match self
.runtime_py
.eval(&self.driver_type()?, py_sql, &mut arg, &self.runtime_expr)
{
Ok(v) => Ok(v),
Err(e) => Err(Error::from(e)),
}
}
pub async fn py_fetch<T, Arg>(&self, context_id: &str, py_sql: &str, arg: &Arg) -> Result<T, Error>
where
T: DeserializeOwned,
Arg: Serialize + Send + Sync,
{
let (sql, args) = self.py_to_sql(py_sql, arg)?;
return self.fetch_prepare(context_id, sql.as_str(), &args).await;
}
pub async fn py_exec<Arg>(
&self,
context_id: &str,
py_sql: &str,
arg: &Arg,
) -> Result<DBExecResult, Error>
where
Arg: Serialize + Send + Sync,
{
let (sql, args) = self.py_to_sql(py_sql, arg)?;
return self.exec_prepare(context_id, sql.as_str(), &args).await;
}
pub async fn fetch_page<T>(
&self,
context_id: &str,
sql: &str,
args: &Vec<serde_json::Value>,
page_request: &dyn IPageRequest,
) -> Result<Page<T>, Error>
where
T: DeserializeOwned + Serialize + Send + Sync,
{
let sql = self.driver_type()?.upper_case_sql(sql);
let mut page_result = Page::new(page_request.get_page_no(), page_request.get_page_size());
let (count_sql, sql) = self.page_plugin.make_page_sql(
&self.driver_type()?,
context_id,
&sql,
args,
page_request,
)?;
if page_request.is_search_count() {
let total: Option<u64> = self
.fetch_prepare(context_id, count_sql.as_str(), args)
.await?;
page_result.set_total(total.unwrap_or(0));
page_result.pages = page_result.get_pages();
if page_result.get_total() == 0 {
return Ok(page_result);
}
}
let data: Option<Vec<T>> = self.fetch_prepare(context_id, sql.as_str(), args).await?;
page_result.set_records(data.unwrap_or(vec![]));
page_result.pages = page_result.get_pages();
return Ok(page_result);
}
pub async fn py_fetch_page<T, Arg>(
&self,
context_id: &str,
py_sql: &str,
arg: &Arg,
page_request: &dyn IPageRequest,
) -> Result<Page<T>, Error>
where
T: DeserializeOwned + Serialize + Send + Sync,
Arg: Serialize + Send + Sync,
{
let (sql, args) = self.py_to_sql(py_sql, arg)?;
return self
.fetch_page::<T>(context_id, sql.as_str(), &args, page_request)
.await;
}
}