1use crate::entity::{DeleteParam, Id, ListQuery, PageQuery, QueryOption};
6use crate::error::{CoolError, CoolResult, PageResult};
7use crate::event::{events, global_event_manager, SoftDeleteEvent};
8use async_trait::async_trait;
9use sea_orm::{ConnectionTrait, DatabaseConnection, Statement};
10use serde_json::Value;
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ModifyType {
16 Add,
18 Update,
20 Delete,
22}
23
24#[async_trait]
28pub trait BaseService: Send + Sync {
29 fn db(&self) -> &DatabaseConnection;
31
32 fn table_name(&self) -> &str;
34
35 async fn add(&self, data: Value) -> CoolResult<Value> {
37 let data = self.modify_before(data, ModifyType::Add).await?;
38
39 let columns: Vec<String> = data
41 .as_object()
42 .map(|obj| obj.keys().cloned().collect())
43 .unwrap_or_default();
44
45 if columns.is_empty() {
46 return Err(CoolError::validate("数据不能为空"));
47 }
48
49 let placeholders: Vec<String> = columns.iter().map(|_| "?".to_string()).collect();
50 let sql = format!(
51 "INSERT INTO {} ({}) VALUES ({})",
52 self.table_name(),
53 columns.join(", "),
54 placeholders.join(", ")
55 );
56
57 let values: Vec<sea_orm::Value> = columns
58 .iter()
59 .filter_map(|col| data.get(col))
60 .map(json_to_sea_value)
61 .collect();
62
63 let stmt = Statement::from_sql_and_values(self.db().get_database_backend(), &sql, values);
64
65 let result = self.db().execute(stmt).await?;
66 let id = result.last_insert_id();
67
68 self.modify_after(data.clone(), ModifyType::Add).await?;
69
70 Ok(serde_json::json!({ "id": id }))
71 }
72
73 async fn delete(&self, param: DeleteParam) -> CoolResult<()> {
75 let data = serde_json::to_value(¶m)?;
76 self.modify_before(data.clone(), ModifyType::Delete).await?;
77
78 let ids_str = param
79 .ids
80 .iter()
81 .map(|id| id.to_string())
82 .collect::<Vec<_>>()
83 .join(",");
84
85 let sql = format!(
86 "DELETE FROM {} WHERE id IN ({})",
87 self.table_name(),
88 ids_str
89 );
90 let stmt = Statement::from_string(self.db().get_database_backend(), sql);
91 self.db().execute(stmt).await?;
92
93 self.modify_after(data, ModifyType::Delete).await?;
94
95 Ok(())
96 }
97
98 async fn soft_delete(&self, ids: Vec<Id>) -> CoolResult<()> {
100 let now = chrono::Utc::now();
101 let ids_str = ids
102 .iter()
103 .map(|id| id.to_string())
104 .collect::<Vec<_>>()
105 .join(",");
106
107 let sql = format!(
108 "UPDATE {} SET delete_time = '{}' WHERE id IN ({})",
109 self.table_name(),
110 now.format("%Y-%m-%d %H:%M:%S"),
111 ids_str
112 );
113 let stmt = Statement::from_string(self.db().get_database_backend(), sql);
114 self.db().execute(stmt).await?;
115
116 global_event_manager()
122 .emit(
123 events::SOFT_DELETE,
124 SoftDeleteEvent {
125 entity: self.table_name().to_string(),
126 ids: ids.clone(),
127 tenant_id: None,
128 },
129 )
130 .await;
131
132 Ok(())
133 }
134
135 async fn update(&self, data: Value) -> CoolResult<()> {
137 let id = data
138 .get("id")
139 .ok_or_else(CoolError::no_id)?
140 .as_i64()
141 .ok_or_else(CoolError::no_id)?;
142
143 let data = self.modify_before(data, ModifyType::Update).await?;
144
145 let updates: Vec<String> = data
147 .as_object()
148 .map(|obj| {
149 obj.iter()
150 .filter(|(k, _)| *k != "id")
151 .map(|(k, _)| format!("{} = ?", k))
152 .collect()
153 })
154 .unwrap_or_default();
155
156 if updates.is_empty() {
157 return Ok(());
158 }
159
160 let sql = format!(
161 "UPDATE {} SET {} WHERE id = ?",
162 self.table_name(),
163 updates.join(", ")
164 );
165
166 let mut values: Vec<sea_orm::Value> = data
167 .as_object()
168 .map(|obj| {
169 obj.iter()
170 .filter(|(k, _)| *k != "id")
171 .map(|(_, v)| json_to_sea_value(v))
172 .collect()
173 })
174 .unwrap_or_default();
175 values.push(sea_orm::Value::BigInt(Some(id)));
176
177 let stmt = Statement::from_sql_and_values(self.db().get_database_backend(), &sql, values);
178
179 self.db().execute(stmt).await?;
180 self.modify_after(data, ModifyType::Update).await?;
181
182 Ok(())
183 }
184
185 async fn info(&self, id: Id, _ignore_fields: Option<Vec<String>>) -> CoolResult<Option<Value>> {
187 let sql = format!("SELECT * FROM {} WHERE id = ? LIMIT 1", self.table_name());
188 let stmt = Statement::from_sql_and_values(
189 self.db().get_database_backend(),
190 &sql,
191 vec![sea_orm::Value::BigInt(Some(id))],
192 );
193
194 let result = self.db().query_one(stmt).await?;
195
196 Ok(result.map(|row| self.map_row(row)))
197 }
198
199 async fn page(
211 &self,
212 query: PageQuery,
213 mut option: QueryOption,
214 ) -> CoolResult<PageResult<Value>> {
215 let offset = query.offset();
216 let size = query.size;
217
218 validate_query_safety(&query, &option)?;
220
221 let select_sql = if option.select.is_empty() {
223 "*".to_string()
224 } else {
225 for col in &option.select {
227 validate_identifier(col)?;
228 }
229 option.select.join(", ")
230 };
231
232 if !option.left_join.is_empty() {
234 option.joins.extend(option.left_join.clone());
235 }
236
237 let mut from_sql = self.table_name().to_string();
239 for join in &option.joins {
240 validate_identifier(&join.entity)?;
241 validate_identifier(&join.alias)?;
242 if join.condition.contains(|c| {
244 c == ';' || c == '#' || c == '\'' || c == '"' || c == '\n' || c == '\r'
245 }) {
246 return Err(CoolError::validate("关联条件包含非法字符"));
247 }
248
249 let join_kw = match join.join_type {
250 crate::entity::JoinType::Inner => "INNER JOIN",
251 crate::entity::JoinType::Left => "LEFT JOIN",
252 };
253 from_sql.push(' ');
254 from_sql.push_str(join_kw);
255 from_sql.push(' ');
256 from_sql.push_str(&join.entity);
257 from_sql.push_str(" AS ");
258 from_sql.push_str(&join.alias);
259 from_sql.push_str(" ON ");
260 from_sql.push_str(&join.condition);
261 }
262
263 let mut where_sql = String::new();
265 let mut params: Vec<sea_orm::Value> = Vec::new();
266
267 if let Some(ref kw) = query.key_word {
268 if !option.key_word_like_fields.is_empty() {
269 where_sql.push_str(" WHERE ");
270 for (idx, col) in option.key_word_like_fields.iter().enumerate() {
271 if idx > 0 {
272 where_sql.push_str(" OR ");
273 }
274 where_sql.push_str(&format!("{} LIKE ?", col));
275 params.push(sea_orm::Value::String(Some(Box::new(format!("%{}%", kw)))));
276 }
277 }
278 }
279
280 for cond in &option.where_and {
282 if cond.contains(';')
283 || cond.contains("--")
284 || cond.contains("/*")
285 || cond.contains("*/")
286 {
287 return Err(CoolError::validate("where_and 包含非法字符"));
288 }
289 if where_sql.is_empty() {
290 where_sql.push_str(" WHERE ");
291 } else {
292 where_sql.push_str(" AND ");
293 }
294 where_sql.push_str(cond);
295 }
296
297 for frag in &option.extra_where {
299 if frag.sql.contains(';')
300 || frag.sql.contains("--")
301 || frag.sql.contains("/*")
302 || frag.sql.contains("*/")
303 {
304 return Err(CoolError::validate("extra_where 包含非法字符"));
305 }
306 if where_sql.is_empty() {
307 where_sql.push_str(" WHERE ");
308 } else {
309 where_sql.push_str(" AND ");
310 }
311 where_sql.push_str(&format!("({})", frag.sql));
312 for p in &frag.params {
313 params.push(json_to_sea_value(p));
314 }
315 }
316
317 let mut order_sql = String::new();
319 if let Some(ref order_field) = query.order {
320 validate_identifier(order_field)?;
321 let asc = query.is_asc();
322 order_sql.push_str(" ORDER BY ");
323 order_sql.push_str(order_field);
324 order_sql.push_str(if asc { " ASC" } else { " DESC" });
325 } else if !option.order_by.is_empty() {
326 order_sql.push_str(" ORDER BY ");
327 let mut first = true;
328 for (col, asc) in &option.order_by {
329 validate_identifier(col)?;
330 if !first {
331 order_sql.push_str(", ");
332 }
333 first = false;
334 order_sql.push_str(col);
335 order_sql.push_str(if *asc { " ASC" } else { " DESC" });
336 }
337 }
338
339 let count_sql = format!("SELECT COUNT(*) as count FROM {}{}", from_sql, where_sql);
341 let count_stmt = Statement::from_sql_and_values(
342 self.db().get_database_backend(),
343 count_sql,
344 params.clone(),
345 );
346 let count_result = self.db().query_one(count_stmt).await?;
347 let total: u64 = count_result
348 .as_ref()
349 .and_then(|r| r.try_get_by_index::<i64>(0).ok())
350 .map(|v| v as u64)
351 .unwrap_or(0);
352
353 let sql = format!(
355 "SELECT {} FROM {}{}{} LIMIT ? OFFSET ?",
356 select_sql, from_sql, where_sql, order_sql
357 );
358 let mut data_params = params;
359 data_params.push(sea_orm::Value::BigUnsigned(Some(size)));
360 data_params.push(sea_orm::Value::BigUnsigned(Some(offset)));
361
362 let stmt =
363 Statement::from_sql_and_values(self.db().get_database_backend(), sql, data_params);
364 let results = self.db().query_all(stmt).await?;
365
366 let list: Vec<Value> = results.into_iter().map(|row| self.map_row(row)).collect();
367
368 Ok(PageResult::new(list, query.page, size, total))
369 }
370
371 async fn page_with_filters(
378 &self,
379 query: PageQuery,
380 filters: &Value,
381 mut option: QueryOption,
382 ) -> CoolResult<PageResult<Value>> {
383 let offset = query.offset();
385 let size = query.size;
386
387 validate_query_safety(&query, &option)?;
389
390 if !option.left_join.is_empty() {
392 option.joins.extend(option.left_join.clone());
393 }
394
395 let select_sql = if option.select.is_empty() {
397 "*".to_string()
398 } else {
399 for col in &option.select {
400 validate_identifier(col)?;
401 }
402 option.select.join(", ")
403 };
404
405 let mut from_sql = self.table_name().to_string();
407 for join in &option.joins {
408 validate_identifier(&join.entity)?;
409 validate_identifier(&join.alias)?;
410 if join.condition.contains(|c| {
411 c == ';' || c == '#' || c == '\'' || c == '"' || c == '\n' || c == '\r'
412 }) {
413 return Err(CoolError::validate("关联条件包含非法字符"));
414 }
415
416 let join_kw = match join.join_type {
417 crate::entity::JoinType::Inner => "INNER JOIN",
418 crate::entity::JoinType::Left => "LEFT JOIN",
419 };
420 from_sql.push(' ');
421 from_sql.push_str(join_kw);
422 from_sql.push(' ');
423 from_sql.push_str(&join.entity);
424 from_sql.push_str(" AS ");
425 from_sql.push_str(&join.alias);
426 from_sql.push_str(" ON ");
427 from_sql.push_str(&join.condition);
428 }
429
430 let mut where_sql = String::new();
432 let mut params: Vec<sea_orm::Value> = Vec::new();
433 let mut has_where = false;
434
435 if let Some(ref kw) = query.key_word {
437 if !option.key_word_like_fields.is_empty() {
438 validate_keyword(kw)?;
439 where_sql.push_str(" WHERE ");
440 has_where = true;
441 for (idx, col) in option.key_word_like_fields.iter().enumerate() {
442 validate_identifier(col)?;
443 if idx > 0 {
444 where_sql.push_str(" OR ");
445 }
446 where_sql.push_str(&format!("{} LIKE ?", col));
447 params.push(sea_orm::Value::String(Some(Box::new(format!("%{}%", kw)))));
448 }
449 }
450 }
451
452 for cond in &option.field_eq {
454 validate_identifier(&cond.column)?;
455 if let Some(value) = filters.get(&cond.request_param) {
456 if !value.is_null() {
457 if !has_where {
458 where_sql.push_str(" WHERE ");
459 has_where = true;
460 } else {
461 where_sql.push_str(" AND ");
462 }
463 where_sql.push_str(&format!("{} = ?", cond.column));
464 params.push(json_to_sea_value(value));
465 }
466 }
467 }
468
469 for cond in &option.field_like {
471 validate_identifier(&cond.column)?;
472 if let Some(value) = filters.get(&cond.request_param) {
473 if let Some(val_str) = value.as_str() {
474 validate_keyword(val_str)?;
475 if !has_where {
476 where_sql.push_str(" WHERE ");
477 has_where = true;
478 } else {
479 where_sql.push_str(" AND ");
480 }
481 where_sql.push_str(&format!("{} LIKE ?", cond.column));
482 params.push(sea_orm::Value::String(Some(Box::new(format!(
483 "%{}%",
484 val_str
485 )))));
486 }
487 }
488 }
489
490 for cond in &option.where_and {
492 if cond.contains(';')
493 || cond.contains("--")
494 || cond.contains("/*")
495 || cond.contains("*/")
496 {
497 return Err(CoolError::validate("where_and 包含非法字符"));
498 }
499 if !has_where {
500 where_sql.push_str(" WHERE ");
501 has_where = true;
502 } else {
503 where_sql.push_str(" AND ");
504 }
505 where_sql.push_str(cond);
506 }
507
508 for frag in &option.extra_where {
510 if frag.sql.contains(';')
511 || frag.sql.contains("--")
512 || frag.sql.contains("/*")
513 || frag.sql.contains("*/")
514 {
515 return Err(CoolError::validate("extra_where 包含非法字符"));
516 }
517 if !has_where {
518 where_sql.push_str(" WHERE ");
519 has_where = true;
520 } else {
521 where_sql.push_str(" AND ");
522 }
523 where_sql.push_str(&format!("({})", frag.sql));
524 for p in &frag.params {
525 params.push(json_to_sea_value(p));
526 }
527 }
528
529 let mut order_sql = String::new();
531 if let Some(ref order_field) = query.order {
532 validate_identifier(order_field)?;
533 let asc = query.is_asc();
534 order_sql.push_str(" ORDER BY ");
535 order_sql.push_str(order_field);
536 order_sql.push_str(if asc { " ASC" } else { " DESC" });
537 } else if !option.order_by.is_empty() {
538 order_sql.push_str(" ORDER BY ");
539 let mut first = true;
540 for (col, asc) in &option.order_by {
541 validate_identifier(col)?;
542 if !first {
543 order_sql.push_str(", ");
544 }
545 first = false;
546 order_sql.push_str(col);
547 order_sql.push_str(if *asc { " ASC" } else { " DESC" });
548 }
549 }
550
551 let count_sql = format!("SELECT COUNT(*) as count FROM {}{}", from_sql, where_sql);
553 let count_stmt = Statement::from_sql_and_values(
554 self.db().get_database_backend(),
555 count_sql,
556 params.clone(),
557 );
558 let count_result = self.db().query_one(count_stmt).await?;
559 let total: u64 = count_result
560 .as_ref()
561 .and_then(|r| r.try_get_by_index::<i64>(0).ok())
562 .map(|v| v as u64)
563 .unwrap_or(0);
564
565 let sql = format!(
567 "SELECT {} FROM {}{}{} LIMIT ? OFFSET ?",
568 select_sql, from_sql, where_sql, order_sql
569 );
570 let mut data_params = params;
571 data_params.push(sea_orm::Value::BigUnsigned(Some(size)));
572 data_params.push(sea_orm::Value::BigUnsigned(Some(offset)));
573
574 let stmt =
575 Statement::from_sql_and_values(self.db().get_database_backend(), sql, data_params);
576 let results = self.db().query_all(stmt).await?;
577
578 let list: Vec<Value> = results.into_iter().map(|row| self.map_row(row)).collect();
579
580 Ok(PageResult::new(list, query.page, size, total))
581 }
582
583 async fn list(&self, query: ListQuery, option: QueryOption) -> CoolResult<Vec<Value>> {
589 let select_sql = if option.select.is_empty() {
591 "*".to_string()
592 } else {
593 for col in &option.select {
594 validate_identifier(col)?;
595 }
596 option.select.join(", ")
597 };
598
599 let mut from_sql = self.table_name().to_string();
601 for join in &option.joins {
602 validate_identifier(&join.entity)?;
603 validate_identifier(&join.alias)?;
604 if join.condition.contains(|c| {
605 c == ';' || c == '#' || c == '\'' || c == '"' || c == '\n' || c == '\r'
606 }) {
607 return Err(CoolError::validate("关联条件包含非法字符"));
608 }
609
610 let join_kw = match join.join_type {
611 crate::entity::JoinType::Inner => "INNER JOIN",
612 crate::entity::JoinType::Left => "LEFT JOIN",
613 };
614 from_sql.push(' ');
615 from_sql.push_str(join_kw);
616 from_sql.push(' ');
617 from_sql.push_str(&join.entity);
618 from_sql.push_str(" AS ");
619 from_sql.push_str(&join.alias);
620 from_sql.push_str(" ON ");
621 from_sql.push_str(&join.condition);
622 }
623
624 let mut where_sql = String::new();
625 let mut params: Vec<sea_orm::Value> = Vec::new();
626
627 if let Some(ref kw) = query.key_word {
628 if !option.key_word_like_fields.is_empty() {
629 validate_keyword(kw)?;
630 where_sql.push_str(" WHERE ");
631 for (idx, col) in option.key_word_like_fields.iter().enumerate() {
632 validate_identifier(col)?;
633 if idx > 0 {
634 where_sql.push_str(" OR ");
635 }
636 where_sql.push_str(&format!("{} LIKE ?", col));
637 params.push(sea_orm::Value::String(Some(Box::new(format!("%{}%", kw)))));
638 }
639 }
640 }
641
642 let mut order_sql = String::new();
643 if let Some(ref order_field) = query.order {
644 validate_identifier(order_field)?;
645 let asc = query
646 .sort
647 .as_ref()
648 .map(|s| s.to_lowercase() == "asc")
649 .unwrap_or(false);
650 order_sql.push_str(" ORDER BY ");
651 order_sql.push_str(order_field);
652 order_sql.push_str(if asc { " ASC" } else { " DESC" });
653 } else if !option.order_by.is_empty() {
654 order_sql.push_str(" ORDER BY ");
655 let mut first = true;
656 for (col, asc) in &option.order_by {
657 validate_identifier(col)?;
658 if !first {
659 order_sql.push_str(", ");
660 }
661 first = false;
662 order_sql.push_str(col);
663 order_sql.push_str(if *asc { " ASC" } else { " DESC" });
664 }
665 }
666
667 let sql = format!(
668 "SELECT {} FROM {}{}{}",
669 select_sql, from_sql, where_sql, order_sql
670 );
671 let stmt = Statement::from_sql_and_values(self.db().get_database_backend(), sql, params);
672 let results = self.db().query_all(stmt).await?;
673
674 Ok(results.into_iter().map(|row| self.map_row(row)).collect())
675 }
676
677 async fn native_query(&self, sql: &str, params: Vec<Value>) -> CoolResult<Vec<Value>> {
679 let values: Vec<sea_orm::Value> =
680 params.into_iter().map(|v| json_to_sea_value(&v)).collect();
681
682 let stmt = Statement::from_sql_and_values(self.db().get_database_backend(), sql, values);
683
684 let results = self.db().query_all(stmt).await?;
685
686 Ok(results.into_iter().map(|row| self.map_row(row)).collect())
687 }
688
689 async fn execute(&self, sql: &str) -> CoolResult<u64> {
691 let stmt = Statement::from_string(self.db().get_database_backend(), sql.to_string());
692 let result = self.db().execute(stmt).await?;
693 Ok(result.rows_affected())
694 }
695
696 async fn modify_before(&self, data: Value, _modify_type: ModifyType) -> CoolResult<Value> {
698 Ok(data)
699 }
700
701 async fn modify_after(&self, _data: Value, _modify_type: ModifyType) -> CoolResult<()> {
703 Ok(())
704 }
705
706 fn map_row(&self, row: sea_orm::QueryResult) -> Value {
712 Value::String(format!("{:?}", row))
713 }
714}
715
716fn json_to_sea_value(v: &Value) -> sea_orm::Value {
718 match v {
719 Value::Null => sea_orm::Value::String(None),
720 Value::Bool(b) => sea_orm::Value::Bool(Some(*b)),
721 Value::Number(n) => {
722 if let Some(i) = n.as_i64() {
723 sea_orm::Value::BigInt(Some(i))
724 } else if let Some(f) = n.as_f64() {
725 sea_orm::Value::Double(Some(f))
726 } else {
727 sea_orm::Value::String(Some(Box::new(n.to_string())))
728 }
729 }
730 Value::String(s) => sea_orm::Value::String(Some(Box::new(s.clone()))),
731 _ => sea_orm::Value::String(Some(Box::new(v.to_string()))),
732 }
733}
734
735#[allow(dead_code)]
741fn row_to_json(row: sea_orm::QueryResult) -> Value {
742 Value::String(format!("{:?}", row))
745}
746
747fn validate_identifier(ident: &str) -> CoolResult<()> {
753 if ident.is_empty() {
754 return Err(CoolError::validate("字段名不能为空"));
755 }
756
757 if ident.contains("--") || ident.contains("/*") || ident.contains("*/") || ident.contains(';') {
758 return Err(CoolError::validate("字段名包含非法字符"));
759 }
760
761 if !ident
762 .chars()
763 .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.' || c == ',' || c == ' ')
764 {
765 return Err(CoolError::validate(
766 "字段名仅允许字母、数字、下划线、点、逗号和空格",
767 ));
768 }
769
770 Ok(())
771}
772
773fn validate_keyword(kw: &str) -> CoolResult<()> {
775 if kw.len() > 256 {
777 return Err(CoolError::validate("关键字过长"));
778 }
779
780 if kw.contains("--") || kw.contains("/*") || kw.contains("*/") || kw.contains(';') {
782 return Err(CoolError::validate("关键字包含非法字符"));
783 }
784
785 Ok(())
786}
787
788fn validate_query_safety(query: &PageQuery, option: &QueryOption) -> CoolResult<()> {
790 if let Some(ref kw) = query.key_word {
792 validate_keyword(kw)?;
793 }
794
795 for col in &option.key_word_like_fields {
797 validate_identifier(col)?;
798 }
799
800 Ok(())
801}
802
803pub struct SimpleService {
807 db: Arc<DatabaseConnection>,
808 table: String,
809}
810
811impl SimpleService {
812 pub fn new(db: Arc<DatabaseConnection>, table: impl Into<String>) -> Self {
813 Self {
814 db,
815 table: table.into(),
816 }
817 }
818}
819
820#[async_trait]
821impl BaseService for SimpleService {
822 fn db(&self) -> &DatabaseConnection {
823 &self.db
824 }
825
826 fn table_name(&self) -> &str {
827 &self.table
828 }
829}