1use cairo_lang_defs::ids::{LookupItemId, ModuleId, ModuleItemId, TraitFunctionId};
2use cairo_lang_defs::plugin::PluginDiagnostic;
3use cairo_lang_diagnostics::Severity;
4use cairo_lang_semantic::items::functions::GenericFunctionId;
5use cairo_lang_semantic::items::imp::ImplHead;
6use cairo_lang_semantic::{Arenas, Expr, ExprFunctionCall, ExprFunctionCallArg};
7use cairo_lang_syntax::node::ast::{Expr as AstExpr, ExprBinary};
8
9use cairo_lang_syntax::node::{SyntaxNode, TypedStablePtr, TypedSyntaxNode};
10use if_chain::if_chain;
11
12use crate::context::{CairoLintKind, Lint};
13
14use crate::LinterGroup;
15use crate::fixer::InternalFix;
16use crate::helper::is_item_ancestor_of_module;
17use crate::queries::{get_all_function_bodies, get_all_function_calls};
18use salsa::Database;
19
20pub struct IntegerGreaterEqualPlusOne;
21
22impl Lint for IntegerGreaterEqualPlusOne {
46 fn allowed_name(&self) -> &'static str {
47 "int_ge_plus_one"
48 }
49
50 fn diagnostic_message(&self) -> &'static str {
51 "Unnecessary add operation in integer >= comparison. Use simplified comparison."
52 }
53
54 fn kind(&self) -> CairoLintKind {
55 CairoLintKind::IntGePlusOne
56 }
57
58 fn has_fixer(&self) -> bool {
59 true
60 }
61
62 fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
63 fix_int_ge_plus_one(db, node)
64 }
65
66 fn fix_message(&self) -> Option<&'static str> {
67 Some("Replace with simplified '>' comparison")
68 }
69}
70
71pub struct IntegerGreaterEqualMinusOne;
72
73impl Lint for IntegerGreaterEqualMinusOne {
97 fn allowed_name(&self) -> &'static str {
98 "int_ge_min_one"
99 }
100
101 fn diagnostic_message(&self) -> &'static str {
102 "Unnecessary sub operation in integer >= comparison. Use simplified comparison."
103 }
104
105 fn kind(&self) -> CairoLintKind {
106 CairoLintKind::IntGeMinOne
107 }
108
109 fn has_fixer(&self) -> bool {
110 true
111 }
112
113 fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
114 fix_int_ge_min_one(db, node)
115 }
116
117 fn fix_message(&self) -> Option<&'static str> {
118 Some("Replace with simplified '>' comparison")
119 }
120}
121
122pub struct IntegerLessEqualPlusOne;
123
124impl Lint for IntegerLessEqualPlusOne {
148 fn allowed_name(&self) -> &'static str {
149 "int_le_plus_one"
150 }
151
152 fn diagnostic_message(&self) -> &'static str {
153 "Unnecessary add operation in integer <= comparison. Use simplified comparison."
154 }
155
156 fn kind(&self) -> CairoLintKind {
157 CairoLintKind::IntLePlusOne
158 }
159
160 fn has_fixer(&self) -> bool {
161 true
162 }
163
164 fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
165 fix_int_le_plus_one(db, node)
166 }
167
168 fn fix_message(&self) -> Option<&'static str> {
169 Some("Replace with simplified '<' comparison")
170 }
171}
172
173pub struct IntegerLessEqualMinusOne;
174
175impl Lint for IntegerLessEqualMinusOne {
199 fn allowed_name(&self) -> &'static str {
200 "int_le_min_one"
201 }
202
203 fn diagnostic_message(&self) -> &'static str {
204 "Unnecessary sub operation in integer <= comparison. Use simplified comparison."
205 }
206
207 fn kind(&self) -> CairoLintKind {
208 CairoLintKind::IntLeMinOne
209 }
210
211 fn has_fixer(&self) -> bool {
212 true
213 }
214
215 fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
216 fix_int_le_min_one(db, node)
217 }
218
219 fn fix_message(&self) -> Option<&'static str> {
220 Some("Replace with simplified '<' comparison")
221 }
222}
223
224#[tracing::instrument(skip_all, level = "trace")]
225pub fn check_int_op_one<'db>(
226 db: &'db dyn Database,
227 item: &ModuleItemId<'db>,
228 diagnostics: &mut Vec<PluginDiagnostic<'db>>,
229) {
230 let function_bodies = get_all_function_bodies(db, item);
231 for function_body in function_bodies.iter() {
232 let function_call_exprs = get_all_function_calls(function_body);
233 let arenas = &function_body.arenas;
234 for function_call_expr in function_call_exprs {
235 check_single_int_op_one(db, &function_call_expr, arenas, diagnostics);
236 }
237 }
238}
239
240fn check_single_int_op_one<'db>(
241 db: &'db dyn Database,
242 function_call_expr: &ExprFunctionCall<'db>,
243 arenas: &Arenas<'db>,
244 diagnostics: &mut Vec<PluginDiagnostic<'db>>,
245) {
246 let GenericFunctionId::Impl(impl_generic_func_id) = function_call_expr
248 .function
249 .get_concrete(db)
250 .generic_function
251 else {
252 return;
253 };
254
255 let corelib_context = db.corelib_context();
256
257 if impl_generic_func_id.function != corelib_context.get_partial_ord_ge_trait_function_id()
259 && impl_generic_func_id.function != corelib_context.get_partial_ord_le_trait_function_id()
260 {
261 return;
262 }
263
264 let is_part_of_corelib_integer =
266 if let Some(ImplHead::Concrete(impl_def_id)) = impl_generic_func_id.impl_id.head(db) {
267 is_item_ancestor_of_module(
268 db,
269 &LookupItemId::ModuleItem(ModuleItemId::Impl(impl_def_id)),
270 ModuleId::Submodule(corelib_context.get_integer_module_id()),
271 )
272 } else {
273 false
274 };
275
276 if !is_part_of_corelib_integer {
277 return;
278 }
279
280 let lhs = &function_call_expr.args[0];
281 let rhs = &function_call_expr.args[1];
282
283 let add_trait_function_id = corelib_context.get_add_trait_function_id();
284 let sub_trait_function_id = corelib_context.get_sub_trait_function_id();
285 let partial_ord_ge_trait_function_id = corelib_context.get_partial_ord_ge_trait_function_id();
286 let partial_ord_le_trait_function_id = corelib_context.get_partial_ord_le_trait_function_id();
287
288 if check_is_variable(lhs, arenas)
290 && check_is_add_or_sub_one(
291 db,
292 rhs,
293 arenas,
294 is_part_of_corelib_integer,
295 add_trait_function_id,
296 )
297 && impl_generic_func_id.function == partial_ord_ge_trait_function_id
298 {
299 diagnostics.push(PluginDiagnostic {
300 stable_ptr: function_call_expr.stable_ptr.untyped(),
301 message: IntegerGreaterEqualPlusOne.diagnostic_message().to_string(),
302 severity: Severity::Warning,
303 inner_span: None,
304 error_code: None,
305 })
306 }
307
308 if check_is_add_or_sub_one(
310 db,
311 lhs,
312 arenas,
313 is_part_of_corelib_integer,
314 sub_trait_function_id,
315 ) && check_is_variable(rhs, arenas)
316 && impl_generic_func_id.function == partial_ord_ge_trait_function_id
317 {
318 diagnostics.push(PluginDiagnostic {
319 stable_ptr: function_call_expr.stable_ptr.untyped(),
320 message: IntegerGreaterEqualMinusOne.diagnostic_message().to_string(),
321 severity: Severity::Warning,
322 inner_span: None,
323 error_code: None,
324 })
325 }
326
327 if check_is_add_or_sub_one(
329 db,
330 lhs,
331 arenas,
332 is_part_of_corelib_integer,
333 add_trait_function_id,
334 ) && check_is_variable(rhs, arenas)
335 && impl_generic_func_id.function == partial_ord_le_trait_function_id
336 {
337 diagnostics.push(PluginDiagnostic {
338 stable_ptr: function_call_expr.stable_ptr.untyped(),
339 message: IntegerLessEqualPlusOne.diagnostic_message().to_string(),
340 severity: Severity::Warning,
341 inner_span: None,
342 error_code: None,
343 })
344 }
345
346 if check_is_variable(lhs, arenas)
348 && check_is_add_or_sub_one(
349 db,
350 rhs,
351 arenas,
352 is_part_of_corelib_integer,
353 sub_trait_function_id,
354 )
355 && impl_generic_func_id.function == partial_ord_le_trait_function_id
356 {
357 diagnostics.push(PluginDiagnostic {
358 stable_ptr: function_call_expr.stable_ptr.untyped(),
359 message: IntegerLessEqualMinusOne.diagnostic_message().to_string(),
360 severity: Severity::Warning,
361 inner_span: None,
362 error_code: None,
363 })
364 }
365}
366
367fn check_is_variable<'db>(arg: &ExprFunctionCallArg<'db>, arenas: &Arenas<'db>) -> bool {
368 if let ExprFunctionCallArg::Value(val_expr) = arg {
369 matches!(arenas.exprs[*val_expr], Expr::Var(_))
370 } else {
371 false
372 }
373}
374
375fn check_is_add_or_sub_one<'db>(
376 db: &'db dyn Database,
377 arg: &ExprFunctionCallArg<'db>,
378 arenas: &Arenas<'db>,
379 is_part_of_corelib_integer: bool,
380 operation_function_trait_id: TraitFunctionId<'db>,
381) -> bool {
382 let ExprFunctionCallArg::Value(v) = arg else {
383 return false;
384 };
385 let Expr::FunctionCall(ref func_call) = arenas.exprs[*v] else {
386 return false;
387 };
388
389 let GenericFunctionId::Impl(impl_generic_func_id) =
390 func_call.function.get_concrete(db).generic_function
391 else {
392 return false;
393 };
394
395 if !is_part_of_corelib_integer && impl_generic_func_id.function != operation_function_trait_id
397 || func_call.args.len() != 2
398 {
399 return false;
400 }
401
402 let lhs = &func_call.args[0];
403 let rhs = &func_call.args[1];
404
405 if let ExprFunctionCallArg::Value(v) = lhs {
407 let Expr::Var(_) = arenas.exprs[*v] else {
408 return false;
409 };
410 };
411
412 if_chain! {
414 if let ExprFunctionCallArg::Value(v) = rhs;
415 if let Expr::Literal(ref litteral_expr) = arenas.exprs[*v];
416 if litteral_expr.value == 1.into();
417 then {
418 return true;
419 }
420 }
421
422 false
423}
424
425#[tracing::instrument(skip_all, level = "trace")]
427pub fn fix_int_ge_plus_one<'db>(
428 db: &'db dyn Database,
429 node: SyntaxNode<'db>,
430) -> Option<InternalFix<'db>> {
431 let node = ExprBinary::from_syntax_node(db, node);
432 let lhs = node.lhs(db).as_syntax_node().get_text(db);
433
434 let AstExpr::Binary(rhs_exp) = node.rhs(db) else {
435 panic!("should be addition")
436 };
437 let rhs = rhs_exp.lhs(db).as_syntax_node().get_text(db);
438
439 let fix = format!("{} > {} ", lhs.trim(), rhs.trim());
440 Some(InternalFix {
441 node: node.as_syntax_node(),
442 suggestion: fix,
443 description: IntegerGreaterEqualPlusOne
444 .fix_message()
445 .unwrap()
446 .to_string(),
447 import_addition_paths: None,
448 })
449}
450
451#[tracing::instrument(skip_all, level = "trace")]
453pub fn fix_int_ge_min_one<'db>(
454 db: &'db dyn Database,
455 node: SyntaxNode<'db>,
456) -> Option<InternalFix<'db>> {
457 let node = ExprBinary::from_syntax_node(db, node);
458 let AstExpr::Binary(lhs_exp) = node.lhs(db) else {
459 panic!("should be substraction")
460 };
461 let rhs = node.rhs(db).as_syntax_node().get_text(db);
462
463 let lhs = lhs_exp.lhs(db).as_syntax_node().get_text(db);
464
465 let fix = format!("{} > {} ", lhs.trim(), rhs.trim());
466 Some(InternalFix {
467 node: node.as_syntax_node(),
468 suggestion: fix,
469 description: IntegerGreaterEqualMinusOne
470 .fix_message()
471 .unwrap()
472 .to_string(),
473 import_addition_paths: None,
474 })
475}
476
477#[tracing::instrument(skip_all, level = "trace")]
479pub fn fix_int_le_plus_one<'db>(
480 db: &'db dyn Database,
481 node: SyntaxNode<'db>,
482) -> Option<InternalFix<'db>> {
483 let node = ExprBinary::from_syntax_node(db, node);
484 let AstExpr::Binary(lhs_exp) = node.lhs(db) else {
485 panic!("should be addition")
486 };
487 let rhs = node.rhs(db).as_syntax_node().get_text(db);
488
489 let lhs = lhs_exp.lhs(db).as_syntax_node().get_text(db);
490
491 let fix = format!("{} < {} ", lhs.trim(), rhs.trim());
492 Some(InternalFix {
493 node: node.as_syntax_node(),
494 suggestion: fix,
495 description: IntegerLessEqualPlusOne.fix_message().unwrap().to_string(),
496 import_addition_paths: None,
497 })
498}
499
500#[tracing::instrument(skip_all, level = "trace")]
502pub fn fix_int_le_min_one<'db>(
503 db: &'db dyn Database,
504 node: SyntaxNode<'db>,
505) -> Option<InternalFix<'db>> {
506 let node = ExprBinary::from_syntax_node(db, node);
507 let lhs = node.lhs(db).as_syntax_node().get_text(db);
508
509 let AstExpr::Binary(rhs_exp) = node.rhs(db) else {
510 panic!("should be substraction")
511 };
512 let rhs = rhs_exp.lhs(db).as_syntax_node().get_text(db);
513
514 let fix = format!("{} < {} ", lhs.trim(), rhs.trim());
515 Some(InternalFix {
516 node: node.as_syntax_node(),
517 suggestion: fix,
518 description: IntegerLessEqualMinusOne.fix_message().unwrap().to_string(),
519 import_addition_paths: None,
520 })
521}