cool_core/service/
base.rs

1//! 服务基类
2//!
3//! 对应 TypeScript 版本的 `service/base.ts`
4
5use crate::entity::{DeleteParam, Id, ListQuery, PageQuery, QueryOption};
6use crate::error::{CoolError, CoolResult, PageResult};
7use crate::event::{events, global_event_manager, SoftDeleteEvent};
8use async_trait::async_trait;
9use sea_orm::{ConnectionTrait, DatabaseConnection, Statement};
10use serde_json::Value;
11use std::sync::Arc;
12
13/// 修改类型
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ModifyType {
16    /// 新增
17    Add,
18    /// 修改
19    Update,
20    /// 删除
21    Delete,
22}
23
24/// 服务基类 trait
25///
26/// 提供通用的 CRUD 操作
27#[async_trait]
28pub trait BaseService: Send + Sync {
29    /// 获取数据库连接
30    fn db(&self) -> &DatabaseConnection;
31
32    /// 获取表名
33    fn table_name(&self) -> &str;
34
35    /// 新增
36    async fn add(&self, data: Value) -> CoolResult<Value> {
37        let data = self.modify_before(data, ModifyType::Add).await?;
38
39        // 构建 INSERT SQL
40        let columns: Vec<String> = data
41            .as_object()
42            .map(|obj| obj.keys().cloned().collect())
43            .unwrap_or_default();
44
45        if columns.is_empty() {
46            return Err(CoolError::validate("数据不能为空"));
47        }
48
49        let placeholders: Vec<String> = columns.iter().map(|_| "?".to_string()).collect();
50        let sql = format!(
51            "INSERT INTO {} ({}) VALUES ({})",
52            self.table_name(),
53            columns.join(", "),
54            placeholders.join(", ")
55        );
56
57        let values: Vec<sea_orm::Value> = columns
58            .iter()
59            .filter_map(|col| data.get(col))
60            .map(json_to_sea_value)
61            .collect();
62
63        let stmt = Statement::from_sql_and_values(self.db().get_database_backend(), &sql, values);
64
65        let result = self.db().execute(stmt).await?;
66        let id = result.last_insert_id();
67
68        self.modify_after(data.clone(), ModifyType::Add).await?;
69
70        Ok(serde_json::json!({ "id": id }))
71    }
72
73    /// 删除
74    async fn delete(&self, param: DeleteParam) -> CoolResult<()> {
75        let data = serde_json::to_value(&param)?;
76        self.modify_before(data.clone(), ModifyType::Delete).await?;
77
78        let ids_str = param
79            .ids
80            .iter()
81            .map(|id| id.to_string())
82            .collect::<Vec<_>>()
83            .join(",");
84
85        let sql = format!(
86            "DELETE FROM {} WHERE id IN ({})",
87            self.table_name(),
88            ids_str
89        );
90        let stmt = Statement::from_string(self.db().get_database_backend(), sql);
91        self.db().execute(stmt).await?;
92
93        self.modify_after(data, ModifyType::Delete).await?;
94
95        Ok(())
96    }
97
98    /// 软删除
99    async fn soft_delete(&self, ids: Vec<Id>) -> CoolResult<()> {
100        let now = chrono::Utc::now();
101        let ids_str = ids
102            .iter()
103            .map(|id| id.to_string())
104            .collect::<Vec<_>>()
105            .join(",");
106
107        let sql = format!(
108            "UPDATE {} SET delete_time = '{}' WHERE id IN ({})",
109            self.table_name(),
110            now.format("%Y-%m-%d %H:%M:%S"),
111            ids_str
112        );
113        let stmt = Statement::from_string(self.db().get_database_backend(), sql);
114        self.db().execute(stmt).await?;
115
116        // 触发软删除事件,方便其他模块(如回收站、日志等)进行处理
117        //
118        // 说明:
119        // - entity 使用表名,保持与 TS 版本中实体名称语义接近
120        // - tenant_id 当前无法在基类中获取,先置为 None,后续可在有租户上下文的 Service 中扩展
121        global_event_manager()
122            .emit(
123                events::SOFT_DELETE,
124                SoftDeleteEvent {
125                    entity: self.table_name().to_string(),
126                    ids: ids.clone(),
127                    tenant_id: None,
128                },
129            )
130            .await;
131
132        Ok(())
133    }
134
135    /// 修改
136    async fn update(&self, data: Value) -> CoolResult<()> {
137        let id = data
138            .get("id")
139            .ok_or_else(CoolError::no_id)?
140            .as_i64()
141            .ok_or_else(CoolError::no_id)?;
142
143        let data = self.modify_before(data, ModifyType::Update).await?;
144
145        // 构建 UPDATE SQL
146        let updates: Vec<String> = data
147            .as_object()
148            .map(|obj| {
149                obj.iter()
150                    .filter(|(k, _)| *k != "id")
151                    .map(|(k, _)| format!("{} = ?", k))
152                    .collect()
153            })
154            .unwrap_or_default();
155
156        if updates.is_empty() {
157            return Ok(());
158        }
159
160        let sql = format!(
161            "UPDATE {} SET {} WHERE id = ?",
162            self.table_name(),
163            updates.join(", ")
164        );
165
166        let mut values: Vec<sea_orm::Value> = data
167            .as_object()
168            .map(|obj| {
169                obj.iter()
170                    .filter(|(k, _)| *k != "id")
171                    .map(|(_, v)| json_to_sea_value(v))
172                    .collect()
173            })
174            .unwrap_or_default();
175        values.push(sea_orm::Value::BigInt(Some(id)));
176
177        let stmt = Statement::from_sql_and_values(self.db().get_database_backend(), &sql, values);
178
179        self.db().execute(stmt).await?;
180        self.modify_after(data, ModifyType::Update).await?;
181
182        Ok(())
183    }
184
185    /// 查询单条记录
186    async fn info(&self, id: Id, _ignore_fields: Option<Vec<String>>) -> CoolResult<Option<Value>> {
187        let sql = format!("SELECT * FROM {} WHERE id = ? LIMIT 1", self.table_name());
188        let stmt = Statement::from_sql_and_values(
189            self.db().get_database_backend(),
190            &sql,
191            vec![sea_orm::Value::BigInt(Some(id))],
192        );
193
194        let result = self.db().query_one(stmt).await?;
195
196        Ok(result.map(|row| self.map_row(row)))
197    }
198
199    /// 分页查询
200    ///
201    /// 说明:
202    /// - 支持基础的关键字模糊查询与排序(基于 `QueryOption`)
203    /// - 在生成 SQL 之前会做字段名和关键字长度的安全检查,避免明显的 SQL 注入
204    /// - 更复杂的关联和条件可以通过 `native_query` 或上层服务自定义
205    ///
206    /// 与 TS 版本的 QueryOp 对齐:
207    /// - `key_word_like_fields`:关键字模糊查询
208    /// - `order_by`:排序配置
209    /// - `field_eq` / `field_like`:可通过 [`page_with_filters`] 传入自定义过滤参数使用
210    async fn page(
211        &self,
212        query: PageQuery,
213        mut option: QueryOption,
214    ) -> CoolResult<PageResult<Value>> {
215        let offset = query.offset();
216        let size = query.size;
217
218        // 基础参数安全性检查
219        validate_query_safety(&query, &option)?;
220
221        // 构建 SELECT 子句
222        let select_sql = if option.select.is_empty() {
223            "*".to_string()
224        } else {
225            // 这里只是简单地拼接字段名,默认调用方传入的是完整的列名(带别名)
226            for col in &option.select {
227                validate_identifier(col)?;
228            }
229            option.select.join(", ")
230        };
231
232        // 合并 left_join 到 joins,兼容 TS 命名
233        if !option.left_join.is_empty() {
234            option.joins.extend(option.left_join.clone());
235        }
236
237        // 构建 FROM + JOIN 子句
238        let mut from_sql = self.table_name().to_string();
239        for join in &option.joins {
240            validate_identifier(&join.entity)?;
241            validate_identifier(&join.alias)?;
242            // 条件里也做一层基础校验,防止明显的注入特征
243            if join.condition.contains(|c| {
244                c == ';' || c == '#' || c == '\'' || c == '"' || c == '\n' || c == '\r'
245            }) {
246                return Err(CoolError::validate("关联条件包含非法字符"));
247            }
248
249            let join_kw = match join.join_type {
250                crate::entity::JoinType::Inner => "INNER JOIN",
251                crate::entity::JoinType::Left => "LEFT JOIN",
252            };
253            from_sql.push(' ');
254            from_sql.push_str(join_kw);
255            from_sql.push(' ');
256            from_sql.push_str(&join.entity);
257            from_sql.push_str(" AS ");
258            from_sql.push_str(&join.alias);
259            from_sql.push_str(" ON ");
260            from_sql.push_str(&join.condition);
261        }
262
263        // 构建 WHERE 子句和参数(关键字模糊 + where_and)
264        let mut where_sql = String::new();
265        let mut params: Vec<sea_orm::Value> = Vec::new();
266
267        if let Some(ref kw) = query.key_word {
268            if !option.key_word_like_fields.is_empty() {
269                where_sql.push_str(" WHERE ");
270                for (idx, col) in option.key_word_like_fields.iter().enumerate() {
271                    if idx > 0 {
272                        where_sql.push_str(" OR ");
273                    }
274                    where_sql.push_str(&format!("{} LIKE ?", col));
275                    params.push(sea_orm::Value::String(Some(Box::new(format!("%{}%", kw)))));
276                }
277            }
278        }
279
280        // 追加 where_and 片段(仅做非常基础的安全检查)
281        for cond in &option.where_and {
282            if cond.contains(';')
283                || cond.contains("--")
284                || cond.contains("/*")
285                || cond.contains("*/")
286            {
287                return Err(CoolError::validate("where_and 包含非法字符"));
288            }
289            if where_sql.is_empty() {
290                where_sql.push_str(" WHERE ");
291            } else {
292                where_sql.push_str(" AND ");
293            }
294            where_sql.push_str(cond);
295        }
296
297        // 追加 extra_where 片段(带参数)
298        for frag in &option.extra_where {
299            if frag.sql.contains(';')
300                || frag.sql.contains("--")
301                || frag.sql.contains("/*")
302                || frag.sql.contains("*/")
303            {
304                return Err(CoolError::validate("extra_where 包含非法字符"));
305            }
306            if where_sql.is_empty() {
307                where_sql.push_str(" WHERE ");
308            } else {
309                where_sql.push_str(" AND ");
310            }
311            where_sql.push_str(&format!("({})", frag.sql));
312            for p in &frag.params {
313                params.push(json_to_sea_value(p));
314            }
315        }
316
317        // 构建 ORDER BY 子句:优先使用前端传入的 order/sort,其次使用 QueryOption.order_by
318        let mut order_sql = String::new();
319        if let Some(ref order_field) = query.order {
320            validate_identifier(order_field)?;
321            let asc = query.is_asc();
322            order_sql.push_str(" ORDER BY ");
323            order_sql.push_str(order_field);
324            order_sql.push_str(if asc { " ASC" } else { " DESC" });
325        } else if !option.order_by.is_empty() {
326            order_sql.push_str(" ORDER BY ");
327            let mut first = true;
328            for (col, asc) in &option.order_by {
329                validate_identifier(col)?;
330                if !first {
331                    order_sql.push_str(", ");
332                }
333                first = false;
334                order_sql.push_str(col);
335                order_sql.push_str(if *asc { " ASC" } else { " DESC" });
336            }
337        }
338
339        // 查询总数
340        let count_sql = format!("SELECT COUNT(*) as count FROM {}{}", from_sql, where_sql);
341        let count_stmt = Statement::from_sql_and_values(
342            self.db().get_database_backend(),
343            count_sql,
344            params.clone(),
345        );
346        let count_result = self.db().query_one(count_stmt).await?;
347        let total: u64 = count_result
348            .as_ref()
349            .and_then(|r| r.try_get_by_index::<i64>(0).ok())
350            .map(|v| v as u64)
351            .unwrap_or(0);
352
353        // 查询数据
354        let sql = format!(
355            "SELECT {} FROM {}{}{} LIMIT ? OFFSET ?",
356            select_sql, from_sql, where_sql, order_sql
357        );
358        let mut data_params = params;
359        data_params.push(sea_orm::Value::BigUnsigned(Some(size)));
360        data_params.push(sea_orm::Value::BigUnsigned(Some(offset)));
361
362        let stmt =
363            Statement::from_sql_and_values(self.db().get_database_backend(), sql, data_params);
364        let results = self.db().query_all(stmt).await?;
365
366        let list: Vec<Value> = results.into_iter().map(|row| self.map_row(row)).collect();
367
368        Ok(PageResult::new(list, query.page, size, total))
369    }
370
371    /// 分页查询(带过滤参数)
372    ///
373    /// 对齐 TS 版本中 `fieldEq` / `fieldLike` 的能力:
374    /// - `filters` 为一个 JSON 对象,key 为 `FieldCondition.request_param`
375    /// - `field_eq`:生成 `AND column = ?`
376    /// - `field_like`:生成 `AND column LIKE ?`
377    async fn page_with_filters(
378        &self,
379        query: PageQuery,
380        filters: &Value,
381        mut option: QueryOption,
382    ) -> CoolResult<PageResult<Value>> {
383        // 先复用基础分页逻辑的大部分实现,但手动展开 WHERE 构建,加入 eq/like 条件
384        let offset = query.offset();
385        let size = query.size;
386
387        // 基础参数安全性检查
388        validate_query_safety(&query, &option)?;
389
390        // 合并 left_join 到 joins,兼容 TS 命名
391        if !option.left_join.is_empty() {
392            option.joins.extend(option.left_join.clone());
393        }
394
395        // 构建 SELECT 子句
396        let select_sql = if option.select.is_empty() {
397            "*".to_string()
398        } else {
399            for col in &option.select {
400                validate_identifier(col)?;
401            }
402            option.select.join(", ")
403        };
404
405        // 构建 FROM + JOIN 子句
406        let mut from_sql = self.table_name().to_string();
407        for join in &option.joins {
408            validate_identifier(&join.entity)?;
409            validate_identifier(&join.alias)?;
410            if join.condition.contains(|c| {
411                c == ';' || c == '#' || c == '\'' || c == '"' || c == '\n' || c == '\r'
412            }) {
413                return Err(CoolError::validate("关联条件包含非法字符"));
414            }
415
416            let join_kw = match join.join_type {
417                crate::entity::JoinType::Inner => "INNER JOIN",
418                crate::entity::JoinType::Left => "LEFT JOIN",
419            };
420            from_sql.push(' ');
421            from_sql.push_str(join_kw);
422            from_sql.push(' ');
423            from_sql.push_str(&join.entity);
424            from_sql.push_str(" AS ");
425            from_sql.push_str(&join.alias);
426            from_sql.push_str(" ON ");
427            from_sql.push_str(&join.condition);
428        }
429
430        // 构建 WHERE 子句和参数:关键字 + 等值 + 模糊字段 + where_and
431        let mut where_sql = String::new();
432        let mut params: Vec<sea_orm::Value> = Vec::new();
433        let mut has_where = false;
434
435        // 关键字模糊查询
436        if let Some(ref kw) = query.key_word {
437            if !option.key_word_like_fields.is_empty() {
438                validate_keyword(kw)?;
439                where_sql.push_str(" WHERE ");
440                has_where = true;
441                for (idx, col) in option.key_word_like_fields.iter().enumerate() {
442                    validate_identifier(col)?;
443                    if idx > 0 {
444                        where_sql.push_str(" OR ");
445                    }
446                    where_sql.push_str(&format!("{} LIKE ?", col));
447                    params.push(sea_orm::Value::String(Some(Box::new(format!("%{}%", kw)))));
448                }
449            }
450        }
451
452        // 等值字段
453        for cond in &option.field_eq {
454            validate_identifier(&cond.column)?;
455            if let Some(value) = filters.get(&cond.request_param) {
456                if !value.is_null() {
457                    if !has_where {
458                        where_sql.push_str(" WHERE ");
459                        has_where = true;
460                    } else {
461                        where_sql.push_str(" AND ");
462                    }
463                    where_sql.push_str(&format!("{} = ?", cond.column));
464                    params.push(json_to_sea_value(value));
465                }
466            }
467        }
468
469        // 模糊匹配字段
470        for cond in &option.field_like {
471            validate_identifier(&cond.column)?;
472            if let Some(value) = filters.get(&cond.request_param) {
473                if let Some(val_str) = value.as_str() {
474                    validate_keyword(val_str)?;
475                    if !has_where {
476                        where_sql.push_str(" WHERE ");
477                        has_where = true;
478                    } else {
479                        where_sql.push_str(" AND ");
480                    }
481                    where_sql.push_str(&format!("{} LIKE ?", cond.column));
482                    params.push(sea_orm::Value::String(Some(Box::new(format!(
483                        "%{}%",
484                        val_str
485                    )))));
486                }
487            }
488        }
489
490        // 追加 where_and 片段(仅做非常基础的安全检查)
491        for cond in &option.where_and {
492            if cond.contains(';')
493                || cond.contains("--")
494                || cond.contains("/*")
495                || cond.contains("*/")
496            {
497                return Err(CoolError::validate("where_and 包含非法字符"));
498            }
499            if !has_where {
500                where_sql.push_str(" WHERE ");
501                has_where = true;
502            } else {
503                where_sql.push_str(" AND ");
504            }
505            where_sql.push_str(cond);
506        }
507
508        // 追加 extra_where 片段(带参数)
509        for frag in &option.extra_where {
510            if frag.sql.contains(';')
511                || frag.sql.contains("--")
512                || frag.sql.contains("/*")
513                || frag.sql.contains("*/")
514            {
515                return Err(CoolError::validate("extra_where 包含非法字符"));
516            }
517            if !has_where {
518                where_sql.push_str(" WHERE ");
519                has_where = true;
520            } else {
521                where_sql.push_str(" AND ");
522            }
523            where_sql.push_str(&format!("({})", frag.sql));
524            for p in &frag.params {
525                params.push(json_to_sea_value(p));
526            }
527        }
528
529        // 构建 ORDER BY 子句
530        let mut order_sql = String::new();
531        if let Some(ref order_field) = query.order {
532            validate_identifier(order_field)?;
533            let asc = query.is_asc();
534            order_sql.push_str(" ORDER BY ");
535            order_sql.push_str(order_field);
536            order_sql.push_str(if asc { " ASC" } else { " DESC" });
537        } else if !option.order_by.is_empty() {
538            order_sql.push_str(" ORDER BY ");
539            let mut first = true;
540            for (col, asc) in &option.order_by {
541                validate_identifier(col)?;
542                if !first {
543                    order_sql.push_str(", ");
544                }
545                first = false;
546                order_sql.push_str(col);
547                order_sql.push_str(if *asc { " ASC" } else { " DESC" });
548            }
549        }
550
551        // 查询总数
552        let count_sql = format!("SELECT COUNT(*) as count FROM {}{}", from_sql, where_sql);
553        let count_stmt = Statement::from_sql_and_values(
554            self.db().get_database_backend(),
555            count_sql,
556            params.clone(),
557        );
558        let count_result = self.db().query_one(count_stmt).await?;
559        let total: u64 = count_result
560            .as_ref()
561            .and_then(|r| r.try_get_by_index::<i64>(0).ok())
562            .map(|v| v as u64)
563            .unwrap_or(0);
564
565        // 查询数据
566        let sql = format!(
567            "SELECT {} FROM {}{}{} LIMIT ? OFFSET ?",
568            select_sql, from_sql, where_sql, order_sql
569        );
570        let mut data_params = params;
571        data_params.push(sea_orm::Value::BigUnsigned(Some(size)));
572        data_params.push(sea_orm::Value::BigUnsigned(Some(offset)));
573
574        let stmt =
575            Statement::from_sql_and_values(self.db().get_database_backend(), sql, data_params);
576        let results = self.db().query_all(stmt).await?;
577
578        let list: Vec<Value> = results.into_iter().map(|row| self.map_row(row)).collect();
579
580        Ok(PageResult::new(list, query.page, size, total))
581    }
582
583    /// 列表查询(不分页)
584    ///
585    /// 说明:
586    /// - 支持基础的关键字模糊查询与排序
587    /// - 在生成 SQL 之前会做字段名和关键字长度的安全检查
588    async fn list(&self, query: ListQuery, option: QueryOption) -> CoolResult<Vec<Value>> {
589        // 构建 SELECT 子句
590        let select_sql = if option.select.is_empty() {
591            "*".to_string()
592        } else {
593            for col in &option.select {
594                validate_identifier(col)?;
595            }
596            option.select.join(", ")
597        };
598
599        // 构建 FROM + JOIN 子句
600        let mut from_sql = self.table_name().to_string();
601        for join in &option.joins {
602            validate_identifier(&join.entity)?;
603            validate_identifier(&join.alias)?;
604            if join.condition.contains(|c| {
605                c == ';' || c == '#' || c == '\'' || c == '"' || c == '\n' || c == '\r'
606            }) {
607                return Err(CoolError::validate("关联条件包含非法字符"));
608            }
609
610            let join_kw = match join.join_type {
611                crate::entity::JoinType::Inner => "INNER JOIN",
612                crate::entity::JoinType::Left => "LEFT JOIN",
613            };
614            from_sql.push(' ');
615            from_sql.push_str(join_kw);
616            from_sql.push(' ');
617            from_sql.push_str(&join.entity);
618            from_sql.push_str(" AS ");
619            from_sql.push_str(&join.alias);
620            from_sql.push_str(" ON ");
621            from_sql.push_str(&join.condition);
622        }
623
624        let mut where_sql = String::new();
625        let mut params: Vec<sea_orm::Value> = Vec::new();
626
627        if let Some(ref kw) = query.key_word {
628            if !option.key_word_like_fields.is_empty() {
629                validate_keyword(kw)?;
630                where_sql.push_str(" WHERE ");
631                for (idx, col) in option.key_word_like_fields.iter().enumerate() {
632                    validate_identifier(col)?;
633                    if idx > 0 {
634                        where_sql.push_str(" OR ");
635                    }
636                    where_sql.push_str(&format!("{} LIKE ?", col));
637                    params.push(sea_orm::Value::String(Some(Box::new(format!("%{}%", kw)))));
638                }
639            }
640        }
641
642        let mut order_sql = String::new();
643        if let Some(ref order_field) = query.order {
644            validate_identifier(order_field)?;
645            let asc = query
646                .sort
647                .as_ref()
648                .map(|s| s.to_lowercase() == "asc")
649                .unwrap_or(false);
650            order_sql.push_str(" ORDER BY ");
651            order_sql.push_str(order_field);
652            order_sql.push_str(if asc { " ASC" } else { " DESC" });
653        } else if !option.order_by.is_empty() {
654            order_sql.push_str(" ORDER BY ");
655            let mut first = true;
656            for (col, asc) in &option.order_by {
657                validate_identifier(col)?;
658                if !first {
659                    order_sql.push_str(", ");
660                }
661                first = false;
662                order_sql.push_str(col);
663                order_sql.push_str(if *asc { " ASC" } else { " DESC" });
664            }
665        }
666
667        let sql = format!(
668            "SELECT {} FROM {}{}{}",
669            select_sql, from_sql, where_sql, order_sql
670        );
671        let stmt = Statement::from_sql_and_values(self.db().get_database_backend(), sql, params);
672        let results = self.db().query_all(stmt).await?;
673
674        Ok(results.into_iter().map(|row| self.map_row(row)).collect())
675    }
676
677    /// 原生 SQL 查询
678    async fn native_query(&self, sql: &str, params: Vec<Value>) -> CoolResult<Vec<Value>> {
679        let values: Vec<sea_orm::Value> =
680            params.into_iter().map(|v| json_to_sea_value(&v)).collect();
681
682        let stmt = Statement::from_sql_and_values(self.db().get_database_backend(), sql, values);
683
684        let results = self.db().query_all(stmt).await?;
685
686        Ok(results.into_iter().map(|row| self.map_row(row)).collect())
687    }
688
689    /// 执行 SQL
690    async fn execute(&self, sql: &str) -> CoolResult<u64> {
691        let stmt = Statement::from_string(self.db().get_database_backend(), sql.to_string());
692        let result = self.db().execute(stmt).await?;
693        Ok(result.rows_affected())
694    }
695
696    /// 修改前置钩子
697    async fn modify_before(&self, data: Value, _modify_type: ModifyType) -> CoolResult<Value> {
698        Ok(data)
699    }
700
701    /// 修改后置钩子
702    async fn modify_after(&self, _data: Value, _modify_type: ModifyType) -> CoolResult<()> {
703        Ok(())
704    }
705
706    /// 行数据映射
707    ///
708    /// 说明:
709    /// - 默认实现返回调试字符串,保证兼容所有数据库类型
710    /// - 具体业务 Service 可以重写此方法,根据表结构把 `QueryResult` 精准映射为字段级 JSON
711    fn map_row(&self, row: sea_orm::QueryResult) -> Value {
712        Value::String(format!("{:?}", row))
713    }
714}
715
716/// JSON Value 转 SeaORM Value
717fn json_to_sea_value(v: &Value) -> sea_orm::Value {
718    match v {
719        Value::Null => sea_orm::Value::String(None),
720        Value::Bool(b) => sea_orm::Value::Bool(Some(*b)),
721        Value::Number(n) => {
722            if let Some(i) = n.as_i64() {
723                sea_orm::Value::BigInt(Some(i))
724            } else if let Some(f) = n.as_f64() {
725                sea_orm::Value::Double(Some(f))
726            } else {
727                sea_orm::Value::String(Some(Box::new(n.to_string())))
728            }
729        }
730        Value::String(s) => sea_orm::Value::String(Some(Box::new(s.clone()))),
731        _ => sea_orm::Value::String(Some(Box::new(v.to_string()))),
732    }
733}
734
735/// 将 SeaORM 行数据转换为 `serde_json::Value`
736///
737/// 说明:
738/// - 遍历行中的列和值,根据列名构建 JSON 对象
739/// - 只处理常见基础类型,其余类型统一转为字符串,保证不 panic
740#[allow(dead_code)]
741fn row_to_json(row: sea_orm::QueryResult) -> Value {
742    // 目前 SeaORM 未公开 QueryResult 的列/值字段结构,
743    // 这里先使用 Debug 格式整体输出,保证接口可用。
744    Value::String(format!("{:?}", row))
745}
746
747/// 校验 SQL 标识符(列名、排序字段等)的安全性
748///
749/// 约束:
750/// - 仅允许字母、数字、下划线、点、逗号和空格
751/// - 不允许出现 SQL 关键字符组合:`--`、`/*`、`*/`、`;`
752fn validate_identifier(ident: &str) -> CoolResult<()> {
753    if ident.is_empty() {
754        return Err(CoolError::validate("字段名不能为空"));
755    }
756
757    if ident.contains("--") || ident.contains("/*") || ident.contains("*/") || ident.contains(';') {
758        return Err(CoolError::validate("字段名包含非法字符"));
759    }
760
761    if !ident
762        .chars()
763        .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.' || c == ',' || c == ' ')
764    {
765        return Err(CoolError::validate(
766            "字段名仅允许字母、数字、下划线、点、逗号和空格",
767        ));
768    }
769
770    Ok(())
771}
772
773/// 校验关键字搜索内容,避免过长或明显注入字符
774fn validate_keyword(kw: &str) -> CoolResult<()> {
775    // 限制关键字长度,避免异常长输入
776    if kw.len() > 256 {
777        return Err(CoolError::validate("关键字过长"));
778    }
779
780    // 简单过滤明显的注入特征
781    if kw.contains("--") || kw.contains("/*") || kw.contains("*/") || kw.contains(';') {
782        return Err(CoolError::validate("关键字包含非法字符"));
783    }
784
785    Ok(())
786}
787
788/// 对查询配置做基础安全检查
789fn validate_query_safety(query: &PageQuery, option: &QueryOption) -> CoolResult<()> {
790    // 校验关键字
791    if let Some(ref kw) = query.key_word {
792        validate_keyword(kw)?;
793    }
794
795    // 校验模糊查询字段
796    for col in &option.key_word_like_fields {
797        validate_identifier(col)?;
798    }
799
800    Ok(())
801}
802
803/// 简化版服务
804///
805/// 用于快速创建服务实例
806pub struct SimpleService {
807    db: Arc<DatabaseConnection>,
808    table: String,
809}
810
811impl SimpleService {
812    pub fn new(db: Arc<DatabaseConnection>, table: impl Into<String>) -> Self {
813        Self {
814            db,
815            table: table.into(),
816        }
817    }
818}
819
820#[async_trait]
821impl BaseService for SimpleService {
822    fn db(&self) -> &DatabaseConnection {
823        &self.db
824    }
825
826    fn table_name(&self) -> &str {
827        &self.table
828    }
829}