1#[cfg(feature = "mysql")]
2type DbArgs = sqlx::mysql::MySqlArguments;
3#[cfg(all(
4 not(feature = "mysql"),
5 not(feature = "postgres"),
6 not(feature = "sqlite"),
7 ))]
9type DbArgs = sqlx::mysql::MySqlArguments; #[cfg(feature = "postgres")]
11type DbArgs = sqlx::postgres::PgArguments;
12#[cfg(feature = "sqlite")]
13type DbArgs = sqlx::sqlite::SqliteArguments;
14pub trait BindArg {
18 fn bind_to(self, args: &mut DbArgs);
19}
20
21#[cfg(feature = "mysql")]
22impl<T> BindArg for T
23where
24 T: 'static + sqlx::Encode<'static, sqlx::MySql> + sqlx::Type<sqlx::MySql>,
25{
26 fn bind_to(self, args: &mut DbArgs) {
27 use sqlx::Arguments as _;
28 let _ = args.add(self);
29 }
30}
31
32#[cfg(feature = "postgres")]
33impl<T> BindArg for T
34where
35 T: 'static + sqlx::Encode<'static, sqlx::Postgres> + sqlx::Type<sqlx::Postgres>,
36{
37 fn bind_to(self, args: &mut DbArgs) {
38 use sqlx::Arguments as _;
39 let _ = args.add(self);
40 }
41}
42
43#[cfg(feature = "sqlite")]
44impl<T> BindArg for T
45where
46 T: 'static + sqlx::Encode<'static, sqlx::Sqlite> + sqlx::Type<sqlx::Sqlite>,
47{
48 fn bind_to(self, args: &mut DbArgs) {
49 use sqlx::Arguments as _;
50 let _ = args.add(self);
51 }
52}
53
54#[derive(Debug, Default, Clone)]
75pub struct Where {
76 sql: String,
77 args: DbArgs,
78 next_index: usize,
79 pending_logic: Option<&'static str>,
80 has_any_predicate: bool,
81 skip_next_logic: bool,
82}
83
84impl Where {
85 pub fn new() -> Self {
91 Self {
92 sql: String::from("WHERE "),
93 args: Default::default(),
94 next_index: 1,
95 pending_logic: None,
96 has_any_predicate: false,
97 skip_next_logic: false,
98 }
99 }
100
101 pub fn and(mut self) -> Self {
111 self.pending_logic = Some("AND");
112 self
113 }
114
115 pub fn or(mut self) -> Self {
125 self.pending_logic = Some("OR");
126 self
127 }
128
129 fn push_logic_if_needed(&mut self) {
131 if self.skip_next_logic {
132 self.skip_next_logic = false;
133 return;
134 }
135 if self.has_any_predicate {
136 let logic = self.pending_logic.take().unwrap_or("AND");
137 self.sql.push_str(logic);
138 self.sql.push(' ');
139 }
140 }
141
142 fn placeholder(&mut self) -> String {
144 #[cfg(feature = "postgres")]
145 {
146 let p = format!("${}", self.next_index);
147 self.next_index += 1;
148 return p;
149 }
150 #[cfg(feature = "sqlite")]
151 {
152 let p = format!("?{}", self.next_index);
153 self.next_index += 1;
154 return p;
155 }
156 #[cfg(feature = "mysql")]
157 {
158 self.next_index += 1;
159 return "?".to_string();
160 }
161 #[allow(unreachable_code)]
167 {
168 self.next_index += 1;
169 "?".to_string()
170 }
171 }
172
173 pub fn raw(mut self, fragment: &str) -> Self {
179 self.push_logic_if_needed();
180 self.sql.push_str(fragment);
181 self.sql.push(' ');
182 self.has_any_predicate = true;
183 self
184 }
185
186 pub fn eq(mut self, col: &str, value: impl BindArg) -> Self {
196 self.push_logic_if_needed();
197 let ph = self.placeholder();
198 self.sql.push_str(col);
199 self.sql.push_str(" = ");
200 self.sql.push_str(&ph);
201 self.sql.push(' ');
202 self.has_any_predicate = true;
203 value.bind_to(&mut self.args);
204 self
205 }
206
207 pub fn ne(mut self, col: &str, value: impl BindArg) -> Self {
217 self.push_logic_if_needed();
218 let ph = self.placeholder();
219 self.sql.push_str(col);
220 self.sql.push_str(" <> ");
221 self.sql.push_str(&ph);
222 self.sql.push(' ');
223 self.has_any_predicate = true;
224 value.bind_to(&mut self.args);
225 self
226 }
227
228 pub fn lt(mut self, col: &str, value: impl BindArg) -> Self {
238 self.push_logic_if_needed();
239 let ph = self.placeholder();
240 self.sql.push_str(col);
241 self.sql.push_str(" < ");
242 self.sql.push_str(&ph);
243 self.sql.push(' ');
244 self.has_any_predicate = true;
245 value.bind_to(&mut self.args);
246 self
247 }
248
249 pub fn le(mut self, col: &str, value: impl BindArg) -> Self {
259 self.push_logic_if_needed();
260 let ph = self.placeholder();
261 self.sql.push_str(col);
262 self.sql.push_str(" <= ");
263 self.sql.push_str(&ph);
264 self.sql.push(' ');
265 self.has_any_predicate = true;
266 value.bind_to(&mut self.args);
267 self
268 }
269
270 pub fn gt(mut self, col: &str, value: impl BindArg) -> Self {
280 self.push_logic_if_needed();
281 let ph = self.placeholder();
282 self.sql.push_str(col);
283 self.sql.push_str(" > ");
284 self.sql.push_str(&ph);
285 self.sql.push(' ');
286 self.has_any_predicate = true;
287 value.bind_to(&mut self.args);
288 self
289 }
290
291 pub fn ge(mut self, col: &str, value: impl BindArg) -> Self {
301 self.push_logic_if_needed();
302 let ph = self.placeholder();
303 self.sql.push_str(col);
304 self.sql.push_str(" >= ");
305 self.sql.push_str(&ph);
306 self.sql.push(' ');
307 self.has_any_predicate = true;
308 value.bind_to(&mut self.args);
309 self
310 }
311
312 pub fn like(mut self, col: &str, value: impl BindArg) -> Self {
322 self.push_logic_if_needed();
323 let ph = self.placeholder();
324 self.sql.push_str(col);
325 self.sql.push_str(" LIKE ");
326 self.sql.push_str(&ph);
327 self.sql.push(' ');
328 self.has_any_predicate = true;
329 value.bind_to(&mut self.args);
330 self
331 }
332
333 pub fn is_null(mut self, col: &str) -> Self {
343 self.push_logic_if_needed();
344 self.sql.push_str(col);
345 self.sql.push_str(" IS NULL ");
346 self.has_any_predicate = true;
347 self
348 }
349
350 pub fn is_not_null(mut self, col: &str) -> Self {
360 self.push_logic_if_needed();
361 self.sql.push_str(col);
362 self.sql.push_str(" IS NOT NULL ");
363 self.has_any_predicate = true;
364 self
365 }
366
367 pub fn between(mut self, col: &str, start: impl BindArg, end: impl BindArg) -> Self {
377 self.push_logic_if_needed();
378 let ph1 = self.placeholder();
379 let ph2 = self.placeholder();
380 self.sql.push_str(col);
381 self.sql.push_str(" BETWEEN ");
382 self.sql.push_str(&ph1);
383 self.sql.push_str(" AND ");
384 self.sql.push_str(&ph2);
385 self.sql.push(' ');
386 self.has_any_predicate = true;
387 start.bind_to(&mut self.args);
388 end.bind_to(&mut self.args);
389 self
390 }
391
392 pub fn r#in<V>(mut self, col: &str, values: V) -> Self
402 where
403 V: IntoIterator,
404 V::Item: BindArg,
405 {
406 self.push_logic_if_needed();
407 self.sql.push_str(col);
408 self.sql.push_str(" IN (");
409 let mut first = true;
410 for v in values {
411 if !first {
412 self.sql.push_str(", ");
413 }
414 first = false;
415 let ph = self.placeholder();
416 self.sql.push_str(&ph);
417 v.bind_to(&mut self.args);
418 }
419 self.sql.push(')');
420 self.sql.push(' ');
421 self.has_any_predicate = true;
422 self
423 }
424
425 pub fn and_group<F>(mut self, f: F) -> Self
437 where
438 F: FnOnce(Where) -> Where,
439 {
440 self.pending_logic = Some("AND");
441 self.push_logic_if_needed();
442 self.sql.push('(');
443 let prev_skip = self.skip_next_logic;
444 self.skip_next_logic = true;
445 let mut new_where = f(self);
446 new_where.skip_next_logic = prev_skip;
447 if new_where.sql.ends_with(' ') {
448 new_where.sql.pop();
449 }
450 new_where.sql.push(')');
451 new_where.sql.push(' ');
452 new_where.has_any_predicate = true;
453 new_where
454 }
455
456 pub fn or_group<F>(mut self, f: F) -> Self
468 where
469 F: FnOnce(Where) -> Where,
470 {
471 self.pending_logic = Some("OR");
472 self.push_logic_if_needed();
473 self.sql.push('(');
474 let prev_skip = self.skip_next_logic;
475 self.skip_next_logic = true;
476 let mut new_where = f(self);
477 new_where.skip_next_logic = prev_skip;
478 if new_where.sql.ends_with(' ') {
479 new_where.sql.pop();
480 }
481 new_where.sql.push(')');
482 new_where.sql.push(' ');
483 new_where.has_any_predicate = true;
484 new_where
485 }
486
487 pub fn build(mut self) -> (String, DbArgs) {
493 if !self.has_any_predicate {
494 return (String::new(), self.args);
495 }
496 if self.sql.ends_with(' ') {
497 self.sql.pop();
498 }
499 (self.sql, self.args)
500 }
501}
502
503#[cfg(test)]
504mod test {
505 use super::*;
506
507 #[test]
508 fn test_where() {
509 let w = Where::new().eq("status", "active").and().ge("age", 18).or().le("age", 21);
510 let (sql, args) = w.build();
511 assert_eq!(sql, "WHERE status = ? AND age >= ? OR age <= ?");
512 println!("args: {:?}", args);
513
514 let w = Where::new().gt("age", 18).and().lt("age", 21);
515 let (sql, _args) = w.build();
516 assert_eq!(sql, "WHERE age > ? AND age < ?");
517
518 let w = Where::new().between("age", 18, 21);
519 let (sql, _args) = w.build();
520 assert_eq!(sql, "WHERE age BETWEEN ? AND ?");
521
522 let w = Where::new().r#in("status", vec!["active", "verified", "premium"]);
523 let (sql, _args) = w.build();
524 assert_eq!(sql, "WHERE status IN (?, ?, ?)");
525
526 let w = Where::new().is_null("status");
527 let (sql, _args) = w.build();
528 assert_eq!(sql, "WHERE status IS NULL");
529
530 let w = Where::new().is_not_null("status");
531 let (sql, _args) = w.build();
532 assert_eq!(sql, "WHERE status IS NOT NULL");
533
534 let w = Where::new().like("name", "%admin%");
535 let (sql, _args) = w.build();
536 assert_eq!(sql, "WHERE name LIKE ?");
537 }
538
539 #[test]
540 fn test_where_and_group() {
541 let w = Where::new()
542 .eq("status", "active")
543 .and_group(|w| w.ge("age", 18).or().eq("name", "root"));
544 let (sql, _args) = w.build();
545 assert_eq!(sql, "WHERE status = ? AND (age >= ? OR name = ?)");
546 }
547
548 #[test]
549 fn test_where_or_group() {
550 let w = Where::new()
551 .eq("name", "root")
552 .or_group(|w| w.eq("name", "admin").and().ge("age", 21));
553 let (sql, _args) = w.build();
554 assert_eq!(sql, "WHERE name = ? OR (name = ? AND age >= ?)");
555 }
556}