Skip to main content

cairo_lint/lints/
double_comparison.rs

1use std::collections::HashSet;
2
3use cairo_lang_defs::ids::ModuleItemId;
4use cairo_lang_defs::plugin::PluginDiagnostic;
5use cairo_lang_diagnostics::Severity;
6use cairo_lang_semantic::{
7    Arenas, Expr, ExprFunctionCall, ExprFunctionCallArg, ExprLogicalOperator, LogicalOperator,
8};
9use cairo_lang_syntax::node::ast::{BinaryOperator, Expr as AstExpr};
10
11use cairo_lang_syntax::node::{SyntaxNode, TypedStablePtr, TypedSyntaxNode};
12
13use super::function_trait_name_from_fn_id;
14use crate::context::{CairoLintKind, Lint};
15
16use crate::fixer::InternalFix;
17use crate::lints::{EQ, GE, GT, LE, LT};
18use crate::queries::{get_all_function_bodies, get_all_logical_operator_expressions};
19use salsa::Database;
20
21pub struct ImpossibleComparison;
22
23/// ## What it does
24///
25/// Checks for impossible comparisons. Those ones always return false.
26///
27/// ## Example
28///
29/// Here is an example of impossible comparison:
30///
31/// ```cairo
32/// fn main() {
33///     let x: u32 = 1;
34///     if x > 200 && x < 100 {
35///         //impossible to reach
36///     }
37/// }
38/// ```
39impl Lint for ImpossibleComparison {
40    fn allowed_name(&self) -> &'static str {
41        "impossible_comparison"
42    }
43
44    fn diagnostic_message(&self) -> &'static str {
45        "Impossible condition, always false"
46    }
47
48    fn kind(&self) -> CairoLintKind {
49        CairoLintKind::ImpossibleComparison
50    }
51}
52
53pub struct SimplifiableComparison;
54
55/// ## What it does
56///
57/// Checks for double comparisons that can be simplified.
58/// Those are comparisons that can be simplified to a single comparison.
59///
60/// ## Example
61///
62/// ```cairo
63/// fn main() -> bool {
64///     let x = 5_u32;
65///     let y = 10_u32;
66///     if x == y || x > y {
67///         true
68///     } else {
69///         false
70///     }
71/// }
72/// ```
73///
74/// The above code can be simplified to:
75///
76/// ```cairo
77/// fn main() -> bool {
78///     let x = 5_u32;
79///     let y = 10_u32;
80///     if x >= y {
81///         true
82///     } else {
83///         false
84///     }
85/// }
86/// ```
87impl Lint for SimplifiableComparison {
88    fn allowed_name(&self) -> &'static str {
89        "simplifiable_comparison"
90    }
91
92    fn diagnostic_message(&self) -> &'static str {
93        "This double comparison can be simplified."
94    }
95
96    fn kind(&self) -> CairoLintKind {
97        CairoLintKind::DoubleComparison
98    }
99
100    fn has_fixer(&self) -> bool {
101        true
102    }
103
104    fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
105        fix_simplifiable_comparison(db, node)
106    }
107
108    fn fix_message(&self) -> Option<&'static str> {
109        Some("Simplify to single comparison")
110    }
111}
112
113pub struct RedundantComparison;
114
115/// ## What it does
116///
117/// Checks for double comparisons that are redundant. Those are comparisons that can be simplified to a single comparison.
118///
119/// ## Example
120///
121/// ```cairo
122/// fn main() -> bool {
123///     let x = 5_u32;
124///     let y = 10_u32;
125///     if x >= y || x <= y {
126///         true
127///     } else {
128///         false
129///     }
130/// }
131/// ```
132///
133/// Could be simplified to just:
134///
135/// ```cairo
136/// fn main() -> bool {
137///     let x = 5_u32;
138///     let y = 10_u32;
139///     true
140/// }
141/// ```
142impl Lint for RedundantComparison {
143    fn allowed_name(&self) -> &'static str {
144        "redundant_comparison"
145    }
146
147    fn diagnostic_message(&self) -> &'static str {
148        "Redundant double comparison found. Consider simplifying to a single comparison."
149    }
150
151    fn kind(&self) -> CairoLintKind {
152        CairoLintKind::DoubleComparison
153    }
154
155    fn has_fixer(&self) -> bool {
156        true
157    }
158
159    fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
160        fix_redundant_comparison(db, node)
161    }
162
163    fn fix_message(&self) -> Option<&'static str> {
164        Some("Remove redundant comparison")
165    }
166}
167
168pub struct ContradictoryComparison;
169
170/// ## What it does
171///
172/// Checks for double comparisons that are contradictory. Those are comparisons that are always false.
173///
174/// ## Example
175///
176/// ```cairo
177/// fn main() -> bool {
178///     let x = 5_u32;
179///     let y = 10_u32;
180///     if x < y && x > y {
181///         true
182///     } else {
183///         false
184///     }
185/// }
186/// ```
187///
188/// Could be simplified to just:
189///
190/// ```cairo
191/// fn main() -> bool {
192///     let x = 5_u32;
193///     let y = 10_u32;
194///     false
195/// }
196/// ```
197impl Lint for ContradictoryComparison {
198    fn allowed_name(&self) -> &'static str {
199        "contradictory_comparison"
200    }
201
202    fn diagnostic_message(&self) -> &'static str {
203        "This double comparison is contradictory and always false."
204    }
205
206    fn kind(&self) -> CairoLintKind {
207        CairoLintKind::DoubleComparison
208    }
209
210    fn has_fixer(&self) -> bool {
211        true
212    }
213
214    fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
215        fix_contradictory_comparison(db, node)
216    }
217
218    fn fix_message(&self) -> Option<&'static str> {
219        Some("Replace with 'false' since comparison is impossible")
220    }
221}
222
223#[tracing::instrument(skip_all, level = "trace")]
224pub fn check_double_comparison<'db>(
225    db: &'db dyn Database,
226    item: &ModuleItemId<'db>,
227    diagnostics: &mut Vec<PluginDiagnostic<'db>>,
228) {
229    let function_bodies = get_all_function_bodies(db, item);
230    for function_body in function_bodies {
231        let logical_operator_exprs = get_all_logical_operator_expressions(function_body);
232        let arenas = &function_body.arenas;
233        for logical_operator_expr in logical_operator_exprs.iter() {
234            check_single_double_comparison(db, logical_operator_expr, arenas, diagnostics);
235        }
236    }
237}
238
239fn check_single_double_comparison<'db>(
240    db: &'db dyn Database,
241    logical_operator_exprs: &ExprLogicalOperator<'db>,
242    arenas: &Arenas<'db>,
243    diagnostics: &mut Vec<PluginDiagnostic<'db>>,
244) {
245    let Expr::FunctionCall(lhs_comparison) = &arenas.exprs[logical_operator_exprs.lhs] else {
246        return;
247    };
248    // If it's not 2 args it cannot be a regular comparison
249    if lhs_comparison.args.len() != 2 {
250        return;
251    }
252
253    let Expr::FunctionCall(rhs_comparison) = &arenas.exprs[logical_operator_exprs.rhs] else {
254        return;
255    };
256    // If it's not 2 args it cannot be a regular comparison
257    if rhs_comparison.args.len() != 2 {
258        return;
259    }
260    // Get the full name of the function used (trait name)
261    let (lhs_fn_trait_name, rhs_fn_trait_name) = (
262        function_trait_name_from_fn_id(db, &lhs_comparison.function),
263        function_trait_name_from_fn_id(db, &rhs_comparison.function),
264    );
265
266    // Check the impossible comparison
267    if check_impossible_comparison(
268        lhs_comparison,
269        rhs_comparison,
270        &lhs_fn_trait_name,
271        &rhs_fn_trait_name,
272        logical_operator_exprs,
273        db,
274        arenas,
275    ) {
276        diagnostics.push(PluginDiagnostic {
277            message: ImpossibleComparison.diagnostic_message().to_string(),
278            stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
279            severity: Severity::Warning,
280            inner_span: None,
281            error_code: None,
282        })
283    }
284
285    // The comparison functions don't work with refs so should only be value
286    let (llhs, rlhs) = match (&lhs_comparison.args[0], &lhs_comparison.args[1]) {
287        (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
288            (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id])
289        }
290        _ => {
291            return;
292        }
293    };
294    let (lrhs, rrhs) = match (&rhs_comparison.args[0], &rhs_comparison.args[1]) {
295        (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
296            (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id])
297        }
298        _ => return,
299    };
300    // Get all the operands
301    let llhs_var = llhs.stable_ptr().lookup(db).as_syntax_node().get_text(db);
302    let rlhs_var = rlhs.stable_ptr().lookup(db).as_syntax_node().get_text(db);
303    let lrhs_var = lrhs.stable_ptr().lookup(db).as_syntax_node().get_text(db);
304    let rrhs_var = rrhs.stable_ptr().lookup(db).as_syntax_node().get_text(db);
305    // Put them in a hashset to check equality without order
306    let lhs: HashSet<_> = HashSet::from_iter([llhs_var, rlhs_var]);
307    let rhs: HashSet<_> = HashSet::from_iter([lrhs_var, rrhs_var]);
308    if lhs != rhs {
309        return;
310    }
311
312    // TODO: support other expressions like tuples and literals
313    let should_return = match (llhs, rlhs) {
314        (Expr::Snapshot(llhs), Expr::Snapshot(rlhs)) => {
315            matches!(arenas.exprs[llhs.inner], Expr::FunctionCall(_))
316                || matches!(arenas.exprs[rlhs.inner], Expr::FunctionCall(_))
317        }
318        (Expr::Var(_), Expr::Var(_)) => false,
319        (Expr::Snapshot(llhs), Expr::Var(_)) => {
320            matches!(arenas.exprs[llhs.inner], Expr::FunctionCall(_))
321        }
322        (Expr::Var(_), Expr::Snapshot(rlhs)) => {
323            matches!(arenas.exprs[rlhs.inner], Expr::FunctionCall(_))
324        }
325        _ => return,
326    };
327    if should_return {
328        return;
329    }
330
331    if is_simplifiable_double_comparison(
332        &lhs_fn_trait_name,
333        &rhs_fn_trait_name,
334        &logical_operator_exprs.op,
335    ) {
336        diagnostics.push(PluginDiagnostic {
337            message: SimplifiableComparison.diagnostic_message().to_string(),
338            stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
339            severity: Severity::Warning,
340            inner_span: None,
341            error_code: None,
342        });
343    } else if is_redundant_double_comparison(
344        &lhs_fn_trait_name,
345        &rhs_fn_trait_name,
346        &logical_operator_exprs.op,
347    ) {
348        diagnostics.push(PluginDiagnostic {
349            message: RedundantComparison.diagnostic_message().to_string(),
350            stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
351            severity: Severity::Warning,
352            inner_span: None,
353            error_code: None,
354        });
355    } else if is_contradictory_double_comparison(
356        &lhs_fn_trait_name,
357        &rhs_fn_trait_name,
358        &logical_operator_exprs.op,
359    ) {
360        diagnostics.push(PluginDiagnostic {
361            message: ContradictoryComparison.diagnostic_message().to_string(),
362            stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
363            severity: Severity::Warning,
364            inner_span: None,
365            error_code: None,
366        });
367    }
368}
369
370fn check_impossible_comparison<'db>(
371    lhs_comparison: &ExprFunctionCall<'db>,
372    rhs_comparison: &ExprFunctionCall<'db>,
373    lhs_op: &str,
374    rhs_op: &str,
375    logical_operator_exprs: &ExprLogicalOperator<'db>,
376    db: &'db dyn Database,
377    arenas: &Arenas<'db>,
378) -> bool {
379    let (lhs_var, lhs_literal) = match (&lhs_comparison.args[0], &lhs_comparison.args[1]) {
380        (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
381            match (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id]) {
382                (Expr::Var(var), Expr::Literal(literal)) => (var, literal),
383                (Expr::Literal(literal), Expr::Var(var)) => (var, literal),
384                _ => {
385                    return false;
386                }
387            }
388        }
389        _ => {
390            return false;
391        }
392    };
393    let (rhs_var, rhs_literal) = match (&rhs_comparison.args[0], &rhs_comparison.args[1]) {
394        (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
395            match (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id]) {
396                (Expr::Var(var), Expr::Literal(literal)) => (var, literal),
397                (Expr::Literal(literal), Expr::Var(var)) => (var, literal),
398                _ => {
399                    return false;
400                }
401            }
402        }
403        _ => {
404            return false;
405        }
406    };
407
408    if lhs_var.stable_ptr.lookup(db).as_syntax_node().get_text(db)
409        != rhs_var.stable_ptr.lookup(db).as_syntax_node().get_text(db)
410    {
411        return false;
412    }
413
414    match (lhs_op, &logical_operator_exprs.op, rhs_op) {
415        (GT, LogicalOperator::AndAnd, LT) => lhs_literal.value >= rhs_literal.value,
416        (GT, LogicalOperator::AndAnd, LE) => lhs_literal.value >= rhs_literal.value,
417        (GE, LogicalOperator::AndAnd, LT) => lhs_literal.value >= rhs_literal.value,
418        (GE, LogicalOperator::AndAnd, LE) => lhs_literal.value > rhs_literal.value,
419        (LT, LogicalOperator::AndAnd, GT) => lhs_literal.value <= rhs_literal.value,
420        (LT, LogicalOperator::AndAnd, GE) => lhs_literal.value <= rhs_literal.value,
421        (LE, LogicalOperator::AndAnd, GT) => lhs_literal.value <= rhs_literal.value,
422        (LE, LogicalOperator::AndAnd, GE) => lhs_literal.value < rhs_literal.value,
423        _ => false,
424    }
425}
426
427fn is_simplifiable_double_comparison(
428    lhs_op: &str,
429    rhs_op: &str,
430    middle_op: &LogicalOperator,
431) -> bool {
432    matches!(
433        (lhs_op, middle_op, rhs_op),
434        (LE, LogicalOperator::AndAnd, GE)
435            | (GE, LogicalOperator::AndAnd, LE)
436            | (LT, LogicalOperator::OrOr, EQ)
437            | (EQ, LogicalOperator::OrOr, LT)
438            | (GT, LogicalOperator::OrOr, EQ)
439            | (EQ, LogicalOperator::OrOr, GT)
440    )
441}
442
443fn is_redundant_double_comparison(lhs_op: &str, rhs_op: &str, middle_op: &LogicalOperator) -> bool {
444    matches!(
445        (lhs_op, middle_op, rhs_op),
446        (LE, LogicalOperator::OrOr, GE)
447            | (GE, LogicalOperator::OrOr, LE)
448            | (LT, LogicalOperator::OrOr, GT)
449            | (GT, LogicalOperator::OrOr, LT)
450    )
451}
452
453fn is_contradictory_double_comparison(
454    lhs_op: &str,
455    rhs_op: &str,
456    middle_op: &LogicalOperator,
457) -> bool {
458    matches!(
459        (lhs_op, middle_op, rhs_op),
460        (EQ, LogicalOperator::AndAnd, LT)
461            | (LT, LogicalOperator::AndAnd, EQ)
462            | (EQ, LogicalOperator::AndAnd, GT)
463            | (GT, LogicalOperator::AndAnd, EQ)
464            | (LT, LogicalOperator::AndAnd, GT)
465            | (GT, LogicalOperator::AndAnd, LT)
466            | (GT, LogicalOperator::AndAnd, GE)
467            | (LE, LogicalOperator::AndAnd, GT)
468    )
469}
470
471#[tracing::instrument(skip_all, level = "trace")]
472pub fn fix_simplifiable_comparison<'db>(
473    db: &'db dyn Database,
474    node: SyntaxNode<'db>,
475) -> Option<InternalFix<'db>> {
476    let lhs_text = fix_double_comparison(db, node)?;
477
478    Some(InternalFix {
479        node,
480        suggestion: lhs_text,
481        description: SimplifiableComparison.fix_message().unwrap().to_string(),
482        import_addition_paths: None,
483    })
484}
485
486#[tracing::instrument(skip_all, level = "trace")]
487pub fn fix_redundant_comparison<'db>(
488    db: &'db dyn Database,
489    node: SyntaxNode<'db>,
490) -> Option<InternalFix<'db>> {
491    let lhs_text = fix_double_comparison(db, node)?;
492
493    Some(InternalFix {
494        node,
495        suggestion: lhs_text,
496        description: RedundantComparison.fix_message().unwrap().to_string(),
497        import_addition_paths: None,
498    })
499}
500
501#[tracing::instrument(skip_all, level = "trace")]
502pub fn fix_contradictory_comparison<'db>(
503    db: &'db dyn Database,
504    node: SyntaxNode<'db>,
505) -> Option<InternalFix<'db>> {
506    let lhs_text = fix_double_comparison(db, node)?;
507
508    Some(InternalFix {
509        node,
510        suggestion: lhs_text,
511        description: ContradictoryComparison.fix_message().unwrap().to_string(),
512        import_addition_paths: None,
513    })
514}
515
516/// Rewrites a double comparison. Ex: `a > b || a == b` to `a >= b`
517pub fn fix_double_comparison<'db>(db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<String> {
518    let expr = AstExpr::from_syntax_node(db, node);
519
520    if let AstExpr::Binary(binary_op) = expr {
521        let lhs = binary_op.lhs(db);
522        let rhs = binary_op.rhs(db);
523        let middle_op = binary_op.op(db);
524
525        if let (Some(lhs_op), Some(rhs_op)) = (
526            extract_binary_operator_expr(&lhs, db),
527            extract_binary_operator_expr(&rhs, db),
528        ) && let Some(simplified_op) =
529            determine_simplified_operator(&lhs_op, &rhs_op, &middle_op)
530            && let Some(operator_to_replace) = operator_to_replace(lhs_op)
531        {
532            let lhs_text = lhs
533                .as_syntax_node()
534                .get_text(db)
535                .replace(operator_to_replace, simplified_op);
536            return Some(lhs_text.to_string());
537        }
538    }
539
540    None
541}
542
543fn operator_to_replace(lhs_op: BinaryOperator) -> Option<&'static str> {
544    match lhs_op {
545        BinaryOperator::EqEq(_) => Some("=="),
546        BinaryOperator::GT(_) => Some(">"),
547        BinaryOperator::LT(_) => Some("<"),
548        BinaryOperator::GE(_) => Some(">="),
549        BinaryOperator::LE(_) => Some("<="),
550        _ => None,
551    }
552}
553
554fn determine_simplified_operator(
555    lhs_op: &BinaryOperator,
556    rhs_op: &BinaryOperator,
557    middle_op: &BinaryOperator,
558) -> Option<&'static str> {
559    match (lhs_op, middle_op, rhs_op) {
560        (BinaryOperator::LE(_), BinaryOperator::AndAnd(_), BinaryOperator::GE(_))
561        | (BinaryOperator::GE(_), BinaryOperator::AndAnd(_), BinaryOperator::LE(_)) => Some("=="),
562
563        (BinaryOperator::LT(_), BinaryOperator::OrOr(_), BinaryOperator::EqEq(_))
564        | (BinaryOperator::EqEq(_), BinaryOperator::OrOr(_), BinaryOperator::LT(_)) => Some("<="),
565
566        (BinaryOperator::GT(_), BinaryOperator::OrOr(_), BinaryOperator::EqEq(_))
567        | (BinaryOperator::EqEq(_), BinaryOperator::OrOr(_), BinaryOperator::GT(_)) => Some(">="),
568
569        (BinaryOperator::LT(_), BinaryOperator::OrOr(_), BinaryOperator::GT(_))
570        | (BinaryOperator::GT(_), BinaryOperator::OrOr(_), BinaryOperator::LT(_)) => Some("!="),
571
572        _ => None,
573    }
574}
575
576fn extract_binary_operator_expr<'db>(
577    expr: &AstExpr<'db>,
578    db: &'db dyn Database,
579) -> Option<BinaryOperator<'db>> {
580    if let AstExpr::Binary(binary_op) = expr {
581        Some(binary_op.op(db))
582    } else {
583        None
584    }
585}