1use cairo_lang_defs::ids::ModuleItemId;
2use cairo_lang_defs::plugin::PluginDiagnostic;
3use cairo_lang_diagnostics::Severity;
4use cairo_lang_semantic::items::enm::EnumSemantic;
5use cairo_lang_semantic::{Arenas, ExprMatch, Pattern};
6use cairo_lang_syntax::node::ast::{Expr as AstExpr, ExprBlock, ExprListParenthesized, Statement};
7
8use cairo_lang_syntax::node::{
9 SyntaxNode, TypedStablePtr, TypedSyntaxNode,
10 ast::{ExprMatch as AstExprMatch, Pattern as AstPattern},
11};
12use if_chain::if_chain;
13
14use crate::context::{CairoLintKind, Lint};
15use crate::fixer::InternalFix;
16use crate::helper::indent_snippet;
17use crate::queries::{get_all_function_bodies, get_all_match_expressions};
18use salsa::Database;
19
20pub struct DestructMatch;
21
22impl Lint for DestructMatch {
44 fn allowed_name(&self) -> &'static str {
45 "destruct_match"
46 }
47
48 fn diagnostic_message(&self) -> &'static str {
49 "you seem to be trying to use `match` for destructuring a single pattern. Consider using `if let`"
50 }
51
52 fn kind(&self) -> CairoLintKind {
53 CairoLintKind::DestructMatch
54 }
55
56 fn has_fixer(&self) -> bool {
57 true
58 }
59
60 fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
61 fix_destruct_match(db, node)
62 }
63
64 fn fix_message(&self) -> Option<&'static str> {
65 Some("Convert to 'if let' pattern matching")
66 }
67}
68
69pub struct EqualityMatch;
70
71impl Lint for EqualityMatch {
92 fn allowed_name(&self) -> &'static str {
93 "equality_match"
94 }
95
96 fn diagnostic_message(&self) -> &'static str {
97 "you seem to be trying to use `match` for an equality check. Consider using `if`"
98 }
99
100 fn kind(&self) -> CairoLintKind {
101 CairoLintKind::MatchForEquality
102 }
103}
104
105#[tracing::instrument(skip_all, level = "trace")]
106pub fn check_single_matches<'db>(
107 db: &'db dyn Database,
108 item: &ModuleItemId<'db>,
109 diagnostics: &mut Vec<PluginDiagnostic<'db>>,
110) {
111 let function_bodies = get_all_function_bodies(db, item);
112 for function_body in function_bodies {
113 let match_exprs = get_all_match_expressions(function_body);
114 let arenas = &function_body.arenas;
115 for match_expr in match_exprs.iter() {
116 check_single_match(db, match_expr, arenas, diagnostics);
117 }
118 }
119}
120
121fn check_single_match<'db>(
122 db: &'db dyn Database,
123 match_expr: &ExprMatch<'db>,
124 arenas: &Arenas<'db>,
125 diagnostics: &mut Vec<PluginDiagnostic<'db>>,
126) {
127 let arms = &match_expr.arms;
128 let mut is_single_armed = false;
129 let mut is_complete = false;
130 let mut is_destructuring = false;
131
132 if arms.len() != 2 || !match_expr.ty.is_unit(db) {
135 return;
136 }
137
138 let first_arm = &arms[0];
139 let second_arm = &arms[1];
140 let mut enum_len = None;
141 if let Some(pattern) = first_arm.patterns.first() {
142 match &arenas.patterns[*pattern] {
143 Pattern::Otherwise(_) => return,
145 Pattern::EnumVariant(enum_pat) => {
147 enum_len = Some(
148 db.enum_variants(enum_pat.variant.concrete_enum_id.enum_id(db))
149 .unwrap()
150 .len(),
151 );
152 is_destructuring = enum_pat.inner_pattern.is_some();
154 }
155 Pattern::Struct(_) => {
156 is_destructuring = true;
158 }
159 _ => (),
160 };
161 };
162 if let Some(pattern) = second_arm.patterns.first() {
163 match &arenas.patterns[*pattern] {
164 Pattern::Otherwise(_) => {
166 is_complete = true;
167 }
168 Pattern::EnumVariant(_) => {
169 if enum_len == Some(2) {
171 is_complete = true;
172 }
173 }
174 _ => (),
175 };
176
177 is_single_armed = is_expr_unit(
179 arenas.exprs[second_arm.expression].stable_ptr().lookup(db),
180 db,
181 ) && is_complete;
182 };
183
184 match (is_single_armed, is_destructuring) {
185 (true, false) => diagnostics.push(PluginDiagnostic {
186 stable_ptr: match_expr.stable_ptr.into(),
187 message: EqualityMatch.diagnostic_message().to_string(),
188 severity: Severity::Warning,
189 inner_span: None,
190 error_code: None,
191 }),
192 (true, true) => diagnostics.push(PluginDiagnostic {
193 stable_ptr: match_expr.stable_ptr.into(),
194 message: DestructMatch.diagnostic_message().to_string(),
195 severity: Severity::Warning,
196 inner_span: None,
197 error_code: None,
198 }),
199 (_, _) => (),
200 }
201}
202
203fn is_expr_list_parenthesised_unit(expr: &ExprListParenthesized, db: &dyn Database) -> bool {
205 expr.expressions(db).elements(db).len() == 0
206}
207
208fn is_block_expr_unit_without_comment(block_expr: &ExprBlock, db: &dyn Database) -> bool {
210 let mut statements = block_expr.statements(db).elements(db);
211 if statements.len() == 0
213 && block_expr
214 .rbrace(db)
215 .leading_trivia(db)
216 .node
217 .get_text(db)
218 .trim()
219 .is_empty()
220 {
221 return true;
222 }
223
224 if_chain! {
226 if statements.len() == 1;
227 if let Some(Statement::Expr(statement_expr)) = &statements.next();
228 if let AstExpr::Tuple(tuple_expr) = statement_expr.expr(db);
229 then {
230 let tuple_node = tuple_expr.as_syntax_node();
231 if tuple_node.span(db).start != tuple_node.span_start_without_trivia(db) {
232 return false;
233 }
234 return is_expr_list_parenthesised_unit(&tuple_expr, db);
235 }
236 }
237 false
238}
239
240pub fn is_expr_unit(expr: AstExpr, db: &dyn Database) -> bool {
243 match expr {
244 AstExpr::Block(block_expr) => is_block_expr_unit_without_comment(&block_expr, db),
245 AstExpr::Tuple(tuple_expr) => is_expr_list_parenthesised_unit(&tuple_expr, db),
246 _ => false,
247 }
248}
249
250#[tracing::instrument(skip_all, level = "trace")]
268pub fn fix_destruct_match<'db>(
269 db: &'db dyn Database,
270 node: SyntaxNode<'db>,
271) -> Option<InternalFix<'db>> {
272 let match_expr = AstExprMatch::from_syntax_node(db, node);
273 let mut arms = match_expr.arms(db).elements(db);
274 let first_arm = &arms
275 .next()
276 .expect("Expected a `match` with at least one arm.");
277 let second_arm = &arms.next().expect("Expected a `match` with second arm.");
278 let (pattern, first_expr) = match (
279 &first_arm
280 .patterns(db)
281 .elements(db)
282 .next()
283 .expect("Expected a pattern in the first arm."),
284 &second_arm
285 .patterns(db)
286 .elements(db)
287 .next()
288 .expect("Expected a pattern in the second arm."),
289 ) {
290 (AstPattern::Underscore(_), AstPattern::Enum(pat)) => (pat.as_syntax_node(), second_arm),
291 (AstPattern::Enum(pat), AstPattern::Underscore(_)) => (pat.as_syntax_node(), first_arm),
292 (AstPattern::Underscore(_), AstPattern::Struct(pat)) => (pat.as_syntax_node(), second_arm),
293 (AstPattern::Struct(pat), AstPattern::Underscore(_)) => (pat.as_syntax_node(), first_arm),
294 (AstPattern::Enum(pat1), AstPattern::Enum(pat2)) => {
295 if is_expr_unit(second_arm.expression(db), db) {
296 (pat1.as_syntax_node(), first_arm)
297 } else {
298 (pat2.as_syntax_node(), second_arm)
299 }
300 }
301 (_, _) => panic!("Incorrect diagnostic"),
302 };
303 let mut pattern_span = pattern.span(db);
304 pattern_span.end = pattern.span_start_without_trivia(db);
305 let indent = node
306 .get_text(db)
307 .chars()
308 .take_while(|c| c.is_whitespace())
309 .collect::<String>();
310 let trivia = pattern.get_text_of_span(db, pattern_span);
311 Some(InternalFix {
312 node,
313 suggestion: indent_snippet(
314 &format!(
315 "{trivia}{indent}if let {} = {} {{\n{}\n}}",
316 pattern.get_text_without_trivia(db).long(db).as_str(),
317 match_expr
318 .expr(db)
319 .as_syntax_node()
320 .get_text_without_trivia(db)
321 .long(db)
322 .as_str(),
323 first_expr.expression(db).as_syntax_node().get_text(db),
324 ),
325 indent.len() / 4,
326 ),
327 description: DestructMatch.fix_message().unwrap().to_string(),
328 import_addition_paths: None,
329 })
330}