use crate::entity::{DeleteParam, Id, ListQuery, PageQuery, QueryOption};
use crate::error::{CoolError, CoolResult, PageResult};
use crate::event::{events, global_event_manager, SoftDeleteEvent};
use async_trait::async_trait;
use sea_orm::{ConnectionTrait, DatabaseConnection, Statement};
use serde_json::Value;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModifyType {
Add,
Update,
Delete,
}
#[async_trait]
pub trait BaseService: Send + Sync {
fn db(&self) -> &DatabaseConnection;
fn table_name(&self) -> &str;
async fn add(&self, data: Value) -> CoolResult<Value> {
let data = self.modify_before(data, ModifyType::Add).await?;
let columns: Vec<String> = data
.as_object()
.map(|obj| obj.keys().cloned().collect())
.unwrap_or_default();
if columns.is_empty() {
return Err(CoolError::validate("数据不能为空"));
}
let placeholders: Vec<String> = columns.iter().map(|_| "?".to_string()).collect();
let sql = format!(
"INSERT INTO {} ({}) VALUES ({})",
self.table_name(),
columns.join(", "),
placeholders.join(", ")
);
let values: Vec<sea_orm::Value> = columns
.iter()
.filter_map(|col| data.get(col))
.map(json_to_sea_value)
.collect();
let stmt = Statement::from_sql_and_values(self.db().get_database_backend(), &sql, values);
let result = self.db().execute(stmt).await?;
let id = result.last_insert_id();
self.modify_after(data.clone(), ModifyType::Add).await?;
Ok(serde_json::json!({ "id": id }))
}
async fn delete(&self, param: DeleteParam) -> CoolResult<()> {
let data = serde_json::to_value(¶m)?;
self.modify_before(data.clone(), ModifyType::Delete).await?;
let ids_str = param
.ids
.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join(",");
let sql = format!(
"DELETE FROM {} WHERE id IN ({})",
self.table_name(),
ids_str
);
let stmt = Statement::from_string(self.db().get_database_backend(), sql);
self.db().execute(stmt).await?;
self.modify_after(data, ModifyType::Delete).await?;
Ok(())
}
async fn soft_delete(&self, ids: Vec<Id>) -> CoolResult<()> {
let now = chrono::Utc::now();
let ids_str = ids
.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join(",");
let sql = format!(
"UPDATE {} SET delete_time = '{}' WHERE id IN ({})",
self.table_name(),
now.format("%Y-%m-%d %H:%M:%S"),
ids_str
);
let stmt = Statement::from_string(self.db().get_database_backend(), sql);
self.db().execute(stmt).await?;
global_event_manager()
.emit(
events::SOFT_DELETE,
SoftDeleteEvent {
entity: self.table_name().to_string(),
ids: ids.clone(),
tenant_id: None,
},
)
.await;
Ok(())
}
async fn update(&self, data: Value) -> CoolResult<()> {
let id = data
.get("id")
.ok_or_else(CoolError::no_id)?
.as_i64()
.ok_or_else(CoolError::no_id)?;
let data = self.modify_before(data, ModifyType::Update).await?;
let updates: Vec<String> = data
.as_object()
.map(|obj| {
obj.iter()
.filter(|(k, _)| *k != "id")
.map(|(k, _)| format!("{} = ?", k))
.collect()
})
.unwrap_or_default();
if updates.is_empty() {
return Ok(());
}
let sql = format!(
"UPDATE {} SET {} WHERE id = ?",
self.table_name(),
updates.join(", ")
);
let mut values: Vec<sea_orm::Value> = data
.as_object()
.map(|obj| {
obj.iter()
.filter(|(k, _)| *k != "id")
.map(|(_, v)| json_to_sea_value(v))
.collect()
})
.unwrap_or_default();
values.push(sea_orm::Value::BigInt(Some(id)));
let stmt = Statement::from_sql_and_values(self.db().get_database_backend(), &sql, values);
self.db().execute(stmt).await?;
self.modify_after(data, ModifyType::Update).await?;
Ok(())
}
async fn info(&self, id: Id, _ignore_fields: Option<Vec<String>>) -> CoolResult<Option<Value>> {
let sql = format!("SELECT * FROM {} WHERE id = ? LIMIT 1", self.table_name());
let stmt = Statement::from_sql_and_values(
self.db().get_database_backend(),
&sql,
vec![sea_orm::Value::BigInt(Some(id))],
);
let result = self.db().query_one(stmt).await?;
Ok(result.map(|row| self.map_row(row)))
}
async fn page(
&self,
query: PageQuery,
mut option: QueryOption,
) -> CoolResult<PageResult<Value>> {
let offset = query.offset();
let size = query.size;
validate_query_safety(&query, &option)?;
let select_sql = if option.select.is_empty() {
"*".to_string()
} else {
for col in &option.select {
validate_identifier(col)?;
}
option.select.join(", ")
};
if !option.left_join.is_empty() {
option.joins.extend(option.left_join.clone());
}
let mut from_sql = self.table_name().to_string();
for join in &option.joins {
validate_identifier(&join.entity)?;
validate_identifier(&join.alias)?;
if join.condition.contains(|c| {
c == ';' || c == '#' || c == '\'' || c == '"' || c == '\n' || c == '\r'
}) {
return Err(CoolError::validate("关联条件包含非法字符"));
}
let join_kw = match join.join_type {
crate::entity::JoinType::Inner => "INNER JOIN",
crate::entity::JoinType::Left => "LEFT JOIN",
};
from_sql.push(' ');
from_sql.push_str(join_kw);
from_sql.push(' ');
from_sql.push_str(&join.entity);
from_sql.push_str(" AS ");
from_sql.push_str(&join.alias);
from_sql.push_str(" ON ");
from_sql.push_str(&join.condition);
}
let mut where_sql = String::new();
let mut params: Vec<sea_orm::Value> = Vec::new();
if let Some(ref kw) = query.key_word {
if !option.key_word_like_fields.is_empty() {
where_sql.push_str(" WHERE ");
for (idx, col) in option.key_word_like_fields.iter().enumerate() {
if idx > 0 {
where_sql.push_str(" OR ");
}
where_sql.push_str(&format!("{} LIKE ?", col));
params.push(sea_orm::Value::String(Some(Box::new(format!("%{}%", kw)))));
}
}
}
for cond in &option.where_and {
if cond.contains(';')
|| cond.contains("--")
|| cond.contains("/*")
|| cond.contains("*/")
{
return Err(CoolError::validate("where_and 包含非法字符"));
}
if where_sql.is_empty() {
where_sql.push_str(" WHERE ");
} else {
where_sql.push_str(" AND ");
}
where_sql.push_str(cond);
}
for frag in &option.extra_where {
if frag.sql.contains(';')
|| frag.sql.contains("--")
|| frag.sql.contains("/*")
|| frag.sql.contains("*/")
{
return Err(CoolError::validate("extra_where 包含非法字符"));
}
if where_sql.is_empty() {
where_sql.push_str(" WHERE ");
} else {
where_sql.push_str(" AND ");
}
where_sql.push_str(&format!("({})", frag.sql));
for p in &frag.params {
params.push(json_to_sea_value(p));
}
}
let mut order_sql = String::new();
if let Some(ref order_field) = query.order {
validate_identifier(order_field)?;
let asc = query.is_asc();
order_sql.push_str(" ORDER BY ");
order_sql.push_str(order_field);
order_sql.push_str(if asc { " ASC" } else { " DESC" });
} else if !option.order_by.is_empty() {
order_sql.push_str(" ORDER BY ");
let mut first = true;
for (col, asc) in &option.order_by {
validate_identifier(col)?;
if !first {
order_sql.push_str(", ");
}
first = false;
order_sql.push_str(col);
order_sql.push_str(if *asc { " ASC" } else { " DESC" });
}
}
let count_sql = format!("SELECT COUNT(*) as count FROM {}{}", from_sql, where_sql);
let count_stmt = Statement::from_sql_and_values(
self.db().get_database_backend(),
count_sql,
params.clone(),
);
let count_result = self.db().query_one(count_stmt).await?;
let total: u64 = count_result
.as_ref()
.and_then(|r| r.try_get_by_index::<i64>(0).ok())
.map(|v| v as u64)
.unwrap_or(0);
let sql = format!(
"SELECT {} FROM {}{}{} LIMIT ? OFFSET ?",
select_sql, from_sql, where_sql, order_sql
);
let mut data_params = params;
data_params.push(sea_orm::Value::BigUnsigned(Some(size)));
data_params.push(sea_orm::Value::BigUnsigned(Some(offset)));
let stmt =
Statement::from_sql_and_values(self.db().get_database_backend(), sql, data_params);
let results = self.db().query_all(stmt).await?;
let list: Vec<Value> = results.into_iter().map(|row| self.map_row(row)).collect();
Ok(PageResult::new(list, query.page, size, total))
}
async fn page_with_filters(
&self,
query: PageQuery,
filters: &Value,
mut option: QueryOption,
) -> CoolResult<PageResult<Value>> {
let offset = query.offset();
let size = query.size;
validate_query_safety(&query, &option)?;
if !option.left_join.is_empty() {
option.joins.extend(option.left_join.clone());
}
let select_sql = if option.select.is_empty() {
"*".to_string()
} else {
for col in &option.select {
validate_identifier(col)?;
}
option.select.join(", ")
};
let mut from_sql = self.table_name().to_string();
for join in &option.joins {
validate_identifier(&join.entity)?;
validate_identifier(&join.alias)?;
if join.condition.contains(|c| {
c == ';' || c == '#' || c == '\'' || c == '"' || c == '\n' || c == '\r'
}) {
return Err(CoolError::validate("关联条件包含非法字符"));
}
let join_kw = match join.join_type {
crate::entity::JoinType::Inner => "INNER JOIN",
crate::entity::JoinType::Left => "LEFT JOIN",
};
from_sql.push(' ');
from_sql.push_str(join_kw);
from_sql.push(' ');
from_sql.push_str(&join.entity);
from_sql.push_str(" AS ");
from_sql.push_str(&join.alias);
from_sql.push_str(" ON ");
from_sql.push_str(&join.condition);
}
let mut where_sql = String::new();
let mut params: Vec<sea_orm::Value> = Vec::new();
let mut has_where = false;
if let Some(ref kw) = query.key_word {
if !option.key_word_like_fields.is_empty() {
validate_keyword(kw)?;
where_sql.push_str(" WHERE ");
has_where = true;
for (idx, col) in option.key_word_like_fields.iter().enumerate() {
validate_identifier(col)?;
if idx > 0 {
where_sql.push_str(" OR ");
}
where_sql.push_str(&format!("{} LIKE ?", col));
params.push(sea_orm::Value::String(Some(Box::new(format!("%{}%", kw)))));
}
}
}
for cond in &option.field_eq {
validate_identifier(&cond.column)?;
if let Some(value) = filters.get(&cond.request_param) {
if !value.is_null() {
if !has_where {
where_sql.push_str(" WHERE ");
has_where = true;
} else {
where_sql.push_str(" AND ");
}
where_sql.push_str(&format!("{} = ?", cond.column));
params.push(json_to_sea_value(value));
}
}
}
for cond in &option.field_like {
validate_identifier(&cond.column)?;
if let Some(value) = filters.get(&cond.request_param) {
if let Some(val_str) = value.as_str() {
validate_keyword(val_str)?;
if !has_where {
where_sql.push_str(" WHERE ");
has_where = true;
} else {
where_sql.push_str(" AND ");
}
where_sql.push_str(&format!("{} LIKE ?", cond.column));
params.push(sea_orm::Value::String(Some(Box::new(format!(
"%{}%",
val_str
)))));
}
}
}
for cond in &option.where_and {
if cond.contains(';')
|| cond.contains("--")
|| cond.contains("/*")
|| cond.contains("*/")
{
return Err(CoolError::validate("where_and 包含非法字符"));
}
if !has_where {
where_sql.push_str(" WHERE ");
has_where = true;
} else {
where_sql.push_str(" AND ");
}
where_sql.push_str(cond);
}
for frag in &option.extra_where {
if frag.sql.contains(';')
|| frag.sql.contains("--")
|| frag.sql.contains("/*")
|| frag.sql.contains("*/")
{
return Err(CoolError::validate("extra_where 包含非法字符"));
}
if !has_where {
where_sql.push_str(" WHERE ");
has_where = true;
} else {
where_sql.push_str(" AND ");
}
where_sql.push_str(&format!("({})", frag.sql));
for p in &frag.params {
params.push(json_to_sea_value(p));
}
}
let mut order_sql = String::new();
if let Some(ref order_field) = query.order {
validate_identifier(order_field)?;
let asc = query.is_asc();
order_sql.push_str(" ORDER BY ");
order_sql.push_str(order_field);
order_sql.push_str(if asc { " ASC" } else { " DESC" });
} else if !option.order_by.is_empty() {
order_sql.push_str(" ORDER BY ");
let mut first = true;
for (col, asc) in &option.order_by {
validate_identifier(col)?;
if !first {
order_sql.push_str(", ");
}
first = false;
order_sql.push_str(col);
order_sql.push_str(if *asc { " ASC" } else { " DESC" });
}
}
let count_sql = format!("SELECT COUNT(*) as count FROM {}{}", from_sql, where_sql);
let count_stmt = Statement::from_sql_and_values(
self.db().get_database_backend(),
count_sql,
params.clone(),
);
let count_result = self.db().query_one(count_stmt).await?;
let total: u64 = count_result
.as_ref()
.and_then(|r| r.try_get_by_index::<i64>(0).ok())
.map(|v| v as u64)
.unwrap_or(0);
let sql = format!(
"SELECT {} FROM {}{}{} LIMIT ? OFFSET ?",
select_sql, from_sql, where_sql, order_sql
);
let mut data_params = params;
data_params.push(sea_orm::Value::BigUnsigned(Some(size)));
data_params.push(sea_orm::Value::BigUnsigned(Some(offset)));
let stmt =
Statement::from_sql_and_values(self.db().get_database_backend(), sql, data_params);
let results = self.db().query_all(stmt).await?;
let list: Vec<Value> = results.into_iter().map(|row| self.map_row(row)).collect();
Ok(PageResult::new(list, query.page, size, total))
}
async fn list(&self, query: ListQuery, option: QueryOption) -> CoolResult<Vec<Value>> {
let select_sql = if option.select.is_empty() {
"*".to_string()
} else {
for col in &option.select {
validate_identifier(col)?;
}
option.select.join(", ")
};
let mut from_sql = self.table_name().to_string();
for join in &option.joins {
validate_identifier(&join.entity)?;
validate_identifier(&join.alias)?;
if join.condition.contains(|c| {
c == ';' || c == '#' || c == '\'' || c == '"' || c == '\n' || c == '\r'
}) {
return Err(CoolError::validate("关联条件包含非法字符"));
}
let join_kw = match join.join_type {
crate::entity::JoinType::Inner => "INNER JOIN",
crate::entity::JoinType::Left => "LEFT JOIN",
};
from_sql.push(' ');
from_sql.push_str(join_kw);
from_sql.push(' ');
from_sql.push_str(&join.entity);
from_sql.push_str(" AS ");
from_sql.push_str(&join.alias);
from_sql.push_str(" ON ");
from_sql.push_str(&join.condition);
}
let mut where_sql = String::new();
let mut params: Vec<sea_orm::Value> = Vec::new();
if let Some(ref kw) = query.key_word {
if !option.key_word_like_fields.is_empty() {
validate_keyword(kw)?;
where_sql.push_str(" WHERE ");
for (idx, col) in option.key_word_like_fields.iter().enumerate() {
validate_identifier(col)?;
if idx > 0 {
where_sql.push_str(" OR ");
}
where_sql.push_str(&format!("{} LIKE ?", col));
params.push(sea_orm::Value::String(Some(Box::new(format!("%{}%", kw)))));
}
}
}
let mut order_sql = String::new();
if let Some(ref order_field) = query.order {
validate_identifier(order_field)?;
let asc = query
.sort
.as_ref()
.map(|s| s.to_lowercase() == "asc")
.unwrap_or(false);
order_sql.push_str(" ORDER BY ");
order_sql.push_str(order_field);
order_sql.push_str(if asc { " ASC" } else { " DESC" });
} else if !option.order_by.is_empty() {
order_sql.push_str(" ORDER BY ");
let mut first = true;
for (col, asc) in &option.order_by {
validate_identifier(col)?;
if !first {
order_sql.push_str(", ");
}
first = false;
order_sql.push_str(col);
order_sql.push_str(if *asc { " ASC" } else { " DESC" });
}
}
let sql = format!(
"SELECT {} FROM {}{}{}",
select_sql, from_sql, where_sql, order_sql
);
let stmt = Statement::from_sql_and_values(self.db().get_database_backend(), sql, params);
let results = self.db().query_all(stmt).await?;
Ok(results.into_iter().map(|row| self.map_row(row)).collect())
}
async fn native_query(&self, sql: &str, params: Vec<Value>) -> CoolResult<Vec<Value>> {
let values: Vec<sea_orm::Value> =
params.into_iter().map(|v| json_to_sea_value(&v)).collect();
let stmt = Statement::from_sql_and_values(self.db().get_database_backend(), sql, values);
let results = self.db().query_all(stmt).await?;
Ok(results.into_iter().map(|row| self.map_row(row)).collect())
}
async fn execute(&self, sql: &str) -> CoolResult<u64> {
let stmt = Statement::from_string(self.db().get_database_backend(), sql.to_string());
let result = self.db().execute(stmt).await?;
Ok(result.rows_affected())
}
async fn modify_before(&self, data: Value, _modify_type: ModifyType) -> CoolResult<Value> {
Ok(data)
}
async fn modify_after(&self, _data: Value, _modify_type: ModifyType) -> CoolResult<()> {
Ok(())
}
fn map_row(&self, row: sea_orm::QueryResult) -> Value {
Value::String(format!("{:?}", row))
}
}
fn json_to_sea_value(v: &Value) -> sea_orm::Value {
match v {
Value::Null => sea_orm::Value::String(None),
Value::Bool(b) => sea_orm::Value::Bool(Some(*b)),
Value::Number(n) => {
if let Some(i) = n.as_i64() {
sea_orm::Value::BigInt(Some(i))
} else if let Some(f) = n.as_f64() {
sea_orm::Value::Double(Some(f))
} else {
sea_orm::Value::String(Some(Box::new(n.to_string())))
}
}
Value::String(s) => sea_orm::Value::String(Some(Box::new(s.clone()))),
_ => sea_orm::Value::String(Some(Box::new(v.to_string()))),
}
}
#[allow(dead_code)]
fn row_to_json(row: sea_orm::QueryResult) -> Value {
Value::String(format!("{:?}", row))
}
fn validate_identifier(ident: &str) -> CoolResult<()> {
if ident.is_empty() {
return Err(CoolError::validate("字段名不能为空"));
}
if ident.contains("--") || ident.contains("/*") || ident.contains("*/") || ident.contains(';') {
return Err(CoolError::validate("字段名包含非法字符"));
}
if !ident
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.' || c == ',' || c == ' ')
{
return Err(CoolError::validate(
"字段名仅允许字母、数字、下划线、点、逗号和空格",
));
}
Ok(())
}
fn validate_keyword(kw: &str) -> CoolResult<()> {
if kw.len() > 256 {
return Err(CoolError::validate("关键字过长"));
}
if kw.contains("--") || kw.contains("/*") || kw.contains("*/") || kw.contains(';') {
return Err(CoolError::validate("关键字包含非法字符"));
}
Ok(())
}
fn validate_query_safety(query: &PageQuery, option: &QueryOption) -> CoolResult<()> {
if let Some(ref kw) = query.key_word {
validate_keyword(kw)?;
}
for col in &option.key_word_like_fields {
validate_identifier(col)?;
}
Ok(())
}
pub struct SimpleService {
db: Arc<DatabaseConnection>,
table: String,
}
impl SimpleService {
pub fn new(db: Arc<DatabaseConnection>, table: impl Into<String>) -> Self {
Self {
db,
table: table.into(),
}
}
}
#[async_trait]
impl BaseService for SimpleService {
fn db(&self) -> &DatabaseConnection {
&self.db
}
fn table_name(&self) -> &str {
&self.table
}
}