Skip to main content

cairo_lint/lints/
int_op_one.rs

1use cairo_lang_defs::ids::{LookupItemId, ModuleId, ModuleItemId, TraitFunctionId};
2use cairo_lang_defs::plugin::PluginDiagnostic;
3use cairo_lang_diagnostics::Severity;
4use cairo_lang_semantic::items::functions::GenericFunctionId;
5use cairo_lang_semantic::items::imp::ImplHead;
6use cairo_lang_semantic::{Arenas, Expr, ExprFunctionCall, ExprFunctionCallArg};
7use cairo_lang_syntax::node::ast::{Expr as AstExpr, ExprBinary};
8
9use cairo_lang_syntax::node::{SyntaxNode, TypedStablePtr, TypedSyntaxNode};
10use if_chain::if_chain;
11
12use crate::context::{CairoLintKind, Lint};
13
14use crate::LinterGroup;
15use crate::fixer::InternalFix;
16use crate::helper::is_item_ancestor_of_module;
17use crate::queries::{get_all_function_bodies, get_all_function_calls};
18use salsa::Database;
19
20pub struct IntegerGreaterEqualPlusOne;
21
22/// ## What it does
23///
24/// Check for unnecessary add operation in integer >= comparison.
25///
26/// ## Example
27///
28/// ```cairo
29/// fn main() {
30///     let x: u32 = 1;
31///     let y: u32 = 1;
32///     if x >= y + 1 {}
33/// }
34/// ```
35///
36/// Can be simplified to:
37///
38/// ```cairo
39/// fn main() {
40///     let x: u32 = 1;
41///     let y: u32 = 1;
42///     if x > y {}
43/// }
44/// ```
45impl Lint for IntegerGreaterEqualPlusOne {
46    fn allowed_name(&self) -> &'static str {
47        "int_ge_plus_one"
48    }
49
50    fn diagnostic_message(&self) -> &'static str {
51        "Unnecessary add operation in integer >= comparison. Use simplified comparison."
52    }
53
54    fn kind(&self) -> CairoLintKind {
55        CairoLintKind::IntGePlusOne
56    }
57
58    fn has_fixer(&self) -> bool {
59        true
60    }
61
62    fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
63        fix_int_ge_plus_one(db, node)
64    }
65
66    fn fix_message(&self) -> Option<&'static str> {
67        Some("Replace with simplified '>' comparison")
68    }
69}
70
71pub struct IntegerGreaterEqualMinusOne;
72
73/// ## What it does
74///
75/// Check for unnecessary sub operation in integer >= comparison.
76///
77/// ## Example
78///
79/// ```cairo
80/// fn main() {
81///     let x: u32 = 1;
82///     let y: u32 = 1;
83///     if x - 1 >= y {}
84/// }
85/// ```
86///
87/// Can be simplified to:
88///
89/// ```cairo
90/// fn main() {
91///     let x: u32 = 1;
92///     let y: u32 = 1;
93///     if x > y {}
94/// }
95/// ```
96impl Lint for IntegerGreaterEqualMinusOne {
97    fn allowed_name(&self) -> &'static str {
98        "int_ge_min_one"
99    }
100
101    fn diagnostic_message(&self) -> &'static str {
102        "Unnecessary sub operation in integer >= comparison. Use simplified comparison."
103    }
104
105    fn kind(&self) -> CairoLintKind {
106        CairoLintKind::IntGeMinOne
107    }
108
109    fn has_fixer(&self) -> bool {
110        true
111    }
112
113    fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
114        fix_int_ge_min_one(db, node)
115    }
116
117    fn fix_message(&self) -> Option<&'static str> {
118        Some("Replace with simplified '>' comparison")
119    }
120}
121
122pub struct IntegerLessEqualPlusOne;
123
124/// ## What it does
125///
126/// Check for unnecessary add operation in integer <= comparison.
127///
128/// ## Example
129///
130/// ```cairo
131/// fn main() {
132///     let x: u32 = 1;
133///     let y: u32 = 1;
134///     if x + 1 <= y {}
135/// }
136/// ```
137///
138/// Can be simplified to:
139///
140/// ```cairo
141/// fn main() {
142///     let x: u32 = 1;
143///     let y: u32 = 1;
144///     if x < y {}
145/// }
146/// ```
147impl Lint for IntegerLessEqualPlusOne {
148    fn allowed_name(&self) -> &'static str {
149        "int_le_plus_one"
150    }
151
152    fn diagnostic_message(&self) -> &'static str {
153        "Unnecessary add operation in integer <= comparison. Use simplified comparison."
154    }
155
156    fn kind(&self) -> CairoLintKind {
157        CairoLintKind::IntLePlusOne
158    }
159
160    fn has_fixer(&self) -> bool {
161        true
162    }
163
164    fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
165        fix_int_le_plus_one(db, node)
166    }
167
168    fn fix_message(&self) -> Option<&'static str> {
169        Some("Replace with simplified '<' comparison")
170    }
171}
172
173pub struct IntegerLessEqualMinusOne;
174
175/// ## What it does
176///
177/// Check for unnecessary sub operation in integer <= comparison.
178///
179/// ## Example
180///
181/// ```cairo
182/// fn main() {
183///     let x: u32 = 1;
184///     let y: u32 = 1;
185///     if x <= y - 1 {}
186/// }
187/// ```
188///
189/// Can be simplified to:
190///
191/// ```cairo
192/// fn main() {
193///     let x: u32 = 1;
194///     let y: u32 = 1;
195///     if x < y {}
196/// }
197/// ```
198impl Lint for IntegerLessEqualMinusOne {
199    fn allowed_name(&self) -> &'static str {
200        "int_le_min_one"
201    }
202
203    fn diagnostic_message(&self) -> &'static str {
204        "Unnecessary sub operation in integer <= comparison. Use simplified comparison."
205    }
206
207    fn kind(&self) -> CairoLintKind {
208        CairoLintKind::IntLeMinOne
209    }
210
211    fn has_fixer(&self) -> bool {
212        true
213    }
214
215    fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
216        fix_int_le_min_one(db, node)
217    }
218
219    fn fix_message(&self) -> Option<&'static str> {
220        Some("Replace with simplified '<' comparison")
221    }
222}
223
224#[tracing::instrument(skip_all, level = "trace")]
225pub fn check_int_op_one<'db>(
226    db: &'db dyn Database,
227    item: &ModuleItemId<'db>,
228    diagnostics: &mut Vec<PluginDiagnostic<'db>>,
229) {
230    let function_bodies = get_all_function_bodies(db, item);
231    for function_body in function_bodies.iter() {
232        let function_call_exprs = get_all_function_calls(function_body);
233        let arenas = &function_body.arenas;
234        for function_call_expr in function_call_exprs {
235            check_single_int_op_one(db, &function_call_expr, arenas, diagnostics);
236        }
237    }
238}
239
240fn check_single_int_op_one<'db>(
241    db: &'db dyn Database,
242    function_call_expr: &ExprFunctionCall<'db>,
243    arenas: &Arenas<'db>,
244    diagnostics: &mut Vec<PluginDiagnostic<'db>>,
245) {
246    // Check if the function call is part of the implementation.
247    let GenericFunctionId::Impl(impl_generic_func_id) = function_call_expr
248        .function
249        .get_concrete(db)
250        .generic_function
251    else {
252        return;
253    };
254
255    let corelib_context = db.corelib_context();
256
257    // Check if the function call is the bool greater or equal (>=) or lower or equal (<=).
258    if impl_generic_func_id.function != corelib_context.get_partial_ord_ge_trait_function_id()
259        && impl_generic_func_id.function != corelib_context.get_partial_ord_le_trait_function_id()
260    {
261        return;
262    }
263
264    // Check if the function call is part of the corelib integer module.
265    let is_part_of_corelib_integer =
266        if let Some(ImplHead::Concrete(impl_def_id)) = impl_generic_func_id.impl_id.head(db) {
267            is_item_ancestor_of_module(
268                db,
269                &LookupItemId::ModuleItem(ModuleItemId::Impl(impl_def_id)),
270                ModuleId::Submodule(corelib_context.get_integer_module_id()),
271            )
272        } else {
273            false
274        };
275
276    if !is_part_of_corelib_integer {
277        return;
278    }
279
280    let lhs = &function_call_expr.args[0];
281    let rhs = &function_call_expr.args[1];
282
283    let add_trait_function_id = corelib_context.get_add_trait_function_id();
284    let sub_trait_function_id = corelib_context.get_sub_trait_function_id();
285    let partial_ord_ge_trait_function_id = corelib_context.get_partial_ord_ge_trait_function_id();
286    let partial_ord_le_trait_function_id = corelib_context.get_partial_ord_le_trait_function_id();
287
288    // x >= y + 1
289    if check_is_variable(lhs, arenas)
290        && check_is_add_or_sub_one(
291            db,
292            rhs,
293            arenas,
294            is_part_of_corelib_integer,
295            add_trait_function_id,
296        )
297        && impl_generic_func_id.function == partial_ord_ge_trait_function_id
298    {
299        diagnostics.push(PluginDiagnostic {
300            stable_ptr: function_call_expr.stable_ptr.untyped(),
301            message: IntegerGreaterEqualPlusOne.diagnostic_message().to_string(),
302            severity: Severity::Warning,
303            inner_span: None,
304            error_code: None,
305        })
306    }
307
308    // x - 1 >= y
309    if check_is_add_or_sub_one(
310        db,
311        lhs,
312        arenas,
313        is_part_of_corelib_integer,
314        sub_trait_function_id,
315    ) && check_is_variable(rhs, arenas)
316        && impl_generic_func_id.function == partial_ord_ge_trait_function_id
317    {
318        diagnostics.push(PluginDiagnostic {
319            stable_ptr: function_call_expr.stable_ptr.untyped(),
320            message: IntegerGreaterEqualMinusOne.diagnostic_message().to_string(),
321            severity: Severity::Warning,
322            inner_span: None,
323            error_code: None,
324        })
325    }
326
327    // x + 1 <= y
328    if check_is_add_or_sub_one(
329        db,
330        lhs,
331        arenas,
332        is_part_of_corelib_integer,
333        add_trait_function_id,
334    ) && check_is_variable(rhs, arenas)
335        && impl_generic_func_id.function == partial_ord_le_trait_function_id
336    {
337        diagnostics.push(PluginDiagnostic {
338            stable_ptr: function_call_expr.stable_ptr.untyped(),
339            message: IntegerLessEqualPlusOne.diagnostic_message().to_string(),
340            severity: Severity::Warning,
341            inner_span: None,
342            error_code: None,
343        })
344    }
345
346    // x <= y - 1
347    if check_is_variable(lhs, arenas)
348        && check_is_add_or_sub_one(
349            db,
350            rhs,
351            arenas,
352            is_part_of_corelib_integer,
353            sub_trait_function_id,
354        )
355        && impl_generic_func_id.function == partial_ord_le_trait_function_id
356    {
357        diagnostics.push(PluginDiagnostic {
358            stable_ptr: function_call_expr.stable_ptr.untyped(),
359            message: IntegerLessEqualMinusOne.diagnostic_message().to_string(),
360            severity: Severity::Warning,
361            inner_span: None,
362            error_code: None,
363        })
364    }
365}
366
367fn check_is_variable<'db>(arg: &ExprFunctionCallArg<'db>, arenas: &Arenas<'db>) -> bool {
368    if let ExprFunctionCallArg::Value(val_expr) = arg {
369        matches!(arenas.exprs[*val_expr], Expr::Var(_))
370    } else {
371        false
372    }
373}
374
375fn check_is_add_or_sub_one<'db>(
376    db: &'db dyn Database,
377    arg: &ExprFunctionCallArg<'db>,
378    arenas: &Arenas<'db>,
379    is_part_of_corelib_integer: bool,
380    operation_function_trait_id: TraitFunctionId<'db>,
381) -> bool {
382    let ExprFunctionCallArg::Value(v) = arg else {
383        return false;
384    };
385    let Expr::FunctionCall(ref func_call) = arenas.exprs[*v] else {
386        return false;
387    };
388
389    let GenericFunctionId::Impl(impl_generic_func_id) =
390        func_call.function.get_concrete(db).generic_function
391    else {
392        return false;
393    };
394
395    // Check is addition or substraction
396    if !is_part_of_corelib_integer && impl_generic_func_id.function != operation_function_trait_id
397        || func_call.args.len() != 2
398    {
399        return false;
400    }
401
402    let lhs = &func_call.args[0];
403    let rhs = &func_call.args[1];
404
405    // Check lhs is var
406    if let ExprFunctionCallArg::Value(v) = lhs {
407        let Expr::Var(_) = arenas.exprs[*v] else {
408            return false;
409        };
410    };
411
412    // Check rhs is 1
413    if_chain! {
414        if let ExprFunctionCallArg::Value(v) = rhs;
415        if let Expr::Literal(ref litteral_expr) = arenas.exprs[*v];
416        if litteral_expr.value == 1.into();
417        then {
418            return true;
419        }
420    }
421
422    false
423}
424
425/// Rewrites a manual implementation of int ge plus one x >= y + 1
426#[tracing::instrument(skip_all, level = "trace")]
427pub fn fix_int_ge_plus_one<'db>(
428    db: &'db dyn Database,
429    node: SyntaxNode<'db>,
430) -> Option<InternalFix<'db>> {
431    let node = ExprBinary::from_syntax_node(db, node);
432    let lhs = node.lhs(db).as_syntax_node().get_text(db);
433
434    let AstExpr::Binary(rhs_exp) = node.rhs(db) else {
435        panic!("should be addition")
436    };
437    let rhs = rhs_exp.lhs(db).as_syntax_node().get_text(db);
438
439    let fix = format!("{} > {} ", lhs.trim(), rhs.trim());
440    Some(InternalFix {
441        node: node.as_syntax_node(),
442        suggestion: fix,
443        description: IntegerGreaterEqualPlusOne
444            .fix_message()
445            .unwrap()
446            .to_string(),
447        import_addition_paths: None,
448    })
449}
450
451/// Rewrites a manual implementation of int ge min one x - 1 >= y
452#[tracing::instrument(skip_all, level = "trace")]
453pub fn fix_int_ge_min_one<'db>(
454    db: &'db dyn Database,
455    node: SyntaxNode<'db>,
456) -> Option<InternalFix<'db>> {
457    let node = ExprBinary::from_syntax_node(db, node);
458    let AstExpr::Binary(lhs_exp) = node.lhs(db) else {
459        panic!("should be substraction")
460    };
461    let rhs = node.rhs(db).as_syntax_node().get_text(db);
462
463    let lhs = lhs_exp.lhs(db).as_syntax_node().get_text(db);
464
465    let fix = format!("{} > {} ", lhs.trim(), rhs.trim());
466    Some(InternalFix {
467        node: node.as_syntax_node(),
468        suggestion: fix,
469        description: IntegerGreaterEqualMinusOne
470            .fix_message()
471            .unwrap()
472            .to_string(),
473        import_addition_paths: None,
474    })
475}
476
477/// Rewrites a manual implementation of int le plus one x + 1 <= y
478#[tracing::instrument(skip_all, level = "trace")]
479pub fn fix_int_le_plus_one<'db>(
480    db: &'db dyn Database,
481    node: SyntaxNode<'db>,
482) -> Option<InternalFix<'db>> {
483    let node = ExprBinary::from_syntax_node(db, node);
484    let AstExpr::Binary(lhs_exp) = node.lhs(db) else {
485        panic!("should be addition")
486    };
487    let rhs = node.rhs(db).as_syntax_node().get_text(db);
488
489    let lhs = lhs_exp.lhs(db).as_syntax_node().get_text(db);
490
491    let fix = format!("{} < {} ", lhs.trim(), rhs.trim());
492    Some(InternalFix {
493        node: node.as_syntax_node(),
494        suggestion: fix,
495        description: IntegerLessEqualPlusOne.fix_message().unwrap().to_string(),
496        import_addition_paths: None,
497    })
498}
499
500/// Rewrites a manual implementation of int le min one x <= y -1
501#[tracing::instrument(skip_all, level = "trace")]
502pub fn fix_int_le_min_one<'db>(
503    db: &'db dyn Database,
504    node: SyntaxNode<'db>,
505) -> Option<InternalFix<'db>> {
506    let node = ExprBinary::from_syntax_node(db, node);
507    let lhs = node.lhs(db).as_syntax_node().get_text(db);
508
509    let AstExpr::Binary(rhs_exp) = node.rhs(db) else {
510        panic!("should be substraction")
511    };
512    let rhs = rhs_exp.lhs(db).as_syntax_node().get_text(db);
513
514    let fix = format!("{} < {} ", lhs.trim(), rhs.trim());
515    Some(InternalFix {
516        node: node.as_syntax_node(),
517        suggestion: fix,
518        description: IntegerLessEqualMinusOne.fix_message().unwrap().to_string(),
519        import_addition_paths: None,
520    })
521}