Skip to main content

rigsql_rules/convention/
cv04.rs

1use rigsql_core::{Segment, SegmentType};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6/// CV04: Use COUNT(*) instead of COUNT(0) or COUNT(1).
7///
8/// COUNT(*) is the standard way to count rows and is clear in intent.
9#[derive(Debug, Default)]
10pub struct RuleCV04;
11
12impl Rule for RuleCV04 {
13    fn code(&self) -> &'static str {
14        "CV04"
15    }
16    fn name(&self) -> &'static str {
17        "convention.count"
18    }
19    fn description(&self) -> &'static str {
20        "Use consistent syntax to count all rows."
21    }
22    fn explanation(&self) -> &'static str {
23        "COUNT(*) is the standard and most readable way to count all rows. \
24         COUNT(1) and COUNT(0) produce the same result but are less clear in intent. \
25         Using COUNT(*) consistently makes the code more readable."
26    }
27    fn groups(&self) -> &[RuleGroup] {
28        &[RuleGroup::Convention]
29    }
30    fn is_fixable(&self) -> bool {
31        true
32    }
33
34    fn crawl_type(&self) -> CrawlType {
35        CrawlType::Segment(vec![SegmentType::FunctionCall])
36    }
37
38    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
39        let children = ctx.segment.children();
40
41        // Check if function is COUNT
42        let func_name = children.iter().find(|c| !c.segment_type().is_trivia());
43        let is_count = match func_name {
44            Some(Segment::Token(t)) => t.token.text.eq_ignore_ascii_case("COUNT"),
45            _ => false,
46        };
47
48        if !is_count {
49            return vec![];
50        }
51
52        // Find the FunctionArgs node and check its content
53        for child in children {
54            if child.segment_type() == SegmentType::FunctionArgs {
55                let arg_tokens = child.tokens();
56                // Filter to non-trivia, non-paren tokens
57                let args: Vec<_> = arg_tokens
58                    .iter()
59                    .filter(|t| {
60                        !t.kind.is_trivia()
61                            && t.kind != rigsql_core::TokenKind::LParen
62                            && t.kind != rigsql_core::TokenKind::RParen
63                    })
64                    .collect();
65
66                // If single argument is a numeric literal "0" or "1"
67                if args.len() == 1 {
68                    let text = args[0].text.as_str();
69                    if text == "0" || text == "1" {
70                        return vec![LintViolation::new(
71                            self.code(),
72                            format!("Use COUNT(*) instead of COUNT({}).", text),
73                            ctx.segment.span(),
74                        )];
75                    }
76                }
77            }
78        }
79
80        vec![]
81    }
82}