1use std::collections::HashSet;
2
3use cairo_lang_defs::ids::ModuleItemId;
4use cairo_lang_defs::plugin::PluginDiagnostic;
5use cairo_lang_diagnostics::Severity;
6use cairo_lang_semantic::{
7 Arenas, Expr, ExprFunctionCall, ExprFunctionCallArg, ExprLogicalOperator, LogicalOperator,
8};
9use cairo_lang_syntax::node::ast::{BinaryOperator, Expr as AstExpr};
10
11use cairo_lang_syntax::node::{SyntaxNode, TypedStablePtr, TypedSyntaxNode};
12
13use super::function_trait_name_from_fn_id;
14use crate::context::{CairoLintKind, Lint};
15
16use crate::fixer::InternalFix;
17use crate::lints::{EQ, GE, GT, LE, LT};
18use crate::queries::{get_all_function_bodies, get_all_logical_operator_expressions};
19use salsa::Database;
20
21pub struct ImpossibleComparison;
22
23impl Lint for ImpossibleComparison {
40 fn allowed_name(&self) -> &'static str {
41 "impossible_comparison"
42 }
43
44 fn diagnostic_message(&self) -> &'static str {
45 "Impossible condition, always false"
46 }
47
48 fn kind(&self) -> CairoLintKind {
49 CairoLintKind::ImpossibleComparison
50 }
51}
52
53pub struct SimplifiableComparison;
54
55impl Lint for SimplifiableComparison {
88 fn allowed_name(&self) -> &'static str {
89 "simplifiable_comparison"
90 }
91
92 fn diagnostic_message(&self) -> &'static str {
93 "This double comparison can be simplified."
94 }
95
96 fn kind(&self) -> CairoLintKind {
97 CairoLintKind::DoubleComparison
98 }
99
100 fn has_fixer(&self) -> bool {
101 true
102 }
103
104 fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
105 fix_simplifiable_comparison(db, node)
106 }
107
108 fn fix_message(&self) -> Option<&'static str> {
109 Some("Simplify to single comparison")
110 }
111}
112
113pub struct RedundantComparison;
114
115impl Lint for RedundantComparison {
143 fn allowed_name(&self) -> &'static str {
144 "redundant_comparison"
145 }
146
147 fn diagnostic_message(&self) -> &'static str {
148 "Redundant double comparison found. Consider simplifying to a single comparison."
149 }
150
151 fn kind(&self) -> CairoLintKind {
152 CairoLintKind::DoubleComparison
153 }
154
155 fn has_fixer(&self) -> bool {
156 true
157 }
158
159 fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
160 fix_redundant_comparison(db, node)
161 }
162
163 fn fix_message(&self) -> Option<&'static str> {
164 Some("Remove redundant comparison")
165 }
166}
167
168pub struct ContradictoryComparison;
169
170impl Lint for ContradictoryComparison {
198 fn allowed_name(&self) -> &'static str {
199 "contradictory_comparison"
200 }
201
202 fn diagnostic_message(&self) -> &'static str {
203 "This double comparison is contradictory and always false."
204 }
205
206 fn kind(&self) -> CairoLintKind {
207 CairoLintKind::DoubleComparison
208 }
209
210 fn has_fixer(&self) -> bool {
211 true
212 }
213
214 fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
215 fix_contradictory_comparison(db, node)
216 }
217
218 fn fix_message(&self) -> Option<&'static str> {
219 Some("Replace with 'false' since comparison is impossible")
220 }
221}
222
223#[tracing::instrument(skip_all, level = "trace")]
224pub fn check_double_comparison<'db>(
225 db: &'db dyn Database,
226 item: &ModuleItemId<'db>,
227 diagnostics: &mut Vec<PluginDiagnostic<'db>>,
228) {
229 let function_bodies = get_all_function_bodies(db, item);
230 for function_body in function_bodies {
231 let logical_operator_exprs = get_all_logical_operator_expressions(function_body);
232 let arenas = &function_body.arenas;
233 for logical_operator_expr in logical_operator_exprs.iter() {
234 check_single_double_comparison(db, logical_operator_expr, arenas, diagnostics);
235 }
236 }
237}
238
239fn check_single_double_comparison<'db>(
240 db: &'db dyn Database,
241 logical_operator_exprs: &ExprLogicalOperator<'db>,
242 arenas: &Arenas<'db>,
243 diagnostics: &mut Vec<PluginDiagnostic<'db>>,
244) {
245 let Expr::FunctionCall(lhs_comparison) = &arenas.exprs[logical_operator_exprs.lhs] else {
246 return;
247 };
248 if lhs_comparison.args.len() != 2 {
250 return;
251 }
252
253 let Expr::FunctionCall(rhs_comparison) = &arenas.exprs[logical_operator_exprs.rhs] else {
254 return;
255 };
256 if rhs_comparison.args.len() != 2 {
258 return;
259 }
260 let (lhs_fn_trait_name, rhs_fn_trait_name) = (
262 function_trait_name_from_fn_id(db, &lhs_comparison.function),
263 function_trait_name_from_fn_id(db, &rhs_comparison.function),
264 );
265
266 if check_impossible_comparison(
268 lhs_comparison,
269 rhs_comparison,
270 &lhs_fn_trait_name,
271 &rhs_fn_trait_name,
272 logical_operator_exprs,
273 db,
274 arenas,
275 ) {
276 diagnostics.push(PluginDiagnostic {
277 message: ImpossibleComparison.diagnostic_message().to_string(),
278 stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
279 severity: Severity::Warning,
280 inner_span: None,
281 error_code: None,
282 })
283 }
284
285 let (llhs, rlhs) = match (&lhs_comparison.args[0], &lhs_comparison.args[1]) {
287 (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
288 (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id])
289 }
290 _ => {
291 return;
292 }
293 };
294 let (lrhs, rrhs) = match (&rhs_comparison.args[0], &rhs_comparison.args[1]) {
295 (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
296 (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id])
297 }
298 _ => return,
299 };
300 let llhs_var = llhs.stable_ptr().lookup(db).as_syntax_node().get_text(db);
302 let rlhs_var = rlhs.stable_ptr().lookup(db).as_syntax_node().get_text(db);
303 let lrhs_var = lrhs.stable_ptr().lookup(db).as_syntax_node().get_text(db);
304 let rrhs_var = rrhs.stable_ptr().lookup(db).as_syntax_node().get_text(db);
305 let lhs: HashSet<_> = HashSet::from_iter([llhs_var, rlhs_var]);
307 let rhs: HashSet<_> = HashSet::from_iter([lrhs_var, rrhs_var]);
308 if lhs != rhs {
309 return;
310 }
311
312 let should_return = match (llhs, rlhs) {
314 (Expr::Snapshot(llhs), Expr::Snapshot(rlhs)) => {
315 matches!(arenas.exprs[llhs.inner], Expr::FunctionCall(_))
316 || matches!(arenas.exprs[rlhs.inner], Expr::FunctionCall(_))
317 }
318 (Expr::Var(_), Expr::Var(_)) => false,
319 (Expr::Snapshot(llhs), Expr::Var(_)) => {
320 matches!(arenas.exprs[llhs.inner], Expr::FunctionCall(_))
321 }
322 (Expr::Var(_), Expr::Snapshot(rlhs)) => {
323 matches!(arenas.exprs[rlhs.inner], Expr::FunctionCall(_))
324 }
325 _ => return,
326 };
327 if should_return {
328 return;
329 }
330
331 if is_simplifiable_double_comparison(
332 &lhs_fn_trait_name,
333 &rhs_fn_trait_name,
334 &logical_operator_exprs.op,
335 ) {
336 diagnostics.push(PluginDiagnostic {
337 message: SimplifiableComparison.diagnostic_message().to_string(),
338 stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
339 severity: Severity::Warning,
340 inner_span: None,
341 error_code: None,
342 });
343 } else if is_redundant_double_comparison(
344 &lhs_fn_trait_name,
345 &rhs_fn_trait_name,
346 &logical_operator_exprs.op,
347 ) {
348 diagnostics.push(PluginDiagnostic {
349 message: RedundantComparison.diagnostic_message().to_string(),
350 stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
351 severity: Severity::Warning,
352 inner_span: None,
353 error_code: None,
354 });
355 } else if is_contradictory_double_comparison(
356 &lhs_fn_trait_name,
357 &rhs_fn_trait_name,
358 &logical_operator_exprs.op,
359 ) {
360 diagnostics.push(PluginDiagnostic {
361 message: ContradictoryComparison.diagnostic_message().to_string(),
362 stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
363 severity: Severity::Warning,
364 inner_span: None,
365 error_code: None,
366 });
367 }
368}
369
370fn check_impossible_comparison<'db>(
371 lhs_comparison: &ExprFunctionCall<'db>,
372 rhs_comparison: &ExprFunctionCall<'db>,
373 lhs_op: &str,
374 rhs_op: &str,
375 logical_operator_exprs: &ExprLogicalOperator<'db>,
376 db: &'db dyn Database,
377 arenas: &Arenas<'db>,
378) -> bool {
379 let (lhs_var, lhs_literal) = match (&lhs_comparison.args[0], &lhs_comparison.args[1]) {
380 (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
381 match (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id]) {
382 (Expr::Var(var), Expr::Literal(literal)) => (var, literal),
383 (Expr::Literal(literal), Expr::Var(var)) => (var, literal),
384 _ => {
385 return false;
386 }
387 }
388 }
389 _ => {
390 return false;
391 }
392 };
393 let (rhs_var, rhs_literal) = match (&rhs_comparison.args[0], &rhs_comparison.args[1]) {
394 (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
395 match (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id]) {
396 (Expr::Var(var), Expr::Literal(literal)) => (var, literal),
397 (Expr::Literal(literal), Expr::Var(var)) => (var, literal),
398 _ => {
399 return false;
400 }
401 }
402 }
403 _ => {
404 return false;
405 }
406 };
407
408 if lhs_var.stable_ptr.lookup(db).as_syntax_node().get_text(db)
409 != rhs_var.stable_ptr.lookup(db).as_syntax_node().get_text(db)
410 {
411 return false;
412 }
413
414 match (lhs_op, &logical_operator_exprs.op, rhs_op) {
415 (GT, LogicalOperator::AndAnd, LT) => lhs_literal.value >= rhs_literal.value,
416 (GT, LogicalOperator::AndAnd, LE) => lhs_literal.value >= rhs_literal.value,
417 (GE, LogicalOperator::AndAnd, LT) => lhs_literal.value >= rhs_literal.value,
418 (GE, LogicalOperator::AndAnd, LE) => lhs_literal.value > rhs_literal.value,
419 (LT, LogicalOperator::AndAnd, GT) => lhs_literal.value <= rhs_literal.value,
420 (LT, LogicalOperator::AndAnd, GE) => lhs_literal.value <= rhs_literal.value,
421 (LE, LogicalOperator::AndAnd, GT) => lhs_literal.value <= rhs_literal.value,
422 (LE, LogicalOperator::AndAnd, GE) => lhs_literal.value < rhs_literal.value,
423 _ => false,
424 }
425}
426
427fn is_simplifiable_double_comparison(
428 lhs_op: &str,
429 rhs_op: &str,
430 middle_op: &LogicalOperator,
431) -> bool {
432 matches!(
433 (lhs_op, middle_op, rhs_op),
434 (LE, LogicalOperator::AndAnd, GE)
435 | (GE, LogicalOperator::AndAnd, LE)
436 | (LT, LogicalOperator::OrOr, EQ)
437 | (EQ, LogicalOperator::OrOr, LT)
438 | (GT, LogicalOperator::OrOr, EQ)
439 | (EQ, LogicalOperator::OrOr, GT)
440 )
441}
442
443fn is_redundant_double_comparison(lhs_op: &str, rhs_op: &str, middle_op: &LogicalOperator) -> bool {
444 matches!(
445 (lhs_op, middle_op, rhs_op),
446 (LE, LogicalOperator::OrOr, GE)
447 | (GE, LogicalOperator::OrOr, LE)
448 | (LT, LogicalOperator::OrOr, GT)
449 | (GT, LogicalOperator::OrOr, LT)
450 )
451}
452
453fn is_contradictory_double_comparison(
454 lhs_op: &str,
455 rhs_op: &str,
456 middle_op: &LogicalOperator,
457) -> bool {
458 matches!(
459 (lhs_op, middle_op, rhs_op),
460 (EQ, LogicalOperator::AndAnd, LT)
461 | (LT, LogicalOperator::AndAnd, EQ)
462 | (EQ, LogicalOperator::AndAnd, GT)
463 | (GT, LogicalOperator::AndAnd, EQ)
464 | (LT, LogicalOperator::AndAnd, GT)
465 | (GT, LogicalOperator::AndAnd, LT)
466 | (GT, LogicalOperator::AndAnd, GE)
467 | (LE, LogicalOperator::AndAnd, GT)
468 )
469}
470
471#[tracing::instrument(skip_all, level = "trace")]
472pub fn fix_simplifiable_comparison<'db>(
473 db: &'db dyn Database,
474 node: SyntaxNode<'db>,
475) -> Option<InternalFix<'db>> {
476 let lhs_text = fix_double_comparison(db, node)?;
477
478 Some(InternalFix {
479 node,
480 suggestion: lhs_text,
481 description: SimplifiableComparison.fix_message().unwrap().to_string(),
482 import_addition_paths: None,
483 })
484}
485
486#[tracing::instrument(skip_all, level = "trace")]
487pub fn fix_redundant_comparison<'db>(
488 db: &'db dyn Database,
489 node: SyntaxNode<'db>,
490) -> Option<InternalFix<'db>> {
491 let lhs_text = fix_double_comparison(db, node)?;
492
493 Some(InternalFix {
494 node,
495 suggestion: lhs_text,
496 description: RedundantComparison.fix_message().unwrap().to_string(),
497 import_addition_paths: None,
498 })
499}
500
501#[tracing::instrument(skip_all, level = "trace")]
502pub fn fix_contradictory_comparison<'db>(
503 db: &'db dyn Database,
504 node: SyntaxNode<'db>,
505) -> Option<InternalFix<'db>> {
506 let lhs_text = fix_double_comparison(db, node)?;
507
508 Some(InternalFix {
509 node,
510 suggestion: lhs_text,
511 description: ContradictoryComparison.fix_message().unwrap().to_string(),
512 import_addition_paths: None,
513 })
514}
515
516pub fn fix_double_comparison<'db>(db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<String> {
518 let expr = AstExpr::from_syntax_node(db, node);
519
520 if let AstExpr::Binary(binary_op) = expr {
521 let lhs = binary_op.lhs(db);
522 let rhs = binary_op.rhs(db);
523 let middle_op = binary_op.op(db);
524
525 if let (Some(lhs_op), Some(rhs_op)) = (
526 extract_binary_operator_expr(&lhs, db),
527 extract_binary_operator_expr(&rhs, db),
528 ) && let Some(simplified_op) =
529 determine_simplified_operator(&lhs_op, &rhs_op, &middle_op)
530 && let Some(operator_to_replace) = operator_to_replace(lhs_op)
531 {
532 let lhs_text = lhs
533 .as_syntax_node()
534 .get_text(db)
535 .replace(operator_to_replace, simplified_op);
536 return Some(lhs_text.to_string());
537 }
538 }
539
540 None
541}
542
543fn operator_to_replace(lhs_op: BinaryOperator) -> Option<&'static str> {
544 match lhs_op {
545 BinaryOperator::EqEq(_) => Some("=="),
546 BinaryOperator::GT(_) => Some(">"),
547 BinaryOperator::LT(_) => Some("<"),
548 BinaryOperator::GE(_) => Some(">="),
549 BinaryOperator::LE(_) => Some("<="),
550 _ => None,
551 }
552}
553
554fn determine_simplified_operator(
555 lhs_op: &BinaryOperator,
556 rhs_op: &BinaryOperator,
557 middle_op: &BinaryOperator,
558) -> Option<&'static str> {
559 match (lhs_op, middle_op, rhs_op) {
560 (BinaryOperator::LE(_), BinaryOperator::AndAnd(_), BinaryOperator::GE(_))
561 | (BinaryOperator::GE(_), BinaryOperator::AndAnd(_), BinaryOperator::LE(_)) => Some("=="),
562
563 (BinaryOperator::LT(_), BinaryOperator::OrOr(_), BinaryOperator::EqEq(_))
564 | (BinaryOperator::EqEq(_), BinaryOperator::OrOr(_), BinaryOperator::LT(_)) => Some("<="),
565
566 (BinaryOperator::GT(_), BinaryOperator::OrOr(_), BinaryOperator::EqEq(_))
567 | (BinaryOperator::EqEq(_), BinaryOperator::OrOr(_), BinaryOperator::GT(_)) => Some(">="),
568
569 (BinaryOperator::LT(_), BinaryOperator::OrOr(_), BinaryOperator::GT(_))
570 | (BinaryOperator::GT(_), BinaryOperator::OrOr(_), BinaryOperator::LT(_)) => Some("!="),
571
572 _ => None,
573 }
574}
575
576fn extract_binary_operator_expr<'db>(
577 expr: &AstExpr<'db>,
578 db: &'db dyn Database,
579) -> Option<BinaryOperator<'db>> {
580 if let AstExpr::Binary(binary_op) = expr {
581 Some(binary_op.op(db))
582 } else {
583 None
584 }
585}