Skip to main content

cairo_lint/lints/loops/
loop_match_pop_front.rs

1use cairo_lang_defs::ids::{ModuleItemId, TopLevelLanguageElementId};
2use cairo_lang_defs::plugin::PluginDiagnostic;
3use cairo_lang_diagnostics::Severity;
4use cairo_lang_semantic::{
5    Arenas, Expr, ExprBlock, ExprId, ExprLoop, ExprMatch, Pattern, PatternEnumVariant, Statement,
6};
7use cairo_lang_syntax::node::SyntaxNode;
8
9use cairo_lang_syntax::node::{
10    TypedStablePtr, TypedSyntaxNode,
11    ast::{
12        Expr as AstExpr, ExprLoop as AstExprLoop, OptionPatternEnumInnerPattern,
13        Pattern as AstPattern, Statement as AstStatement,
14    },
15};
16use if_chain::if_chain;
17
18use crate::context::{CairoLintKind, Lint};
19
20use crate::fixer::InternalFix;
21use crate::helper::indent_snippet;
22use crate::lints::{NONE, SOME, function_trait_name_from_fn_id};
23use crate::queries::{get_all_function_bodies, get_all_loop_expressions};
24use salsa::Database;
25
26const POP_FRONT_SPAN_TRAIT_FUNCTION: &str = "core::array::SpanTrait::pop_front";
27
28pub struct LoopMatchPopFront;
29
30/// ## What it does
31///
32/// Checks for loops that are used to iterate over a span using `pop_front`.
33///
34/// ## Example
35///
36/// ```cairo
37/// let a: Span<u32> = array![1, 2, 3].span();
38/// loop {
39///     match a.pop_front() {
40///         Option::Some(val) => {do_smth(val); },
41///         Option::None => { break; }
42///     }
43/// }
44/// ```
45///
46/// Which can be rewritten as
47///
48/// ```cairo
49/// let a: Span<u32> = array![1, 2, 3].span();
50/// for val in a {
51///     do_smth(val);
52/// }
53/// ```
54impl Lint for LoopMatchPopFront {
55    fn allowed_name(&self) -> &'static str {
56        "loop_match_pop_front"
57    }
58
59    fn diagnostic_message(&self) -> &'static str {
60        "you seem to be trying to use `loop` for iterating over a span. Consider using `for in`"
61    }
62
63    fn kind(&self) -> CairoLintKind {
64        CairoLintKind::LoopMatchPopFront
65    }
66
67    fn has_fixer(&self) -> bool {
68        true
69    }
70
71    fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
72        fix_loop_match_pop_front(db, node)
73    }
74
75    fn fix_message(&self) -> Option<&'static str> {
76        Some("Replace `loop` with `for in` for iterating over spans")
77    }
78}
79
80#[tracing::instrument(skip_all, level = "trace")]
81pub fn check_loop_match_pop_front<'db>(
82    db: &'db dyn Database,
83    item: &ModuleItemId<'db>,
84    diagnostics: &mut Vec<PluginDiagnostic<'db>>,
85) {
86    let function_bodies = get_all_function_bodies(db, item);
87    for function_body in function_bodies {
88        let loop_exprs = get_all_loop_expressions(function_body);
89        let arenas = &function_body.arenas;
90        for loop_expr in loop_exprs.iter() {
91            check_single_loop_match_pop_front(db, loop_expr, diagnostics, arenas);
92        }
93    }
94}
95
96fn check_single_loop_match_pop_front<'db>(
97    db: &'db dyn Database,
98    loop_expr: &ExprLoop<'db>,
99    diagnostics: &mut Vec<PluginDiagnostic<'db>>,
100    arenas: &Arenas<'db>,
101) {
102    // Checks that the loop doesn't return anything
103    if !loop_expr.ty.is_unit(db) {
104        return;
105    }
106    let Expr::Block(expr_block) = &arenas.exprs[loop_expr.body] else {
107        return;
108    };
109
110    // Case where there's no statements only an expression in the tail.
111    if_chain! {
112        if expr_block.statements.is_empty();
113        if let Some(tail) = &expr_block.tail;
114        // Get the function call and check that it's the span match pop front function from the corelib
115        if let Expr::Match(expr_match) = &arenas.exprs[*tail];
116        if let Expr::FunctionCall(func_call) = &arenas.exprs[expr_match.matched_expr];
117        if function_trait_name_from_fn_id(db, &func_call.function) == POP_FRONT_SPAN_TRAIT_FUNCTION;
118        then {
119            // Check that something is done only in the Some branch of the match
120            if !check_single_match(db, expr_match, arenas) {
121                return;
122            }
123            diagnostics.push(PluginDiagnostic {
124                stable_ptr: loop_expr.stable_ptr.into(),
125                message: LoopMatchPopFront.diagnostic_message().to_owned(),
126                severity: Severity::Warning,
127                error_code: None,
128                inner_span: None
129            });
130            return;
131        }
132    }
133
134    // If the loop contains multiple statements.
135    if_chain! {
136        if !expr_block.statements.is_empty();
137        // If the first statement is the match we're looking for. the order is important
138        if let Statement::Expr(stmt_expr) = &arenas.statements[expr_block.statements[0]];
139        if let Expr::Match(expr_match) = &arenas.exprs[stmt_expr.expr];
140        then {
141            // Checks that we're only doing something in the some branch
142            if !check_single_match(db, expr_match, arenas) {
143                return;
144            }
145            let Expr::FunctionCall(func_call) = &arenas.exprs[expr_match.matched_expr] else {
146                return;
147            };
148            if function_trait_name_from_fn_id(db, &func_call.function) == POP_FRONT_SPAN_TRAIT_FUNCTION {
149                diagnostics.push(PluginDiagnostic {
150                    stable_ptr: loop_expr.stable_ptr.into(),
151                    message: LoopMatchPopFront.diagnostic_message().to_owned(),
152                    severity: Severity::Warning,
153                    error_code: None,
154                    inner_span: None
155                })
156            }
157        }
158    }
159}
160
161const OPTION_TYPE: &str = "core::option::Option::<";
162
163fn check_single_match<'db>(
164    db: &dyn Database,
165    match_expr: &ExprMatch<'db>,
166    arenas: &Arenas<'db>,
167) -> bool {
168    let arms = &match_expr.arms;
169
170    // Check that we're in a setup with 2 arms that return unit
171    if arms.len() == 2 && match_expr.ty.is_unit(db) {
172        let first_arm = &arms[0];
173        let second_arm = &arms[1];
174        let is_first_arm_correct = if let Some(pattern) = first_arm.patterns.first() {
175            match &arenas.patterns[*pattern] {
176                // If the first arm is `_ => smth` it's incorrect
177                Pattern::Otherwise(_) => false,
178                // Check if the variant is of type option and if it's `None` checks that it only contains `{ break; }`
179                // without comments`
180                Pattern::EnumVariant(enum_pat) => {
181                    check_enum_pattern(db, enum_pat, arenas, first_arm.expression)
182                }
183                _ => false,
184            }
185        } else {
186            false
187        };
188        let is_second_arm_correct = if let Some(pattern) = second_arm.patterns.first() {
189            match &arenas.patterns[*pattern] {
190                // If the 2nd arm is `_ => smth`, checks that smth is `{ break; }`
191                Pattern::Otherwise(_) => {
192                    if let Expr::Block(expr_block) = &arenas.exprs[second_arm.expression] {
193                        check_block_is_break(db, expr_block, arenas)
194                    } else {
195                        return false;
196                    }
197                }
198                // Check if the variant is of type option and if it's `None` checks that it only contains `{ break; }`
199                // without comments`
200                Pattern::EnumVariant(enum_pat) => {
201                    check_enum_pattern(db, enum_pat, arenas, second_arm.expression)
202                }
203                _ => false,
204            }
205        } else {
206            false
207        };
208        is_first_arm_correct && is_second_arm_correct
209    } else {
210        false
211    }
212}
213fn check_enum_pattern<'db>(
214    db: &'db dyn Database,
215    enum_pat: &PatternEnumVariant<'db>,
216    arenas: &Arenas<'db>,
217    arm_expression: ExprId,
218) -> bool {
219    // Checks that the variant is from the option type.
220    if !enum_pat.ty.format(db).starts_with(OPTION_TYPE) {
221        return false;
222    }
223
224    // Check if the variant is the None variant
225    if_chain! {
226        if enum_pat.variant.id.full_path(db) == NONE;
227        // Get the expression of the None variant and checks if it's a block expression.
228        if let Expr::Block(expr_block) = &arenas.exprs[arm_expression];
229        // If it's a block expression checks that it only contains `break;`
230        if check_block_is_break(db, expr_block, arenas);
231      then {
232          return true;
233      }
234    }
235    enum_pat.variant.id.full_path(db) == SOME
236}
237/// Checks that the block only contains `break;` without comments
238fn check_block_is_break(db: &dyn Database, expr_block: &ExprBlock, arenas: &Arenas) -> bool {
239    if_chain! {
240        if expr_block.statements.len() == 1;
241        if let Statement::Break(break_stmt) = &arenas.statements[expr_block.statements[0]];
242        then {
243            let break_node = break_stmt.stable_ptr.lookup(db).as_syntax_node();
244            // Checks that the trimmed text == the text without trivia which would mean that there is no comment
245            let break_text = break_node.get_text(db).trim().to_string();
246            if break_text == break_node.get_text_without_trivia(db).to_string(db)
247                && (break_text == "break;" || break_text == "break ();")
248            {
249                return true;
250            }
251        }
252    }
253    false
254}
255
256/// Rewrites this:
257///
258/// ```ignore
259/// loop {
260///     match some_span.pop_front() {
261///         Option::Some(val) => do_smth(val),
262///         Option::None => break;
263///     }
264/// }
265/// ```
266/// to this:
267/// ```ignore
268/// for val in span {
269///     do_smth(val);
270/// };
271/// ```
272#[tracing::instrument(skip_all, level = "trace")]
273pub fn fix_loop_match_pop_front<'db>(
274    db: &'db dyn Database,
275    node: SyntaxNode<'db>,
276) -> Option<InternalFix<'db>> {
277    let expr_loop = AstExprLoop::from_syntax_node(db, node);
278    let body = expr_loop.body(db);
279    let Some(AstStatement::Expr(expr)) = &body.statements(db).elements(db).next() else {
280        panic!(
281            "Wrong statement type. This is probably a bug in the lint detection. Please report it"
282        )
283    };
284    let AstExpr::Match(expr_match) = expr.expr(db) else {
285        panic!(
286            "Wrong expression type. This is probably a bug in the lint detection. Please report it"
287        )
288    };
289    let val = expr_match.expr(db);
290    let span_name = match val {
291        AstExpr::FunctionCall(func_call) => func_call
292            .arguments(db)
293            .arguments(db)
294            .elements(db)
295            .next()
296            .expect("Expected at least one argument for the function call")
297            .arg_clause(db)
298            .as_syntax_node()
299            .get_text(db),
300        AstExpr::Binary(dot_call) => dot_call.lhs(db).as_syntax_node().get_text(db),
301        _ => panic!(
302            "Wrong expression type. This is probably a bug in the lint detection. Please report it"
303        ),
304    };
305    let mut elt_name = "";
306    let mut some_arm = "";
307    let arms = expr_match.arms(db).elements(db);
308
309    let mut loop_span = node.span(db);
310    loop_span.end = node.span_start_without_trivia(db);
311    let indent = node
312        .get_text(db)
313        .chars()
314        .take_while(|c| c.is_whitespace())
315        .collect::<String>();
316    let trivia = node.get_text_of_span(db, loop_span).trim().to_string();
317    let trivia = if trivia.is_empty() {
318        trivia
319    } else {
320        format!("{indent}{trivia}\n")
321    };
322    for arm in arms {
323        if_chain! {
324            if let Some(AstPattern::Enum(enum_pattern)) = &arm.patterns(db).elements(db).next();
325            if let OptionPatternEnumInnerPattern::PatternEnumInnerPattern(var) = enum_pattern.pattern(db);
326            then {
327                elt_name = var.pattern(db).as_syntax_node().get_text(db);
328                some_arm = if let AstExpr::Block(block_expr) = arm.expression(db) {
329                    block_expr.statements(db).as_syntax_node().get_text(db)
330                } else {
331                    arm.expression(db).as_syntax_node().get_text(db)
332                }
333            }
334        }
335    }
336    Some(InternalFix {
337        node,
338        suggestion: indent_snippet(
339            &format!("{trivia}for {elt_name} in {span_name} {{\n{some_arm}\n}};\n"),
340            indent.len() / 4,
341        ),
342        description: LoopMatchPopFront.fix_message().unwrap().to_string(),
343        import_addition_paths: None,
344    })
345}