1use cairo_lang_defs::ids::ModuleItemId;
2use cairo_lang_defs::plugin::PluginDiagnostic;
3use cairo_lang_diagnostics::Severity;
4use cairo_lang_semantic::db::SemanticGroup;
5use cairo_lang_semantic::{Arenas, Expr, ExprFunctionCall, ExprFunctionCallArg};
6use cairo_lang_syntax::node::ast::{Expr as AstExpr, ExprBinary};
7use cairo_lang_syntax::node::db::SyntaxGroup;
8use cairo_lang_syntax::node::{SyntaxNode, TypedStablePtr, TypedSyntaxNode};
9use if_chain::if_chain;
10
11use crate::context::{CairoLintKind, Lint};
12use crate::queries::{get_all_function_bodies, get_all_function_calls};
13
14pub struct IntegerGreaterEqualPlusOne;
15
16impl Lint for IntegerGreaterEqualPlusOne {
40 fn allowed_name(&self) -> &'static str {
41 "int_ge_plus_one"
42 }
43
44 fn diagnostic_message(&self) -> &'static str {
45 "Unnecessary add operation in integer >= comparison. Use simplified comparison."
46 }
47
48 fn kind(&self) -> CairoLintKind {
49 CairoLintKind::IntGePlusOne
50 }
51
52 fn has_fixer(&self) -> bool {
53 true
54 }
55
56 fn fix(&self, db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
57 fix_int_ge_plus_one(db, node)
58 }
59}
60
61pub struct IntegerGreaterEqualMinusOne;
62
63impl Lint for IntegerGreaterEqualMinusOne {
87 fn allowed_name(&self) -> &'static str {
88 "int_ge_min_one"
89 }
90
91 fn diagnostic_message(&self) -> &'static str {
92 "Unnecessary sub operation in integer >= comparison. Use simplified comparison."
93 }
94
95 fn kind(&self) -> CairoLintKind {
96 CairoLintKind::IntGeMinOne
97 }
98
99 fn has_fixer(&self) -> bool {
100 true
101 }
102
103 fn fix(&self, db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
104 fix_int_ge_min_one(db, node)
105 }
106}
107
108pub struct IntegerLessEqualPlusOne;
109
110impl Lint for IntegerLessEqualPlusOne {
134 fn allowed_name(&self) -> &'static str {
135 "int_le_plus_one"
136 }
137
138 fn diagnostic_message(&self) -> &'static str {
139 "Unnecessary add operation in integer <= comparison. Use simplified comparison."
140 }
141
142 fn kind(&self) -> CairoLintKind {
143 CairoLintKind::IntLePlusOne
144 }
145
146 fn has_fixer(&self) -> bool {
147 true
148 }
149
150 fn fix(&self, db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
151 fix_int_le_plus_one(db, node)
152 }
153}
154
155pub struct IntegerLessEqualMinusOne;
156
157impl Lint for IntegerLessEqualMinusOne {
181 fn allowed_name(&self) -> &'static str {
182 "int_le_min_one"
183 }
184
185 fn diagnostic_message(&self) -> &'static str {
186 "Unnecessary sub operation in integer <= comparison. Use simplified comparison."
187 }
188
189 fn kind(&self) -> CairoLintKind {
190 CairoLintKind::IntLeMinOne
191 }
192
193 fn has_fixer(&self) -> bool {
194 true
195 }
196
197 fn fix(&self, db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
198 fix_int_le_min_one(db, node)
199 }
200}
201
202pub fn check_int_op_one(
203 db: &dyn SemanticGroup,
204 item: &ModuleItemId,
205 diagnostics: &mut Vec<PluginDiagnostic>,
206) {
207 let function_bodies = get_all_function_bodies(db, item);
208 for function_body in function_bodies.iter() {
209 let function_call_exprs = get_all_function_calls(function_body);
210 let arenas = &function_body.arenas;
211 for function_call_expr in function_call_exprs.iter() {
212 check_single_int_op_one(db, function_call_expr, arenas, diagnostics);
213 }
214 }
215}
216
217fn check_single_int_op_one(
218 db: &dyn SemanticGroup,
219 function_call_expr: &ExprFunctionCall,
220 arenas: &Arenas,
221 diagnostics: &mut Vec<PluginDiagnostic>,
222) {
223 let full_name = function_call_expr.function.full_path(db);
225 if !full_name.contains("core::integer::")
226 || (!full_name.contains("PartialOrd::ge") && !full_name.contains("PartialOrd::le"))
227 {
228 return;
229 }
230
231 let lhs = &function_call_expr.args[0];
232 let rhs = &function_call_expr.args[1];
233
234 if check_is_variable(lhs, arenas)
236 && check_is_add_or_sub_one(db, rhs, arenas, "::add")
237 && function_call_expr.function.full_path(db).contains("::ge")
238 {
239 diagnostics.push(PluginDiagnostic {
240 stable_ptr: function_call_expr.stable_ptr.untyped(),
241 message: IntegerGreaterEqualPlusOne.diagnostic_message().to_string(),
242 severity: Severity::Warning,
243 })
244 }
245
246 if check_is_add_or_sub_one(db, lhs, arenas, "::sub")
248 && check_is_variable(rhs, arenas)
249 && function_call_expr.function.full_path(db).contains("::ge")
250 {
251 diagnostics.push(PluginDiagnostic {
252 stable_ptr: function_call_expr.stable_ptr.untyped(),
253 message: IntegerGreaterEqualMinusOne.diagnostic_message().to_string(),
254 severity: Severity::Warning,
255 })
256 }
257
258 if check_is_add_or_sub_one(db, lhs, arenas, "::add")
260 && check_is_variable(rhs, arenas)
261 && function_call_expr.function.full_path(db).contains("::le")
262 {
263 diagnostics.push(PluginDiagnostic {
264 stable_ptr: function_call_expr.stable_ptr.untyped(),
265 message: IntegerLessEqualPlusOne.diagnostic_message().to_string(),
266 severity: Severity::Warning,
267 })
268 }
269
270 if check_is_variable(lhs, arenas)
272 && check_is_add_or_sub_one(db, rhs, arenas, "::sub")
273 && function_call_expr.function.full_path(db).contains("::le")
274 {
275 diagnostics.push(PluginDiagnostic {
276 stable_ptr: function_call_expr.stable_ptr.untyped(),
277 message: IntegerLessEqualMinusOne.diagnostic_message().to_string(),
278 severity: Severity::Warning,
279 })
280 }
281}
282
283fn check_is_variable(arg: &ExprFunctionCallArg, arenas: &Arenas) -> bool {
284 if let ExprFunctionCallArg::Value(val_expr) = arg {
285 matches!(arenas.exprs[*val_expr], Expr::Var(_))
286 } else {
287 false
288 }
289}
290
291fn check_is_add_or_sub_one(
292 db: &dyn SemanticGroup,
293 arg: &ExprFunctionCallArg,
294 arenas: &Arenas,
295 operation: &str,
296) -> bool {
297 let ExprFunctionCallArg::Value(v) = arg else {
298 return false;
299 };
300 let Expr::FunctionCall(ref func_call) = arenas.exprs[*v] else {
301 return false;
302 };
303
304 let full_name = func_call.function.full_path(db);
306 if !full_name.contains("core::integer::") && !full_name.contains(operation)
307 || func_call.args.len() != 2
308 {
309 return false;
310 }
311
312 let lhs = &func_call.args[0];
313 let rhs = &func_call.args[1];
314
315 if let ExprFunctionCallArg::Value(v) = lhs {
317 let Expr::Var(_) = arenas.exprs[*v] else {
318 return false;
319 };
320 };
321
322 if_chain! {
324 if let ExprFunctionCallArg::Value(v) = rhs;
325 if let Expr::Literal(ref litteral_expr) = arenas.exprs[*v];
326 if litteral_expr.value == 1.into();
327 then {
328 return true;
329 }
330 }
331
332 false
333}
334
335pub fn fix_int_ge_plus_one(db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
337 let node = ExprBinary::from_syntax_node(db, node);
338 let lhs = node.lhs(db).as_syntax_node().get_text(db);
339
340 let AstExpr::Binary(rhs_exp) = node.rhs(db) else {
341 panic!("should be addition")
342 };
343 let rhs = rhs_exp.lhs(db).as_syntax_node().get_text(db);
344
345 let fix = format!("{} > {} ", lhs.trim(), rhs.trim());
346 Some((node.as_syntax_node(), fix))
347}
348
349pub fn fix_int_ge_min_one(db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
351 let node = ExprBinary::from_syntax_node(db, node);
352 let AstExpr::Binary(lhs_exp) = node.lhs(db) else {
353 panic!("should be substraction")
354 };
355 let rhs = node.rhs(db).as_syntax_node().get_text(db);
356
357 let lhs = lhs_exp.lhs(db).as_syntax_node().get_text(db);
358
359 let fix = format!("{} > {} ", lhs.trim(), rhs.trim());
360 Some((node.as_syntax_node(), fix))
361}
362
363pub fn fix_int_le_plus_one(db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
365 let node = ExprBinary::from_syntax_node(db, node);
366 let AstExpr::Binary(lhs_exp) = node.lhs(db) else {
367 panic!("should be addition")
368 };
369 let rhs = node.rhs(db).as_syntax_node().get_text(db);
370
371 let lhs = lhs_exp.lhs(db).as_syntax_node().get_text(db);
372
373 let fix = format!("{} < {} ", lhs.trim(), rhs.trim());
374 Some((node.as_syntax_node(), fix))
375}
376
377pub fn fix_int_le_min_one(db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
379 let node = ExprBinary::from_syntax_node(db, node);
380 let lhs = node.lhs(db).as_syntax_node().get_text(db);
381
382 let AstExpr::Binary(rhs_exp) = node.rhs(db) else {
383 panic!("should be substraction")
384 };
385 let rhs = rhs_exp.lhs(db).as_syntax_node().get_text(db);
386
387 let fix = format!("{} < {} ", lhs.trim(), rhs.trim());
388 Some((node.as_syntax_node(), fix))
389}