co_orm/
filter.rs

1#[cfg(feature = "mysql")]
2type DbArgs = sqlx::mysql::MySqlArguments;
3#[cfg(feature = "postgres")]
4type DbArgs = sqlx::postgres::PgArguments;
5#[cfg(feature = "sqlite")]
6type DbArgs = sqlx::sqlite::SqliteArguments;
7// #[cfg(feature = "any")]
8// type DbArgs = sqlx::any::AnyArguments<'static>;
9#[cfg(all(
10    not(feature = "mysql"),
11    not(feature = "postgres"),
12    not(feature = "sqlite"),
13    // not(feature = "any"),
14    // not(feature = "mssql"),
15))]
16type DbArgs = sqlx::mysql::MySqlArguments;
17
18// #[cfg(feature = "mssql")]
19// type DbArgs = sqlx::mssql::MssqlArguments;
20
21pub trait BindArg {
22    fn bind_to(self, args: &mut DbArgs);
23}
24
25#[cfg(feature = "mysql")]
26impl<T> BindArg for T
27where
28    T: 'static + sqlx::Encode<'static, sqlx::MySql> + sqlx::Type<sqlx::MySql>,
29{
30    fn bind_to(self, args: &mut DbArgs) {
31        use sqlx::Arguments as _;
32        let _ = args.add(self);
33    }
34}
35
36#[cfg(feature = "postgres")]
37impl<T> BindArg for T
38where
39    T: 'static + sqlx::Encode<'static, sqlx::Postgres> + sqlx::Type<sqlx::Postgres>,
40{
41    fn bind_to(self, args: &mut DbArgs) {
42        use sqlx::Arguments as _;
43        let _ = args.add(self);
44    }
45}
46
47#[cfg(feature = "sqlite")]
48impl<T> BindArg for T
49where
50    T: 'static + sqlx::Encode<'static, sqlx::Sqlite> + sqlx::Type<sqlx::Sqlite>,
51{
52    fn bind_to(self, args: &mut DbArgs) {
53        use sqlx::Arguments as _;
54        let _ = args.add(self);
55    }
56}
57
58// #[cfg(feature = "any")]
59// impl<T> BindArg for T
60// where
61//     T: 'static + sqlx::Encode<'static, sqlx::Any> + sqlx::Type<sqlx::Any>,
62// {
63//     fn bind_to(self, args: &mut DbArgs) {
64//         use sqlx::Arguments as _;
65//         let _ = args.add(self);
66//     }
67// }
68
69// #[cfg(feature = "mssql")]
70// impl<T> BindArg for T
71// where
72//     T: 'static + sqlx::Encode<'static, sqlx::Mssql> + sqlx::Type<sqlx::Mssql>,
73// {
74//     fn bind_to(self, args: &mut DbArgs) {
75//         use sqlx::Arguments as _;
76//         let _ = args.add(self);
77//     }
78// }
79
80/// A builder for WHERE clauses
81/// # Example:
82/// ```ignore
83/// let w = Where::new().eq("name", "admin").and().ge("age", 18);
84/// ```
85/// will generate:
86/// ```ignore
87/// WHERE name = ? AND age >= ?
88/// ```
89#[derive(Default, Clone, Debug)]
90pub struct Where {
91    sql: String,
92    args: DbArgs,
93    next_index: usize,
94    pending_logic: Option<&'static str>,
95    has_any_predicate: bool,
96    skip_next_logic: bool,
97}
98
99impl Where {
100    /// Create a new Where builder
101    /// # Example:
102    /// ```ignore
103    /// let w = Where::new();
104    /// ```
105    pub fn new() -> Self {
106        Self {
107            sql: String::from("WHERE "),
108            args: Default::default(),
109            next_index: 1,
110            pending_logic: None,
111            has_any_predicate: false,
112            skip_next_logic: false,
113        }
114    }
115
116    /// Add AND logic to the WHERE clause
117    /// # Example:
118    /// ```ignore
119    /// Where::new().eq("name", "admin").and().eq("age", 18);
120    /// ```
121    /// will generate:
122    /// ```ignore
123    /// WHERE name = ? AND age = ?
124    /// ```
125    pub fn and(mut self) -> Self {
126        self.pending_logic = Some("AND");
127        self
128    }
129
130    /// Add OR logic to the WHERE clause
131    /// # Example:
132    /// ```ignore
133    /// Where::new().eq("name", "admin").or().eq("age", 18);
134    /// ```
135    /// will generate:
136    /// ```ignore
137    /// WHERE name = ? OR age = ?
138    /// ```
139    pub fn or(mut self) -> Self {
140        self.pending_logic = Some("OR");
141        self
142    }
143
144    /// Push logic if needed
145    fn push_logic_if_needed(&mut self) {
146        if self.skip_next_logic {
147            self.skip_next_logic = false;
148            return;
149        }
150        if self.has_any_predicate {
151            let logic = self.pending_logic.take().unwrap_or("AND");
152            self.sql.push_str(logic);
153            self.sql.push(' ');
154        }
155    }
156
157    /// Generate a placeholder for the next value
158    fn placeholder(&mut self) -> String {
159        #[cfg(feature = "postgres")]
160        {
161            let p = format!("${}", self.next_index);
162            self.next_index += 1;
163            return p;
164        }
165        #[cfg(feature = "sqlite")]
166        {
167            let p = format!("?{}", self.next_index);
168            self.next_index += 1;
169            return p;
170        }
171        #[cfg(feature = "mysql")]
172        {
173            self.next_index += 1;
174            return "?".to_string();
175        }
176
177        // #[cfg(feature = "mssql")]
178        // {
179        //     self.next_index += 1;
180        //     return "?".to_string();
181        // }
182        #[allow(unreachable_code)]
183        {
184            self.next_index += 1;
185            "?".to_string()
186        }
187    }
188
189    /// Add a raw SQL fragment to the WHERE clause
190    /// # Example:
191    /// ```ignore
192    /// Where::new().raw("name = admin AND age >= 18");
193    /// ```
194    pub fn raw(mut self, fragment: &str) -> Self {
195        self.push_logic_if_needed();
196        self.sql.push_str(fragment);
197        self.sql.push(' ');
198        self.has_any_predicate = true;
199        self
200    }
201
202    /// Add an equality condition(=) to the WHERE clause
203    /// # Example:
204    /// ```ignore
205    /// Where::new().eq("name", "admin");
206    /// ```
207    /// will generate:
208    /// ```ignore
209    /// WHERE name = ?
210    /// ```
211    pub fn eq(mut self, col: &str, value: impl BindArg) -> Self {
212        self.push_logic_if_needed();
213        let ph = self.placeholder();
214        self.sql.push_str(col);
215        self.sql.push_str(" = ");
216        self.sql.push_str(&ph);
217        self.sql.push(' ');
218        self.has_any_predicate = true;
219        value.bind_to(&mut self.args);
220        self
221    }
222
223    /// Add a non-equality condition(<>) to the WHERE clause
224    /// # Example:
225    /// ```ignore
226    /// Where::new().ne("name", "admin");
227    /// ```
228    /// will generate:
229    /// ```ignore
230    /// WHERE name <> ?
231    /// ```
232    pub fn ne(mut self, col: &str, value: impl BindArg) -> Self {
233        self.push_logic_if_needed();
234        let ph = self.placeholder();
235        self.sql.push_str(col);
236        self.sql.push_str(" <> ");
237        self.sql.push_str(&ph);
238        self.sql.push(' ');
239        self.has_any_predicate = true;
240        value.bind_to(&mut self.args);
241        self
242    }
243
244    /// Add a less than condition(<) to the WHERE clause
245    /// # Example:
246    /// ```ignore
247    /// Where::new().lt("age", 18);
248    /// ```
249    /// will generate:
250    /// ```ignore
251    /// WHERE age < ?
252    /// ```
253    pub fn lt(mut self, col: &str, value: impl BindArg) -> Self {
254        self.push_logic_if_needed();
255        let ph = self.placeholder();
256        self.sql.push_str(col);
257        self.sql.push_str(" < ");
258        self.sql.push_str(&ph);
259        self.sql.push(' ');
260        self.has_any_predicate = true;
261        value.bind_to(&mut self.args);
262        self
263    }
264
265    /// Add a less than or equal to condition(<=) to the WHERE clause
266    /// # Example:
267    /// ```ignore
268    /// Where::new().le("age", 18);
269    /// ```
270    /// will generate:
271    /// ```ignore
272    /// WHERE age <= ?
273    /// ```
274    pub fn le(mut self, col: &str, value: impl BindArg) -> Self {
275        self.push_logic_if_needed();
276        let ph = self.placeholder();
277        self.sql.push_str(col);
278        self.sql.push_str(" <= ");
279        self.sql.push_str(&ph);
280        self.sql.push(' ');
281        self.has_any_predicate = true;
282        value.bind_to(&mut self.args);
283        self
284    }
285
286    /// Add a greater than condition(>) to the WHERE clause
287    /// # Example:
288    /// ```ignore
289    /// Where::new().gt("age", 18);
290    /// ```
291    /// will generate:
292    /// ```ignore
293    /// WHERE age > ?
294    /// ```
295    pub fn gt(mut self, col: &str, value: impl BindArg) -> Self {
296        self.push_logic_if_needed();
297        let ph = self.placeholder();
298        self.sql.push_str(col);
299        self.sql.push_str(" > ");
300        self.sql.push_str(&ph);
301        self.sql.push(' ');
302        self.has_any_predicate = true;
303        value.bind_to(&mut self.args);
304        self
305    }
306
307    /// Add a greater than or equal to condition(>=) to the WHERE clause
308    /// # Example:
309    /// ```ignore
310    /// Where::new().ge("age", 18);
311    /// ```
312    /// will generate:
313    /// ```ignore
314    /// WHERE age >= ?
315    /// ```
316    pub fn ge(mut self, col: &str, value: impl BindArg) -> Self {
317        self.push_logic_if_needed();
318        let ph = self.placeholder();
319        self.sql.push_str(col);
320        self.sql.push_str(" >= ");
321        self.sql.push_str(&ph);
322        self.sql.push(' ');
323        self.has_any_predicate = true;
324        value.bind_to(&mut self.args);
325        self
326    }
327
328    /// Add a LIKE condition to the WHERE clause
329    /// # Example:
330    /// ```ignore
331    /// Where::new().like("name", "%admin%");
332    /// ```
333    /// will generate:
334    /// ```ignore
335    /// WHERE name LIKE ?
336    /// ```
337    pub fn like(mut self, col: &str, value: impl BindArg) -> Self {
338        self.push_logic_if_needed();
339        let ph = self.placeholder();
340        self.sql.push_str(col);
341        self.sql.push_str(" LIKE ");
342        self.sql.push_str(&ph);
343        self.sql.push(' ');
344        self.has_any_predicate = true;
345        value.bind_to(&mut self.args);
346        self
347    }
348
349    /// Add a IS NULL condition to the WHERE clause
350    /// # Example:
351    /// ```ignore
352    /// Where::new().is_null("status");
353    /// ```
354    /// will generate:
355    /// ```ignore
356    /// WHERE status IS NULL
357    /// ```
358    pub fn is_null(mut self, col: &str) -> Self {
359        self.push_logic_if_needed();
360        self.sql.push_str(col);
361        self.sql.push_str(" IS NULL ");
362        self.has_any_predicate = true;
363        self
364    }
365
366    /// Add a IS NOT NULL condition to the WHERE clause
367    /// # Example:
368    /// ```ignore
369    /// Where::new().is_not_null("status");
370    /// ```
371    /// will generate:
372    /// ```ignore
373    /// WHERE status IS NOT NULL
374    /// ```
375    pub fn is_not_null(mut self, col: &str) -> Self {
376        self.push_logic_if_needed();
377        self.sql.push_str(col);
378        self.sql.push_str(" IS NOT NULL ");
379        self.has_any_predicate = true;
380        self
381    }
382
383    /// Add a BETWEEN condition to the WHERE clause
384    /// # Example:
385    /// ```ignore
386    /// Where::new().between("age", 18, 21);
387    /// ```
388    /// will generate:
389    /// ```ignore
390    /// WHERE age BETWEEN ? AND ?
391    /// ```
392    pub fn between(mut self, col: &str, start: impl BindArg, end: impl BindArg) -> Self {
393        self.push_logic_if_needed();
394        let ph1 = self.placeholder();
395        let ph2 = self.placeholder();
396        self.sql.push_str(col);
397        self.sql.push_str(" BETWEEN ");
398        self.sql.push_str(&ph1);
399        self.sql.push_str(" AND ");
400        self.sql.push_str(&ph2);
401        self.sql.push(' ');
402        self.has_any_predicate = true;
403        start.bind_to(&mut self.args);
404        end.bind_to(&mut self.args);
405        self
406    }
407
408    /// Add a IN condition to the WHERE clause
409    /// # Example:
410    /// ```ignore
411    /// Where::new().r#in("status", vec!["active", "verified", "premium"]);
412    /// ```
413    /// will generate:
414    /// ```ignore
415    /// WHERE status IN (?, ?, ?)
416    /// ```
417    pub fn r#in<V>(mut self, col: &str, values: V) -> Self
418    where
419        V: IntoIterator,
420        V::Item: BindArg,
421    {
422        self.push_logic_if_needed();
423        self.sql.push_str(col);
424        self.sql.push_str(" IN (");
425        let mut first = true;
426        for v in values {
427            if !first {
428                self.sql.push_str(", ");
429            }
430            first = false;
431            let ph = self.placeholder();
432            self.sql.push_str(&ph);
433            v.bind_to(&mut self.args);
434        }
435        self.sql.push(')');
436        self.sql.push(' ');
437        self.has_any_predicate = true;
438        self
439    }
440
441    /// Add a AND group to the WHERE clause
442    /// # Example:
443    /// ```ignore
444    /// Where::new().eq("name", "admin").and_group(|w| {
445    ///     w.eq("password", "123456").and().ge("age", 21);
446    /// });
447    /// ```
448    /// will generate:
449    /// ```ignore
450    /// WHERE name = ? AND (password = ? AND age >= ?)
451    /// ```
452    pub fn and_group<F>(mut self, f: F) -> Self
453    where
454        F: FnOnce(Where) -> Where,
455    {
456        self.pending_logic = Some("AND");
457        self.push_logic_if_needed();
458        self.sql.push('(');
459        let prev_skip = self.skip_next_logic;
460        self.skip_next_logic = true;
461        let mut new_where = f(self);
462        new_where.skip_next_logic = prev_skip;
463        if new_where.sql.ends_with(' ') {
464            new_where.sql.pop();
465        }
466        new_where.sql.push(')');
467        new_where.sql.push(' ');
468        new_where.has_any_predicate = true;
469        new_where
470    }
471
472    /// Add a OR group to the WHERE clause
473    /// # Example:
474    /// ```ignore
475    /// Where::new().eq("name", "root").or_group(|w| {
476    ///     w.eq("name", "admin").and().ge("age", 21);
477    /// });
478    /// ```
479    /// will generate:
480    /// ```ignore
481    /// WHERE name = ? AND (age >= ? OR name = ?)
482    /// ```
483    pub fn or_group<F>(mut self, f: F) -> Self
484    where
485        F: FnOnce(Where) -> Where,
486    {
487        self.pending_logic = Some("OR");
488        self.push_logic_if_needed();
489        self.sql.push('(');
490        let prev_skip = self.skip_next_logic;
491        self.skip_next_logic = true;
492        let mut new_where = f(self);
493        new_where.skip_next_logic = prev_skip;
494        if new_where.sql.ends_with(' ') {
495            new_where.sql.pop();
496        }
497        new_where.sql.push(')');
498        new_where.sql.push(' ');
499        new_where.has_any_predicate = true;
500        new_where
501    }
502
503    /// Build the WHERE clause
504    /// # Example:
505    /// ```ignore
506    /// let (sql, args) = Where::new().build();
507    /// ```
508    pub fn build(mut self) -> (String, DbArgs) {
509        if !self.has_any_predicate {
510            return (String::new(), self.args);
511        }
512        if self.sql.ends_with(' ') {
513            self.sql.pop();
514        }
515        (self.sql, self.args)
516    }
517}
518
519#[cfg(test)]
520mod test {
521    use super::*;
522
523    #[test]
524    fn test_where() {
525        let w = Where::new().eq("status", "active").and().ge("age", 18).or().le("age", 21);
526        let (sql, _args) = w.build();
527        assert_eq!(sql, "WHERE status = ? AND age >= ? OR age <= ?");
528
529        let w = Where::new().gt("age", 18).and().lt("age", 21);
530        let (sql, _args) = w.build();
531        assert_eq!(sql, "WHERE age > ? AND age < ?");
532
533        let w = Where::new().between("age", 18, 21);
534        let (sql, _args) = w.build();
535        assert_eq!(sql, "WHERE age BETWEEN ? AND ?");
536
537        let w = Where::new().r#in("status", vec!["active", "verified", "premium"]);
538        let (sql, _args) = w.build();
539        assert_eq!(sql, "WHERE status IN (?, ?, ?)");
540
541        let w = Where::new().is_null("status");
542        let (sql, _args) = w.build();
543        assert_eq!(sql, "WHERE status IS NULL");
544
545        let w = Where::new().is_not_null("status");
546        let (sql, _args) = w.build();
547        assert_eq!(sql, "WHERE status IS NOT NULL");
548
549        let w = Where::new().like("name", "%admin%");
550        let (sql, _args) = w.build();
551        assert_eq!(sql, "WHERE name LIKE ?");
552    }
553
554    #[test]
555    fn test_where_and_group() {
556        let w = Where::new()
557            .eq("status", "active")
558            .and_group(|w| w.ge("age", 18).or().eq("name", "root"));
559        let (sql, _args) = w.build();
560        assert_eq!(sql, "WHERE status = ? AND (age >= ? OR name = ?)");
561    }
562
563    #[test]
564    fn test_where_or_group() {
565        let w = Where::new()
566            .eq("name", "root")
567            .or_group(|w| w.eq("name", "admin").and().ge("age", 21));
568        let (sql, _args) = w.build();
569        assert_eq!(sql, "WHERE name = ? OR (name = ? AND age >= ?)");
570    }
571}