1use std::collections::HashSet;
2
3use cairo_lang_defs::ids::ModuleItemId;
4use cairo_lang_defs::plugin::PluginDiagnostic;
5use cairo_lang_diagnostics::Severity;
6use cairo_lang_semantic::db::SemanticGroup;
7use cairo_lang_semantic::{
8 Arenas, Expr, ExprFunctionCall, ExprFunctionCallArg, ExprLogicalOperator, LogicalOperator,
9};
10use cairo_lang_syntax::node::ast::{BinaryOperator, Expr as AstExpr};
11use cairo_lang_syntax::node::db::SyntaxGroup;
12use cairo_lang_syntax::node::{SyntaxNode, TypedStablePtr, TypedSyntaxNode};
13
14use super::function_trait_name_from_fn_id;
15use crate::context::{CairoLintKind, Lint};
16use crate::lints::{EQ, GE, GT, LE, LT};
17use crate::queries::{get_all_function_bodies, get_all_logical_operator_expressions};
18
19pub struct ImpossibleComparison;
20
21impl Lint for ImpossibleComparison {
38 fn allowed_name(&self) -> &'static str {
39 "impossible_comparison"
40 }
41
42 fn diagnostic_message(&self) -> &'static str {
43 "Impossible condition, always false"
44 }
45
46 fn kind(&self) -> CairoLintKind {
47 CairoLintKind::ImpossibleComparison
48 }
49}
50
51pub struct SimplifiableComparison;
52
53impl Lint for SimplifiableComparison {
86 fn allowed_name(&self) -> &'static str {
87 "simplifiable_comparison"
88 }
89
90 fn diagnostic_message(&self) -> &'static str {
91 "This double comparison can be simplified."
92 }
93
94 fn kind(&self) -> CairoLintKind {
95 CairoLintKind::DoubleComparison
96 }
97
98 fn has_fixer(&self) -> bool {
99 true
100 }
101
102 fn fix(&self, db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
103 fix_double_comparison(db, node)
104 }
105}
106
107pub struct RedundantComparison;
108
109impl Lint for RedundantComparison {
137 fn allowed_name(&self) -> &'static str {
138 "redundant_comparison"
139 }
140
141 fn diagnostic_message(&self) -> &'static str {
142 "Redundant double comparison found. Consider simplifying to a single comparison."
143 }
144
145 fn kind(&self) -> CairoLintKind {
146 CairoLintKind::DoubleComparison
147 }
148
149 fn has_fixer(&self) -> bool {
150 true
151 }
152
153 fn fix(&self, db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
154 fix_double_comparison(db, node)
155 }
156}
157
158pub struct ContradictoryComparison;
159
160impl Lint for ContradictoryComparison {
188 fn allowed_name(&self) -> &'static str {
189 "contradictory_comparison"
190 }
191
192 fn diagnostic_message(&self) -> &'static str {
193 "This double comparison is contradictory and always false."
194 }
195
196 fn kind(&self) -> CairoLintKind {
197 CairoLintKind::DoubleComparison
198 }
199
200 fn has_fixer(&self) -> bool {
201 true
202 }
203
204 fn fix(&self, db: &dyn SyntaxGroup, node: SyntaxNode) -> Option<(SyntaxNode, String)> {
205 fix_double_comparison(db, node)
206 }
207}
208
209pub fn check_double_comparison(
210 db: &dyn SemanticGroup,
211 item: &ModuleItemId,
212 diagnostics: &mut Vec<PluginDiagnostic>,
213) {
214 let function_bodies = get_all_function_bodies(db, item);
215 for function_body in function_bodies.iter() {
216 let logical_operator_exprs = get_all_logical_operator_expressions(function_body);
217 let arenas = &function_body.arenas;
218 for logical_operator_expr in logical_operator_exprs.iter() {
219 check_single_double_comparison(db, logical_operator_expr, arenas, diagnostics);
220 }
221 }
222}
223
224fn check_single_double_comparison(
225 db: &dyn SemanticGroup,
226 logical_operator_exprs: &ExprLogicalOperator,
227 arenas: &Arenas,
228 diagnostics: &mut Vec<PluginDiagnostic>,
229) {
230 let Expr::FunctionCall(lhs_comparison) = &arenas.exprs[logical_operator_exprs.lhs] else {
231 return;
232 };
233 if lhs_comparison.args.len() != 2 {
235 return;
236 }
237
238 let Expr::FunctionCall(rhs_comparison) = &arenas.exprs[logical_operator_exprs.rhs] else {
239 return;
240 };
241 if rhs_comparison.args.len() != 2 {
243 return;
244 }
245 let (lhs_fn_trait_name, rhs_fn_trait_name) = (
247 function_trait_name_from_fn_id(db, &lhs_comparison.function),
248 function_trait_name_from_fn_id(db, &rhs_comparison.function),
249 );
250
251 if check_impossible_comparison(
253 lhs_comparison,
254 rhs_comparison,
255 &lhs_fn_trait_name,
256 &rhs_fn_trait_name,
257 logical_operator_exprs,
258 db,
259 arenas,
260 ) {
261 diagnostics.push(PluginDiagnostic {
262 message: ImpossibleComparison.diagnostic_message().to_string(),
263 stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
264 severity: Severity::Error,
265 })
266 }
267
268 let (llhs, rlhs) = match (&lhs_comparison.args[0], &lhs_comparison.args[1]) {
270 (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
271 (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id])
272 }
273 _ => {
274 return;
275 }
276 };
277 let (lrhs, rrhs) = match (&rhs_comparison.args[0], &rhs_comparison.args[1]) {
278 (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
279 (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id])
280 }
281 _ => return,
282 };
283 let llhs_var = llhs
285 .stable_ptr()
286 .lookup(db.upcast())
287 .as_syntax_node()
288 .get_text(db.upcast());
289 let rlhs_var = rlhs
290 .stable_ptr()
291 .lookup(db.upcast())
292 .as_syntax_node()
293 .get_text(db.upcast());
294 let lrhs_var = lrhs
295 .stable_ptr()
296 .lookup(db.upcast())
297 .as_syntax_node()
298 .get_text(db.upcast());
299 let rrhs_var = rrhs
300 .stable_ptr()
301 .lookup(db.upcast())
302 .as_syntax_node()
303 .get_text(db.upcast());
304 let lhs: HashSet<String> = HashSet::from_iter([llhs_var, rlhs_var]);
306 let rhs: HashSet<String> = HashSet::from_iter([lrhs_var, rrhs_var]);
307 if lhs != rhs {
308 return;
309 }
310
311 let should_return = match (llhs, rlhs) {
313 (Expr::Snapshot(llhs), Expr::Snapshot(rlhs)) => {
314 matches!(arenas.exprs[llhs.inner], Expr::FunctionCall(_))
315 || matches!(arenas.exprs[rlhs.inner], Expr::FunctionCall(_))
316 }
317 (Expr::Var(_), Expr::Var(_)) => false,
318 (Expr::Snapshot(llhs), Expr::Var(_)) => {
319 matches!(arenas.exprs[llhs.inner], Expr::FunctionCall(_))
320 }
321 (Expr::Var(_), Expr::Snapshot(rlhs)) => {
322 matches!(arenas.exprs[rlhs.inner], Expr::FunctionCall(_))
323 }
324 _ => return,
325 };
326 if should_return {
327 return;
328 }
329
330 if is_simplifiable_double_comparison(
331 &lhs_fn_trait_name,
332 &rhs_fn_trait_name,
333 &logical_operator_exprs.op,
334 ) {
335 diagnostics.push(PluginDiagnostic {
336 message: SimplifiableComparison.diagnostic_message().to_string(),
337 stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
338 severity: Severity::Warning,
339 });
340 } else if is_redundant_double_comparison(
341 &lhs_fn_trait_name,
342 &rhs_fn_trait_name,
343 &logical_operator_exprs.op,
344 ) {
345 diagnostics.push(PluginDiagnostic {
346 message: RedundantComparison.diagnostic_message().to_string(),
347 stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
348 severity: Severity::Warning,
349 });
350 } else if is_contradictory_double_comparison(
351 &lhs_fn_trait_name,
352 &rhs_fn_trait_name,
353 &logical_operator_exprs.op,
354 ) {
355 diagnostics.push(PluginDiagnostic {
356 message: ContradictoryComparison.diagnostic_message().to_string(),
357 stable_ptr: logical_operator_exprs.stable_ptr.untyped(),
358 severity: Severity::Error,
359 });
360 }
361}
362
363fn check_impossible_comparison(
364 lhs_comparison: &ExprFunctionCall,
365 rhs_comparison: &ExprFunctionCall,
366 lhs_op: &str,
367 rhs_op: &str,
368 logical_operator_exprs: &ExprLogicalOperator,
369 db: &dyn SemanticGroup,
370 arenas: &Arenas,
371) -> bool {
372 let (lhs_var, lhs_literal) = match (&lhs_comparison.args[0], &lhs_comparison.args[1]) {
373 (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
374 match (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id]) {
375 (Expr::Var(var), Expr::Literal(literal)) => (var, literal),
376 (Expr::Literal(literal), Expr::Var(var)) => (var, literal),
377 _ => {
378 return false;
379 }
380 }
381 }
382 _ => {
383 return false;
384 }
385 };
386 let (rhs_var, rhs_literal) = match (&rhs_comparison.args[0], &rhs_comparison.args[1]) {
387 (ExprFunctionCallArg::Value(l_expr_id), ExprFunctionCallArg::Value(r_expr_id)) => {
388 match (&arenas.exprs[*l_expr_id], &arenas.exprs[*r_expr_id]) {
389 (Expr::Var(var), Expr::Literal(literal)) => (var, literal),
390 (Expr::Literal(literal), Expr::Var(var)) => (var, literal),
391 _ => {
392 return false;
393 }
394 }
395 }
396 _ => {
397 return false;
398 }
399 };
400
401 if lhs_var
402 .stable_ptr
403 .lookup(db.upcast())
404 .as_syntax_node()
405 .get_text(db.upcast())
406 != rhs_var
407 .stable_ptr
408 .lookup(db.upcast())
409 .as_syntax_node()
410 .get_text(db.upcast())
411 {
412 return false;
413 }
414
415 match (lhs_op, &logical_operator_exprs.op, rhs_op) {
416 (GT, LogicalOperator::AndAnd, LT) => lhs_literal.value >= rhs_literal.value,
417 (GT, LogicalOperator::AndAnd, LE) => lhs_literal.value >= rhs_literal.value,
418 (GE, LogicalOperator::AndAnd, LT) => lhs_literal.value >= rhs_literal.value,
419 (GE, LogicalOperator::AndAnd, LE) => lhs_literal.value > rhs_literal.value,
420 (LT, LogicalOperator::AndAnd, GT) => lhs_literal.value <= rhs_literal.value,
421 (LT, LogicalOperator::AndAnd, GE) => lhs_literal.value <= rhs_literal.value,
422 (LE, LogicalOperator::AndAnd, GT) => lhs_literal.value <= rhs_literal.value,
423 (LE, LogicalOperator::AndAnd, GE) => lhs_literal.value < rhs_literal.value,
424 _ => false,
425 }
426}
427
428fn is_simplifiable_double_comparison(
429 lhs_op: &str,
430 rhs_op: &str,
431 middle_op: &LogicalOperator,
432) -> bool {
433 matches!(
434 (lhs_op, middle_op, rhs_op),
435 (LE, LogicalOperator::AndAnd, GE)
436 | (GE, LogicalOperator::AndAnd, LE)
437 | (LT, LogicalOperator::OrOr, EQ)
438 | (EQ, LogicalOperator::OrOr, LT)
439 | (GT, LogicalOperator::OrOr, EQ)
440 | (EQ, LogicalOperator::OrOr, GT)
441 )
442}
443
444fn is_redundant_double_comparison(lhs_op: &str, rhs_op: &str, middle_op: &LogicalOperator) -> bool {
445 matches!(
446 (lhs_op, middle_op, rhs_op),
447 (LE, LogicalOperator::OrOr, GE)
448 | (GE, LogicalOperator::OrOr, LE)
449 | (LT, LogicalOperator::OrOr, GT)
450 | (GT, LogicalOperator::OrOr, LT)
451 )
452}
453
454fn is_contradictory_double_comparison(
455 lhs_op: &str,
456 rhs_op: &str,
457 middle_op: &LogicalOperator,
458) -> bool {
459 matches!(
460 (lhs_op, middle_op, rhs_op),
461 (EQ, LogicalOperator::AndAnd, LT)
462 | (LT, LogicalOperator::AndAnd, EQ)
463 | (EQ, LogicalOperator::AndAnd, GT)
464 | (GT, LogicalOperator::AndAnd, EQ)
465 | (LT, LogicalOperator::AndAnd, GT)
466 | (GT, LogicalOperator::AndAnd, LT)
467 | (GT, LogicalOperator::AndAnd, GE)
468 | (LE, LogicalOperator::AndAnd, GT)
469 )
470}
471
472pub fn fix_double_comparison(
474 db: &dyn SyntaxGroup,
475 node: SyntaxNode,
476) -> Option<(SyntaxNode, String)> {
477 let expr = AstExpr::from_syntax_node(db, node.clone());
478
479 if let AstExpr::Binary(binary_op) = expr {
480 let lhs = binary_op.lhs(db);
481 let rhs = binary_op.rhs(db);
482 let middle_op = binary_op.op(db);
483
484 if let (Some(lhs_op), Some(rhs_op)) = (
485 extract_binary_operator_expr(&lhs, db),
486 extract_binary_operator_expr(&rhs, db),
487 ) {
488 let simplified_op = determine_simplified_operator(&lhs_op, &rhs_op, &middle_op);
489
490 if let Some(simplified_op) = simplified_op {
491 if let Some(operator_to_replace) = operator_to_replace(lhs_op) {
492 let lhs_text = lhs
493 .as_syntax_node()
494 .get_text(db)
495 .replace(operator_to_replace, simplified_op);
496 return Some((node, lhs_text.to_string()));
497 }
498 }
499 }
500 }
501
502 None
503}
504
505fn operator_to_replace(lhs_op: BinaryOperator) -> Option<&'static str> {
506 match lhs_op {
507 BinaryOperator::EqEq(_) => Some("=="),
508 BinaryOperator::GT(_) => Some(">"),
509 BinaryOperator::LT(_) => Some("<"),
510 BinaryOperator::GE(_) => Some(">="),
511 BinaryOperator::LE(_) => Some("<="),
512 _ => None,
513 }
514}
515
516fn determine_simplified_operator(
517 lhs_op: &BinaryOperator,
518 rhs_op: &BinaryOperator,
519 middle_op: &BinaryOperator,
520) -> Option<&'static str> {
521 match (lhs_op, middle_op, rhs_op) {
522 (BinaryOperator::LE(_), BinaryOperator::AndAnd(_), BinaryOperator::GE(_))
523 | (BinaryOperator::GE(_), BinaryOperator::AndAnd(_), BinaryOperator::LE(_)) => Some("=="),
524
525 (BinaryOperator::LT(_), BinaryOperator::OrOr(_), BinaryOperator::EqEq(_))
526 | (BinaryOperator::EqEq(_), BinaryOperator::OrOr(_), BinaryOperator::LT(_)) => Some("<="),
527
528 (BinaryOperator::GT(_), BinaryOperator::OrOr(_), BinaryOperator::EqEq(_))
529 | (BinaryOperator::EqEq(_), BinaryOperator::OrOr(_), BinaryOperator::GT(_)) => Some(">="),
530
531 (BinaryOperator::LT(_), BinaryOperator::OrOr(_), BinaryOperator::GT(_))
532 | (BinaryOperator::GT(_), BinaryOperator::OrOr(_), BinaryOperator::LT(_)) => Some("!="),
533
534 _ => None,
535 }
536}
537
538fn extract_binary_operator_expr(expr: &AstExpr, db: &dyn SyntaxGroup) -> Option<BinaryOperator> {
539 if let AstExpr::Binary(binary_op) = expr {
540 Some(binary_op.op(db))
541 } else {
542 None
543 }
544}