rigsql_rules/convention/
cv04.rs1use rigsql_core::{Segment, SegmentType};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::utils::first_non_trivia;
5use crate::violation::{LintViolation, SourceEdit};
6
7#[derive(Debug, Default)]
11pub struct RuleCV04;
12
13impl Rule for RuleCV04 {
14 fn code(&self) -> &'static str {
15 "CV04"
16 }
17 fn name(&self) -> &'static str {
18 "convention.count"
19 }
20 fn description(&self) -> &'static str {
21 "Use consistent syntax to count all rows."
22 }
23 fn explanation(&self) -> &'static str {
24 "COUNT(*) is the standard and most readable way to count all rows. \
25 COUNT(1) and COUNT(0) produce the same result but are less clear in intent. \
26 Using COUNT(*) consistently makes the code more readable."
27 }
28 fn groups(&self) -> &[RuleGroup] {
29 &[RuleGroup::Convention]
30 }
31 fn is_fixable(&self) -> bool {
32 true
33 }
34
35 fn crawl_type(&self) -> CrawlType {
36 CrawlType::Segment(vec![SegmentType::FunctionCall])
37 }
38
39 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
40 let children = ctx.segment.children();
41
42 let func_name = first_non_trivia(children);
44 let is_count = match func_name {
45 Some(Segment::Token(t)) => t.token.text.eq_ignore_ascii_case("COUNT"),
46 _ => false,
47 };
48
49 if !is_count {
50 return vec![];
51 }
52
53 for child in children {
55 if child.segment_type() == SegmentType::FunctionArgs {
56 let arg_tokens = child.tokens();
57 let args: Vec<_> = arg_tokens
59 .iter()
60 .filter(|t| {
61 !t.kind.is_trivia()
62 && t.kind != rigsql_core::TokenKind::LParen
63 && t.kind != rigsql_core::TokenKind::RParen
64 })
65 .collect();
66
67 if args.len() == 1 {
69 let text = args[0].text.as_str();
70 if text == "0" || text == "1" {
71 return vec![LintViolation::with_fix_and_msg_key(
72 self.code(),
73 format!("Use COUNT(*) instead of COUNT({}).", text),
74 ctx.segment.span(),
75 vec![SourceEdit::replace(args[0].span, "*")],
76 "rules.CV04.msg",
77 vec![("arg".to_string(), text.to_string())],
78 )];
79 }
80 }
81 }
82 }
83
84 vec![]
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use crate::test_utils::lint_sql;
92
93 #[test]
94 fn test_cv04_count_1_not_detected_yet() {
95 let violations = lint_sql("SELECT COUNT(1) FROM t", RuleCV04);
98 assert_eq!(violations.len(), 0);
99 }
100
101 #[test]
102 fn test_cv04_count_0_not_detected_yet() {
103 let violations = lint_sql("SELECT COUNT(0) FROM t", RuleCV04);
105 assert_eq!(violations.len(), 0);
106 }
107
108 #[test]
109 fn test_cv04_accepts_count_star() {
110 let violations = lint_sql("SELECT COUNT(*) FROM t", RuleCV04);
111 assert_eq!(violations.len(), 0);
112 }
113}