1use crate::util::SqlExtension;
2use crate::{Dialect, ToSql};
3
4#[derive(Debug, Clone, PartialEq, Eq)]
5#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6pub struct Case {
7 cases: Vec<(Expr, Expr)>,
8 els: Option<Box<Expr>>,
9}
10
11impl Case {
12 pub fn new_when<C: Into<Expr>, V: Into<Expr>>(condition: C, then_value: V) -> Self {
13 Self {
14 cases: vec![(condition.into(), then_value.into())],
15 els: None,
16 }
17 }
18
19 pub fn when(mut self, condition: Expr, value: Expr) -> Self {
20 self.cases.push((condition, value));
21 self
22 }
23
24 pub fn els<V: Into<Expr>>(mut self, value: V) -> Self {
25 self.els = Some(Box::new(value.into()));
26 self
27 }
28}
29
30impl ToSql for Case {
31 fn write_sql(&self, buf: &mut String, dialect: Dialect) {
32 buf.push_str("CASE ");
33 for c in &self.cases {
34 buf.push_str("WHEN ");
35 buf.push_sql(&c.0, dialect);
36 buf.push_str(" THEN ");
37 buf.push_sql(&c.1, dialect);
38 }
39 if let Some(els) = &self.els {
40 buf.push_str(" ELSE ");
41 buf.push_sql(els.as_ref(), dialect);
42 }
43 buf.push_str(" END");
44 }
45}
46
47#[derive(Debug, Clone, PartialEq, Eq)]
48#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
49pub enum Operation {
50 Eq,
51 Gte,
52 Lte,
53 Gt,
54 Lt,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
58#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
59pub enum Expr {
60 Case(Case),
61 And(Vec<Expr>),
62 Raw(String),
63 NotDistinctFrom(Box<Expr>, Box<Expr>),
64 Column {
65 schema: Option<String>,
66 table: Option<String>,
67 column: String,
68 },
69 BinOp(Operation, Box<Expr>, Box<Expr>),
70}
71
72impl Expr {
73 pub fn excluded(column: &str) -> Self {
74 Self::Raw(format!("excluded.\"{}\"", column))
75 }
76
77 pub fn column(column: &str) -> Self {
78 Self::Column {
79 schema: None,
80 table: None,
81 column: column.to_string(),
82 }
83 }
84
85 pub fn new_eq<L: Into<Expr>, R: Into<Expr>>(left: L, right: R) -> Self {
86 Self::BinOp(Operation::Eq, Box::new(left.into()), Box::new(right.into()))
87 }
88
89 pub fn table_column(table: &str, column: &str) -> Self {
90 Self::Column {
91 schema: None,
92 table: Some(table.to_string()),
93 column: column.to_string(),
94 }
95 }
96
97 pub fn schema_column(schema: &str, table: &str, column: &str) -> Self {
98 Self::Column {
99 schema: Some(schema.to_string()),
100 table: Some(table.to_string()),
101 column: column.to_string(),
102 }
103 }
104
105 pub fn new_and(and: Vec<Expr>) -> Self {
106 Self::And(and)
107 }
108
109 pub fn case(case: Case) -> Self {
110 Self::Case(case)
111 }
112
113 pub fn not_distinct_from<L: Into<Expr>, R: Into<Expr>>(left: L, right: R) -> Self {
114 Self::NotDistinctFrom(Box::new(left.into()), Box::new(right.into()))
115 }
116}
117
118impl Into<Expr> for &str {
119 fn into(self) -> Expr {
120 Expr::Raw(self.to_string())
121 }
122}
123
124impl ToSql for Expr {
125 fn write_sql(&self, buf: &mut String, dialect: Dialect) {
126 match self {
127 Expr::Case(c) => c.write_sql(buf, dialect),
128 Expr::And(and) => {
129 buf.push('(');
130 for (i, expr) in and.iter().enumerate() {
131 if i > 0 {
132 buf.push_str(" AND ");
133 }
134 buf.push('(');
135 expr.write_sql(buf, dialect);
136 buf.push(')');
137 }
138 buf.push(')');
139 }
140 Expr::Raw(a) => buf.push_str(a),
141 Expr::NotDistinctFrom(l, r) => {
142 buf.push_sql(l.as_ref(), dialect);
143 buf.push_str(" IS NOT DISTINCT FROM ");
144 buf.push_sql(r.as_ref(), dialect);
145 }
146 Expr::Column {
147 schema,
148 table,
149 column,
150 } => {
151 if let Some(schema) = schema {
152 buf.push_quoted(schema);
153 buf.push('.');
154 }
155 if let Some(table) = table {
156 buf.push_quoted(table);
157 buf.push('.');
158 }
159 buf.push_quoted(column);
160 }
161 Expr::BinOp(op, l, r) => {
162 buf.push_sql(l.as_ref(), dialect);
163 buf.push_sql(op, dialect);
164 buf.push_sql(r.as_ref(), dialect);
165 }
166 }
167 }
168}
169
170impl ToSql for Operation {
171 fn write_sql(&self, buf: &mut String, _dialect: Dialect) {
172 match self {
173 Operation::Eq => buf.push_str(" = "),
174 Operation::Gte => buf.push_str(" >= "),
175 Operation::Lte => buf.push_str(" <= "),
176 Operation::Gt => buf.push_str(" > "),
177 Operation::Lt => buf.push_str(" < "),
178 }
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_and_clauses_wrapped_in_parentheses() {
188 let expr = Expr::And(vec![
189 Expr::Raw("a = 1".to_string()),
190 Expr::Raw("b = 2".to_string()),
191 Expr::Raw("c = 3".to_string()),
192 ]);
193 let sql = expr.to_sql(Dialect::Postgres);
194 assert_eq!(sql, "((a = 1) AND (b = 2) AND (c = 3))");
195 }
196}