cairo_lint_core/lints/
double_comparison.rs

1use std::collections::HashSet;
2
3use cairo_lang_defs::ids::ModuleItemId;
4use cairo_lang_defs::plugin::PluginDiagnostic;
5use cairo_lang_diagnostics::Severity;
6use cairo_lang_semantic::db::SemanticGroup;
7use cairo_lang_semantic::{
8    Arenas, Expr, ExprFunctionCall, ExprFunctionCallArg, ExprLogicalOperator, LogicalOperator,
9};
10use cairo_lang_syntax::node::ast::{BinaryOperator, Expr as AstExpr};
11use cairo_lang_syntax::node::db::SyntaxGroup;
12use cairo_lang_syntax::node::{SyntaxNode, TypedStablePtr, TypedSyntaxNode};
13
14use super::function_trait_name_from_fn_id;
15use crate::context::{CairoLintKind, Lint};
16use crate::lints::{EQ, GE, GT, LE, LT};
17use crate::queries::{get_all_function_bodies, get_all_logical_operator_expressions};
18
19pub struct ImpossibleComparison;
20
21/// ## What it does
22///
23/// Checks for impossible comparisons. Those ones always return false.
24///
25/// ## Example
26///
27/// Here is an example of impossible comparison:
28///
29/// ```cairo
30/// fn main() {
31///     let x: u32 = 1;
32///     if x > 200 && x < 100 {
33///         //impossible to reach
34///     }
35/// }
36/// ```
37impl Lint for ImpossibleComparison {
38    fn allowed_name(&self) -> &'static str {
39        "impossible_comparison"
40    }
41
42    fn diagnostic_message(&self) -> &'static str {
43        "Impossible condition, always false"
44    }
45
46    fn kind(&self) -> CairoLintKind {
47        CairoLintKind::ImpossibleComparison
48    }
49}
50
51pub struct SimplifiableComparison;
52
53/// ## What it does
54///
55/// Checks for double comparisons that can be simplified.
56/// Those are comparisons that can be simplified to a single comparison.
57///
58/// ## Example
59///
60/// ```cairo
61/// fn main() -> bool {
62///     let x = 5_u32;
63///     let y = 10_u32;
64///     if x == y || x > y {
65///         true
66///     } else {
67///         false
68///     }
69/// }
70/// ```
71///
72/// The above code can be simplified to:
73///
74/// ```cairo
75/// fn main() -> bool {
76///     let x = 5_u32;
77///     let y = 10_u32;
78///     if x >= y {
79///         true
80///     } else {
81///         false
82///     }
83/// }
84/// ```
85impl Lint for SimplifiableComparison {
86    fn allowed_name(&self) -> &'static str {
87        "simplifiable_comparison"
88    }
89
90    fn diagnostic_message(&self) -> &'static str {
91        "This double comparison can be simplified."
92    }
93
94    fn kind(&self) -> CairoLintKind {
95        CairoLintKind::DoubleComparison
96    }
97
98    fn has_fixer(&self) -> bool {
99        true
100    }
101
102    fn fix(&self, db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
103        fix_double_comparison(db, node)
104    }
105}
106
107pub struct RedundantComparison;
108
109/// ## What it does
110///
111/// Checks for double comparisons that are redundant. Those are comparisons that can be simplified to a single comparison.
112///
113/// ## Example
114///
115/// ```cairo
116/// fn main() -> bool {
117///     let x = 5_u32;
118///     let y = 10_u32;
119///     if x >= y || x <= y {
120///         true
121///     } else {
122///         false
123///     }
124/// }
125/// ```
126///
127/// Could be simplified to just:
128///
129/// ```cairo
130/// fn main() -> bool {
131///     let x = 5_u32;
132///     let y = 10_u32;
133///     true
134/// }
135/// ```
136impl Lint for RedundantComparison {
137    fn allowed_name(&self) -> &'static str {
138        "redundant_comparison"
139    }
140
141    fn diagnostic_message(&self) -> &'static str {
142        "Redundant double comparison found. Consider simplifying to a single comparison."
143    }
144
145    fn kind(&self) -> CairoLintKind {
146        CairoLintKind::DoubleComparison
147    }
148
149    fn has_fixer(&self) -> bool {
150        true
151    }
152
153    fn fix(&self, db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
154        fix_double_comparison(db, node)
155    }
156}
157
158pub struct ContradictoryComparison;
159
160/// ## What it does
161///
162/// Checks for double comparisons that are contradictory. Those are comparisons that are always false.
163///
164/// ## Example
165///
166/// ```cairo
167/// fn main() -> bool {
168///     let x = 5_u32;
169///     let y = 10_u32;
170///     if x < y && x > y {
171///         true
172///     } else {
173///         false
174///     }
175/// }
176/// ```
177///
178/// Could be simplified to just:
179///
180/// ```cairo
181/// fn main() -> bool {
182///     let x = 5_u32;
183///     let y = 10_u32;
184///     false
185/// }
186/// ```
187impl Lint for ContradictoryComparison {
188    fn allowed_name(&self) -> &'static str {
189        "contradictory_comparison"
190    }
191
192    fn diagnostic_message(&self) -> &'static str {
193        "This double comparison is contradictory and always false."
194    }
195
196    fn kind(&self) -> CairoLintKind {
197        CairoLintKind::DoubleComparison
198    }
199
200    fn has_fixer(&self) -> bool {
201        true
202    }
203
204    fn fix(&self, db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
205        fix_double_comparison(db, node)
206    }
207}
208
209pub fn check_double_comparison(
210    db: &dyn SemanticGroup,
211    item: &ModuleItemId,
212    diagnostics: &mut Vec<PluginDiagnostic>,
213) {
214    let function_bodies = get_all_function_bodies(db, item);
215    for function_body in function_bodies.iter() {
216        let logical_operator_exprs = get_all_logical_operator_expressions(function_body);
217        let arenas = &function_body.arenas;
218        for logical_operator_expr in logical_operator_exprs.iter() {
219            check_single_double_comparison(db, logical_operator_expr, arenas, diagnostics);
220        }
221    }
222}
223
224fn check_single_double_comparison(
225    db: &dyn SemanticGroup,
226    logical_operator_exprs: &ExprLogicalOperator,
227    arenas: &Arenas,
228    diagnostics: &mut Vec<PluginDiagnostic>,
229) {
230    let Expr::FunctionCall(lhs_comparison) = &arenas.exprs[logical_operator_exprs.lhs] else {
231        return;
232    };
233    // If it's not 2 args it cannot be a regular comparison
234    if lhs_comparison.args.len() != 2 {
235        return;
236    }
237
238    let Expr::FunctionCall(rhs_comparison) = &arenas.exprs[logical_operator_exprs.rhs] else {
239        return;
240    };
241    // If it's not 2 args it cannot be a regular comparison
242    if rhs_comparison.args.len() != 2 {
243        return;
244    }
245    // Get the full name of the function used (trait name)
246    let (lhs_fn_trait_name, rhs_fn_trait_name) = (
247        function_trait_name_from_fn_id(db, &lhs_comparison.function),
248        function_trait_name_from_fn_id(db, &rhs_comparison.function),
249    );
250
251    // Check the impossible comparison
252    if check_impossible_comparison(
253        lhs_comparison,
254        rhs_comparison,
255        &lhs_fn_trait_name,
256        &rhs_fn_trait_name,
257        logical_operator_exprs,
258        db,
259        arenas,
260    ) {
261        diagnostics.push(PluginDiagnostic {
262            message: ImpossibleComparison.diagnostic_message().to_string(),
263            stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
264            severity: Severity::Error,
265        })
266    }
267
268    // The comparison functions don't work with refs so should only be value
269    let (llhs, rlhs) = match (&lhs_comparison.args[0], &lhs_comparison.args[1]) {
270        (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
271            (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id])
272        }
273        _ => {
274            return;
275        }
276    };
277    let (lrhs, rrhs) = match (&rhs_comparison.args[0], &rhs_comparison.args[1]) {
278        (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
279            (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id])
280        }
281        _ => return,
282    };
283    // Get all the operands
284    let llhs_var = llhs
285        .stable_ptr()
286        .lookup(db.upcast())
287        .as_syntax_node()
288        .get_text(db.upcast());
289    let rlhs_var = rlhs
290        .stable_ptr()
291        .lookup(db.upcast())
292        .as_syntax_node()
293        .get_text(db.upcast());
294    let lrhs_var = lrhs
295        .stable_ptr()
296        .lookup(db.upcast())
297        .as_syntax_node()
298        .get_text(db.upcast());
299    let rrhs_var = rrhs
300        .stable_ptr()
301        .lookup(db.upcast())
302        .as_syntax_node()
303        .get_text(db.upcast());
304    // Put them in a hashset to check equality without order
305    let lhs: HashSet<String> = HashSet::from_iter([llhs_var, rlhs_var]);
306    let rhs: HashSet<String> = HashSet::from_iter([lrhs_var, rrhs_var]);
307    if lhs != rhs {
308        return;
309    }
310
311    // TODO: support other expressions like tuples and literals
312    let should_return = match (llhs, rlhs) {
313        (Expr::Snapshot(llhs), Expr::Snapshot(rlhs)) => {
314            matches!(arenas.exprs[llhs.inner], Expr::FunctionCall(_))
315                || matches!(arenas.exprs[rlhs.inner], Expr::FunctionCall(_))
316        }
317        (Expr::Var(_), Expr::Var(_)) => false,
318        (Expr::Snapshot(llhs), Expr::Var(_)) => {
319            matches!(arenas.exprs[llhs.inner], Expr::FunctionCall(_))
320        }
321        (Expr::Var(_), Expr::Snapshot(rlhs)) => {
322            matches!(arenas.exprs[rlhs.inner], Expr::FunctionCall(_))
323        }
324        _ => return,
325    };
326    if should_return {
327        return;
328    }
329
330    if is_simplifiable_double_comparison(
331        &lhs_fn_trait_name,
332        &rhs_fn_trait_name,
333        &logical_operator_exprs.op,
334    ) {
335        diagnostics.push(PluginDiagnostic {
336            message: SimplifiableComparison.diagnostic_message().to_string(),
337            stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
338            severity: Severity::Warning,
339        });
340    } else if is_redundant_double_comparison(
341        &lhs_fn_trait_name,
342        &rhs_fn_trait_name,
343        &logical_operator_exprs.op,
344    ) {
345        diagnostics.push(PluginDiagnostic {
346            message: RedundantComparison.diagnostic_message().to_string(),
347            stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
348            severity: Severity::Warning,
349        });
350    } else if is_contradictory_double_comparison(
351        &lhs_fn_trait_name,
352        &rhs_fn_trait_name,
353        &logical_operator_exprs.op,
354    ) {
355        diagnostics.push(PluginDiagnostic {
356            message: ContradictoryComparison.diagnostic_message().to_string(),
357            stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
358            severity: Severity::Error,
359        });
360    }
361}
362
363fn check_impossible_comparison(
364    lhs_comparison: &ExprFunctionCall,
365    rhs_comparison: &ExprFunctionCall,
366    lhs_op: &str,
367    rhs_op: &str,
368    logical_operator_exprs: &ExprLogicalOperator,
369    db: &dyn SemanticGroup,
370    arenas: &Arenas,
371) -> bool {
372    let (lhs_var, lhs_literal) = match (&lhs_comparison.args[0], &lhs_comparison.args[1]) {
373        (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
374            match (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id]) {
375                (Expr::Var(var), Expr::Literal(literal)) => (var, literal),
376                (Expr::Literal(literal), Expr::Var(var)) => (var, literal),
377                _ => {
378                    return false;
379                }
380            }
381        }
382        _ => {
383            return false;
384        }
385    };
386    let (rhs_var, rhs_literal) = match (&rhs_comparison.args[0], &rhs_comparison.args[1]) {
387        (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
388            match (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id]) {
389                (Expr::Var(var), Expr::Literal(literal)) => (var, literal),
390                (Expr::Literal(literal), Expr::Var(var)) => (var, literal),
391                _ => {
392                    return false;
393                }
394            }
395        }
396        _ => {
397            return false;
398        }
399    };
400
401    if lhs_var
402        .stable_ptr
403        .lookup(db.upcast())
404        .as_syntax_node()
405        .get_text(db.upcast())
406        != rhs_var
407            .stable_ptr
408            .lookup(db.upcast())
409            .as_syntax_node()
410            .get_text(db.upcast())
411    {
412        return false;
413    }
414
415    match (lhs_op, &logical_operator_exprs.op, rhs_op) {
416        (GT, LogicalOperator::AndAnd, LT) => lhs_literal.value >= rhs_literal.value,
417        (GT, LogicalOperator::AndAnd, LE) => lhs_literal.value >= rhs_literal.value,
418        (GE, LogicalOperator::AndAnd, LT) => lhs_literal.value >= rhs_literal.value,
419        (GE, LogicalOperator::AndAnd, LE) => lhs_literal.value > rhs_literal.value,
420        (LT, LogicalOperator::AndAnd, GT) => lhs_literal.value <= rhs_literal.value,
421        (LT, LogicalOperator::AndAnd, GE) => lhs_literal.value <= rhs_literal.value,
422        (LE, LogicalOperator::AndAnd, GT) => lhs_literal.value <= rhs_literal.value,
423        (LE, LogicalOperator::AndAnd, GE) => lhs_literal.value < rhs_literal.value,
424        _ => false,
425    }
426}
427
428fn is_simplifiable_double_comparison(
429    lhs_op: &str,
430    rhs_op: &str,
431    middle_op: &LogicalOperator,
432) -> bool {
433    matches!(
434        (lhs_op, middle_op, rhs_op),
435        (LE, LogicalOperator::AndAnd, GE)
436            | (GE, LogicalOperator::AndAnd, LE)
437            | (LT, LogicalOperator::OrOr, EQ)
438            | (EQ, LogicalOperator::OrOr, LT)
439            | (GT, LogicalOperator::OrOr, EQ)
440            | (EQ, LogicalOperator::OrOr, GT)
441    )
442}
443
444fn is_redundant_double_comparison(lhs_op: &str, rhs_op: &str, middle_op: &LogicalOperator) -> bool {
445    matches!(
446        (lhs_op, middle_op, rhs_op),
447        (LE, LogicalOperator::OrOr, GE)
448            | (GE, LogicalOperator::OrOr, LE)
449            | (LT, LogicalOperator::OrOr, GT)
450            | (GT, LogicalOperator::OrOr, LT)
451    )
452}
453
454fn is_contradictory_double_comparison(
455    lhs_op: &str,
456    rhs_op: &str,
457    middle_op: &LogicalOperator,
458) -> bool {
459    matches!(
460        (lhs_op, middle_op, rhs_op),
461        (EQ, LogicalOperator::AndAnd, LT)
462            | (LT, LogicalOperator::AndAnd, EQ)
463            | (EQ, LogicalOperator::AndAnd, GT)
464            | (GT, LogicalOperator::AndAnd, EQ)
465            | (LT, LogicalOperator::AndAnd, GT)
466            | (GT, LogicalOperator::AndAnd, LT)
467            | (GT, LogicalOperator::AndAnd, GE)
468            | (LE, LogicalOperator::AndAnd, GT)
469    )
470}
471
472/// Rewrites a double comparison. Ex: `a > b || a == b` to `a >= b`
473pub fn fix_double_comparison(
474    db: &dyn SyntaxGroup,
475    node: SyntaxNode,
476) -> Option<(SyntaxNode, String)> {
477    let expr = AstExpr::from_syntax_node(db, node.clone());
478
479    if let AstExpr::Binary(binary_op) = expr {
480        let lhs = binary_op.lhs(db);
481        let rhs = binary_op.rhs(db);
482        let middle_op = binary_op.op(db);
483
484        if let (Some(lhs_op), Some(rhs_op)) = (
485            extract_binary_operator_expr(&lhs, db),
486            extract_binary_operator_expr(&rhs, db),
487        ) {
488            let simplified_op = determine_simplified_operator(&lhs_op, &rhs_op, &middle_op);
489
490            if let Some(simplified_op) = simplified_op {
491                if let Some(operator_to_replace) = operator_to_replace(lhs_op) {
492                    let lhs_text = lhs
493                        .as_syntax_node()
494                        .get_text(db)
495                        .replace(operator_to_replace, simplified_op);
496                    return Some((node, lhs_text.to_string()));
497                }
498            }
499        }
500    }
501
502    None
503}
504
505fn operator_to_replace(lhs_op: BinaryOperator) -> Option<&'static str> {
506    match lhs_op {
507        BinaryOperator::EqEq(_) => Some("=="),
508        BinaryOperator::GT(_) => Some(">"),
509        BinaryOperator::LT(_) => Some("<"),
510        BinaryOperator::GE(_) => Some(">="),
511        BinaryOperator::LE(_) => Some("<="),
512        _ => None,
513    }
514}
515
516fn determine_simplified_operator(
517    lhs_op: &BinaryOperator,
518    rhs_op: &BinaryOperator,
519    middle_op: &BinaryOperator,
520) -> Option<&'static str> {
521    match (lhs_op, middle_op, rhs_op) {
522        (BinaryOperator::LE(_), BinaryOperator::AndAnd(_), BinaryOperator::GE(_))
523        | (BinaryOperator::GE(_), BinaryOperator::AndAnd(_), BinaryOperator::LE(_)) => Some("=="),
524
525        (BinaryOperator::LT(_), BinaryOperator::OrOr(_), BinaryOperator::EqEq(_))
526        | (BinaryOperator::EqEq(_), BinaryOperator::OrOr(_), BinaryOperator::LT(_)) => Some("<="),
527
528        (BinaryOperator::GT(_), BinaryOperator::OrOr(_), BinaryOperator::EqEq(_))
529        | (BinaryOperator::EqEq(_), BinaryOperator::OrOr(_), BinaryOperator::GT(_)) => Some(">="),
530
531        (BinaryOperator::LT(_), BinaryOperator::OrOr(_), BinaryOperator::GT(_))
532        | (BinaryOperator::GT(_), BinaryOperator::OrOr(_), BinaryOperator::LT(_)) => Some("!="),
533
534        _ => None,
535    }
536}
537
538fn extract_binary_operator_expr(expr: &AstExpr, db: &dyn SyntaxGroup) -> Option<BinaryOperator> {
539    if let AstExpr::Binary(binary_op) = expr {
540        Some(binary_op.op(db))
541    } else {
542        None
543    }
544}