rbatis_wrapper/
wrapper.rs

1use rbatis::RBatis;
2use rbatis::Error;
3use serde::Serialize;
4
5// 添加分页结果结构体
6#[derive(Debug, Serialize)]
7pub struct Page<T> {
8    pub records: Vec<T>,         // 数据列表
9    pub total: u64,             // 总记录数
10    pub page_no: u64,           // 当前页码
11    pub page_size: u64,         // 每页大小
12    pub pages: u64,             // 总页数
13    pub has_next: bool,         // 是否有下一页
14}
15
16impl<T> Page<T> {
17    pub fn new(records: Vec<T>, total: u64, page_no: u64, page_size: u64) -> Self {
18        let pages = (total + page_size - 1) / page_size;
19        let has_next = page_no < pages;
20        
21        Self {
22            records,
23            total,
24            page_no,
25            page_size,
26            pages,
27            has_next,
28        }
29    }
30}
31
32/// like mybatis plus
33/// for example:
34/// ```
35/// let count = QueryWrapper::new()
36///     .custom_sql("select count(*) from member")
37///     .get_one::<u64>(&RB, "")
38///     .await?;
39/// println!("count: {:?}", count);
40
41/// #[derive(serde::Deserialize, serde::Serialize, Debug)]
42/// struct Member {
43///     id: u64,
44///     email: Option<String>
45/// }
46
47/// let member = QueryWrapper::new()
48///     .eq("id", 7386)
49///     .get_one::<Member>(&RB, "member")
50///     .await?;
51/// println!("member: {:?}", member);
52
53/// Ok(Json(json!({
54///     "code": 0,
55///     "data": member,
56///     "count": count,
57/// })))
58/// ```
59#[derive(Default, Debug, Clone)]
60pub struct QueryWrapper {
61    where_conditions: Vec<String>,
62    order_by: Vec<String>,
63    select_columns: Vec<String>,
64    limit: Option<u64>,
65    offset: Option<u64>,
66    custom_sql: Option<String>,    // 添加自定义SQL支持
67    join_conditions: Vec<String>,  // 添加JOIN条件支持
68}
69
70impl QueryWrapper {
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    // 等于条件
76    pub fn eq<T: ToString>(mut self, column: &str, value: T) -> Self {
77        self.where_conditions.push(format!("{} = '{}'", column, value.to_string()));
78        self
79    }
80
81    // 不等于条件
82    pub fn ne<T: ToString>(mut self, column: &str, value: T) -> Self {
83        self.where_conditions.push(format!("{} != '{}'", column, value.to_string()));
84        self
85    }
86
87    // 大于条件
88    pub fn gt<T: ToString>(mut self, column: &str, value: T) -> Self {
89        self.where_conditions.push(format!("{} > '{}'", column, value.to_string()));
90        self
91    }
92
93    // 小于条件
94    pub fn lt<T: ToString>(mut self, column: &str, value: T) -> Self {
95        self.where_conditions.push(format!("{} < '{}'", column, value.to_string()));
96        self
97    }
98
99    // LIKE 条件
100    pub fn like(mut self, column: &str, value: &str) -> Self {
101        self.where_conditions.push(format!("{} LIKE '%{}%'", column, value));
102        self
103    }
104
105    // 指定查询列
106    pub fn select(mut self, columns: Vec<&str>) -> Self {
107        self.select_columns = columns.into_iter().map(String::from).collect();
108        self
109    }
110
111    // 排序
112    pub fn order_by(mut self, column: &str, asc: bool) -> Self {
113        let order = if asc { "ASC" } else { "DESC" };
114        self.order_by.push(format!("{} {}", column, order));
115        self
116    }
117
118    // 修改 limit 方法为引用
119    pub fn limit(&mut self, limit: u64) -> &mut Self {
120        self.limit = Some(limit);
121        self
122    }
123
124    // 修改 offset 方法为引用
125    pub fn offset(&mut self, offset: u64) -> &mut Self {
126        self.offset = Some(offset);
127        self
128    }
129
130    // 添加自定义SQL方法
131    pub fn custom_sql(mut self, sql: &str) -> Self {
132        self.custom_sql = Some(sql.to_string());
133        self
134    }
135
136    // 添加 INNER JOIN
137    pub fn inner_join(mut self, table: &str, on_condition: &str) -> Self {
138        self.join_conditions.push(format!("INNER JOIN {} ON {}", table, on_condition));
139        self
140    }
141
142    // 添加 LEFT JOIN
143    pub fn left_join(mut self, table: &str, on_condition: &str) -> Self {
144        self.join_conditions.push(format!("LEFT JOIN {} ON {}", table, on_condition));
145        self
146    }
147
148    // 添加 RIGHT JOIN
149    pub fn right_join(mut self, table: &str, on_condition: &str) -> Self {
150        self.join_conditions.push(format!("RIGHT JOIN {} ON {}", table, on_condition));
151        self
152    }
153
154    // 修改构建SQL语句方法
155    pub fn build_sql(&self, table_name: &str) -> String {
156        // 如果有自定义SQL,直接使用它
157        if let Some(custom_sql) = &self.custom_sql {
158            let mut sql = custom_sql.clone();
159            
160            // 添加WHERE条件
161            if !self.where_conditions.is_empty() {
162                if !sql.to_uppercase().contains("WHERE") {
163                    sql.push_str(" WHERE ");
164                } else {
165                    sql.push_str(" AND ");
166                }
167                sql.push_str(&self.where_conditions.join(" AND "));
168            }
169
170            // 添加排序
171            if !self.order_by.is_empty() {
172                sql.push_str(" ORDER BY ");
173                sql.push_str(&self.order_by.join(", "));
174            }
175
176            // 添加分页
177            if let Some(limit) = self.limit {
178                sql.push_str(&format!(" LIMIT {}", limit));
179            }
180            if let Some(offset) = self.offset {
181                sql.push_str(&format!(" OFFSET {}", offset));
182            }
183
184            return sql;
185        }
186
187        // 常规SQL构建
188        let select = if self.select_columns.is_empty() {
189            "*".to_string()
190        } else {
191            self.select_columns.join(", ")
192        };
193
194        let mut sql = format!("SELECT {} FROM {}", select, table_name);
195
196        // 添加JOIN条件
197        if !self.join_conditions.is_empty() {
198            sql.push_str(" ");
199            sql.push_str(&self.join_conditions.join(" "));
200        }
201
202        if !self.where_conditions.is_empty() {
203            sql.push_str(" WHERE ");
204            sql.push_str(&self.where_conditions.join(" AND "));
205        }
206
207        if !self.order_by.is_empty() {
208            sql.push_str(" ORDER BY ");
209            sql.push_str(&self.order_by.join(", "));
210        }
211
212        if let Some(limit) = self.limit {
213            sql.push_str(&format!(" LIMIT {}", limit));
214        }
215
216        if let Some(offset) = self.offset {
217            sql.push_str(&format!(" OFFSET {}", offset));
218        }
219
220        sql
221    }
222
223    // 执行查询
224    pub async fn query<T>(&self, rb: &RBatis, table_name: &str) -> Result<Vec<T>, Error>
225    where
226        T: Serialize + for<'de> serde::Deserialize<'de>,
227    {
228        let sql = self.build_sql(table_name);
229        rb.query_decode(&sql, vec![]).await
230    }
231
232    // 执行查询
233    pub async fn get_one<T>(&self, rb: &RBatis, table_name: &str) -> Result<Option<T>, Error>
234    where
235        T: Serialize + for<'de> serde::Deserialize<'de>,
236    {
237        let sql = self.build_sql(table_name);
238        rb.query_decode::<Option<T>>(&sql, vec![]).await
239    }
240
241    // 修改分页方法
242    pub async fn page<T>(&self, rb: &RBatis, table_name: &str, page_no: u64, page_size: u64) -> Result<Page<T>, Error>
243    where
244        T: Serialize + for<'de> serde::Deserialize<'de>,
245    {
246        // 1. 先查询总记录数
247        let count_sql = self.build_count_sql(table_name);
248        let total: u64 = rb.query_decode(&count_sql, vec![]).await?;
249
250        // 2. 如果有数据,再查询分页数据
251        if total > 0 {
252            // 设置分页参数
253            let offset = (page_no - 1) * page_size;
254            let mut wrapper = self.clone();
255            wrapper.limit(page_size);  // 现在这些方法返回 &mut Self
256            wrapper.offset(offset);    // 可以分开调用
257            
258            // 查询分页数据
259            let records: Vec<T> = wrapper.query(rb, table_name).await?;
260            
261            Ok(Page::new(records, total, page_no, page_size))
262        } else {
263            // 没有数据时返回空页
264            Ok(Page::new(vec![], 0, page_no, page_size))
265        }
266    }
267
268    // 修改构建统计SQL方法
269    fn build_count_sql(&self, table_name: &str) -> String {
270        if let Some(custom_sql) = &self.custom_sql {
271            // 将 WHERE 条件放入子查询内部
272            let mut inner_sql = custom_sql.clone();
273            
274            if !self.where_conditions.is_empty() {
275                if !inner_sql.to_uppercase().contains("WHERE") {
276                    inner_sql.push_str(" WHERE ");
277                } else {
278                    inner_sql.push_str(" AND ");
279                }
280                inner_sql.push_str(&self.where_conditions.join(" AND "));
281            }
282
283            // 包装成计数查询
284            format!("SELECT COUNT(*) FROM ({}) as t", inner_sql)
285        } else {
286            let mut sql = format!("SELECT COUNT(*) FROM {}", table_name);
287
288            // 添加JOIN条件
289            if !self.join_conditions.is_empty() {
290                sql.push_str(" ");
291                sql.push_str(&self.join_conditions.join(" "));
292            }
293
294            if !self.where_conditions.is_empty() {
295                sql.push_str(" WHERE ");
296                sql.push_str(&self.where_conditions.join(" AND "));
297            }
298
299            sql
300        }
301    }
302}