pgbatis/
wrapper.rs

1/*
2 * @Author: venom
3 * @Date: 2021-08-02 10:21:41
4 * @LastEditors: BuddyCoder
5 * @LastEditTime: 2023-05-15 10:11:57
6 * @Description:
7 * @FilePath: /pgbatis/src/wrapper.rs
8 * MIT
9 */
10use crate::{ColumnExt, Parameters};
11use std::collections::HashMap;
12use std::fmt;
13use tokio_postgres::types::ToSql;
14
15use tracing::*;
16
17///定义Wrapper返回错误
18
19#[derive(Debug)]
20pub enum WrapperError {
21    OrderBy,
22}
23
24impl std::error::Error for WrapperError {}
25
26impl WrapperError {
27    pub fn to_string(&self) -> String {
28        match self {
29            WrapperError::OrderBy => "根据不同的元素返回响应的字符串信息 for OrderBy".to_string(),
30        }
31    }
32}
33
34impl fmt::Display for WrapperError {
35    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
36        match self {
37            WrapperError::OrderBy => write!(f, "Error: OrderBy"),
38        }
39    }
40}
41
42#[derive(Clone)]
43pub struct Wrapper<'a> {
44    //  pub args_number:u32,
45    pub where_sql: Vec<String>,
46    pub args: Vec<&'a (dyn ToSql + Sync)>,
47    pub formats: HashMap<String, String>,
48    pub limit: Option<u64>,
49    pub page_size: u64,
50    pub page_no: u64,
51    pub order_by: &'a str,
52    pub recoder_field: Vec<&'a str>,
53    //     pub order_by : Option<&'a dyn ColumnExt>,
54    //    // pub order_by:  Option<&'a (dyn ColumnExt + 'a)>,
55    pub desc: bool,
56}
57
58fn format_where_sql(temp_where: String, args_number: u32) -> (String, u32) {
59    match temp_where.find("#args") {
60        Some(_r) => {
61            let args_number_str = format!("${}", args_number);
62            let result = temp_where.replacen("#args", args_number_str.as_str(), 1);
63            let args_numbers = args_number + 1;
64            let args_number_result = format_where_sql(result, args_numbers);
65            args_number_result
66        }
67        None => (temp_where, args_number),
68    }
69}
70
71//防止order_by 进行入侵
72pub fn check_order_by(order_by: &str) -> bool {
73    let temp = order_by.to_string();
74    let jh_index = temp.find("--");
75    if jh_index.is_some() {
76        return false;
77    }
78
79    let fh_index = temp.find(";");
80
81    if fh_index.is_some() {
82        return false;
83    }
84    return true;
85}
86
87impl<'a> Wrapper<'a> {
88    pub fn new() -> Self {
89        Self {
90            //  args_number,
91            where_sql: Vec::with_capacity(200),
92            args: Vec::with_capacity(200),
93            formats: Default::default(),
94            limit: None,
95            page_size: 20,
96            page_no: 0,
97            order_by: "",
98            desc: false,
99            recoder_field: Vec::new(),
100        }
101    }
102    /// 等于
103    pub fn eq<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
104    where
105        T: ToSql + Sync,
106    {
107        let column_name = column.get();
108        let s = format!(" \"{}\" = #args ", column_name);
109        self.where_sql.push(s);
110        self.args.push(obj);
111        self
112    }
113
114    /// 不等于
115    pub fn not_eq<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
116    where
117        T: ToSql + Sync,
118    {
119        let column_name = column.get();
120        let s = format!(" \"{}\" <> #args ", column_name);
121        self.where_sql.push(s);
122        self.args.push(obj);
123        self
124    }
125
126    /// 空值
127    pub fn is_null<T>(mut self, column: &dyn ColumnExt) -> Self
128    where
129        T: ToSql + Sync,
130    {
131        let column_name = column.get();
132        let s = format!(" \"{}\" is NULL ", column_name);
133        self.where_sql.push(s);
134        self
135    }
136
137    /// 非空值
138    pub fn is_not_null<T>(mut self, column: &dyn ColumnExt) -> Self
139    where
140        T: ToSql + Sync,
141    {
142        let column_name = column.get();
143        let s = format!(" \"{}\" is not NULL ", column_name);
144        self.where_sql.push(s);
145        self
146    }
147
148    /// 在范围
149    pub fn between<T>(mut self, column: &dyn ColumnExt, min: &'a T, max: &'a T) -> Self
150    where
151        T: ToSql + Sync,
152    {
153        let column_name = column.get();
154        let s = format!(" between \"{}\" #args AND #args ", column_name);
155        self.where_sql.push(s);
156        self.args.push(min);
157        self.args.push(max);
158        self
159    }
160
161    /// 不在范围
162    pub fn not_between<T>(mut self, column: &dyn ColumnExt, min: &'a T, max: &'a T) -> Self
163    where
164        T: ToSql + Sync,
165    {
166        let column_name = column.get();
167        let s = format!(" not between \"{}\" #args AND #args ", column_name);
168        self.where_sql.push(s);
169        self.args.push(min);
170        self.args.push(max);
171        self
172    }
173
174    /// like 向前匹配还是向后匹配由前端进行控制
175    pub fn like<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
176    where
177        T: ToSql + Sync,
178    {
179        let column_name = column.get();
180        let s = format!(" \"{}\" like #args ", column_name);
181        self.where_sql.push(s);
182        self.args.push(obj);
183        self
184    }
185
186    /// in
187    pub fn set_in<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
188    where
189        T: ToSql + Sync,
190    {
191        let column_name = column.get();
192        let s = format!(" \"{}\" in (#args) ", column_name);
193        self.where_sql.push(s);
194        self.args.push(obj);
195        self
196    }
197
198    // 字符串拼接的写法 有安全隐患,可以进行SQL注入攻击。
199    //    //只支持字符串
200    //     pub fn like_left(mut self, column: &str, obj: &str) -> Self
201    //     {
202    //         let s = format!(" \"{}\" like '%{}' ",column,obj);
203    //         self.where_sql.push(s);
204    //         self
205    //     }
206
207    //     //只支持字符串
208    //     pub fn like_right(mut self, column: &str, obj: &str) -> Self
209    //     {
210    //         let s = format!(" \"{}\" like '{}%' ",column,obj);
211    //         self.where_sql.push(s);
212    //         self
213    //     }
214
215    /// 大于
216    pub fn gt<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
217    where
218        T: ToSql + Sync,
219    {
220        let column_name = column.get();
221        let s = format!(" \"{}\" > #args ", column_name);
222        self.where_sql.push(s);
223        self.args.push(obj);
224        self
225    }
226
227    /// 大于等于
228    pub fn ge<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
229    where
230        T: ToSql + Sync,
231    {
232        let column_name = column.get();
233        let s = format!(" \"{}\" >= #args ", column_name);
234        self.where_sql.push(s);
235        self.args.push(obj);
236        self
237    }
238
239    /// 小于
240    pub fn lt<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
241    where
242        T: ToSql + Sync,
243    {
244        let column_name = column.get();
245        let s = format!(" \"{}\" < #args ", column_name);
246        self.where_sql.push(s);
247        self.args.push(obj);
248        self
249    }
250
251    /// 小于等于
252    pub fn le<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
253    where
254        T: ToSql + Sync,
255    {
256        let column_name = column.get();
257        let s = format!(" \"{}\" <= #args ", column_name);
258        self.where_sql.push(s);
259        self.args.push(obj);
260        self
261    }
262
263    //偏移返回
264    pub fn limit(mut self, limit: u64) -> Self {
265        self.limit = Some(limit);
266        self
267    }
268
269    //设置翻页
270    pub fn set_pages(mut self, page_no: u64, page_size: u64) -> Self {
271        self.page_no = page_no;
272        self.page_size = page_size;
273        self
274    }
275
276    ///设置排序
277    pub fn set_order_by(mut self, order_by: &'a str, desc: bool) -> Result<Self, WrapperError> {
278        if !check_order_by(order_by) {
279            return Err(WrapperError::OrderBy);
280        }
281        self.order_by = order_by;
282        self.desc = desc;
283        Ok(self)
284    }
285
286    ///设置排序参数是ColumnExt 类型
287    pub fn set_order_by_column_ext(mut self, order_by: &dyn ColumnExt, desc: bool) -> Self {
288        self.order_by = order_by.get();
289        self.desc = desc;
290        self
291    }
292
293    pub fn get_order_by(self) -> Option<String> {
294        if self.order_by.is_empty() {
295            return None;
296        } else {
297            let order_by = if self.desc {
298                format!(" ORDER BY {}  DESC ", self.order_by)
299            } else {
300                format!(" ORDER BY {}  ASC ", self.order_by)
301            };
302            return Some(order_by);
303        }
304    }
305
306    // (limit_str,limit,self.page_no,self.page_size)
307    pub fn get_page_info(self) -> (String, u64, u64, u64) {
308        debug!(
309            "get_page_info: page_no {},page_size {}",
310            self.page_no, self.page_size
311        );
312        let page_no = if self.page_no < 1 { 1 } else { self.page_no };
313        let limit = (page_no - 1) * self.page_size;
314        let limit_str = format!(" LIMIT {} OFFSET {}", self.page_size, limit);
315        (limit_str, limit, page_no, self.page_size)
316    }
317    /**
318     * @description: 
319     * @param {*} self
320     * @param {u32} args_number  从1开始
321     * @param {*} Vec
322     * @return {*}
323     */    
324    pub fn build(self, args_number: u32) -> Result<(String, Vec<&'a (dyn ToSql + Sync)>), String> {
325        let mut args_number = args_number;
326        let mut temp_where_sql = Vec::new();
327        for mut temp_where in self.where_sql.iter() {
328            let (wheresql, temp_args_number) =
329                format_where_sql((&mut temp_where).to_string(), args_number);
330            args_number = temp_args_number;
331            temp_where_sql.push(wheresql);
332        }
333        let sql = temp_where_sql.join(" AND ");
334        let args = self.args;
335        return Ok((sql, args));
336    }
337
338    //设置返回的字段名称
339    pub fn set_recoder_field(mut self, column: &dyn ColumnExt) -> Self {
340        let column_name = column.get();
341        self.recoder_field.push(column_name);
342        self
343    }
344
345    pub fn get_recoder_field<T>(self) -> String
346    where
347        T: Parameters,
348    {
349        if self.recoder_field.len() == 0 {
350            T::get_field_list()
351        } else {
352            self.recoder_field.join(",")
353        }
354    }
355
356
357    pub fn in_array_string(mut self, column: &dyn ColumnExt, obj: &Vec<String>) -> Self          
358    {
359        let column_name = column.get();
360        let array_str = obj.iter().map(|s| format!("'{}'", s)).collect::<Vec<String>>().join(", ");
361        //let my_text_array = postgres_array::Array::from_vec(obj).to_string();
362        let s = format!(" \"{}\" in ({}) ", column_name,array_str);
363        self.where_sql.push(s);
364        //self.args.push();
365        self
366    }
367}