1use std::borrow::Cow;
2use std::fmt::{Display, Formatter, Result as FmtResult};
3use std::marker::PhantomData;
4
5use serde::de::DeserializeOwned;
6
7
8pub trait TableName {
10 fn table_name() -> &'static str;
11}
12
13pub trait Wrapper<T>
15where
16 T: DeserializeOwned,
17 T: TableName,
18{
19 fn generate_sql(&self) -> String;
20}
21
22#[derive(Debug, Clone, PartialEq)]
24pub enum CompareOperator {
25 Eq, Ne, Gt, Ge, Lt, Le, Like, NotLike, LikeLeft, LikeRight, IsNull, IsNotNull, In, NotIn, Between, NotBetween, }
42
43impl Display for CompareOperator {
44 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
45 match self {
46 CompareOperator::Eq => write!(f, "="),
47 CompareOperator::Ne => write!(f, "<>"),
48 CompareOperator::Gt => write!(f, ">"),
49 CompareOperator::Ge => write!(f, ">="),
50 CompareOperator::Lt => write!(f, "<"),
51 CompareOperator::Le => write!(f, "<="),
52 CompareOperator::Like => write!(f, "LIKE"),
53 CompareOperator::NotLike => write!(f, "NOT LIKE"),
54 CompareOperator::LikeLeft => write!(f, "LIKE"),
55 CompareOperator::LikeRight => write!(f, "LIKE"),
56 CompareOperator::IsNull => write!(f, "IS NULL"),
57 CompareOperator::IsNotNull => write!(f, "IS NOT NULL"),
58 CompareOperator::In => write!(f, "IN"),
59 CompareOperator::NotIn => write!(f, "NOT IN"),
60 CompareOperator::Between => write!(f, "BETWEEN"),
61 CompareOperator::NotBetween => write!(f, "NOT BETWEEN"),
62 }
63 }
64}
65
66#[derive(Debug, Clone, PartialEq)]
68pub enum LogicalOperator {
69 And,
70 Or,
71}
72
73impl Display for LogicalOperator {
74 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
75 match self {
76 LogicalOperator::And => write!(f, "AND"),
77 LogicalOperator::Or => write!(f, "OR"),
78 }
79 }
80}
81
82#[derive(Debug, Clone, PartialEq)]
84pub enum OrderDirection {
85 Asc,
86 Desc,
87}
88
89impl Display for OrderDirection {
90 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
91 match self {
92 OrderDirection::Asc => write!(f, "ASC"),
93 OrderDirection::Desc => write!(f, "DESC"),
94 }
95 }
96}
97
98#[derive(Debug, Clone)]
100pub enum ConditionValue<'a> {
101 None,
102 Single(Cow<'a, str>),
103 Multiple(Vec<Cow<'a, str>>),
104 Range(Cow<'a, str>, Cow<'a, str>),
105}
106
107#[derive(Debug, Clone)]
109pub struct Condition<'a> {
110 pub column: Cow<'a, str>,
111 pub operator: CompareOperator,
112 pub value: ConditionValue<'a>,
113}
114
115#[derive(Debug, Clone)]
117pub struct OrderBy<'a> {
118 pub column: Cow<'a, str>,
119 pub direction: OrderDirection,
120}
121
122#[derive(Debug, Clone)]
124pub enum ConditionNode<'a> {
125 Leaf(Condition<'a>),
126 Branch {
127 left: Box<ConditionNode<'a>>,
128 op: LogicalOperator,
129 right: Box<ConditionNode<'a>>,
130 },
131 Group(Box<ConditionNode<'a>>),
132 Empty,
133}
134
135impl<'a> Default for ConditionNode<'a> {
136 fn default() -> Self {
137 ConditionNode::Empty
138 }
139}
140
141#[derive(Debug, Clone, Default)]
143pub struct QueryWrapper<'a, T>
144where
145 T: DeserializeOwned,
146 T: TableName,
147{
148 pub root_condition: ConditionNode<'a>,
149 pub order_by: Vec<OrderBy<'a>>,
150 pub group_by: Vec<Cow<'a, str>>,
151 pub having: Option<Cow<'a, str>>,
152 pub limit: Option<usize>,
153 pub offset: Option<usize>,
154 pub select_columns: Vec<Cow<'a, str>>,
155 pub custom_sql: Option<Cow<'a, str>>,
156 _marker: PhantomData<T>,
157}
158
159impl<'a, T> QueryWrapper<'a, T>
160where
161 T: DeserializeOwned,
162 T: TableName,
163{
164 pub fn new() -> Self {
166 Self {
167 root_condition: ConditionNode::Empty,
168 order_by: Vec::new(),
169 group_by: Vec::new(),
170 having: None,
171 limit: None,
172 offset: None,
173 select_columns: Vec::new(),
174 custom_sql: None,
175 _marker: PhantomData,
176 }
177 }
178
179 fn add_condition(&mut self, condition: Condition<'a>) {
181 let new_node = ConditionNode::Leaf(condition);
182
183 match &self.root_condition {
184 ConditionNode::Empty => {
185 self.root_condition = new_node;
186 }
187 _ => {
188 let old_root = std::mem::replace(&mut self.root_condition, ConditionNode::Empty);
189 self.root_condition = ConditionNode::Branch {
190 left: Box::new(old_root),
191 op: LogicalOperator::And,
192 right: Box::new(new_node),
193 };
194 }
195 }
196 }
197
198 pub fn eq<F, V>(mut self, column: F, value: V) -> Self
200 where
201 F: FnOnce() -> V,
202 V: Into<Cow<'a, str>>,
203 {
204 let column_str = column().into();
205 self.add_condition(Condition {
206 column: column_str,
207 operator: CompareOperator::Eq,
208 value: ConditionValue::Single(value.into()),
209 });
210 self
211 }
212
213 pub fn ne<F, V>(mut self, column: F, value: V) -> Self
215 where
216 F: FnOnce() -> V,
217 V: Into<Cow<'a, str>>,
218 {
219 let column_str = column().into();
220 self.add_condition(Condition {
221 column: column_str,
222 operator: CompareOperator::Ne,
223 value: ConditionValue::Single(value.into()),
224 });
225 self
226 }
227
228 pub fn gt<F, V>(mut self, column: F, value: V) -> Self
230 where
231 F: FnOnce() -> V,
232 V: Into<Cow<'a, str>>,
233 {
234 let column_str = column().into();
235 self.add_condition(Condition {
236 column: column_str,
237 operator: CompareOperator::Gt,
238 value: ConditionValue::Single(value.into()),
239 });
240 self
241 }
242
243 pub fn ge<F, V>(mut self, column: F, value: V) -> Self
245 where
246 F: FnOnce() -> V,
247 V: Into<Cow<'a, str>>,
248 {
249 let column_str = column().into();
250 self.add_condition(Condition {
251 column: column_str,
252 operator: CompareOperator::Ge,
253 value: ConditionValue::Single(value.into()),
254 });
255 self
256 }
257
258 pub fn lt<F, V>(mut self, column: F, value: V) -> Self
260 where
261 F: FnOnce() -> V,
262 V: Into<Cow<'a, str>>,
263 {
264 let column_str = column().into();
265 self.add_condition(Condition {
266 column: column_str,
267 operator: CompareOperator::Lt,
268 value: ConditionValue::Single(value.into()),
269 });
270 self
271 }
272
273 pub fn le<F, V>(mut self, column: F, value: V) -> Self
275 where
276 F: FnOnce() -> V,
277 V: Into<Cow<'a, str>>,
278 {
279 let column_str = column().into();
280 self.add_condition(Condition {
281 column: column_str,
282 operator: CompareOperator::Le,
283 value: ConditionValue::Single(value.into()),
284 });
285 self
286 }
287
288 pub fn like<F, V>(mut self, column: F, value: V) -> Self
290 where
291 F: FnOnce() -> V,
292 V: Into<Cow<'a, str>>,
293 {
294 let column_str = column().into();
295 let value_str = value.into();
296 self.add_condition(Condition {
297 column: column_str,
298 operator: CompareOperator::Like,
299 value: ConditionValue::Single(value_str),
300 });
301 self
302 }
303
304 pub fn not_like<F, V>(mut self, column: F, value: V) -> Self
306 where
307 F: FnOnce() -> V,
308 V: Into<Cow<'a, str>>,
309 {
310 let column_str = column().into();
311 let value_str = value.into();
312 let formatted_value = format!("%{}%", value_str);
313
314 self.add_condition(Condition {
315 column: column_str,
316 operator: CompareOperator::NotLike,
317 value: ConditionValue::Single(Cow::Owned(formatted_value)),
318 });
319 self
320 }
321
322 pub fn like_left<F, V>(mut self, column: F, value: V) -> Self
324 where
325 F: FnOnce() -> V,
326 V: Into<Cow<'a, str>>,
327 {
328 let column_str = column().into();
329 let value_str = value.into();
330 let formatted_value = format!("%{}", value_str);
331
332 self.add_condition(Condition {
333 column: column_str,
334 operator: CompareOperator::LikeLeft,
335 value: ConditionValue::Single(Cow::Owned(formatted_value)),
336 });
337 self
338 }
339
340 pub fn like_right<F, V>(mut self, column: F, value: V) -> Self
342 where
343 F: FnOnce() -> V,
344 V: Into<Cow<'a, str>>,
345 {
346 let column_str = column().into();
347 let value_str = value.into();
348 let formatted_value = format!("{}%", value_str);
349
350 self.add_condition(Condition {
351 column: column_str,
352 operator: CompareOperator::LikeRight,
353 value: ConditionValue::Single(Cow::Owned(formatted_value)),
354 });
355 self
356 }
357
358 pub fn is_null<F, V>(mut self, column: F) -> Self
360 where
361 F: FnOnce() -> V,
362 V: Into<Cow<'a, str>>,
363 {
364 let column_str = column().into();
365 self.add_condition(Condition {
366 column: column_str,
367 operator: CompareOperator::IsNull,
368 value: ConditionValue::None,
369 });
370 self
371 }
372
373 pub fn is_not_null<F, V>(mut self, column: F) -> Self
375 where
376 F: FnOnce() -> V,
377 V: Into<Cow<'a, str>>,
378 {
379 let column_str = column().into();
380 self.add_condition(Condition {
381 column: column_str,
382 operator: CompareOperator::IsNotNull,
383 value: ConditionValue::None,
384 });
385 self
386 }
387
388 pub fn r#in<F, V, I>(mut self, column: F, values: I) -> Self
390 where
391 F: FnOnce() -> V,
392 V: Into<Cow<'a, str>>,
393 I: IntoIterator<Item = V>,
394 {
395 let column_str = column().into();
396 let values_vec: Vec<Cow<'a, str>> = values.into_iter().map(|v| v.into()).collect();
397
398 self.add_condition(Condition {
399 column: column_str,
400 operator: CompareOperator::In,
401 value: ConditionValue::Multiple(values_vec),
402 });
403 self
404 }
405
406 pub fn not_in<F, V, I>(mut self, column: F, values: I) -> Self
408 where
409 F: FnOnce() -> V,
410 V: Into<Cow<'a, str>>,
411 I: IntoIterator<Item = V>,
412 {
413 let column_str = column().into();
414 let values_vec: Vec<Cow<'a, str>> = values.into_iter().map(|v| v.into()).collect();
415 self.add_condition(Condition {
416 column: column_str,
417 operator: CompareOperator::NotIn,
418 value: ConditionValue::Multiple(values_vec),
419 });
420 self
421 }
422
423 pub fn between<F, V>(mut self, column: F, value1: V, value2: V) -> Self
425 where
426 F: FnOnce() -> V,
427 V: Into<Cow<'a, str>>,
428 {
429 let column_str = column().into();
430 self.add_condition(Condition {
431 column: column_str,
432 operator: CompareOperator::Between,
433 value: ConditionValue::Range(value1.into(), value2.into()),
434 });
435 self
436 }
437
438 pub fn not_between<F, V>(mut self, column: F, value1: V, value2: V) -> Self
440 where
441 F: FnOnce() -> V,
442 V: Into<Cow<'a, str>>,
443 {
444 let column_str = column().into();
445 self.add_condition(Condition {
446 column: column_str,
447 operator: CompareOperator::NotBetween,
448 value: ConditionValue::Range(value1.into(), value2.into()),
449 });
450 self
451 }
452
453 pub fn or(mut self) -> Self {
455 self.change_last_operator(LogicalOperator::Or);
457 self
458 }
459
460 fn change_last_operator(&mut self, op: LogicalOperator) {
462 fn find_last_branch<'b>(node: &'b mut ConditionNode) -> Option<&'b mut LogicalOperator> {
463 match node {
464 ConditionNode::Branch { left: _, op, right } => {
465 if let Some(last_op) = find_last_branch(right) {
466 Some(last_op)
467 } else {
468 Some(op)
469 }
470 }
471 ConditionNode::Group(inner) => find_last_branch(inner),
472 _ => None,
473 }
474 }
475
476 if let Some(last_op) = find_last_branch(&mut self.root_condition) {
477 *last_op = op;
478 }
479 }
480
481 pub fn nested<F>(mut self, f: F) -> Self
483 where
484 F: FnOnce(QueryWrapper<'a, T>) -> QueryWrapper<'a, T>,
485 {
486 let nested = f(QueryWrapper::new());
487
488 if let ConditionNode::Empty = nested.root_condition {
489 return self;
490 }
491
492 let nested_node = ConditionNode::Group(Box::new(nested.root_condition));
493
494 match &self.root_condition {
495 ConditionNode::Empty => {
496 self.root_condition = nested_node;
497 }
498 _ => {
499 let old_root = std::mem::replace(&mut self.root_condition, ConditionNode::Empty);
500 self.root_condition = ConditionNode::Branch {
501 left: Box::new(old_root),
502 op: LogicalOperator::And,
503 right: Box::new(nested_node),
504 };
505 }
506 }
507
508 self
509 }
510
511 pub fn apply<V: Into<Cow<'a, str>>>(mut self, sql: V) -> Self {
513 self.custom_sql = Some(sql.into());
514 self
515 }
516
517 pub fn order_by<F, V>(mut self, column: F, direction: OrderDirection) -> Self
519 where
520 F: FnOnce() -> V,
521 V: Into<Cow<'a, str>>,
522 {
523 let column_str = column().into();
524 self.order_by.push(OrderBy {
525 column: column_str,
526 direction,
527 });
528 self
529 }
530
531 pub fn order_by_asc<F, V>(self, column: F) -> Self
533 where
534 F: FnOnce() -> V,
535 V: Into<Cow<'a, str>>,
536 {
537 self.order_by(column, OrderDirection::Asc)
538 }
539
540 pub fn order_by_desc<F, V>(self, column: F) -> Self
542 where
543 F: FnOnce() -> V,
544 V: Into<Cow<'a, str>>,
545 {
546 self.order_by(column, OrderDirection::Desc)
547 }
548
549 pub fn group_by<V, I>(mut self, columns: I) -> Self
551 where
552 V: Into<Cow<'a, str>>,
553 I: IntoIterator<Item = V>,
554 {
555 self.group_by = columns.into_iter().map(|c| c.into()).collect();
556 self
557 }
558
559 pub fn having<V: Into<Cow<'a, str>>>(mut self, condition: V) -> Self {
561 self.having = Some(condition.into());
562 self
563 }
564
565 pub fn select<V, I>(mut self, columns: I) -> Self
567 where
568 V: Into<Cow<'a, str>>,
569 I: IntoIterator<Item = V>,
570 {
571 self.select_columns = columns.into_iter().map(|c| c.into()).collect();
572 self
573 }
574
575 pub fn limit(mut self, limit: usize) -> Self {
577 self.limit = Some(limit);
578 self
579 }
580
581 pub fn offset(mut self, offset: usize) -> Self {
583 self.offset = Some(offset);
584 self
585 }
586
587 pub fn build_where_clause(&self) -> String {
589 fn build_condition_node(node: &ConditionNode) -> String {
590 match node {
591 ConditionNode::Empty => String::new(),
592 ConditionNode::Leaf(condition) => match &condition.operator {
593 CompareOperator::IsNull | CompareOperator::IsNotNull => {
594 format!("{} {}", condition.column, condition.operator)
595 }
596 CompareOperator::In | CompareOperator::NotIn => {
597 if let ConditionValue::Multiple(values) = &condition.value {
598 let values_str = values
599 .iter()
600 .map(|v| format!("'{}'", v))
601 .collect::<Vec<_>>()
602 .join(", ");
603 format!(
604 "{} {} ({})",
605 condition.column, condition.operator, values_str
606 )
607 } else {
608 String::new()
609 }
610 }
611 CompareOperator::Between | CompareOperator::NotBetween => {
612 if let ConditionValue::Range(value1, value2) = &condition.value {
613 format!(
614 "{} {} '{}' AND '{}'",
615 condition.column, condition.operator, value1, value2
616 )
617 } else {
618 String::new()
619 }
620 }
621 _ => {
622 if let ConditionValue::Single(value) = &condition.value {
623 format!("{} {} '{}'", condition.column, condition.operator, value)
624 } else {
625 String::new()
626 }
627 }
628 },
629 ConditionNode::Branch { left, op, right } => {
630 let left_str = build_condition_node(left);
631 let right_str = build_condition_node(right);
632
633 if left_str.is_empty() {
634 right_str
635 } else if right_str.is_empty() {
636 left_str
637 } else {
638 format!("{} {} {}", left_str, op, right_str)
639 }
640 }
641 ConditionNode::Group(inner) => {
642 let inner_str = build_condition_node(inner);
643 if inner_str.is_empty() {
644 inner_str
645 } else {
646 format!("({})", inner_str)
647 }
648 }
649 }
650 }
651
652 let condition_str = build_condition_node(&self.root_condition);
653
654 if condition_str.is_empty() && self.custom_sql.is_none() {
655 return String::new();
656 }
657
658 let mut where_clause = String::from("WHERE ");
659
660 if !condition_str.is_empty() {
661 where_clause.push_str(&condition_str);
662 }
663
664 if let Some(sql) = &self.custom_sql {
665 if !condition_str.is_empty() {
666 where_clause.push_str(" AND ");
667 }
668 where_clause.push_str(sql);
669 }
670
671 where_clause
672 }
673
674 pub fn build_order_by_clause(&self) -> String {
676 if self.order_by.is_empty() {
677 return String::new();
678 }
679
680 let order_by_str = self
681 .order_by
682 .iter()
683 .map(|o| format!("{} {}", o.column, o.direction))
684 .collect::<Vec<_>>()
685 .join(", ");
686
687 format!("ORDER BY {}", order_by_str)
688 }
689
690 pub fn build_group_by_clause(&self) -> String {
692 if self.group_by.is_empty() {
693 return String::new();
694 }
695
696 format!("GROUP BY {}", self.group_by.join(", "))
697 }
698
699 pub fn build_having_clause(&self) -> String {
701 if let Some(having) = &self.having {
702 format!("HAVING {}", having)
703 } else {
704 String::new()
705 }
706 }
707
708 pub fn build_limit_offset_clause(&self) -> String {
710 let mut clause = String::new();
711
712 if let Some(limit) = self.limit {
713 clause.push_str(&format!("LIMIT {}", limit));
714 }
715
716 if let Some(offset) = self.offset {
717 if !clause.is_empty() {
718 clause.push_str(" ");
719 }
720 clause.push_str(&format!("OFFSET {}", offset));
721 }
722
723 clause
724 }
725
726 pub fn build_select_clause(&self) -> String {
728 if self.select_columns.is_empty() {
729 return String::from("*");
730 }
731
732 self.select_columns.join(", ")
733 }
734
735 fn build_sql(&self, table_name: &str) -> String {
737 let select_clause = self.build_select_clause();
738 let where_clause = self.build_where_clause();
739 let group_by_clause = self.build_group_by_clause();
740 let having_clause = self.build_having_clause();
741 let order_by_clause = self.build_order_by_clause();
742 let limit_offset_clause = self.build_limit_offset_clause();
743
744 let mut sql = format!("SELECT {} FROM {}", select_clause, table_name);
745
746 if !where_clause.is_empty() {
747 sql.push_str(&format!(" {}", where_clause));
748 }
749
750 if !group_by_clause.is_empty() {
751 sql.push_str(&format!(" {}", group_by_clause));
752 }
753
754 if !having_clause.is_empty() {
755 sql.push_str(&format!(" {}", having_clause));
756 }
757
758 if !order_by_clause.is_empty() {
759 sql.push_str(&format!(" {}", order_by_clause));
760 }
761
762 if !limit_offset_clause.is_empty() {
763 sql.push_str(&format!(" {}", limit_offset_clause));
764 }
765
766 sql
767 }
768
769 fn to_snake_case(pascal_case: &str) -> Cow<'_, str> {
771 if pascal_case.chars().all(|c| c.is_lowercase() || c == '_') {
772 return Cow::Borrowed(pascal_case);
773 }
774
775 let mut name = String::with_capacity(pascal_case.len() + 4);
776 let mut chars = pascal_case.chars().peekable();
777
778 while let Some(c) = chars.next() {
779 if c.is_uppercase() {
780 if !name.is_empty() && chars.peek().map_or(false, |next| next.is_lowercase()) {
781 name.push('_');
782 }
783 name.push(c.to_lowercase().next().unwrap());
784 } else {
785 name.push(c);
786 }
787 }
788 Cow::Owned(name)
789 }
790}
791
792impl<'a, T> Wrapper<T> for QueryWrapper<'_, T>
793where
794 T: DeserializeOwned,
795 T: TableName,
796{
797 fn generate_sql(&self) -> String {
798 self.build_sql(T::table_name())
799 }
800}
801
802#[cfg(test)]
803mod tests {
804 use super::*;
805 use serde::Deserialize;
806
807 #[derive(Clone, Default, Deserialize)]
808 struct Test;
809
810 impl Test {
811 pub fn name() -> &'static str {
812 "name"
813 }
814 }
815
816 impl TableName for Test {
817 fn table_name() -> &'static str {
818 "test"
819 }
820 }
821
822 #[test]
823 fn test_simple_query() {
824 let query = QueryWrapper::<Test>::new()
825 .eq(|| "name", "张三")
826 .gt(|| "age", "18")
827 .build_sql("users");
828
829 assert_eq!(
830 query,
831 "SELECT * FROM users WHERE name = '张三' AND age > '18'"
832 );
833 }
834
835 #[test]
836 fn test_complex_query() {
837 let query = QueryWrapper::<Test>::new()
838 .eq(|| "status", "active")
839 .or()
840 .nested(|q| q.eq(|| "role", "admin").gt(|| "level", "5"))
841 .order_by_desc(|| "created_at")
842 .limit(10)
843 .offset(20)
844 .build_sql("users");
845
846 assert_eq!(
847 query,
848 "SELECT * FROM users WHERE status = 'active' OR (role = 'admin' AND level > '5') ORDER BY created_at DESC LIMIT 10 OFFSET 20"
849 );
850 }
851
852 #[test]
853 fn test_in_condition() {
854 let query = QueryWrapper::<Test>::new()
855 .r#in(|| "id", vec!["1", "2", "3"])
856 .build_sql("users");
857
858 assert_eq!(query, "SELECT * FROM users WHERE id IN ('1', '2', '3')");
859 }
860
861 #[test]
862 fn test_between_condition() {
863 let query = QueryWrapper::<Test>::new()
864 .between(|| "age", "18", "30")
865 .build_sql("users");
866
867 assert_eq!(query, "SELECT * FROM users WHERE age BETWEEN '18' AND '30'");
868 }
869
870 #[test]
871 fn test_like_condition() {
872 let name = || { "张三"};
873 let query = QueryWrapper::<Test>::new()
874 .like(Test::name, "张")
875 .build_sql("users");
876
877 assert_eq!(query, "SELECT * FROM users WHERE name LIKE '%张%'");
878 }
879
880 #[test]
881 fn test_group_by_having() {
882 let query = QueryWrapper::<Test>::new()
883 .select(vec!["department", "COUNT(*) as count"])
884 .group_by(vec!["department"])
885 .having("COUNT(*) > 5")
886 .build_sql("employees");
887
888 assert_eq!(
889 query,
890 "SELECT department, COUNT(*) as count FROM employees GROUP BY department HAVING COUNT(*) > 5"
891 );
892 }
893
894 #[test]
895 fn test_count() {
896 let query = QueryWrapper::<Test>::new()
897 .select(vec!["COUNT(1) as count"])
898 .eq(||"name", "huihui")
899 .generate_sql();
900 assert_eq!(query, "SELECT COUNT(1) as count FROM test WHERE name = 'huihui'");
901 }
902}