use std::collections::HashMap;
use std::ops::Add;
use serde::{Deserialize, Serialize};
use serde_json::{json, Map, Value};
use crate::core::convert::StmtConvert;
use crate::core::db::DriverType;
use crate::core::Error;
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Wrapper {
pub driver_type: DriverType,
pub sql: String,
pub args: Vec<serde_json::Value>,
pub formats: HashMap<String, String>,
pub error: Option<Error>,
pub checked: bool,
}
impl Wrapper {
pub fn new(driver_type: &DriverType) -> Self {
Self {
driver_type: driver_type.clone(),
sql: "".to_string(),
args: vec![],
formats: Default::default(),
error: None,
checked: false,
}
}
pub fn from(driver_type: &DriverType, sql: &str, args: &Vec<serde_json::Value>) -> Self {
Self {
driver_type: driver_type.clone(),
sql: sql.to_string(),
args: args.clone(),
formats: HashMap::new(),
error: None,
checked: false,
}
}
pub fn check(&mut self) -> Result<Wrapper, Error> {
if self.error.is_some() {
return Err(self.error.take().unwrap());
}
self.trim_and();
self.trim_or();
self.checked = true;
return Ok(self.clone());
}
pub fn set_formats(&mut self, formats: HashMap<String, String>) -> &mut Self {
self.formats = formats;
self
}
pub fn push_wrapper(&mut self, arg: &Wrapper) -> &mut Self {
self.push(&arg.sql, &arg.args)
}
pub fn push<T>(&mut self, sql: &str, args: &[T]) -> &mut Self
where T: Serialize {
let mut new_sql = sql.to_string();
if self.driver_type.is_number_type() {
let self_arg_len = self.args.len();
for index in 0..args.len() {
let str = self.driver_type.stmt_convert(index);
new_sql = new_sql.replace(str.as_str(), self.driver_type.stmt_convert(index + args.len()).as_str());
}
for index in args.len()..self_arg_len {
let str = self.driver_type.stmt_convert(index);
new_sql = new_sql.replace(str.as_str(), self.driver_type.stmt_convert(index + args.len()).as_str());
}
}
self.sql.push_str(new_sql.as_str());
let args = json!(args);
if args.is_null() {
return self;
}
let args = args.as_array().unwrap();
for x in args {
self.args.push(x.to_owned());
}
self
}
pub fn do_if<'s, F>(&'s mut self, test: bool, method: F) -> &'s mut Self
where F: FnOnce(&'s mut Self) -> &'s mut Self {
if test {
return method(self);
}
return self;
}
pub fn do_match<'s, F>(&'s mut self, cases: &[(bool, fn(&mut Wrapper) -> &mut Wrapper)], default: F) -> &'s mut Self
where F: FnOnce(&'s mut Self) -> &'s mut Self {
for (test, case) in cases {
if *test {
return case(self);
}
}
return default(self);
}
pub fn set_sql(&mut self, sql: &str) -> &mut Self {
self.sql = sql
.replace(" and ", " AND ")
.replace(" or ", " OR ")
.replace(" where ", " WHERE ");
self
}
pub fn push_sql(&mut self, sql: &str) -> &mut Self {
let s = sql
.replace(" and ", " AND ")
.replace(" or ", " OR ")
.replace(" where ", " WHERE ");
self.sql.push_str(s.as_str());
self
}
pub fn set_args<T>(&mut self, args: &[T]) -> &mut Self where T: Serialize {
let v = json!(args);
if v.is_null() {
return self;
}
if v.is_array() {
self.args = v.as_array().unwrap_or(&vec![]).to_owned();
}
self
}
pub fn push_arg<T>(&mut self, arg: T) -> &mut Self where T: Serialize {
let v = json!(arg);
self.args.push(v);
self
}
pub fn pop_arg(&mut self) -> &mut Self {
self.args.pop();
self
}
fn not_allow_and_or(&self) -> bool {
let sql = self.sql.trim_end();
sql.ends_with(" WHERE")
|| sql.ends_with(" AND")
|| sql.ends_with(" OR")
|| sql.ends_with("(")
|| sql.ends_with(",")
|| sql.ends_with("=")
|| sql.ends_with("+")
|| sql.ends_with("-")
|| sql.ends_with("*")
|| sql.ends_with("/")
|| sql.ends_with("%")
|| sql.ends_with("^")
|| sql.ends_with(">")
|| sql.ends_with("<")
|| sql.ends_with("&")
|| sql.ends_with("|")
}
pub fn and(&mut self) -> &mut Self {
if !self.not_allow_and_or() {
self.sql.push_str(" AND ");
}
self
}
pub fn or(&mut self) -> &mut Self {
if !self.not_allow_and_or() {
self.sql.push_str(" OR ");
}
self
}
pub fn having(&mut self, sql_having: &str) -> &mut Self {
self.and();
self.sql.push_str(format!(" HAVING {} ", sql_having).as_str());
self
}
pub fn all_eq<T>(&mut self, arg: T) -> &mut Self
where T: Serialize {
self.and();
let v = json!(arg);
if v.is_null() {
self.error = Some(Error::from("[rbatis] wrapper all_eq only support object/map struct!"));
return self;
}
if !v.is_object() {
self.error = Some(Error::from("[rbatis] wrapper all_eq only support object/map struct!"));
return self;
}
let map = v.as_object().unwrap();
if map.len() == 0 {
return self;
}
let len = map.len();
let mut index = 0;
for (k, v) in map {
self.eq(k.as_str(), v);
if (index + 1) != len {
self.sql.push_str(" , ");
index += 1;
}
}
self
}
fn do_format_column(&self, column: &str, data: String) -> String {
let source = self.formats.get(column);
match source {
Some(s) => {
return s.replace("{}", &data);
}
_ => {
return data.to_string();
}
}
}
pub fn eq<T>(&mut self, column: &str, obj: T) -> &mut Self
where T: Serialize {
self.and();
self.sql.push_str(&format!("{} = {}", column, self.do_format_column(column, self.driver_type.stmt_convert(self.args.len()))));
self.args.push(json!(obj));
self
}
pub fn ne<T>(&mut self, column: &str, obj: T) -> &mut Self
where T: Serialize {
self.and();
self.sql.push_str(&format!("{} <> {}", column, self.do_format_column(column, self.driver_type.stmt_convert(self.args.len()))));
self.args.push(json!(obj));
self
}
pub fn order_by(&mut self, is_asc: bool, columns: &[&str]) -> &mut Self {
let len = columns.len();
if len == 0 {
return self;
}
let mut index = 0;
self.sql = self.sql.trim().trim_end_matches(" WHERE")
.trim_end_matches(" AND")
.trim_end_matches(" OR").to_string();
self.sql.push_str(" ORDER BY ");
for x in columns {
if is_asc {
self.sql.push_str(format!("{} ASC", x).as_str());
} else {
self.sql.push_str(format!("{} DESC", x, ).as_str());
}
if (index + 1) != len {
self.sql.push_str(" , ");
index += 1;
}
}
self
}
pub fn group_by(&mut self, columns: &[&str]) -> &mut Self {
let len = columns.len();
if len == 0 {
return self;
}
let mut index = 0;
self.sql = self.sql.trim()
.trim_end_matches(" WHERE")
.trim_end_matches(" AND")
.trim_end_matches(" OR").to_string();
self.sql.push_str(" GROUP BY ");
for x in columns {
self.sql.push_str(x);
if (index + 1) != len {
self.sql.push_str(" , ");
index += 1;
}
}
self
}
pub fn gt<T>(&mut self, column: &str, obj: T) -> &mut Self
where T: Serialize {
self.and();
self.sql.push_str(&format!("{} > {}", column, self.do_format_column(column, self.driver_type.stmt_convert(self.args.len()))));
self.args.push(json!(obj));
self
}
pub fn ge<T>(&mut self, column: &str, obj: T) -> &mut Self
where T: Serialize {
self.and();
self.sql.push_str(&format!("{} >= {}", column, self.do_format_column(column, self.driver_type.stmt_convert(self.args.len()))));
self.args.push(json!(obj));
self
}
pub fn lt<T>(&mut self, column: &str, obj: T) -> &mut Self
where T: Serialize {
self.and();
self.sql.push_str(&format!("{} < {}", column, self.do_format_column(column, self.driver_type.stmt_convert(self.args.len()))));
self.args.push(json!(obj));
self
}
pub fn le<T>(&mut self, column: &str, obj: T) -> &mut Self
where T: Serialize {
self.and();
self.sql.push_str(&format!("{} <= {}", column, self.do_format_column(column, self.driver_type.stmt_convert(self.args.len()))));
self.args.push(json!(obj));
self
}
pub fn between<T>(&mut self, column: &str, min: T, max: T) -> &mut Self
where T: Serialize {
self.and();
self.sql.push_str(&format!("{} BETWEEN {} AND {}", column, self.do_format_column(column, self.driver_type.stmt_convert(self.args.len())), self.do_format_column(column, self.driver_type.stmt_convert(self.args.len() + 1))));
self.args.push(json!(min));
self.args.push(json!(max));
self
}
pub fn not_between<T>(&mut self, column: &str, min: T, max: T) -> &mut Self
where T: Serialize {
self.and();
self.sql.push_str(&format!("{} NOT BETWEEN {} AND {}", column, self.do_format_column(column, self.driver_type.stmt_convert(self.args.len())), self.do_format_column(column, self.driver_type.stmt_convert(self.args.len() + 1))));
self.args.push(json!(min));
self.args.push(json!(max));
self
}
pub fn like<T>(&mut self, column: &str, obj: T) -> &mut Self
where T: Serialize {
self.and();
let v = json!(obj);
let mut v_str = String::new();
if v.is_string() {
v_str = format!("%{}%", v.as_str().unwrap());
} else {
v_str = format!("%{}%", v.to_string());
}
self.sql.push_str(&format!("{} LIKE {}", column, self.do_format_column(column, self.driver_type.stmt_convert(self.args.len()))));
self.args.push(json!(v_str));
self
}
pub fn like_left<T>(&mut self, column: &str, obj: T) -> &mut Self
where T: Serialize {
self.and();
let v = json!(obj);
let mut v_str = String::new();
if v.is_string() {
v_str = format!("%{}", v.as_str().unwrap());
} else {
v_str = format!("%{}", v.to_string());
}
self.sql.push_str(&format!("{} LIKE {}", column, self.do_format_column(column, self.driver_type.stmt_convert(self.args.len()))));
self.args.push(json!(v_str));
self
}
pub fn like_right<T>(&mut self, column: &str, obj: T) -> &mut Self
where T: Serialize {
self.and();
let v = json!(obj);
let mut v_str = String::new();
if v.is_string() {
v_str = format!("{}%", v.as_str().unwrap());
} else {
v_str = format!("{}%", v.to_string());
}
self.sql.push_str(&format!("{} LIKE {}", column, self.do_format_column(column, self.driver_type.stmt_convert(self.args.len()))));
self.args.push(json!(v_str));
self
}
pub fn not_like<T>(&mut self, column: &str, obj: T) -> &mut Self
where T: Serialize {
self.and();
let v = json!(obj);
let mut v_str = String::new();
if v.is_string() {
v_str = format!("%{}%", v.as_str().unwrap());
} else {
v_str = format!("%{}%", v.to_string());
}
self.sql.push_str(&format!("{} NOT LIKE {}", column, self.do_format_column(column, self.driver_type.stmt_convert(self.args.len()))));
self.args.push(json!(v_str));
self
}
pub fn is_null(&mut self, column: &str) -> &mut Self {
self.and();
self.sql.push_str(column);
self.sql.push_str(" IS NULL");
self
}
pub fn is_not_null(&mut self, column: &str) -> &mut Self {
self.and();
self.sql.push_str(column);
self.sql.push_str(" IS NOT NULL");
self
}
pub fn in_array<T>(&mut self, column: &str, obj: &[T]) -> &mut Self
where T: Serialize {
self.and();
if obj.len() == 0 {
return self;
}
let arr = json!(obj);
let vec = arr.as_array().unwrap();
let mut sqls = String::new();
for x in vec {
sqls.push_str(&format!(" {} ", self.do_format_column(column, self.driver_type.stmt_convert(self.args.len()))));
sqls.push_str(",");
self.args.push(x.clone());
}
sqls.pop();
self.sql.push_str(format!("{} IN ({})", column, sqls).as_str());
self
}
pub fn in_<T>(&mut self, column: &str, obj: &[T]) -> &mut Self
where T: Serialize {
self.in_array(column, obj)
}
pub fn r#in<T>(&mut self, column: &str, obj: &[T]) -> &mut Self
where T: Serialize {
self.in_array(column, obj)
}
pub fn not_in<T>(&mut self, column: &str, obj: &[T]) -> &mut Self
where T: Serialize {
self.and();
let arr = json!(obj);
let vec = arr.as_array().unwrap();
let mut sqls = String::new();
for x in vec {
sqls.push_str(&format!(" {} ", self.do_format_column(column, self.driver_type.stmt_convert(self.args.len()))));
sqls.push_str(",");
self.args.push(x.clone());
}
sqls.pop();
self.sql.push_str(format!("{} NOT IN ({})", column, sqls).as_str());
self
}
pub fn trim_and(&mut self) -> &mut Self {
self.sql = self.sql.trim()
.trim_start_matches("AND ")
.trim_end_matches(" AND")
.to_string();
self
}
pub fn trim_or(&mut self) -> &mut Self {
self.sql = self.sql
.trim_start_matches("OR ")
.trim_end_matches(" OR")
.to_string();
self
}
}
#[cfg(test)]
mod test {
use serde_json::json;
use serde_json::Map;
use crate::core::db::DriverType;
use crate::utils::bencher::QPS;
use crate::wrapper::Wrapper;
#[test]
fn test_trim() {
let mut w = Wrapper::new(&DriverType::Mysql);
w.push_sql("WHERE ");
w.order_by(true, &["id"]);
println!("sql:{:?}", w.sql.as_str());
println!("arg:{:?}", w.args.clone());
assert_eq!("ORDER BY id ASC", w.sql.as_str().trim());
}
#[test]
fn test_select() {
let mut m = Map::new();
m.insert("a".to_string(), json!("1"));
let w = Wrapper::new(&DriverType::Mysql).eq("id", 1)
.ne("id", 1)
.in_array("id", &[1, 2, 3])
.not_in("id", &[1, 2, 3])
.all_eq(&m)
.like("name", 1)
.or()
.not_like("name", "asdf")
.between("create_time", "2020-01-01 00:00:00", "2020-12-12 00:00:00")
.group_by(&["id"])
.order_by(true, &["id", "name"])
.check().unwrap();
println!("sql:{:?}", w.sql.as_str());
println!("arg:{:?}", w.args.clone());
let ms: Vec<&str> = w.sql.matches("?").collect();
assert_eq!(ms.len(), w.args.len());
}
#[test]
fn bench_select() {
let mut map = Map::new();
map.insert("a".to_string(), json!("1"));
let total = 100000;
let now = std::time::Instant::now();
for _ in 0..total {
let w = Wrapper::new(&DriverType::Mysql).eq("id", 1)
.ne("id", 1)
.in_array("id", &[1, 2, 3])
.r#in("id", &[1, 2, 3])
.in_("id", &[1, 2, 3])
.not_in("id", &[1, 2, 3])
.all_eq(&map)
.like("name", 1)
.or()
.not_like("name", "asdf")
.between("create_time", "2020-01-01 00:00:00", "2020-12-12 00:00:00")
.group_by(&["id"])
.order_by(true, &["id", "name"])
.check().unwrap();
}
now.time(total);
now.qps(total);
}
#[test]
fn test_link() {
let w = Wrapper::new(&DriverType::Postgres).eq("a", "1").check().unwrap();
let w2 = Wrapper::new(&DriverType::Postgres).eq("b", "2")
.and()
.push_wrapper(&w)
.check().unwrap();
println!("sql:{:?}", w2.sql.as_str());
println!("arg:{:?}", w2.args.clone());
let ms: Vec<&str> = w.sql.matches("$").collect();
assert_eq!(ms.len(), w.args.len());
}
#[test]
fn test_do_if() {
let p = Option::<i32>::Some(1);
let w = Wrapper::new(&DriverType::Postgres)
.do_if(p.is_some(), |w| w.eq("a", p.clone()))
.check().unwrap();
println!("sql:{:?}", w.sql.as_str());
println!("arg:{:?}", w.args.clone());
assert_eq!(&w.sql, "a = $1");
assert_eq!(&w.args[0], &json!(p));
}
#[test]
fn test_do_match() {
let p = 1;
let w = Wrapper::new(&DriverType::Postgres)
.do_match(&[
(p == 0, |w| w.eq("0", "some")),
(p == 1, |w| w.eq("1", "some")),
], |w| w.eq("default", "default"))
.check().unwrap();
assert_eq!(&w.sql, "1 = $1");
}
#[test]
fn test_wp() {
let w = Wrapper::new(&DriverType::Postgres)
.eq("1", "1")
.or()
.like("TITLE", "title")
.or()
.like("ORIGINAL_NAME", "saf")
.check().unwrap();
println!("sql:{:?}", w.sql.as_str());
println!("arg:{:?}", w.args.clone());
}
#[test]
fn test_push_arg() {
let w = Wrapper::new(&DriverType::Mysql)
.push_sql("?,?")
.push_arg(1)
.push_arg("asdfasdfa")
.check().unwrap();
println!("sql:{:?}", w.sql.as_str());
println!("arg:{:?}", w.args.clone());
}
#[test]
fn test_push_wrapper() {
let mut w1 = Wrapper::new(&DriverType::Postgres);
let mut w2 = w1.clone();
let w2 = w1
.eq("b", "b")
.eq("b1", "b1")
.eq("b2", "b2")
.and()
.push_wrapper(&w2.push_sql("(").eq("a", "a").push_sql(")").check().unwrap())
.check().unwrap();
println!("sql:{:?}", w2.sql.as_str());
println!("arg:{:?}", w2.args.clone());
assert_eq!(w2.sql.contains("b = $1"), true);
assert_eq!(w2.sql.contains("a = $4"), true);
}
#[test]
fn test_bench_is_end_opt() {
let w = Wrapper::new(&DriverType::Postgres);
let total = 100000;
let now = std::time::Instant::now();
for _ in 0..total {
w.not_allow_and_or();
}
now.time(total);
}
}