1#[cfg(feature = "mysql")]
2type DbArgs = sqlx::mysql::MySqlArguments;
3#[cfg(feature = "postgres")]
4type DbArgs = sqlx::postgres::PgArguments;
5#[cfg(feature = "sqlite")]
6type DbArgs = sqlx::sqlite::SqliteArguments;
7#[cfg(all(
10 not(feature = "mysql"),
11 not(feature = "postgres"),
12 not(feature = "sqlite"),
13 ))]
16type DbArgs = sqlx::mysql::MySqlArguments;
17
18pub trait BindArg {
22 fn bind_to(self, args: &mut DbArgs);
23}
24
25#[cfg(feature = "mysql")]
26impl<T> BindArg for T
27where
28 T: 'static + sqlx::Encode<'static, sqlx::MySql> + sqlx::Type<sqlx::MySql>,
29{
30 fn bind_to(self, args: &mut DbArgs) {
31 use sqlx::Arguments as _;
32 let _ = args.add(self);
33 }
34}
35
36#[cfg(feature = "postgres")]
37impl<T> BindArg for T
38where
39 T: 'static + sqlx::Encode<'static, sqlx::Postgres> + sqlx::Type<sqlx::Postgres>,
40{
41 fn bind_to(self, args: &mut DbArgs) {
42 use sqlx::Arguments as _;
43 let _ = args.add(self);
44 }
45}
46
47#[cfg(feature = "sqlite")]
48impl<T> BindArg for T
49where
50 T: 'static + sqlx::Encode<'static, sqlx::Sqlite> + sqlx::Type<sqlx::Sqlite>,
51{
52 fn bind_to(self, args: &mut DbArgs) {
53 use sqlx::Arguments as _;
54 let _ = args.add(self);
55 }
56}
57
58#[derive(Default, Clone, Debug)]
90pub struct Where {
91 sql: String,
92 args: DbArgs,
93 next_index: usize,
94 pending_logic: Option<&'static str>,
95 has_any_predicate: bool,
96 skip_next_logic: bool,
97}
98
99impl Where {
100 pub fn new() -> Self {
106 Self {
107 sql: String::from("WHERE "),
108 args: Default::default(),
109 next_index: 1,
110 pending_logic: None,
111 has_any_predicate: false,
112 skip_next_logic: false,
113 }
114 }
115
116 pub fn and(mut self) -> Self {
126 self.pending_logic = Some("AND");
127 self
128 }
129
130 pub fn or(mut self) -> Self {
140 self.pending_logic = Some("OR");
141 self
142 }
143
144 fn push_logic_if_needed(&mut self) {
146 if self.skip_next_logic {
147 self.skip_next_logic = false;
148 return;
149 }
150 if self.has_any_predicate {
151 let logic = self.pending_logic.take().unwrap_or("AND");
152 self.sql.push_str(logic);
153 self.sql.push(' ');
154 }
155 }
156
157 fn placeholder(&mut self) -> String {
159 #[cfg(feature = "postgres")]
160 {
161 let p = format!("${}", self.next_index);
162 self.next_index += 1;
163 return p;
164 }
165 #[cfg(feature = "sqlite")]
166 {
167 let p = format!("?{}", self.next_index);
168 self.next_index += 1;
169 return p;
170 }
171 #[cfg(feature = "mysql")]
172 {
173 self.next_index += 1;
174 return "?".to_string();
175 }
176
177 #[allow(unreachable_code)]
183 {
184 self.next_index += 1;
185 "?".to_string()
186 }
187 }
188
189 pub fn raw(mut self, fragment: &str) -> Self {
195 self.push_logic_if_needed();
196 self.sql.push_str(fragment);
197 self.sql.push(' ');
198 self.has_any_predicate = true;
199 self
200 }
201
202 pub fn eq(mut self, col: &str, value: impl BindArg) -> Self {
212 self.push_logic_if_needed();
213 let ph = self.placeholder();
214 self.sql.push_str(col);
215 self.sql.push_str(" = ");
216 self.sql.push_str(&ph);
217 self.sql.push(' ');
218 self.has_any_predicate = true;
219 value.bind_to(&mut self.args);
220 self
221 }
222
223 pub fn ne(mut self, col: &str, value: impl BindArg) -> Self {
233 self.push_logic_if_needed();
234 let ph = self.placeholder();
235 self.sql.push_str(col);
236 self.sql.push_str(" <> ");
237 self.sql.push_str(&ph);
238 self.sql.push(' ');
239 self.has_any_predicate = true;
240 value.bind_to(&mut self.args);
241 self
242 }
243
244 pub fn lt(mut self, col: &str, value: impl BindArg) -> Self {
254 self.push_logic_if_needed();
255 let ph = self.placeholder();
256 self.sql.push_str(col);
257 self.sql.push_str(" < ");
258 self.sql.push_str(&ph);
259 self.sql.push(' ');
260 self.has_any_predicate = true;
261 value.bind_to(&mut self.args);
262 self
263 }
264
265 pub fn le(mut self, col: &str, value: impl BindArg) -> Self {
275 self.push_logic_if_needed();
276 let ph = self.placeholder();
277 self.sql.push_str(col);
278 self.sql.push_str(" <= ");
279 self.sql.push_str(&ph);
280 self.sql.push(' ');
281 self.has_any_predicate = true;
282 value.bind_to(&mut self.args);
283 self
284 }
285
286 pub fn gt(mut self, col: &str, value: impl BindArg) -> Self {
296 self.push_logic_if_needed();
297 let ph = self.placeholder();
298 self.sql.push_str(col);
299 self.sql.push_str(" > ");
300 self.sql.push_str(&ph);
301 self.sql.push(' ');
302 self.has_any_predicate = true;
303 value.bind_to(&mut self.args);
304 self
305 }
306
307 pub fn ge(mut self, col: &str, value: impl BindArg) -> Self {
317 self.push_logic_if_needed();
318 let ph = self.placeholder();
319 self.sql.push_str(col);
320 self.sql.push_str(" >= ");
321 self.sql.push_str(&ph);
322 self.sql.push(' ');
323 self.has_any_predicate = true;
324 value.bind_to(&mut self.args);
325 self
326 }
327
328 pub fn like(mut self, col: &str, value: impl BindArg) -> Self {
338 self.push_logic_if_needed();
339 let ph = self.placeholder();
340 self.sql.push_str(col);
341 self.sql.push_str(" LIKE ");
342 self.sql.push_str(&ph);
343 self.sql.push(' ');
344 self.has_any_predicate = true;
345 value.bind_to(&mut self.args);
346 self
347 }
348
349 pub fn is_null(mut self, col: &str) -> Self {
359 self.push_logic_if_needed();
360 self.sql.push_str(col);
361 self.sql.push_str(" IS NULL ");
362 self.has_any_predicate = true;
363 self
364 }
365
366 pub fn is_not_null(mut self, col: &str) -> Self {
376 self.push_logic_if_needed();
377 self.sql.push_str(col);
378 self.sql.push_str(" IS NOT NULL ");
379 self.has_any_predicate = true;
380 self
381 }
382
383 pub fn between(mut self, col: &str, start: impl BindArg, end: impl BindArg) -> Self {
393 self.push_logic_if_needed();
394 let ph1 = self.placeholder();
395 let ph2 = self.placeholder();
396 self.sql.push_str(col);
397 self.sql.push_str(" BETWEEN ");
398 self.sql.push_str(&ph1);
399 self.sql.push_str(" AND ");
400 self.sql.push_str(&ph2);
401 self.sql.push(' ');
402 self.has_any_predicate = true;
403 start.bind_to(&mut self.args);
404 end.bind_to(&mut self.args);
405 self
406 }
407
408 pub fn r#in<V>(mut self, col: &str, values: V) -> Self
418 where
419 V: IntoIterator,
420 V::Item: BindArg,
421 {
422 self.push_logic_if_needed();
423 self.sql.push_str(col);
424 self.sql.push_str(" IN (");
425 let mut first = true;
426 for v in values {
427 if !first {
428 self.sql.push_str(", ");
429 }
430 first = false;
431 let ph = self.placeholder();
432 self.sql.push_str(&ph);
433 v.bind_to(&mut self.args);
434 }
435 self.sql.push(')');
436 self.sql.push(' ');
437 self.has_any_predicate = true;
438 self
439 }
440
441 pub fn and_group<F>(mut self, f: F) -> Self
453 where
454 F: FnOnce(Where) -> Where,
455 {
456 self.pending_logic = Some("AND");
457 self.push_logic_if_needed();
458 self.sql.push('(');
459 let prev_skip = self.skip_next_logic;
460 self.skip_next_logic = true;
461 let mut new_where = f(self);
462 new_where.skip_next_logic = prev_skip;
463 if new_where.sql.ends_with(' ') {
464 new_where.sql.pop();
465 }
466 new_where.sql.push(')');
467 new_where.sql.push(' ');
468 new_where.has_any_predicate = true;
469 new_where
470 }
471
472 pub fn or_group<F>(mut self, f: F) -> Self
484 where
485 F: FnOnce(Where) -> Where,
486 {
487 self.pending_logic = Some("OR");
488 self.push_logic_if_needed();
489 self.sql.push('(');
490 let prev_skip = self.skip_next_logic;
491 self.skip_next_logic = true;
492 let mut new_where = f(self);
493 new_where.skip_next_logic = prev_skip;
494 if new_where.sql.ends_with(' ') {
495 new_where.sql.pop();
496 }
497 new_where.sql.push(')');
498 new_where.sql.push(' ');
499 new_where.has_any_predicate = true;
500 new_where
501 }
502
503 pub fn build(mut self) -> (String, DbArgs) {
509 if !self.has_any_predicate {
510 return (String::new(), self.args);
511 }
512 if self.sql.ends_with(' ') {
513 self.sql.pop();
514 }
515 (self.sql, self.args)
516 }
517}
518
519#[cfg(test)]
520mod test {
521 use super::*;
522
523 #[test]
524 fn test_where() {
525 let w = Where::new().eq("status", "active").and().ge("age", 18).or().le("age", 21);
526 let (sql, _args) = w.build();
527 assert_eq!(sql, "WHERE status = ? AND age >= ? OR age <= ?");
528
529 let w = Where::new().gt("age", 18).and().lt("age", 21);
530 let (sql, _args) = w.build();
531 assert_eq!(sql, "WHERE age > ? AND age < ?");
532
533 let w = Where::new().between("age", 18, 21);
534 let (sql, _args) = w.build();
535 assert_eq!(sql, "WHERE age BETWEEN ? AND ?");
536
537 let w = Where::new().r#in("status", vec!["active", "verified", "premium"]);
538 let (sql, _args) = w.build();
539 assert_eq!(sql, "WHERE status IN (?, ?, ?)");
540
541 let w = Where::new().is_null("status");
542 let (sql, _args) = w.build();
543 assert_eq!(sql, "WHERE status IS NULL");
544
545 let w = Where::new().is_not_null("status");
546 let (sql, _args) = w.build();
547 assert_eq!(sql, "WHERE status IS NOT NULL");
548
549 let w = Where::new().like("name", "%admin%");
550 let (sql, _args) = w.build();
551 assert_eq!(sql, "WHERE name LIKE ?");
552 }
553
554 #[test]
555 fn test_where_and_group() {
556 let w = Where::new()
557 .eq("status", "active")
558 .and_group(|w| w.ge("age", 18).or().eq("name", "root"));
559 let (sql, _args) = w.build();
560 assert_eq!(sql, "WHERE status = ? AND (age >= ? OR name = ?)");
561 }
562
563 #[test]
564 fn test_where_or_group() {
565 let w = Where::new()
566 .eq("name", "root")
567 .or_group(|w| w.eq("name", "admin").and().ge("age", 21));
568 let (sql, _args) = w.build();
569 assert_eq!(sql, "WHERE name = ? OR (name = ? AND age >= ?)");
570 }
571}