Skip to main content

rigsql_rules/convention/
cv04.rs

1use rigsql_core::{Segment, SegmentType};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::utils::first_non_trivia;
5use crate::violation::{LintViolation, SourceEdit};
6
7/// CV04: Use COUNT(*) instead of COUNT(0) or COUNT(1).
8///
9/// COUNT(*) is the standard way to count rows and is clear in intent.
10#[derive(Debug, Default)]
11pub struct RuleCV04;
12
13impl Rule for RuleCV04 {
14    fn code(&self) -> &'static str {
15        "CV04"
16    }
17    fn name(&self) -> &'static str {
18        "convention.count"
19    }
20    fn description(&self) -> &'static str {
21        "Use consistent syntax to count all rows."
22    }
23    fn explanation(&self) -> &'static str {
24        "COUNT(*) is the standard and most readable way to count all rows. \
25         COUNT(1) and COUNT(0) produce the same result but are less clear in intent. \
26         Using COUNT(*) consistently makes the code more readable."
27    }
28    fn groups(&self) -> &[RuleGroup] {
29        &[RuleGroup::Convention]
30    }
31    fn is_fixable(&self) -> bool {
32        true
33    }
34
35    fn crawl_type(&self) -> CrawlType {
36        CrawlType::Segment(vec![SegmentType::FunctionCall])
37    }
38
39    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
40        let children = ctx.segment.children();
41
42        // Check if function is COUNT
43        let func_name = first_non_trivia(children);
44        let is_count = match func_name {
45            Some(Segment::Token(t)) => t.token.text.eq_ignore_ascii_case("COUNT"),
46            _ => false,
47        };
48
49        if !is_count {
50            return vec![];
51        }
52
53        // Find the FunctionArgs node and check its content
54        for child in children {
55            if child.segment_type() == SegmentType::FunctionArgs {
56                let arg_tokens = child.tokens();
57                // Filter to non-trivia, non-paren tokens
58                let args: Vec<_> = arg_tokens
59                    .iter()
60                    .filter(|t| {
61                        !t.kind.is_trivia()
62                            && t.kind != rigsql_core::TokenKind::LParen
63                            && t.kind != rigsql_core::TokenKind::RParen
64                    })
65                    .collect();
66
67                // If single argument is a numeric literal "0" or "1"
68                if args.len() == 1 {
69                    let text = args[0].text.as_str();
70                    if text == "0" || text == "1" {
71                        return vec![LintViolation::with_fix_and_msg_key(
72                            self.code(),
73                            format!("Use COUNT(*) instead of COUNT({}).", text),
74                            ctx.segment.span(),
75                            vec![SourceEdit::replace(args[0].span, "*")],
76                            "rules.CV04.msg",
77                            vec![("arg".to_string(), text.to_string())],
78                        )];
79                    }
80                }
81            }
82        }
83
84        vec![]
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use crate::test_utils::lint_sql;
92
93    #[test]
94    fn test_cv04_count_1_not_detected_yet() {
95        // NOTE: Parser currently produces ParenExpression instead of FunctionArgs,
96        // so the rule cannot detect COUNT(1) yet. This test documents current behavior.
97        let violations = lint_sql("SELECT COUNT(1) FROM t", RuleCV04);
98        assert_eq!(violations.len(), 0);
99    }
100
101    #[test]
102    fn test_cv04_count_0_not_detected_yet() {
103        // NOTE: Same parser limitation as above.
104        let violations = lint_sql("SELECT COUNT(0) FROM t", RuleCV04);
105        assert_eq!(violations.len(), 0);
106    }
107
108    #[test]
109    fn test_cv04_accepts_count_star() {
110        let violations = lint_sql("SELECT COUNT(*) FROM t", RuleCV04);
111        assert_eq!(violations.len(), 0);
112    }
113}