cairo_lint_core/lints/
int_op_one.rs

1use cairo_lang_defs::ids::ModuleItemId;
2use cairo_lang_defs::plugin::PluginDiagnostic;
3use cairo_lang_diagnostics::Severity;
4use cairo_lang_semantic::db::SemanticGroup;
5use cairo_lang_semantic::{Arenas, Expr, ExprFunctionCall, ExprFunctionCallArg};
6use cairo_lang_syntax::node::ast::{Expr as AstExpr, ExprBinary};
7use cairo_lang_syntax::node::db::SyntaxGroup;
8use cairo_lang_syntax::node::{SyntaxNode, TypedStablePtr, TypedSyntaxNode};
9use if_chain::if_chain;
10
11use crate::context::{CairoLintKind, Lint};
12use crate::queries::{get_all_function_bodies, get_all_function_calls};
13
14pub struct IntegerGreaterEqualPlusOne;
15
16/// ## What it does
17///
18/// Check for unnecessary add operation in integer >= comparison.
19///
20/// ## Example
21///
22/// ```cairo
23/// fn main() {
24///     let x: u32 = 1;
25///     let y: u32 = 1;
26///     if x >= y + 1 {}
27/// }
28/// ```
29///
30/// Can be simplified to:
31///
32/// ```cairo
33/// fn main() {
34///     let x: u32 = 1;
35///     let y: u32 = 1;
36///     if x > y {}
37/// }
38/// ```
39impl Lint for IntegerGreaterEqualPlusOne {
40    fn allowed_name(&self) -> &'static str {
41        "int_ge_plus_one"
42    }
43
44    fn diagnostic_message(&self) -> &'static str {
45        "Unnecessary add operation in integer >= comparison. Use simplified comparison."
46    }
47
48    fn kind(&self) -> CairoLintKind {
49        CairoLintKind::IntGePlusOne
50    }
51
52    fn has_fixer(&self) -> bool {
53        true
54    }
55
56    fn fix(&self, db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
57        fix_int_ge_plus_one(db, node)
58    }
59}
60
61pub struct IntegerGreaterEqualMinusOne;
62
63/// ## What it does
64///
65/// Check for unnecessary sub operation in integer >= comparison.
66///
67/// ## Example
68///
69/// ```cairo
70/// fn main() {
71///     let x: u32 = 1;
72///     let y: u32 = 1;
73///     if x - 1 >= y {}
74/// }
75/// ```
76///
77/// Can be simplified to:
78///
79/// ```cairo
80/// fn main() {
81///     let x: u32 = 1;
82///     let y: u32 = 1;
83///     if x > y {}
84/// }
85/// ```
86impl Lint for IntegerGreaterEqualMinusOne {
87    fn allowed_name(&self) -> &'static str {
88        "int_ge_min_one"
89    }
90
91    fn diagnostic_message(&self) -> &'static str {
92        "Unnecessary sub operation in integer >= comparison. Use simplified comparison."
93    }
94
95    fn kind(&self) -> CairoLintKind {
96        CairoLintKind::IntGeMinOne
97    }
98
99    fn has_fixer(&self) -> bool {
100        true
101    }
102
103    fn fix(&self, db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
104        fix_int_ge_min_one(db, node)
105    }
106}
107
108pub struct IntegerLessEqualPlusOne;
109
110/// ## What it does
111///
112/// Check for unnecessary add operation in integer <= comparison.
113///
114/// ## Example
115///
116/// ```cairo
117/// fn main() {
118///     let x: u32 = 1;
119///     let y: u32 = 1;
120///     if x + 1 <= y {}
121/// }
122/// ```
123///
124/// Can be simplified to:
125///
126/// ```cairo
127/// fn main() {
128///     let x: u32 = 1;
129///     let y: u32 = 1;
130///     if x < y {}
131/// }
132/// ```
133impl Lint for IntegerLessEqualPlusOne {
134    fn allowed_name(&self) -> &'static str {
135        "int_le_plus_one"
136    }
137
138    fn diagnostic_message(&self) -> &'static str {
139        "Unnecessary add operation in integer <= comparison. Use simplified comparison."
140    }
141
142    fn kind(&self) -> CairoLintKind {
143        CairoLintKind::IntLePlusOne
144    }
145
146    fn has_fixer(&self) -> bool {
147        true
148    }
149
150    fn fix(&self, db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
151        fix_int_le_plus_one(db, node)
152    }
153}
154
155pub struct IntegerLessEqualMinusOne;
156
157/// ## What it does
158///
159/// Check for unnecessary sub operation in integer <= comparison.
160///
161/// ## Example
162///
163/// ```cairo
164/// fn main() {
165///     let x: u32 = 1;
166///     let y: u32 = 1;
167///     if x <= y - 1 {}
168/// }
169/// ```
170///
171/// Can be simplified to:
172///
173/// ```cairo
174/// fn main() {
175///     let x: u32 = 1;
176///     let y: u32 = 1;
177///     if x < y {}
178/// }
179/// ```
180impl Lint for IntegerLessEqualMinusOne {
181    fn allowed_name(&self) -> &'static str {
182        "int_le_min_one"
183    }
184
185    fn diagnostic_message(&self) -> &'static str {
186        "Unnecessary sub operation in integer <= comparison. Use simplified comparison."
187    }
188
189    fn kind(&self) -> CairoLintKind {
190        CairoLintKind::IntLeMinOne
191    }
192
193    fn has_fixer(&self) -> bool {
194        true
195    }
196
197    fn fix(&self, db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
198        fix_int_le_min_one(db, node)
199    }
200}
201
202pub fn check_int_op_one(
203    db: &dyn SemanticGroup,
204    item: &ModuleItemId,
205    diagnostics: &mut Vec<PluginDiagnostic>,
206) {
207    let function_bodies = get_all_function_bodies(db, item);
208    for function_body in function_bodies.iter() {
209        let function_call_exprs = get_all_function_calls(function_body);
210        let arenas = &function_body.arenas;
211        for function_call_expr in function_call_exprs.iter() {
212            check_single_int_op_one(db, function_call_expr, arenas, diagnostics);
213        }
214    }
215}
216
217fn check_single_int_op_one(
218    db: &dyn SemanticGroup,
219    function_call_expr: &ExprFunctionCall,
220    arenas: &Arenas,
221    diagnostics: &mut Vec<PluginDiagnostic>,
222) {
223    // Check if the function call is the bool greater or equal (>=) or lower or equal (<=).
224    let full_name = function_call_expr.function.full_path(db);
225    if !full_name.contains("core::integer::")
226        || (!full_name.contains("PartialOrd::ge") && !full_name.contains("PartialOrd::le"))
227    {
228        return;
229    }
230
231    let lhs = &function_call_expr.args[0];
232    let rhs = &function_call_expr.args[1];
233
234    // x >= y + 1
235    if check_is_variable(lhs, arenas)
236        && check_is_add_or_sub_one(db, rhs, arenas, "::add")
237        && function_call_expr.function.full_path(db).contains("::ge")
238    {
239        diagnostics.push(PluginDiagnostic {
240            stable_ptr: function_call_expr.stable_ptr.untyped(),
241            message: IntegerGreaterEqualPlusOne.diagnostic_message().to_string(),
242            severity: Severity::Warning,
243        })
244    }
245
246    // x - 1 >= y
247    if check_is_add_or_sub_one(db, lhs, arenas, "::sub")
248        && check_is_variable(rhs, arenas)
249        && function_call_expr.function.full_path(db).contains("::ge")
250    {
251        diagnostics.push(PluginDiagnostic {
252            stable_ptr: function_call_expr.stable_ptr.untyped(),
253            message: IntegerGreaterEqualMinusOne.diagnostic_message().to_string(),
254            severity: Severity::Warning,
255        })
256    }
257
258    // x + 1 <= y
259    if check_is_add_or_sub_one(db, lhs, arenas, "::add")
260        && check_is_variable(rhs, arenas)
261        && function_call_expr.function.full_path(db).contains("::le")
262    {
263        diagnostics.push(PluginDiagnostic {
264            stable_ptr: function_call_expr.stable_ptr.untyped(),
265            message: IntegerLessEqualPlusOne.diagnostic_message().to_string(),
266            severity: Severity::Warning,
267        })
268    }
269
270    // x <= y - 1
271    if check_is_variable(lhs, arenas)
272        && check_is_add_or_sub_one(db, rhs, arenas, "::sub")
273        && function_call_expr.function.full_path(db).contains("::le")
274    {
275        diagnostics.push(PluginDiagnostic {
276            stable_ptr: function_call_expr.stable_ptr.untyped(),
277            message: IntegerLessEqualMinusOne.diagnostic_message().to_string(),
278            severity: Severity::Warning,
279        })
280    }
281}
282
283fn check_is_variable(arg: &ExprFunctionCallArg, arenas: &Arenas) -> bool {
284    if let ExprFunctionCallArg::Value(val_expr) = arg {
285        matches!(arenas.exprs[*val_expr], Expr::Var(_))
286    } else {
287        false
288    }
289}
290
291fn check_is_add_or_sub_one(
292    db: &dyn SemanticGroup,
293    arg: &ExprFunctionCallArg,
294    arenas: &Arenas,
295    operation: &str,
296) -> bool {
297    let ExprFunctionCallArg::Value(v) = arg else {
298        return false;
299    };
300    let Expr::FunctionCall(ref func_call) = arenas.exprs[*v] else {
301        return false;
302    };
303
304    // Check is addition or substraction
305    let full_name = func_call.function.full_path(db);
306    if !full_name.contains("core::integer::") && !full_name.contains(operation)
307        || func_call.args.len() != 2
308    {
309        return false;
310    }
311
312    let lhs = &func_call.args[0];
313    let rhs = &func_call.args[1];
314
315    // Check lhs is var
316    if let ExprFunctionCallArg::Value(v) = lhs {
317        let Expr::Var(_) = arenas.exprs[*v] else {
318            return false;
319        };
320    };
321
322    // Check rhs is 1
323    if_chain! {
324        if let ExprFunctionCallArg::Value(v) = rhs;
325        if let Expr::Literal(ref litteral_expr) = arenas.exprs[*v];
326        if litteral_expr.value == 1.into();
327        then {
328            return true;
329        }
330    }
331
332    false
333}
334
335/// Rewrites a manual implementation of int ge plus one x >= y + 1
336pub fn fix_int_ge_plus_one(db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
337    let node = ExprBinary::from_syntax_node(db, node);
338    let lhs = node.lhs(db).as_syntax_node().get_text(db);
339
340    let AstExpr::Binary(rhs_exp) = node.rhs(db) else {
341        panic!("should be addition")
342    };
343    let rhs = rhs_exp.lhs(db).as_syntax_node().get_text(db);
344
345    let fix = format!("{} > {} ", lhs.trim(), rhs.trim());
346    Some((node.as_syntax_node(), fix))
347}
348
349/// Rewrites a manual implementation of int ge min one x - 1 >= y
350pub fn fix_int_ge_min_one(db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
351    let node = ExprBinary::from_syntax_node(db, node);
352    let AstExpr::Binary(lhs_exp) = node.lhs(db) else {
353        panic!("should be substraction")
354    };
355    let rhs = node.rhs(db).as_syntax_node().get_text(db);
356
357    let lhs = lhs_exp.lhs(db).as_syntax_node().get_text(db);
358
359    let fix = format!("{} > {} ", lhs.trim(), rhs.trim());
360    Some((node.as_syntax_node(), fix))
361}
362
363/// Rewrites a manual implementation of int le plus one x + 1 <= y
364pub fn fix_int_le_plus_one(db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
365    let node = ExprBinary::from_syntax_node(db, node);
366    let AstExpr::Binary(lhs_exp) = node.lhs(db) else {
367        panic!("should be addition")
368    };
369    let rhs = node.rhs(db).as_syntax_node().get_text(db);
370
371    let lhs = lhs_exp.lhs(db).as_syntax_node().get_text(db);
372
373    let fix = format!("{} < {} ", lhs.trim(), rhs.trim());
374    Some((node.as_syntax_node(), fix))
375}
376
377/// Rewrites a manual implementation of int le min one x <= y -1
378pub fn fix_int_le_min_one(db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
379    let node = ExprBinary::from_syntax_node(db, node);
380    let lhs = node.lhs(db).as_syntax_node().get_text(db);
381
382    let AstExpr::Binary(rhs_exp) = node.rhs(db) else {
383        panic!("should be substraction")
384    };
385    let rhs = rhs_exp.lhs(db).as_syntax_node().get_text(db);
386
387    let fix = format!("{} < {} ", lhs.trim(), rhs.trim());
388    Some((node.as_syntax_node(), fix))
389}