1use std::fmt::{self, Display};
2
3use serde::{Deserialize, Serialize};
4use sqlparser::ast;
5use utoipa::ToSchema;
6
7use crate::{
8 error::ParseError, query_metadata::FromClauseIdentifier, support::case_fold_identifier,
9 unsupported,
10};
11
12use super::support::{extract_qualified_column, remove_outer_parens};
13
14#[must_use]
15pub const fn is_binary_operator_supported(op: &ast::BinaryOperator) -> bool {
16 matches!(
17 op,
18 &ast::BinaryOperator::Gt
19 | &ast::BinaryOperator::GtEq
20 | &ast::BinaryOperator::Lt
21 | &ast::BinaryOperator::LtEq
22 | &ast::BinaryOperator::Eq
23 | &ast::BinaryOperator::NotEq
24 )
25}
26
27#[must_use]
28pub const fn is_expression_supported(op: &ast::Expr) -> bool {
29 matches!(
30 op,
31 &ast::Expr::IsNull(..)
32 | &ast::Expr::IsNotNull(..)
33 | &ast::Expr::IsTrue(..)
34 | &ast::Expr::IsNotTrue(..)
35 | &ast::Expr::IsFalse(..)
36 | &ast::Expr::IsNotFalse(..)
37 )
38}
39
40#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default, ToSchema)]
42#[serde(tag = "type")]
43pub enum CompareOp {
44 Lt { value: String },
46 LtEq { value: String },
48 Gt { value: String },
50 GtEq { value: String },
52 Eq { value: String },
54 NotEq { value: String },
56 #[default]
58 IsNull,
59 IsNotNull,
61 IsTrue,
63 IsNotTrue,
65 IsFalse,
67 IsNotFalse,
69}
70
71impl Display for CompareOp {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 match self {
74 Self::Lt { value: _ } => write!(f, "Less than"),
75 Self::LtEq { value: _ } => write!(f, "Less than or equal"),
76 Self::Gt { value: _ } => write!(f, "Greater than"),
77 Self::GtEq { value: _ } => write!(f, "Greater than or equal"),
78 Self::Eq { value: _ } => write!(f, "Equal"),
79 Self::NotEq { value: _ } => write!(f, "Not equal"),
80 Self::IsNull => write!(f, "Is null"),
81 Self::IsNotNull => write!(f, "Is not null"),
82 Self::IsTrue => write!(f, "Is true"),
83 Self::IsNotTrue => write!(f, "Is not true"),
84 Self::IsFalse => write!(f, "Is false"),
85 Self::IsNotFalse => write!(f, "Is not false"),
86 }
87 }
88}
89
90impl CompareOp {
91 pub(crate) fn from_binary_operator(
92 op: &ast::BinaryOperator,
93 value: String,
94 reverse: bool,
95 ) -> Result<Self, ParseError> {
96 let comparison = match op {
97 ast::BinaryOperator::Lt if reverse => Self::Gt { value },
98 ast::BinaryOperator::Lt => Self::Lt { value },
99 ast::BinaryOperator::LtEq if reverse => Self::GtEq { value },
100 ast::BinaryOperator::LtEq => Self::LtEq { value },
101 ast::BinaryOperator::Gt if reverse => Self::Lt { value },
102 ast::BinaryOperator::Gt => Self::Gt { value },
103 ast::BinaryOperator::GtEq if reverse => Self::LtEq { value },
104 ast::BinaryOperator::GtEq => Self::GtEq { value },
105 ast::BinaryOperator::Eq => Self::Eq { value },
106 ast::BinaryOperator::NotEq => Self::NotEq { value },
107 _ => {
108 return Err(unsupported!(format!("the {op} operator.")));
109 }
110 };
111 Ok(comparison)
112 }
113
114 pub(crate) fn from_expr(op: &ast::Expr) -> Result<Self, ParseError> {
115 let comparison = match op {
116 ast::Expr::IsNull(_) => Self::IsNull,
117 ast::Expr::IsNotNull(_) => Self::IsNotNull,
118 ast::Expr::IsTrue(_) => Self::IsTrue,
119 ast::Expr::IsNotTrue(_) => Self::IsNotTrue,
120 ast::Expr::IsFalse(_) => Self::IsFalse,
121 ast::Expr::IsNotFalse(_) => Self::IsNotFalse,
122 _ => {
123 return Err(unsupported!(format!("the {op} operator.")));
124 }
125 };
126 Ok(comparison)
127 }
128}
129
130#[derive(Debug)]
131pub(crate) enum ComparisonOperand<'a> {
132 Column(String),
133 Other(&'a ast::Expr),
135}
136
137impl<'a> ComparisonOperand<'a> {
138 pub(crate) fn from_expression(
139 from_clause_identifier: FromClauseIdentifier<'_>,
140 expr: &'a ast::Expr,
141 ) -> Result<Self, ParseError> {
142 let expr = remove_outer_parens(expr);
143 match expr {
144 ast::Expr::Identifier(ident) => Ok(Self::Column(case_fold_identifier(ident))),
145 ast::Expr::CompoundIdentifier(name_parts) => {
146 extract_qualified_column(from_clause_identifier, expr, name_parts).map(Self::Column)
147 }
148 _ => Ok(Self::Other(expr)),
149 }
150 }
151}
152
153pub(crate) fn analyze_comparison_operands<'a>(
154 binary_expr: &'a ast::Expr,
155 left: ComparisonOperand<'a>,
156 right: ComparisonOperand<'a>,
157) -> Result<(String, &'a ast::Expr, bool), ParseError> {
158 match (left, right) {
159 (ComparisonOperand::Column(column), ComparisonOperand::Other(value)) => {
160 Ok((column, value, false))
161 }
162 (ComparisonOperand::Other(value), ComparisonOperand::Column(column)) => {
163 Ok((column, value, true))
165 }
166 _ => Err(unsupported!(format!(
167 "{binary_expr}. Only comparisons between a column and a constant are supported.",
168 ))),
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use sqlparser::ast::Ident;
175
176 use crate::{
177 comparison::{is_binary_operator_supported, is_expression_supported, CompareOp},
178 error::ParseError,
179 };
180
181 use super::ast;
182 #[test]
183 fn test_supported_binary_operator() {
184 assert!(is_binary_operator_supported(&ast::BinaryOperator::Gt));
185 assert!(is_binary_operator_supported(&ast::BinaryOperator::GtEq));
186 assert!(is_binary_operator_supported(&ast::BinaryOperator::Lt));
187 assert!(is_binary_operator_supported(&ast::BinaryOperator::LtEq));
188 assert!(is_binary_operator_supported(&ast::BinaryOperator::Eq));
189 assert!(is_binary_operator_supported(&ast::BinaryOperator::NotEq));
190 }
191
192 #[test]
193 fn test_unsupported_binary_operator() {
194 assert!(!is_binary_operator_supported(&ast::BinaryOperator::Plus));
195 assert!(!is_binary_operator_supported(&ast::BinaryOperator::Minus));
196 assert!(!is_binary_operator_supported(
197 &ast::BinaryOperator::Multiply
198 ));
199 assert!(!is_binary_operator_supported(&ast::BinaryOperator::Divide));
200 assert!(!is_binary_operator_supported(&ast::BinaryOperator::Modulo));
201 assert!(!is_binary_operator_supported(
202 &ast::BinaryOperator::StringConcat
203 ));
204 assert!(!is_binary_operator_supported(
205 &ast::BinaryOperator::Spaceship
206 ));
207 assert!(!is_binary_operator_supported(&ast::BinaryOperator::And));
208 assert!(!is_binary_operator_supported(&ast::BinaryOperator::Or));
209 assert!(!is_binary_operator_supported(&ast::BinaryOperator::Xor));
210 assert!(!is_binary_operator_supported(
211 &ast::BinaryOperator::BitwiseAnd
212 ));
213 assert!(!is_binary_operator_supported(
214 &ast::BinaryOperator::BitwiseOr
215 ));
216 assert!(!is_binary_operator_supported(
217 &ast::BinaryOperator::BitwiseXor
218 ));
219 assert!(!is_binary_operator_supported(
220 &ast::BinaryOperator::PGBitwiseXor
221 ));
222 assert!(!is_binary_operator_supported(
223 &ast::BinaryOperator::PGBitwiseShiftLeft
224 ));
225 assert!(!is_binary_operator_supported(
226 &ast::BinaryOperator::PGBitwiseShiftRight
227 ));
228 assert!(!is_binary_operator_supported(
229 &ast::BinaryOperator::PGRegexIMatch
230 ));
231 assert!(!is_binary_operator_supported(
232 &ast::BinaryOperator::PGRegexMatch
233 ));
234 assert!(!is_binary_operator_supported(
235 &ast::BinaryOperator::PGRegexNotIMatch
236 ));
237 assert!(!is_binary_operator_supported(
238 &ast::BinaryOperator::PGRegexNotMatch
239 ));
240 assert!(!is_binary_operator_supported(
241 &ast::BinaryOperator::PGCustomBinaryOperator(Vec::new())
242 ));
243 }
244
245 #[test]
246 fn test_supported_expression() {
247 assert!(is_expression_supported(&ast::Expr::IsNull(Box::new(
248 ast::Expr::Identifier(Ident {
249 value: String::new(),
250 quote_style: None
251 })
252 ))));
253 assert!(is_expression_supported(&ast::Expr::IsNotNull(Box::new(
254 ast::Expr::Identifier(Ident {
255 value: String::new(),
256 quote_style: None
257 })
258 ))));
259 }
260
261 #[test]
262 fn test_from_binary_operator() {
263 let value: String = "1".to_string();
264 let mut reverse = false;
265 let expected_lt = CompareOp::Lt {
266 value: value.clone(),
267 };
268 let expected_lt_eq = CompareOp::LtEq {
269 value: value.clone(),
270 };
271 let expected_gt = CompareOp::Gt {
272 value: value.clone(),
273 };
274 let expected_gt_eq = CompareOp::GtEq {
275 value: value.clone(),
276 };
277 let expected_eq = CompareOp::Eq {
278 value: value.clone(),
279 };
280 let expected_not_eq = CompareOp::NotEq {
281 value: value.clone(),
282 };
283 let expected_error = ParseError::Unsupported {
284 message: "the AND operator.".to_string(),
285 };
286
287 let op = ast::BinaryOperator::Lt;
288 let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
289 assert_eq!(expected_lt, result);
290
291 let op = ast::BinaryOperator::LtEq;
292 let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
293 assert_eq!(expected_lt_eq, result);
294
295 let op = ast::BinaryOperator::Gt;
296 let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
297 assert_eq!(expected_gt, result);
298
299 let op = ast::BinaryOperator::GtEq;
300 let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
301 assert_eq!(expected_gt_eq, result);
302
303 reverse = true;
304
305 let op = ast::BinaryOperator::Gt;
306 let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
307 assert_eq!(expected_lt, result);
308
309 let op = ast::BinaryOperator::GtEq;
310 let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
311 assert_eq!(expected_lt_eq, result);
312
313 let op = ast::BinaryOperator::Lt;
314 let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
315 assert_eq!(expected_gt, result);
316
317 let op = ast::BinaryOperator::LtEq;
318 let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
319 assert_eq!(expected_gt_eq, result);
320
321 let op = ast::BinaryOperator::Eq;
322 let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
323 assert_eq!(expected_eq, result);
324
325 let op = ast::BinaryOperator::NotEq;
326 let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
327 assert_eq!(expected_not_eq, result);
328
329 let op = ast::BinaryOperator::And;
330 let result = CompareOp::from_binary_operator(&op, value, reverse).unwrap_err();
331 assert_eq!(expected_error, result);
332 }
333
334 #[test]
335 fn test_from_expr() {
336 let expected_is_null = CompareOp::IsNull;
337 let expected_is_not_null = CompareOp::IsNotNull;
338
339 let op = ast::Expr::IsNull(Box::new(ast::Expr::Identifier(Ident {
340 value: String::new(),
341 quote_style: None,
342 })));
343 let result = CompareOp::from_expr(&op).unwrap();
344 assert_eq!(expected_is_null, result);
345
346 let op = ast::Expr::IsNotNull(Box::new(ast::Expr::Identifier(Ident {
347 value: String::new(),
348 quote_style: None,
349 })));
350 let result = CompareOp::from_expr(&op).unwrap();
351 assert_eq!(expected_is_not_null, result);
352 }
353}