1use crate::util::SqlExtension;
2use crate::{Dialect, ToSql};
3
4#[derive(Debug, Clone, PartialEq, Eq)]
5#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6pub struct Case {
7 cases: Vec<(Expr, Expr)>,
8 els: Option<Box<Expr>>,
9}
10
11impl Case {
12 pub fn new_when<C: Into<Expr>, V: Into<Expr>>(condition: C, then_value: V) -> Self {
13 Self {
14 cases: vec![(condition.into(), then_value.into())],
15 els: None,
16 }
17 }
18
19 pub fn when(mut self, condition: Expr, value: Expr) -> Self {
20 self.cases.push((condition, value));
21 self
22 }
23
24 pub fn els<V: Into<Expr>>(mut self, value: V) -> Self {
25 self.els = Some(Box::new(value.into()));
26 self
27 }
28}
29
30impl ToSql for Case {
31 fn write_sql(&self, buf: &mut String, dialect: Dialect) {
32 buf.push_str("CASE ");
33 for c in &self.cases {
34 buf.push_str("WHEN ");
35 buf.push_sql(&c.0, dialect);
36 buf.push_str(" THEN ");
37 buf.push_sql(&c.1, dialect);
38 }
39 if let Some(els) = &self.els {
40 buf.push_str(" ELSE ");
41 buf.push_sql(els.as_ref(), dialect);
42 }
43 buf.push_str(" END");
44 }
45}
46
47#[derive(Debug, Clone, PartialEq, Eq)]
48#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
49pub enum Operation {
50 Eq,
51 Gte,
52 Lte,
53 Gt,
54 Lt,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
58#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
59pub enum Expr {
60 Case(Case),
61 And(Vec<Expr>),
62 Raw(String),
63 NotDistinctFrom(Box<Expr>, Box<Expr>),
64 Column {
65 schema: Option<String>,
66 table: Option<String>,
67 column: String,
68 },
69 BinOp(Operation, Box<Expr>, Box<Expr>),
70}
71
72impl Expr {
73 pub fn excluded(column: &str) -> Self {
74 Self::Raw(format!("excluded.\"{}\"", column))
75 }
76
77 pub fn column(column: &str) -> Self {
78 Self::Column {
79 schema: None,
80 table: None,
81 column: column.to_string(),
82 }
83 }
84
85 pub fn new_eq<L: Into<Expr>, R: Into<Expr>>(left: L, right: R) -> Self {
86 Self::BinOp(Operation::Eq, Box::new(left.into()), Box::new(right.into()))
87 }
88
89 pub fn table_column(table: &str, column: &str) -> Self {
90 Self::Column {
91 schema: None,
92 table: Some(table.to_string()),
93 column: column.to_string(),
94 }
95 }
96
97 pub fn schema_column(schema: &str, table: &str, column: &str) -> Self {
98 Self::Column {
99 schema: Some(schema.to_string()),
100 table: Some(table.to_string()),
101 column: column.to_string(),
102 }
103 }
104
105 pub fn new_and(and: Vec<Expr>) -> Self {
106 Self::And(and)
107 }
108
109 pub fn case(case: Case) -> Self {
110 Self::Case(case)
111 }
112
113 pub fn not_distinct_from<L: Into<Expr>, R: Into<Expr>>(left: L, right: R) -> Self {
114 Self::NotDistinctFrom(Box::new(left.into()), Box::new(right.into()))
115 }
116}
117
118impl Into<Expr> for &str {
119 fn into(self) -> Expr {
120 Expr::Raw(self.to_string())
121 }
122}
123
124impl ToSql for Expr {
125 fn write_sql(&self, buf: &mut String, dialect: Dialect) {
126 match self {
127 Expr::Case(c) => c.write_sql(buf, dialect),
128 Expr::And(and) => {
129 buf.push('(');
130 buf.push_sql_sequence(&and, " AND ", dialect);
131 buf.push(')');
132 }
133 Expr::Raw(a) => buf.push_str(a),
134 Expr::NotDistinctFrom(l, r) => {
135 buf.push_sql(l.as_ref(), dialect);
136 buf.push_str(" IS NOT DISTINCT FROM ");
137 buf.push_sql(r.as_ref(), dialect);
138 }
139 Expr::Column {
140 schema,
141 table,
142 column,
143 } => {
144 if let Some(schema) = schema {
145 buf.push_quoted(schema);
146 buf.push('.');
147 }
148 if let Some(table) = table {
149 buf.push_quoted(table);
150 buf.push('.');
151 }
152 buf.push_quoted(column);
153 }
154 Expr::BinOp(op, l, r) => {
155 buf.push_sql(l.as_ref(), dialect);
156 buf.push_sql(op, dialect);
157 buf.push_sql(r.as_ref(), dialect);
158 }
159 }
160 }
161}
162
163impl ToSql for Operation {
164 fn write_sql(&self, buf: &mut String, _dialect: Dialect) {
165 match self {
166 Operation::Eq => buf.push_str(" = "),
167 Operation::Gte => buf.push_str(" >= "),
168 Operation::Lte => buf.push_str(" <= "),
169 Operation::Gt => buf.push_str(" > "),
170 Operation::Lt => buf.push_str(" < "),
171 }
172 }
173}