cairo_lint/lints/
bool_comparison.rs1use cairo_lang_defs::ids::{ModuleItemId, TopLevelLanguageElementId};
2use cairo_lang_defs::plugin::PluginDiagnostic;
3use cairo_lang_diagnostics::Severity;
4use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
5use cairo_lang_semantic::items::imp::ImplHead;
6use cairo_lang_semantic::{Arenas, Expr, ExprFunctionCall, ExprFunctionCallArg};
7use cairo_lang_syntax::node::SyntaxNode;
8use cairo_lang_syntax::node::kind::SyntaxKind;
9use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode, ast::ExprBinary};
10use if_chain::if_chain;
11
12use crate::LinterGroup;
13use crate::context::{CairoLintKind, Lint};
14use crate::fixer::InternalFix;
15use crate::queries::{get_all_function_bodies, get_all_function_calls};
16use salsa::Database;
17
18pub struct BoolComparison;
19
20impl Lint for BoolComparison {
46 fn allowed_name(&self) -> &'static str {
47 "bool_comparison"
48 }
49
50 fn diagnostic_message(&self) -> &'static str {
51 "Unnecessary comparison with a boolean value. Use the variable directly."
52 }
53
54 fn kind(&self) -> CairoLintKind {
55 CairoLintKind::BoolComparison
56 }
57
58 fn has_fixer(&self) -> bool {
59 true
60 }
61
62 fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
63 fix_bool_comparison(db, node)
64 }
65
66 fn fix_message(&self) -> Option<&'static str> {
67 Some("Simplify to direct boolean check")
68 }
69}
70
71#[tracing::instrument(skip_all, level = "trace")]
73pub fn check_bool_comparison<'db>(
74 db: &'db dyn Database,
75 item: &ModuleItemId<'db>,
76 diagnostics: &mut Vec<PluginDiagnostic<'db>>,
77) {
78 let function_bodies = get_all_function_bodies(db, item);
79 for function_body in function_bodies.iter() {
80 let function_call_exprs = get_all_function_calls(function_body);
81 let arenas = &function_body.arenas;
82 for function_call_expr in function_call_exprs {
83 check_single_bool_comparison(db, &function_call_expr, arenas, diagnostics);
84 }
85 }
86}
87
88fn check_single_bool_comparison<'db>(
89 db: &'db dyn Database,
90 function_call_expr: &ExprFunctionCall<'db>,
91 arenas: &Arenas<'db>,
92 diagnostics: &mut Vec<PluginDiagnostic<'db>>,
93) {
94 match function_call_expr
96 .function
97 .get_concrete(db)
98 .generic_function
99 {
100 GenericFunctionId::Impl(ImplGenericFunctionId { impl_id, .. }) => {
101 if let Some(ImplHead::Concrete(impl_def_id)) = impl_id.head(db) {
102 if impl_def_id != db.corelib_context().get_bool_partial_eq_impl_id() {
103 return;
104 }
105 } else {
106 return;
107 }
108 }
109 _ => return,
110 }
111
112 for arg in &function_call_expr.args {
115 if_chain! {
116 if let ExprFunctionCallArg::Value(expr) = arg;
117 if let Expr::Snapshot(snap) = &arenas.exprs[*expr];
118 if let Expr::EnumVariantCtor(enum_var) = &arenas.exprs[snap.inner];
119 if enum_var.variant.concrete_enum_id.enum_id(db).full_path(db) == "core::bool";
120 then {
121 diagnostics.push(PluginDiagnostic {
122 stable_ptr: function_call_expr.stable_ptr.untyped(),
123 message: BoolComparison.diagnostic_message().to_string(),
124 severity: Severity::Warning,
125 inner_span: None,
126 error_code: None,
127 });
128 }
129 }
130 }
131}
132
133#[tracing::instrument(skip_all, level = "trace")]
136pub fn fix_bool_comparison<'db>(
137 db: &'db dyn Database,
138 node: SyntaxNode<'db>,
139) -> Option<InternalFix<'db>> {
140 let node = ExprBinary::from_syntax_node(db, node);
141 let lhs = node.lhs(db).as_syntax_node().get_text(db);
142 let rhs = node.rhs(db).as_syntax_node().get_text(db);
143
144 let result = generate_fixed_text_for_comparison(db, lhs, rhs, node.clone());
145 Some(InternalFix {
146 node: node.as_syntax_node(),
147 suggestion: result,
148 description: BoolComparison.fix_message().unwrap().to_string(),
149 import_addition_paths: None,
150 })
151}
152
153fn generate_fixed_text_for_comparison<'db>(
155 db: &'db dyn Database,
156 lhs: &str,
157 rhs: &str,
158 node: ExprBinary<'db>,
159) -> String {
160 let op_kind = node.op(db).as_syntax_node().kind(db);
161 let lhs = lhs.trim();
162 let rhs = rhs.trim();
163
164 match (lhs, rhs, op_kind) {
165 ("false", _, SyntaxKind::TerminalEqEq | SyntaxKind::TokenEqEq) => format!("!{rhs} "),
167 ("true", _, SyntaxKind::TerminalEqEq | SyntaxKind::TokenEqEq) => format!("{rhs} "),
168 ("false", _, SyntaxKind::TerminalNeq) => format!("{rhs} "),
169 ("true", _, SyntaxKind::TerminalNeq) => format!("!{rhs} "),
170
171 (_, "false", SyntaxKind::TerminalEqEq | SyntaxKind::TokenEqEq) => format!("!{lhs} "),
173 (_, "true", SyntaxKind::TerminalEqEq | SyntaxKind::TokenEqEq) => format!("{lhs} "),
174 (_, "false", SyntaxKind::TerminalNeq) => format!("{lhs} "),
175 (_, "true", SyntaxKind::TerminalNeq) => format!("!{lhs} "),
176
177 _ => node.as_syntax_node().get_text(db).to_string(),
178 }
179}