co_orm/
filter.rs

1#[cfg(feature = "mysql")]
2type DbArgs = sqlx::mysql::MySqlArguments;
3#[cfg(all(
4    not(feature = "mysql"),
5    not(feature = "postgres"),
6    not(feature = "sqlite"),
7    // not(feature = "mssql"),
8))]
9type DbArgs = sqlx::mysql::MySqlArguments; // default to MySQL like macros crate
10#[cfg(feature = "postgres")]
11type DbArgs = sqlx::postgres::PgArguments;
12#[cfg(feature = "sqlite")]
13type DbArgs = sqlx::sqlite::SqliteArguments;
14// #[cfg(feature = "mssql")]
15// type DbArgs = sqlx::mssql::MssqlArguments;
16
17pub trait BindArg {
18    fn bind_to(self, args: &mut DbArgs);
19}
20
21#[cfg(feature = "mysql")]
22impl<T> BindArg for T
23where
24    T: 'static + sqlx::Encode<'static, sqlx::MySql> + sqlx::Type<sqlx::MySql>,
25{
26    fn bind_to(self, args: &mut DbArgs) {
27        use sqlx::Arguments as _;
28        let _ = args.add(self);
29    }
30}
31
32#[cfg(feature = "postgres")]
33impl<T> BindArg for T
34where
35    T: 'static + sqlx::Encode<'static, sqlx::Postgres> + sqlx::Type<sqlx::Postgres>,
36{
37    fn bind_to(self, args: &mut DbArgs) {
38        use sqlx::Arguments as _;
39        let _ = args.add(self);
40    }
41}
42
43#[cfg(feature = "sqlite")]
44impl<T> BindArg for T
45where
46    T: 'static + sqlx::Encode<'static, sqlx::Sqlite> + sqlx::Type<sqlx::Sqlite>,
47{
48    fn bind_to(self, args: &mut DbArgs) {
49        use sqlx::Arguments as _;
50        let _ = args.add(self);
51    }
52}
53
54// #[cfg(feature = "mssql")]
55// impl<T> BindArg for T
56// where
57//     T: 'static + sqlx::Encode<'static, sqlx::Mssql> + sqlx::Type<sqlx::Mssql>,
58// {
59//     fn bind_to(self, args: &mut DbArgs) {
60//         use sqlx::Arguments as _;
61//         let _ = args.add(self);
62//     }
63// }
64
65/// A builder for WHERE clauses
66/// # Example:
67/// ```ignore
68/// let w = Where::new().eq("name", "admin").and().ge("age", 18);
69/// ```
70/// will generate:
71/// ```ignore
72/// WHERE name = ? AND age >= ?
73/// ```
74#[derive(Debug, Default, Clone)]
75pub struct Where {
76    sql: String,
77    args: DbArgs,
78    next_index: usize,
79    pending_logic: Option<&'static str>,
80    has_any_predicate: bool,
81    skip_next_logic: bool,
82}
83
84impl Where {
85    /// Create a new Where builder
86    /// # Example:
87    /// ```ignore
88    /// let w = Where::new();
89    /// ```
90    pub fn new() -> Self {
91        Self {
92            sql: String::from("WHERE "),
93            args: Default::default(),
94            next_index: 1,
95            pending_logic: None,
96            has_any_predicate: false,
97            skip_next_logic: false,
98        }
99    }
100
101    /// Add AND logic to the WHERE clause
102    /// # Example:
103    /// ```ignore
104    /// Where::new().eq("name", "admin").and().eq("age", 18);
105    /// ```
106    /// will generate:
107    /// ```ignore
108    /// WHERE name = ? AND age = ?
109    /// ```
110    pub fn and(mut self) -> Self {
111        self.pending_logic = Some("AND");
112        self
113    }
114
115    /// Add OR logic to the WHERE clause
116    /// # Example:
117    /// ```ignore
118    /// Where::new().eq("name", "admin").or().eq("age", 18);
119    /// ```
120    /// will generate:
121    /// ```ignore
122    /// WHERE name = ? OR age = ?
123    /// ```
124    pub fn or(mut self) -> Self {
125        self.pending_logic = Some("OR");
126        self
127    }
128
129    /// Push logic if needed
130    fn push_logic_if_needed(&mut self) {
131        if self.skip_next_logic {
132            self.skip_next_logic = false;
133            return;
134        }
135        if self.has_any_predicate {
136            let logic = self.pending_logic.take().unwrap_or("AND");
137            self.sql.push_str(logic);
138            self.sql.push(' ');
139        }
140    }
141
142    /// Generate a placeholder for the next value
143    fn placeholder(&mut self) -> String {
144        #[cfg(feature = "postgres")]
145        {
146            let p = format!("${}", self.next_index);
147            self.next_index += 1;
148            return p;
149        }
150        #[cfg(feature = "sqlite")]
151        {
152            let p = format!("?{}", self.next_index);
153            self.next_index += 1;
154            return p;
155        }
156        #[cfg(feature = "mysql")]
157        {
158            self.next_index += 1;
159            return "?".to_string();
160        }
161        // #[cfg(feature = "mssql")]
162        // {
163        //     self.next_index += 1;
164        //     return "?".to_string();
165        // }
166        #[allow(unreachable_code)]
167        {
168            self.next_index += 1;
169            "?".to_string()
170        }
171    }
172
173    /// Add a raw SQL fragment to the WHERE clause
174    /// # Example:
175    /// ```ignore
176    /// Where::new().raw("name = admin AND age >= 18");
177    /// ```
178    pub fn raw(mut self, fragment: &str) -> Self {
179        self.push_logic_if_needed();
180        self.sql.push_str(fragment);
181        self.sql.push(' ');
182        self.has_any_predicate = true;
183        self
184    }
185
186    /// Add an equality condition(=) to the WHERE clause
187    /// # Example:
188    /// ```ignore
189    /// Where::new().eq("name", "admin");
190    /// ```
191    /// will generate:
192    /// ```ignore
193    /// WHERE name = ?
194    /// ```
195    pub fn eq(mut self, col: &str, value: impl BindArg) -> Self {
196        self.push_logic_if_needed();
197        let ph = self.placeholder();
198        self.sql.push_str(col);
199        self.sql.push_str(" = ");
200        self.sql.push_str(&ph);
201        self.sql.push(' ');
202        self.has_any_predicate = true;
203        value.bind_to(&mut self.args);
204        self
205    }
206
207    /// Add a non-equality condition(<>) to the WHERE clause
208    /// # Example:
209    /// ```ignore
210    /// Where::new().ne("name", "admin");
211    /// ```
212    /// will generate:
213    /// ```ignore
214    /// WHERE name <> ?
215    /// ```
216    pub fn ne(mut self, col: &str, value: impl BindArg) -> Self {
217        self.push_logic_if_needed();
218        let ph = self.placeholder();
219        self.sql.push_str(col);
220        self.sql.push_str(" <> ");
221        self.sql.push_str(&ph);
222        self.sql.push(' ');
223        self.has_any_predicate = true;
224        value.bind_to(&mut self.args);
225        self
226    }
227
228    /// Add a less than condition(<) to the WHERE clause
229    /// # Example:
230    /// ```ignore
231    /// Where::new().lt("age", 18);
232    /// ```
233    /// will generate:
234    /// ```ignore
235    /// WHERE age < ?
236    /// ```
237    pub fn lt(mut self, col: &str, value: impl BindArg) -> Self {
238        self.push_logic_if_needed();
239        let ph = self.placeholder();
240        self.sql.push_str(col);
241        self.sql.push_str(" < ");
242        self.sql.push_str(&ph);
243        self.sql.push(' ');
244        self.has_any_predicate = true;
245        value.bind_to(&mut self.args);
246        self
247    }
248
249    /// Add a less than or equal to condition(<=) to the WHERE clause
250    /// # Example:
251    /// ```ignore
252    /// Where::new().le("age", 18);
253    /// ```
254    /// will generate:
255    /// ```ignore
256    /// WHERE age <= ?
257    /// ```
258    pub fn le(mut self, col: &str, value: impl BindArg) -> Self {
259        self.push_logic_if_needed();
260        let ph = self.placeholder();
261        self.sql.push_str(col);
262        self.sql.push_str(" <= ");
263        self.sql.push_str(&ph);
264        self.sql.push(' ');
265        self.has_any_predicate = true;
266        value.bind_to(&mut self.args);
267        self
268    }
269
270    /// Add a greater than condition(>) to the WHERE clause
271    /// # Example:
272    /// ```ignore
273    /// Where::new().gt("age", 18);
274    /// ```
275    /// will generate:
276    /// ```ignore
277    /// WHERE age > ?
278    /// ```
279    pub fn gt(mut self, col: &str, value: impl BindArg) -> Self {
280        self.push_logic_if_needed();
281        let ph = self.placeholder();
282        self.sql.push_str(col);
283        self.sql.push_str(" > ");
284        self.sql.push_str(&ph);
285        self.sql.push(' ');
286        self.has_any_predicate = true;
287        value.bind_to(&mut self.args);
288        self
289    }
290
291    /// Add a greater than or equal to condition(>=) to the WHERE clause
292    /// # Example:
293    /// ```ignore
294    /// Where::new().ge("age", 18);
295    /// ```
296    /// will generate:
297    /// ```ignore
298    /// WHERE age >= ?
299    /// ```
300    pub fn ge(mut self, col: &str, value: impl BindArg) -> Self {
301        self.push_logic_if_needed();
302        let ph = self.placeholder();
303        self.sql.push_str(col);
304        self.sql.push_str(" >= ");
305        self.sql.push_str(&ph);
306        self.sql.push(' ');
307        self.has_any_predicate = true;
308        value.bind_to(&mut self.args);
309        self
310    }
311
312    /// Add a LIKE condition to the WHERE clause
313    /// # Example:
314    /// ```ignore
315    /// Where::new().like("name", "%admin%");
316    /// ```
317    /// will generate:
318    /// ```ignore
319    /// WHERE name LIKE ?
320    /// ```
321    pub fn like(mut self, col: &str, value: impl BindArg) -> Self {
322        self.push_logic_if_needed();
323        let ph = self.placeholder();
324        self.sql.push_str(col);
325        self.sql.push_str(" LIKE ");
326        self.sql.push_str(&ph);
327        self.sql.push(' ');
328        self.has_any_predicate = true;
329        value.bind_to(&mut self.args);
330        self
331    }
332
333    /// Add a IS NULL condition to the WHERE clause
334    /// # Example:
335    /// ```ignore
336    /// Where::new().is_null("status");
337    /// ```
338    /// will generate:
339    /// ```ignore
340    /// WHERE status IS NULL
341    /// ```
342    pub fn is_null(mut self, col: &str) -> Self {
343        self.push_logic_if_needed();
344        self.sql.push_str(col);
345        self.sql.push_str(" IS NULL ");
346        self.has_any_predicate = true;
347        self
348    }
349
350    /// Add a IS NOT NULL condition to the WHERE clause
351    /// # Example:
352    /// ```ignore
353    /// Where::new().is_not_null("status");
354    /// ```
355    /// will generate:
356    /// ```ignore
357    /// WHERE status IS NOT NULL
358    /// ```
359    pub fn is_not_null(mut self, col: &str) -> Self {
360        self.push_logic_if_needed();
361        self.sql.push_str(col);
362        self.sql.push_str(" IS NOT NULL ");
363        self.has_any_predicate = true;
364        self
365    }
366
367    /// Add a BETWEEN condition to the WHERE clause
368    /// # Example:
369    /// ```ignore
370    /// Where::new().between("age", 18, 21);
371    /// ```
372    /// will generate:
373    /// ```ignore
374    /// WHERE age BETWEEN ? AND ?
375    /// ```
376    pub fn between(mut self, col: &str, start: impl BindArg, end: impl BindArg) -> Self {
377        self.push_logic_if_needed();
378        let ph1 = self.placeholder();
379        let ph2 = self.placeholder();
380        self.sql.push_str(col);
381        self.sql.push_str(" BETWEEN ");
382        self.sql.push_str(&ph1);
383        self.sql.push_str(" AND ");
384        self.sql.push_str(&ph2);
385        self.sql.push(' ');
386        self.has_any_predicate = true;
387        start.bind_to(&mut self.args);
388        end.bind_to(&mut self.args);
389        self
390    }
391
392    /// Add a IN condition to the WHERE clause
393    /// # Example:
394    /// ```ignore
395    /// Where::new().r#in("status", vec!["active", "verified", "premium"]);
396    /// ```
397    /// will generate:
398    /// ```ignore
399    /// WHERE status IN (?, ?, ?)
400    /// ```
401    pub fn r#in<V>(mut self, col: &str, values: V) -> Self
402    where
403        V: IntoIterator,
404        V::Item: BindArg,
405    {
406        self.push_logic_if_needed();
407        self.sql.push_str(col);
408        self.sql.push_str(" IN (");
409        let mut first = true;
410        for v in values {
411            if !first {
412                self.sql.push_str(", ");
413            }
414            first = false;
415            let ph = self.placeholder();
416            self.sql.push_str(&ph);
417            v.bind_to(&mut self.args);
418        }
419        self.sql.push(')');
420        self.sql.push(' ');
421        self.has_any_predicate = true;
422        self
423    }
424
425    /// Add a AND group to the WHERE clause
426    /// # Example:
427    /// ```ignore
428    /// Where::new().eq("name", "admin").and_group(|w| {
429    ///     w.eq("password", "123456").and().ge("age", 21);
430    /// });
431    /// ```
432    /// will generate:
433    /// ```ignore
434    /// WHERE name = ? AND (password = ? AND age >= ?)
435    /// ```
436    pub fn and_group<F>(mut self, f: F) -> Self
437    where
438        F: FnOnce(Where) -> Where,
439    {
440        self.pending_logic = Some("AND");
441        self.push_logic_if_needed();
442        self.sql.push('(');
443        let prev_skip = self.skip_next_logic;
444        self.skip_next_logic = true;
445        let mut new_where = f(self);
446        new_where.skip_next_logic = prev_skip;
447        if new_where.sql.ends_with(' ') {
448            new_where.sql.pop();
449        }
450        new_where.sql.push(')');
451        new_where.sql.push(' ');
452        new_where.has_any_predicate = true;
453        new_where
454    }
455
456    /// Add a OR group to the WHERE clause
457    /// # Example:
458    /// ```ignore
459    /// Where::new().eq("name", "root").or_group(|w| {
460    ///     w.eq("name", "admin").and().ge("age", 21);
461    /// });
462    /// ```
463    /// will generate:
464    /// ```ignore
465    /// WHERE name = ? AND (age >= ? OR name = ?)
466    /// ```
467    pub fn or_group<F>(mut self, f: F) -> Self
468    where
469        F: FnOnce(Where) -> Where,
470    {
471        self.pending_logic = Some("OR");
472        self.push_logic_if_needed();
473        self.sql.push('(');
474        let prev_skip = self.skip_next_logic;
475        self.skip_next_logic = true;
476        let mut new_where = f(self);
477        new_where.skip_next_logic = prev_skip;
478        if new_where.sql.ends_with(' ') {
479            new_where.sql.pop();
480        }
481        new_where.sql.push(')');
482        new_where.sql.push(' ');
483        new_where.has_any_predicate = true;
484        new_where
485    }
486
487    /// Build the WHERE clause
488    /// # Example:
489    /// ```ignore
490    /// let (sql, args) = Where::new().build();
491    /// ```
492    pub fn build(mut self) -> (String, DbArgs) {
493        if !self.has_any_predicate {
494            return (String::new(), self.args);
495        }
496        if self.sql.ends_with(' ') {
497            self.sql.pop();
498        }
499        (self.sql, self.args)
500    }
501}
502
503#[cfg(test)]
504mod test {
505    use super::*;
506
507    #[test]
508    fn test_where() {
509        let w = Where::new().eq("status", "active").and().ge("age", 18).or().le("age", 21);
510        let (sql, args) = w.build();
511        assert_eq!(sql, "WHERE status = ? AND age >= ? OR age <= ?");
512        println!("args: {:?}", args);
513
514        let w = Where::new().gt("age", 18).and().lt("age", 21);
515        let (sql, _args) = w.build();
516        assert_eq!(sql, "WHERE age > ? AND age < ?");
517
518        let w = Where::new().between("age", 18, 21);
519        let (sql, _args) = w.build();
520        assert_eq!(sql, "WHERE age BETWEEN ? AND ?");
521
522        let w = Where::new().r#in("status", vec!["active", "verified", "premium"]);
523        let (sql, _args) = w.build();
524        assert_eq!(sql, "WHERE status IN (?, ?, ?)");
525
526        let w = Where::new().is_null("status");
527        let (sql, _args) = w.build();
528        assert_eq!(sql, "WHERE status IS NULL");
529
530        let w = Where::new().is_not_null("status");
531        let (sql, _args) = w.build();
532        assert_eq!(sql, "WHERE status IS NOT NULL");
533
534        let w = Where::new().like("name", "%admin%");
535        let (sql, _args) = w.build();
536        assert_eq!(sql, "WHERE name LIKE ?");
537    }
538
539    #[test]
540    fn test_where_and_group() {
541        let w = Where::new()
542            .eq("status", "active")
543            .and_group(|w| w.ge("age", 18).or().eq("name", "root"));
544        let (sql, _args) = w.build();
545        assert_eq!(sql, "WHERE status = ? AND (age >= ? OR name = ?)");
546    }
547
548    #[test]
549    fn test_where_or_group() {
550        let w = Where::new()
551            .eq("name", "root")
552            .or_group(|w| w.eq("name", "admin").and().ge("age", 21));
553        let (sql, _args) = w.build();
554        assert_eq!(sql, "WHERE name = ? OR (name = ? AND age >= ?)");
555    }
556}