pgbatis 0.1.50

pgbaits 用于操作数据库数据的增删改查
Documentation
/*
 * @Author: venom
 * @Date: 2021-08-02 10:21:41
 * @LastEditors: BuddyCoder
 * @LastEditTime: 2025-11-12 18:50:05
 * @Description:
 * @FilePath: /buddyServer/Users/caixiaocong/workspace/rust/rust_crates/pgbatis/src/wrapper.rs
 * MIT
 */
use crate::{ColumnExt, Parameters};
use std::collections::HashMap;
use std::fmt;
use tokio_postgres::types::ToSql;

use tracing::*;

///定义Wrapper返回错误

#[derive(Debug)]
pub enum WrapperError {
    OrderBy,
}

impl std::error::Error for WrapperError {}

impl WrapperError {
    pub fn to_string(&self) -> String {
        match self {
            WrapperError::OrderBy => "根据不同的元素返回响应的字符串信息 for OrderBy".to_string(),
        }
    }
}

impl fmt::Display for WrapperError {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            WrapperError::OrderBy => write!(f, "Error: OrderBy"),
        }
    }
}

#[derive(Clone)]
pub struct Wrapper<'a> {
    //  pub args_number:u32,
    pub where_sql: Vec<String>,
    pub args: Vec<&'a (dyn ToSql + Sync)>,
    pub formats: HashMap<String, String>,
    pub limit: Option<u64>,
    pub page_size: u64,
    pub page_no: u64,
    pub order_by: &'a str,
    pub recoder_field: Vec<&'a str>,
    //     pub order_by : Option<&'a dyn ColumnExt>,
    //    // pub order_by:  Option<&'a (dyn ColumnExt + 'a)>,
    pub desc: bool,
}

fn format_where_sql(temp_where: String, args_number: u32) -> (String, u32) {
    match temp_where.find("#args") {
        Some(_r) => {
            let args_number_str = format!("${}", args_number);
            let result = temp_where.replacen("#args", args_number_str.as_str(), 1);
            let args_numbers = args_number + 1;
            let args_number_result = format_where_sql(result, args_numbers);
            args_number_result
        }
        None => (temp_where, args_number),
    }
}

//防止order_by 进行入侵
pub fn check_order_by(order_by: &str) -> bool {
    let temp = order_by.to_string();
    let jh_index = temp.find("--");
    if jh_index.is_some() {
        return false;
    }

    let fh_index = temp.find(";");

    if fh_index.is_some() {
        return false;
    }
    return true;
}

impl<'a> Wrapper<'a> {
    pub fn new() -> Self {
        Self {
            //  args_number,
            where_sql: Vec::with_capacity(200),
            args: Vec::with_capacity(200),
            formats: Default::default(),
            limit: None,
            page_size: 20,
            page_no: 0,
            order_by: "",
            desc: false,
            recoder_field: Vec::new(),
        }
    }
    /// 等于
    pub fn eq<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
    where
        T: ToSql + Sync,
    {
        let column_name = column.get();
        let s = format!(" \"{}\" = #args ", column_name);
        self.where_sql.push(s);
        self.args.push(obj);
        self
    }

    /// 不等于
    pub fn not_eq<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
    where
        T: ToSql + Sync,
    {
        let column_name = column.get();
        let s = format!(" \"{}\" <> #args ", column_name);
        self.where_sql.push(s);
        self.args.push(obj);
        self
    }

    /// 空值
    pub fn is_null<T>(mut self, column: &dyn ColumnExt) -> Self
    where
        T: ToSql + Sync,
    {
        let column_name = column.get();
        let s = format!(" \"{}\" is NULL ", column_name);
        self.where_sql.push(s);
        self
    }

    /// 非空值
    pub fn is_not_null<T>(mut self, column: &dyn ColumnExt) -> Self
    where
        T: ToSql + Sync,
    {
        let column_name = column.get();
        let s = format!(" \"{}\" is not NULL ", column_name);
        self.where_sql.push(s);
        self
    }

    /// 在范围
    pub fn between<T>(mut self, column: &dyn ColumnExt, min: &'a T, max: &'a T) -> Self
    where
        T: ToSql + Sync,
    {
        let column_name = column.get();
        let s = format!(" \"{}\" between  #args AND #args ", column_name);
        self.where_sql.push(s);
        self.args.push(min);
        self.args.push(max);
        self
    }

    /// 不在范围
    pub fn not_between<T>(mut self, column: &dyn ColumnExt, min: &'a T, max: &'a T) -> Self
    where
        T: ToSql + Sync,
    {
        let column_name = column.get();
        let s = format!(" not between \"{}\" #args AND #args ", column_name);
        self.where_sql.push(s);
        self.args.push(min);
        self.args.push(max);
        self
    }

    /// like 向前匹配还是向后匹配由前端进行控制 在输入的地方增加 % 符号
    pub fn like<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
    where
        T: ToSql + Sync,
    {
        let column_name = column.get();
        let s = format!(" \"{}\" like #args ", column_name);
        self.where_sql.push(s);
        self.args.push(obj);
        self
    }

     

    /// in
    pub fn set_in<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
    where
        T: ToSql + Sync,
    {
        let column_name = column.get();
        let s = format!(" \"{}\" in (#args) ", column_name);
        self.where_sql.push(s);
        self.args.push(obj);
        self
    }

    // 字符串拼接的写法 有安全隐患,可以进行SQL注入攻击。
    //    //只支持字符串
    //     pub fn like_left(mut self, column: &str, obj: &str) -> Self
    //     {
    //         let s = format!(" \"{}\" like '%{}' ",column,obj);
    //         self.where_sql.push(s);
    //         self
    //     }

    //     //只支持字符串
    //     pub fn like_right(mut self, column: &str, obj: &str) -> Self
    //     {
    //         let s = format!(" \"{}\" like '{}%' ",column,obj);
    //         self.where_sql.push(s);
    //         self
    //     }

    /// 大于
    pub fn gt<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
    where
        T: ToSql + Sync,
    {
        let column_name = column.get();
        let s = format!(" \"{}\" > #args ", column_name);
        self.where_sql.push(s);
        self.args.push(obj);
        self
    }

    /// 大于等于
    pub fn ge<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
    where
        T: ToSql + Sync,
    {
        let column_name = column.get();
        let s = format!(" \"{}\" >= #args ", column_name);
        self.where_sql.push(s);
        self.args.push(obj);
        self
    }

    /// 小于
    pub fn lt<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
    where
        T: ToSql + Sync,
    {
        let column_name = column.get();
        let s = format!(" \"{}\" < #args ", column_name);
        self.where_sql.push(s);
        self.args.push(obj);
        self
    }

    /// 小于等于
    pub fn le<T>(mut self, column: &dyn ColumnExt, obj: &'a T) -> Self
    where
        T: ToSql + Sync,
    {
        let column_name = column.get();
        let s = format!(" \"{}\" <= #args ", column_name);
        self.where_sql.push(s);
        self.args.push(obj);
        self
    }

    //偏移返回
    pub fn limit(mut self, limit: u64) -> Self {
        self.limit = Some(limit);
        self
    }

    //设置翻页
    pub fn set_pages(mut self, page_no: u64, page_size: u64) -> Self {
        self.page_no = page_no;
        self.page_size = page_size;
        self
    }

    ///设置排序
    pub fn set_order_by(mut self, order_by: &'a str, desc: bool) -> Result<Self, WrapperError> {
        if !check_order_by(order_by) {
            return Err(WrapperError::OrderBy);
        }
        self.order_by = order_by;
        self.desc = desc;
        Ok(self)
    }

    ///设置排序参数是ColumnExt 类型
    pub fn set_order_by_column_ext(mut self, order_by: &dyn ColumnExt, desc: bool) -> Self {
        self.order_by = order_by.get();
        self.desc = desc;
        self
    }

    pub fn get_order_by(self) -> Option<String> {
        if self.order_by.is_empty() {
            return None;
        } else {
            let order_by = if self.desc {
                format!(" ORDER BY {}  DESC ", self.order_by)
            } else {
                format!(" ORDER BY {}  ASC ", self.order_by)
            };
            return Some(order_by);
        }
    }

    // (limit_str,limit,self.page_no,self.page_size)
    pub fn get_page_info(self) -> (String, u64, u64, u64) {
        debug!(
            "get_page_info: page_no {},page_size {}",
            self.page_no, self.page_size
        );
        let page_no = if self.page_no < 1 { 1 } else { self.page_no };
        let limit = (page_no - 1) * self.page_size;
        let limit_str = format!(" LIMIT {} OFFSET {}", self.page_size, limit);
        (limit_str, limit, page_no, self.page_size)
    }
    /**

     * @description: 
     * @param {*} self
     * @param {u32} args_number  从1开始
     * @param {*} Vec
     * @return {*}
     */    
    pub fn build(self, args_number: u32) -> Result<(String, Vec<&'a (dyn ToSql + Sync)>), String> {
        let mut args_number = args_number;
        let mut temp_where_sql = Vec::new();
        for mut temp_where in self.where_sql.iter() {
            let (wheresql, temp_args_number) =
                format_where_sql((&mut temp_where).to_string(), args_number);
            args_number = temp_args_number;
            temp_where_sql.push(wheresql);
        }
        let sql = temp_where_sql.join(" AND ");
        let args = self.args;
        return Ok((sql, args));
    }

    //设置返回的字段名称
    pub fn set_recoder_field(mut self, column: &dyn ColumnExt) -> Self {
        let column_name = column.get();
        self.recoder_field.push(column_name);
        self
    }

    pub fn get_recoder_field<T>(self) -> String
    where
        T: Parameters,
    {
        if self.recoder_field.len() == 0 {
            T::get_field_list()
        } else {
            self.recoder_field.join(",")
        }
    }


    pub fn in_array_string(mut self, column: &dyn ColumnExt, obj: &Vec<String>) -> Self          
    {
        let column_name = column.get();
        let array_str = obj.iter().map(|s| format!("'{}'", s)).collect::<Vec<String>>().join(", ");
        //let my_text_array = postgres_array::Array::from_vec(obj).to_string();
        let s = format!(" \"{}\" in ({}) ", column_name,array_str);
        self.where_sql.push(s);
        //self.args.push();
        self
    }
}