rbatis_wrapper/
wrapper.rs

1use rbatis::RBatis;
2use rbatis::Error;
3use serde::Serialize;
4
5// 添加分页结果结构体
6#[derive(Debug, Serialize)]
7pub struct Page<T> {
8    pub records: Vec<T>,         // 数据列表
9    pub total: u64,             // 总记录数
10    pub page_no: u64,           // 当前页码
11    pub page_size: u64,         // 每页大小
12    pub pages: u64,             // 总页数
13    pub has_next: bool,         // 是否有下一页
14}
15
16impl<T> Page<T> {
17    pub fn new(records: Vec<T>, total: u64, page_no: u64, page_size: u64) -> Self {
18        let pages = (total + page_size - 1) / page_size;
19        let has_next = page_no < pages;
20        
21        Self {
22            records,
23            total,
24            page_no,
25            page_size,
26            pages,
27            has_next,
28        }
29    }
30}
31
32/// like mybatis plus
33/// for example:
34/// ```
35/// let count = QueryWrapper::new()
36///     .custom_sql("select count(*) from member")
37///     .get_one::<u64>(&RB, "")
38///     .await?;
39/// println!("count: {:?}", count);
40
41/// #[derive(serde::Deserialize, serde::Serialize, Debug)]
42/// struct Member {
43///     id: u64,
44///     email: Option<String>
45/// }
46
47/// let member = QueryWrapper::new()
48///     .eq("id", 7386)
49///     .get_one::<Member>(&RB, "member")
50///     .await?;
51/// println!("member: {:?}", member);
52
53/// Ok(Json(json!({
54///     "code": 0,
55///     "data": member,
56///     "count": count,
57/// })))
58/// ```
59#[derive(Default, Debug, Clone)]
60pub struct QueryWrapper {
61    where_conditions: Vec<String>,
62    order_by: Vec<String>,
63    select_columns: Vec<String>,
64    limit: Option<u64>,
65    offset: Option<u64>,
66    custom_sql: Option<String>,    // 添加自定义SQL支持
67    join_conditions: Vec<String>,  // 添加JOIN条件支持
68}
69
70impl QueryWrapper {
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    // 等于条件
76    pub fn eq<T: ToString>(mut self, column: &str, value: T) -> Self {
77        self.where_conditions.push(format!("{} = '{}'", column, value.to_string()));
78        self
79    }
80
81    // 不等于条件
82    pub fn ne<T: ToString>(mut self, column: &str, value: T) -> Self {
83        self.where_conditions.push(format!("{} != '{}'", column, value.to_string()));
84        self
85    }
86
87    // 大于条件
88    pub fn gt<T: ToString>(mut self, column: &str, value: T) -> Self {
89        self.where_conditions.push(format!("{} > '{}'", column, value.to_string()));
90        self
91    }
92
93    // 小于条件
94    pub fn lt<T: ToString>(mut self, column: &str, value: T) -> Self {
95        self.where_conditions.push(format!("{} < '{}'", column, value.to_string()));
96        self
97    }
98
99    // LIKE 条件
100    pub fn like(mut self, column: &str, value: &str) -> Self {
101        self.where_conditions.push(format!("{} LIKE '%{}%'", column, value));
102        self
103    }
104
105    // 指定查询列
106    pub fn select(mut self, columns: Vec<&str>) -> Self {
107        self.select_columns = columns.into_iter().map(String::from).collect();
108        self
109    }
110
111    // 排序
112    pub fn order_by(mut self, column: &str, asc: bool) -> Self {
113        let order = if asc { "ASC" } else { "DESC" };
114        self.order_by.push(format!("{} {}", column, order));
115        self
116    }
117
118    // 修改 limit 方法为引用
119    pub fn limit(&mut self, limit: u64) -> &mut Self {
120        self.limit = Some(limit);
121        self
122    }
123
124    // 修改 offset 方法为引用
125    pub fn offset(&mut self, offset: u64) -> &mut Self {
126        self.offset = Some(offset);
127        self
128    }
129
130    // 添加自定义SQL方法
131    pub fn custom_sql(mut self, sql: &str) -> Self {
132        self.custom_sql = Some(sql.to_string());
133        self
134    }
135
136    // 添加 INNER JOIN
137    pub fn inner_join(mut self, table: &str, on_condition: &str) -> Self {
138        self.join_conditions.push(format!("INNER JOIN {} ON {}", table, on_condition));
139        self
140    }
141
142    // 添加 LEFT JOIN
143    pub fn left_join(mut self, table: &str, on_condition: &str) -> Self {
144        self.join_conditions.push(format!("LEFT JOIN {} ON {}", table, on_condition));
145        self
146    }
147
148    // 添加 RIGHT JOIN
149    pub fn right_join(mut self, table: &str, on_condition: &str) -> Self {
150        self.join_conditions.push(format!("RIGHT JOIN {} ON {}", table, on_condition));
151        self
152    }
153
154    // 修改构建SQL语句方法
155    pub fn build_sql(&self, table_name: &str) -> String {
156        // 如果有自定义SQL,直接使用它
157        if let Some(custom_sql) = &self.custom_sql {
158            let mut sql = custom_sql.clone();
159            
160            // 添加WHERE条件
161            if !self.where_conditions.is_empty() {
162                if !sql.to_uppercase().contains("WHERE") {
163                    sql.push_str(" WHERE ");
164                } else {
165                    sql.push_str(" AND ");
166                }
167                sql.push_str(&self.where_conditions.join(" AND "));
168            }
169
170            // 添加排序
171            if !self.order_by.is_empty() {
172                sql.push_str(" ORDER BY ");
173                sql.push_str(&self.order_by.join(", "));
174            }
175
176            // 添加分页
177            if let Some(limit) = self.limit {
178                sql.push_str(&format!(" LIMIT {}", limit));
179            }
180            if let Some(offset) = self.offset {
181                sql.push_str(&format!(" OFFSET {}", offset));
182            }
183
184            return sql;
185        }
186
187        // 常规SQL构建
188        let select = if self.select_columns.is_empty() {
189            "*".to_string()
190        } else {
191            self.select_columns.join(", ")
192        };
193
194        let mut sql = format!("SELECT {} FROM {}", select, table_name);
195
196        // 添加JOIN条件
197        if !self.join_conditions.is_empty() {
198            sql.push_str(" ");
199            sql.push_str(&self.join_conditions.join(" "));
200        }
201
202        if !self.where_conditions.is_empty() {
203            sql.push_str(" WHERE ");
204            sql.push_str(&self.where_conditions.join(" AND "));
205        }
206
207        if !self.order_by.is_empty() {
208            sql.push_str(" ORDER BY ");
209            sql.push_str(&self.order_by.join(", "));
210        }
211
212        if let Some(limit) = self.limit {
213            sql.push_str(&format!(" LIMIT {}", limit));
214        }
215
216        if let Some(offset) = self.offset {
217            sql.push_str(&format!(" OFFSET {}", offset));
218        }
219
220        sql
221    }
222
223    // 执行查询
224    pub async fn query<T>(&self, rb: &RBatis, table_name: &str) -> Result<Vec<T>, Error>
225    where
226        T: Serialize + for<'de> serde::Deserialize<'de>,
227    {
228        let sql = self.build_sql(table_name);
229        rb.query_decode(&sql, vec![]).await
230    }
231
232    // 执行查询
233    pub async fn get_one<T>(&self, rb: &RBatis, table_name: &str) -> Result<Option<T>, Error>
234    where
235        T: Serialize + for<'de> serde::Deserialize<'de>,
236    {
237        let sql = self.build_sql(table_name);
238        rb.query_decode::<Option<T>>(&sql, vec![]).await
239    }
240
241    // 执行删除
242    pub async fn delete(self, rb: &RBatis, table_name: &str) -> Result<u64, Error> {
243        let delete_sql = format!("delete from {}", table_name);
244        let sql = self.custom_sql(&delete_sql)
245            .build_sql(table_name);
246        Ok(rb.exec(&sql, vec![]).await?.rows_affected)
247    }
248
249    // 修改分页方法
250    pub async fn page<T>(&self, rb: &RBatis, table_name: &str, page_no: u64, page_size: u64) -> Result<Page<T>, Error>
251    where
252        T: Serialize + for<'de> serde::Deserialize<'de>,
253    {
254        // 1. 先查询总记录数
255        let count_sql = self.build_count_sql(table_name);
256        let total: u64 = rb.query_decode(&count_sql, vec![]).await?;
257
258        // 2. 如果有数据,再查询分页数据
259        if total > 0 {
260            // 设置分页参数
261            let offset = (page_no - 1) * page_size;
262            let mut wrapper = self.clone();
263            wrapper.limit(page_size);  // 现在这些方法返回 &mut Self
264            wrapper.offset(offset);    // 可以分开调用
265            
266            // 查询分页数据
267            let records: Vec<T> = wrapper.query(rb, table_name).await?;
268            
269            Ok(Page::new(records, total, page_no, page_size))
270        } else {
271            // 没有数据时返回空页
272            Ok(Page::new(vec![], 0, page_no, page_size))
273        }
274    }
275
276    // 修改构建统计SQL方法
277    fn build_count_sql(&self, table_name: &str) -> String {
278        if let Some(custom_sql) = &self.custom_sql {
279            // 将 WHERE 条件放入子查询内部
280            let mut inner_sql = custom_sql.clone();
281            
282            if !self.where_conditions.is_empty() {
283                if !inner_sql.to_uppercase().contains("WHERE") {
284                    inner_sql.push_str(" WHERE ");
285                } else {
286                    inner_sql.push_str(" AND ");
287                }
288                inner_sql.push_str(&self.where_conditions.join(" AND "));
289            }
290
291            // 包装成计数查询
292            format!("SELECT COUNT(*) FROM ({}) as t", inner_sql)
293        } else {
294            let mut sql = format!("SELECT COUNT(*) FROM {}", table_name);
295
296            // 添加JOIN条件
297            if !self.join_conditions.is_empty() {
298                sql.push_str(" ");
299                sql.push_str(&self.join_conditions.join(" "));
300            }
301
302            if !self.where_conditions.is_empty() {
303                sql.push_str(" WHERE ");
304                sql.push_str(&self.where_conditions.join(" AND "));
305            }
306
307            sql
308        }
309    }
310}