koron_query_parser/
comparison.rs

1use std::fmt::{self, Display};
2
3use serde::{Deserialize, Serialize};
4use sqlparser::ast;
5use utoipa::ToSchema;
6
7use crate::{
8    error::ParseError, query_metadata::FromClauseIdentifier, support::case_fold_identifier,
9    unsupported,
10};
11
12use super::support::{extract_qualified_column, remove_outer_parens};
13
14#[must_use]
15pub const fn is_binary_operator_supported(op: &ast::BinaryOperator) -> bool {
16    matches!(
17        op,
18        &ast::BinaryOperator::Gt
19            | &ast::BinaryOperator::GtEq
20            | &ast::BinaryOperator::Lt
21            | &ast::BinaryOperator::LtEq
22            | &ast::BinaryOperator::Eq
23            | &ast::BinaryOperator::NotEq
24    )
25}
26
27#[must_use]
28pub const fn is_expression_supported(op: &ast::Expr) -> bool {
29    matches!(
30        op,
31        &ast::Expr::IsNull(..)
32            | &ast::Expr::IsNotNull(..)
33            | &ast::Expr::IsTrue(..)
34            | &ast::Expr::IsNotTrue(..)
35            | &ast::Expr::IsFalse(..)
36            | &ast::Expr::IsNotFalse(..)
37    )
38}
39
40/// The comparison operation between the value of an unspecified column and some constant values.
41#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default, ToSchema)]
42#[serde(tag = "type")]
43pub enum CompareOp {
44    /// Check if column's value is less than `value`.
45    Lt { value: String },
46    /// Check if column's value is less than or equal to `value`.
47    LtEq { value: String },
48    /// Check if column's value is greater than `value`.
49    Gt { value: String },
50    /// Check if column's value is greater than or equal to `value`.
51    GtEq { value: String },
52    /// Check if column's value is equal to `value`.
53    Eq { value: String },
54    /// Check if column's value is not equal to `value`.
55    NotEq { value: String },
56    /// Check if column's value is `NULL`.
57    #[default]
58    IsNull,
59    /// Check if column's value is not `NULL`.
60    IsNotNull,
61    /// Check if column's value is `true`.
62    IsTrue,
63    /// Check if column's value is not `true`.
64    IsNotTrue,
65    /// Check if column's value is `false`.
66    IsFalse,
67    /// Check if column's value is not `false`.
68    IsNotFalse,
69}
70
71impl Display for CompareOp {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        match self {
74            Self::Lt { value: _ } => write!(f, "Less than"),
75            Self::LtEq { value: _ } => write!(f, "Less than or equal"),
76            Self::Gt { value: _ } => write!(f, "Greater than"),
77            Self::GtEq { value: _ } => write!(f, "Greater than or equal"),
78            Self::Eq { value: _ } => write!(f, "Equal"),
79            Self::NotEq { value: _ } => write!(f, "Not equal"),
80            Self::IsNull => write!(f, "Is null"),
81            Self::IsNotNull => write!(f, "Is not null"),
82            Self::IsTrue => write!(f, "Is true"),
83            Self::IsNotTrue => write!(f, "Is not true"),
84            Self::IsFalse => write!(f, "Is false"),
85            Self::IsNotFalse => write!(f, "Is not false"),
86        }
87    }
88}
89
90impl CompareOp {
91    pub(crate) fn from_binary_operator(
92        op: &ast::BinaryOperator,
93        value: String,
94        reverse: bool,
95    ) -> Result<Self, ParseError> {
96        let comparison = match op {
97            ast::BinaryOperator::Lt if reverse => Self::Gt { value },
98            ast::BinaryOperator::Lt => Self::Lt { value },
99            ast::BinaryOperator::LtEq if reverse => Self::GtEq { value },
100            ast::BinaryOperator::LtEq => Self::LtEq { value },
101            ast::BinaryOperator::Gt if reverse => Self::Lt { value },
102            ast::BinaryOperator::Gt => Self::Gt { value },
103            ast::BinaryOperator::GtEq if reverse => Self::LtEq { value },
104            ast::BinaryOperator::GtEq => Self::GtEq { value },
105            ast::BinaryOperator::Eq => Self::Eq { value },
106            ast::BinaryOperator::NotEq => Self::NotEq { value },
107            _ => {
108                return Err(unsupported!(format!("the {op} operator.")));
109            }
110        };
111        Ok(comparison)
112    }
113
114    pub(crate) fn from_expr(op: &ast::Expr) -> Result<Self, ParseError> {
115        let comparison = match op {
116            ast::Expr::IsNull(_) => Self::IsNull,
117            ast::Expr::IsNotNull(_) => Self::IsNotNull,
118            ast::Expr::IsTrue(_) => Self::IsTrue,
119            ast::Expr::IsNotTrue(_) => Self::IsNotTrue,
120            ast::Expr::IsFalse(_) => Self::IsFalse,
121            ast::Expr::IsNotFalse(_) => Self::IsNotFalse,
122            _ => {
123                return Err(unsupported!(format!("the {op} operator.")));
124            }
125        };
126        Ok(comparison)
127    }
128}
129
130#[derive(Debug)]
131pub(crate) enum ComparisonOperand<'a> {
132    Column(String),
133    // Other can be a static value, or another expression
134    Other(&'a ast::Expr),
135}
136
137impl<'a> ComparisonOperand<'a> {
138    pub(crate) fn from_expression(
139        from_clause_identifier: FromClauseIdentifier<'_>,
140        expr: &'a ast::Expr,
141    ) -> Result<Self, ParseError> {
142        let expr = remove_outer_parens(expr);
143        match expr {
144            ast::Expr::Identifier(ident) => Ok(Self::Column(case_fold_identifier(ident))),
145            ast::Expr::CompoundIdentifier(name_parts) => {
146                extract_qualified_column(from_clause_identifier, expr, name_parts).map(Self::Column)
147            }
148            _ => Ok(Self::Other(expr)),
149        }
150    }
151}
152
153pub(crate) fn analyze_comparison_operands<'a>(
154    binary_expr: &'a ast::Expr,
155    left: ComparisonOperand<'a>,
156    right: ComparisonOperand<'a>,
157) -> Result<(String, &'a ast::Expr, bool), ParseError> {
158    match (left, right) {
159        (ComparisonOperand::Column(column), ComparisonOperand::Other(value)) => {
160            Ok((column, value, false))
161        }
162        (ComparisonOperand::Other(value), ComparisonOperand::Column(column)) => {
163            // keep on the left the column
164            Ok((column, value, true))
165        }
166        _ => Err(unsupported!(format!(
167            "{binary_expr}. Only comparisons between a column and a constant are supported.",
168        ))),
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use sqlparser::ast::Ident;
175
176    use crate::{
177        comparison::{is_binary_operator_supported, is_expression_supported, CompareOp},
178        error::ParseError,
179    };
180
181    use super::ast;
182    #[test]
183    fn test_supported_binary_operator() {
184        assert!(is_binary_operator_supported(&ast::BinaryOperator::Gt));
185        assert!(is_binary_operator_supported(&ast::BinaryOperator::GtEq));
186        assert!(is_binary_operator_supported(&ast::BinaryOperator::Lt));
187        assert!(is_binary_operator_supported(&ast::BinaryOperator::LtEq));
188        assert!(is_binary_operator_supported(&ast::BinaryOperator::Eq));
189        assert!(is_binary_operator_supported(&ast::BinaryOperator::NotEq));
190    }
191
192    #[test]
193    fn test_unsupported_binary_operator() {
194        assert!(!is_binary_operator_supported(&ast::BinaryOperator::Plus));
195        assert!(!is_binary_operator_supported(&ast::BinaryOperator::Minus));
196        assert!(!is_binary_operator_supported(
197            &ast::BinaryOperator::Multiply
198        ));
199        assert!(!is_binary_operator_supported(&ast::BinaryOperator::Divide));
200        assert!(!is_binary_operator_supported(&ast::BinaryOperator::Modulo));
201        assert!(!is_binary_operator_supported(
202            &ast::BinaryOperator::StringConcat
203        ));
204        assert!(!is_binary_operator_supported(
205            &ast::BinaryOperator::Spaceship
206        ));
207        assert!(!is_binary_operator_supported(&ast::BinaryOperator::And));
208        assert!(!is_binary_operator_supported(&ast::BinaryOperator::Or));
209        assert!(!is_binary_operator_supported(&ast::BinaryOperator::Xor));
210        assert!(!is_binary_operator_supported(
211            &ast::BinaryOperator::BitwiseAnd
212        ));
213        assert!(!is_binary_operator_supported(
214            &ast::BinaryOperator::BitwiseOr
215        ));
216        assert!(!is_binary_operator_supported(
217            &ast::BinaryOperator::BitwiseXor
218        ));
219        assert!(!is_binary_operator_supported(
220            &ast::BinaryOperator::PGBitwiseXor
221        ));
222        assert!(!is_binary_operator_supported(
223            &ast::BinaryOperator::PGBitwiseShiftLeft
224        ));
225        assert!(!is_binary_operator_supported(
226            &ast::BinaryOperator::PGBitwiseShiftRight
227        ));
228        assert!(!is_binary_operator_supported(
229            &ast::BinaryOperator::PGRegexIMatch
230        ));
231        assert!(!is_binary_operator_supported(
232            &ast::BinaryOperator::PGRegexMatch
233        ));
234        assert!(!is_binary_operator_supported(
235            &ast::BinaryOperator::PGRegexNotIMatch
236        ));
237        assert!(!is_binary_operator_supported(
238            &ast::BinaryOperator::PGRegexNotMatch
239        ));
240        assert!(!is_binary_operator_supported(
241            &ast::BinaryOperator::PGCustomBinaryOperator(Vec::new())
242        ));
243    }
244
245    #[test]
246    fn test_supported_expression() {
247        assert!(is_expression_supported(&ast::Expr::IsNull(Box::new(
248            ast::Expr::Identifier(Ident {
249                value: String::new(),
250                quote_style: None
251            })
252        ))));
253        assert!(is_expression_supported(&ast::Expr::IsNotNull(Box::new(
254            ast::Expr::Identifier(Ident {
255                value: String::new(),
256                quote_style: None
257            })
258        ))));
259    }
260
261    #[test]
262    fn test_from_binary_operator() {
263        let value: String = "1".to_string();
264        let mut reverse = false;
265        let expected_lt = CompareOp::Lt {
266            value: value.clone(),
267        };
268        let expected_lt_eq = CompareOp::LtEq {
269            value: value.clone(),
270        };
271        let expected_gt = CompareOp::Gt {
272            value: value.clone(),
273        };
274        let expected_gt_eq = CompareOp::GtEq {
275            value: value.clone(),
276        };
277        let expected_eq = CompareOp::Eq {
278            value: value.clone(),
279        };
280        let expected_not_eq = CompareOp::NotEq {
281            value: value.clone(),
282        };
283        let expected_error = ParseError::Unsupported {
284            message: "the AND operator.".to_string(),
285        };
286
287        let op = ast::BinaryOperator::Lt;
288        let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
289        assert_eq!(expected_lt, result);
290
291        let op = ast::BinaryOperator::LtEq;
292        let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
293        assert_eq!(expected_lt_eq, result);
294
295        let op = ast::BinaryOperator::Gt;
296        let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
297        assert_eq!(expected_gt, result);
298
299        let op = ast::BinaryOperator::GtEq;
300        let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
301        assert_eq!(expected_gt_eq, result);
302
303        reverse = true;
304
305        let op = ast::BinaryOperator::Gt;
306        let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
307        assert_eq!(expected_lt, result);
308
309        let op = ast::BinaryOperator::GtEq;
310        let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
311        assert_eq!(expected_lt_eq, result);
312
313        let op = ast::BinaryOperator::Lt;
314        let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
315        assert_eq!(expected_gt, result);
316
317        let op = ast::BinaryOperator::LtEq;
318        let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
319        assert_eq!(expected_gt_eq, result);
320
321        let op = ast::BinaryOperator::Eq;
322        let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
323        assert_eq!(expected_eq, result);
324
325        let op = ast::BinaryOperator::NotEq;
326        let result = CompareOp::from_binary_operator(&op, value.clone(), reverse).unwrap();
327        assert_eq!(expected_not_eq, result);
328
329        let op = ast::BinaryOperator::And;
330        let result = CompareOp::from_binary_operator(&op, value, reverse).unwrap_err();
331        assert_eq!(expected_error, result);
332    }
333
334    #[test]
335    fn test_from_expr() {
336        let expected_is_null = CompareOp::IsNull;
337        let expected_is_not_null = CompareOp::IsNotNull;
338
339        let op = ast::Expr::IsNull(Box::new(ast::Expr::Identifier(Ident {
340            value: String::new(),
341            quote_style: None,
342        })));
343        let result = CompareOp::from_expr(&op).unwrap();
344        assert_eq!(expected_is_null, result);
345
346        let op = ast::Expr::IsNotNull(Box::new(ast::Expr::Identifier(Ident {
347            value: String::new(),
348            quote_style: None,
349        })));
350        let result = CompareOp::from_expr(&op).unwrap();
351        assert_eq!(expected_is_not_null, result);
352    }
353}