Skip to main content

cairo_lint/lints/
bool_comparison.rs

1use cairo_lang_defs::ids::{ModuleItemId, TopLevelLanguageElementId};
2use cairo_lang_defs::plugin::PluginDiagnostic;
3use cairo_lang_diagnostics::Severity;
4use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
5use cairo_lang_semantic::items::imp::ImplHead;
6use cairo_lang_semantic::{Arenas, Expr, ExprFunctionCall, ExprFunctionCallArg};
7use cairo_lang_syntax::node::SyntaxNode;
8use cairo_lang_syntax::node::kind::SyntaxKind;
9use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode, ast::ExprBinary};
10use if_chain::if_chain;
11
12use crate::LinterGroup;
13use crate::context::{CairoLintKind, Lint};
14use crate::fixer::InternalFix;
15use crate::queries::{get_all_function_bodies, get_all_function_calls};
16use salsa::Database;
17
18pub struct BoolComparison;
19
20/// ## What it does
21///
22/// Checks for direct variable with boolean literal like `a == true` or `a == false`.
23///
24/// ## Example
25///
26/// ```cairo
27/// fn main() {
28///     let x = true;
29///     if x == true {
30///         println!("x is true");
31///     }
32/// }
33/// ```
34///
35/// Can be rewritten as:
36///
37/// ```cairo
38/// fn main() {
39///    let x = true;
40///    if x {
41///        println!("x is true");
42///    }
43/// }
44/// ```
45impl Lint for BoolComparison {
46    fn allowed_name(&self) -> &'static str {
47        "bool_comparison"
48    }
49
50    fn diagnostic_message(&self) -> &'static str {
51        "Unnecessary comparison with a boolean value. Use the variable directly."
52    }
53
54    fn kind(&self) -> CairoLintKind {
55        CairoLintKind::BoolComparison
56    }
57
58    fn has_fixer(&self) -> bool {
59        true
60    }
61
62    fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
63        fix_bool_comparison(db, node)
64    }
65
66    fn fix_message(&self) -> Option<&'static str> {
67        Some("Simplify to direct boolean check")
68    }
69}
70
71/// Checks for ` a == true`. Bool comparisons are useless and can be rewritten more clearly.
72#[tracing::instrument(skip_all, level = "trace")]
73pub fn check_bool_comparison<'db>(
74    db: &'db dyn Database,
75    item: &ModuleItemId<'db>,
76    diagnostics: &mut Vec<PluginDiagnostic<'db>>,
77) {
78    let function_bodies = get_all_function_bodies(db, item);
79    for function_body in function_bodies.iter() {
80        let function_call_exprs = get_all_function_calls(function_body);
81        let arenas = &function_body.arenas;
82        for function_call_expr in function_call_exprs {
83            check_single_bool_comparison(db, &function_call_expr, arenas, diagnostics);
84        }
85    }
86}
87
88fn check_single_bool_comparison<'db>(
89    db: &'db dyn Database,
90    function_call_expr: &ExprFunctionCall<'db>,
91    arenas: &Arenas<'db>,
92    diagnostics: &mut Vec<PluginDiagnostic<'db>>,
93) {
94    // Check if the function call is the bool partial eq function (==).
95    match function_call_expr
96        .function
97        .get_concrete(db)
98        .generic_function
99    {
100        GenericFunctionId::Impl(ImplGenericFunctionId { impl_id, .. }) => {
101            if let Some(ImplHead::Concrete(impl_def_id)) = impl_id.head(db) {
102                if impl_def_id != db.corelib_context().get_bool_partial_eq_impl_id() {
103                    return;
104                }
105            } else {
106                return;
107            }
108        }
109        _ => return,
110    }
111
112    // Extract the args of the function call. This function expects snapshots hence we need to
113    // destructure that. Also the boolean type in cairo is an enum hence the enum ctor.
114    for arg in &function_call_expr.args {
115        if_chain! {
116            if let ExprFunctionCallArg::Value(expr) = arg;
117            if let Expr::Snapshot(snap) = &arenas.exprs[*expr];
118            if let Expr::EnumVariantCtor(enum_var) = &arenas.exprs[snap.inner];
119            if enum_var.variant.concrete_enum_id.enum_id(db).full_path(db) == "core::bool";
120            then {
121                diagnostics.push(PluginDiagnostic {
122                    stable_ptr: function_call_expr.stable_ptr.untyped(),
123                    message: BoolComparison.diagnostic_message().to_string(),
124                    severity: Severity::Warning,
125                    inner_span: None,
126                    error_code: None,
127                });
128            }
129        }
130    }
131}
132
133/// Rewrites a bool comparison to a simple bool. Ex: `some_bool == false` would be rewritten to
134/// `!some_bool`
135#[tracing::instrument(skip_all, level = "trace")]
136pub fn fix_bool_comparison<'db>(
137    db: &'db dyn Database,
138    node: SyntaxNode<'db>,
139) -> Option<InternalFix<'db>> {
140    let node = ExprBinary::from_syntax_node(db, node);
141    let lhs = node.lhs(db).as_syntax_node().get_text(db);
142    let rhs = node.rhs(db).as_syntax_node().get_text(db);
143
144    let result = generate_fixed_text_for_comparison(db, lhs, rhs, node.clone());
145    Some(InternalFix {
146        node: node.as_syntax_node(),
147        suggestion: result,
148        description: BoolComparison.fix_message().unwrap().to_string(),
149        import_addition_paths: None,
150    })
151}
152
153/// Generates the fixed boolean for a boolean comparison. It will transform `x == false` to `!x`
154fn generate_fixed_text_for_comparison<'db>(
155    db: &'db dyn Database,
156    lhs: &str,
157    rhs: &str,
158    node: ExprBinary<'db>,
159) -> String {
160    let op_kind = node.op(db).as_syntax_node().kind(db);
161    let lhs = lhs.trim();
162    let rhs = rhs.trim();
163
164    match (lhs, rhs, op_kind) {
165        // lhs
166        ("false", _, SyntaxKind::TerminalEqEq | SyntaxKind::TokenEqEq) => format!("!{rhs} "),
167        ("true", _, SyntaxKind::TerminalEqEq | SyntaxKind::TokenEqEq) => format!("{rhs} "),
168        ("false", _, SyntaxKind::TerminalNeq) => format!("{rhs} "),
169        ("true", _, SyntaxKind::TerminalNeq) => format!("!{rhs} "),
170
171        // rhs
172        (_, "false", SyntaxKind::TerminalEqEq | SyntaxKind::TokenEqEq) => format!("!{lhs} "),
173        (_, "true", SyntaxKind::TerminalEqEq | SyntaxKind::TokenEqEq) => format!("{lhs} "),
174        (_, "false", SyntaxKind::TerminalNeq) => format!("{lhs} "),
175        (_, "true", SyntaxKind::TerminalNeq) => format!("!{lhs} "),
176
177        _ => node.as_syntax_node().get_text(db).to_string(),
178    }
179}