Skip to main content

flowscope_core/linter/rules/
cv_004.rs

1//! LINT_CV_004: Prefer COUNT(*) over COUNT(1).
2//!
3//! `COUNT(1)` and `COUNT(*)` are semantically identical in all major databases,
4//! but `COUNT(*)` is the standard convention and more clearly expresses intent.
5
6use crate::linter::config::LintConfig;
7use crate::linter::rule::{LintContext, LintRule};
8use crate::linter::visit;
9use crate::types::{issue_codes, Issue, IssueAutofixApplicability, IssuePatchEdit};
10use sqlparser::ast::{Spanned, *};
11use sqlparser::tokenizer::{Token, TokenWithSpan, Tokenizer, Whitespace};
12
13#[derive(Clone, Copy, Debug, Eq, PartialEq)]
14enum CountPreference {
15    Star,
16    One,
17    Zero,
18}
19
20impl CountPreference {
21    fn from_config(config: &LintConfig) -> Self {
22        let prefer_one = config
23            .rule_option_bool(issue_codes::LINT_CV_004, "prefer_count_1")
24            .unwrap_or(false);
25        let prefer_zero = config
26            .rule_option_bool(issue_codes::LINT_CV_004, "prefer_count_0")
27            .unwrap_or(false);
28
29        if prefer_one {
30            Self::One
31        } else if prefer_zero {
32            Self::Zero
33        } else {
34            Self::Star
35        }
36    }
37
38    fn message(self) -> &'static str {
39        match self {
40            Self::Star => "Use COUNT(*) for row counts.",
41            Self::One => "Use COUNT(1) for row counts.",
42            Self::Zero => "Use COUNT(0) for row counts.",
43        }
44    }
45
46    fn violates(self, kind: CountArgKind) -> bool {
47        match self {
48            Self::Star => matches!(kind, CountArgKind::One | CountArgKind::Zero),
49            Self::One => matches!(kind, CountArgKind::Star | CountArgKind::Zero),
50            Self::Zero => matches!(kind, CountArgKind::Star | CountArgKind::One),
51        }
52    }
53
54    fn replacement(self) -> &'static str {
55        match self {
56            Self::Star => "*",
57            Self::One => "1",
58            Self::Zero => "0",
59        }
60    }
61}
62
63#[derive(Clone, Copy, Debug, Eq, PartialEq)]
64enum CountArgKind {
65    Star,
66    One,
67    Zero,
68    Other,
69}
70
71pub struct CountStyle {
72    preference: CountPreference,
73}
74
75impl CountStyle {
76    pub fn from_config(config: &LintConfig) -> Self {
77        Self {
78            preference: CountPreference::from_config(config),
79        }
80    }
81}
82
83impl Default for CountStyle {
84    fn default() -> Self {
85        Self {
86            preference: CountPreference::Star,
87        }
88    }
89}
90
91impl LintRule for CountStyle {
92    fn code(&self) -> &'static str {
93        issue_codes::LINT_CV_004
94    }
95
96    fn name(&self) -> &'static str {
97        "COUNT style"
98    }
99
100    fn description(&self) -> &'static str {
101        "Use consistent syntax to express \"count number of rows\"."
102    }
103
104    fn check(&self, stmt: &Statement, ctx: &LintContext) -> Vec<Issue> {
105        let tokens =
106            tokenized_for_context(ctx).or_else(|| tokenized(ctx.statement_sql(), ctx.dialect()));
107        let wildcard_spans = tokens
108            .as_deref()
109            .map(collect_count_wildcard_spans)
110            .unwrap_or_default();
111        let numeric_spans = tokens
112            .as_deref()
113            .map(collect_count_numeric_spans)
114            .unwrap_or_default();
115        let mut wildcard_index = 0usize;
116        let mut numeric_index = 0usize;
117
118        let mut issues = Vec::new();
119        visit::visit_expressions(stmt, &mut |expr| {
120            let Expr::Function(func) = expr else {
121                return;
122            };
123            if !func.name.to_string().eq_ignore_ascii_case("COUNT") {
124                return;
125            }
126
127            let kind = count_argument_kind(&func.args);
128            let argument_span = match kind {
129                CountArgKind::Star => {
130                    let span = wildcard_spans.get(wildcard_index).copied();
131                    wildcard_index = wildcard_index.saturating_add(1);
132                    span
133                }
134                CountArgKind::One | CountArgKind::Zero => {
135                    let span = numeric_spans.get(numeric_index).copied();
136                    numeric_index = numeric_index.saturating_add(1);
137                    span.or_else(|| count_numeric_argument_span(ctx, func))
138                }
139                CountArgKind::Other => None,
140            };
141
142            if self.preference.violates(kind) {
143                let mut issue = Issue::info(issue_codes::LINT_CV_004, self.preference.message())
144                    .with_statement(ctx.statement_index);
145                if let Some((start, end)) = argument_span {
146                    let span = ctx.span_from_statement_offset(start, end);
147                    issue = issue.with_span(span).with_autofix_edits(
148                        IssueAutofixApplicability::Safe,
149                        vec![IssuePatchEdit::new(span, self.preference.replacement())],
150                    );
151                }
152                issues.push(issue);
153            }
154        });
155        issues
156    }
157}
158
159fn count_argument_kind(args: &FunctionArguments) -> CountArgKind {
160    let arg_list = match args {
161        FunctionArguments::List(list) => list,
162        _ => return CountArgKind::Other,
163    };
164
165    if arg_list.args.len() != 1 {
166        return CountArgKind::Other;
167    }
168
169    match &arg_list.args[0] {
170        FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => CountArgKind::Star,
171        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(ValueWithSpan {
172            value: Value::Number(n, _),
173            ..
174        }))) if numeric_literal_matches(n, 1) => CountArgKind::One,
175        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(ValueWithSpan {
176            value: Value::Number(n, _),
177            ..
178        }))) if numeric_literal_matches(n, 0) => CountArgKind::Zero,
179        _ => CountArgKind::Other,
180    }
181}
182
183fn numeric_literal_matches(raw: &str, expected: u8) -> bool {
184    raw.trim()
185        .parse::<u64>()
186        .ok()
187        .is_some_and(|value| value == expected as u64)
188}
189
190fn count_numeric_argument_span(ctx: &LintContext, func: &Function) -> Option<(usize, usize)> {
191    let FunctionArguments::List(arg_list) = &func.args else {
192        return None;
193    };
194    if arg_list.args.len() != 1 {
195        return None;
196    }
197
198    let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) = &arg_list.args[0] else {
199        return None;
200    };
201
202    if let Some((start, end)) = expr_span_offsets(ctx.statement_sql(), expr) {
203        return Some((start, end));
204    }
205
206    let (start, end) = expr_span_offsets(ctx.sql, expr)?;
207    if start < ctx.statement_range.start || end > ctx.statement_range.end {
208        return None;
209    }
210
211    Some((
212        start - ctx.statement_range.start,
213        end - ctx.statement_range.start,
214    ))
215}
216
217fn collect_count_wildcard_spans(tokens: &[LocatedToken]) -> Vec<(usize, usize)> {
218    let mut spans = Vec::new();
219    let mut i = 0usize;
220
221    while i < tokens.len() {
222        if !is_count_word(&tokens[i].token) {
223            i += 1;
224            continue;
225        }
226
227        let mut j = i + 1;
228        skip_trivia_tokens(tokens, &mut j);
229        if j >= tokens.len() || !matches!(tokens[j].token, Token::LParen) {
230            i += 1;
231            continue;
232        }
233
234        j += 1;
235        skip_trivia_tokens(tokens, &mut j);
236        if j >= tokens.len() {
237            break;
238        }
239
240        if let Token::Word(word) = &tokens[j].token {
241            if word.value.eq_ignore_ascii_case("ALL") || word.value.eq_ignore_ascii_case("DISTINCT")
242            {
243                j += 1;
244                skip_trivia_tokens(tokens, &mut j);
245            }
246        }
247
248        if j >= tokens.len() || !matches!(tokens[j].token, Token::Mul) {
249            i += 1;
250            continue;
251        }
252
253        let star_start = tokens[j].start;
254        let star_end = tokens[j].end;
255        j += 1;
256        skip_trivia_tokens(tokens, &mut j);
257        if j < tokens.len() && matches!(tokens[j].token, Token::RParen) {
258            spans.push((star_start, star_end));
259            i = j + 1;
260        } else {
261            i += 1;
262        }
263    }
264
265    spans
266}
267
268fn collect_count_numeric_spans(tokens: &[LocatedToken]) -> Vec<(usize, usize)> {
269    let mut spans = Vec::new();
270    let mut i = 0usize;
271
272    while i < tokens.len() {
273        if !is_count_word(&tokens[i].token) {
274            i += 1;
275            continue;
276        }
277
278        let mut j = i + 1;
279        skip_trivia_tokens(tokens, &mut j);
280        if j >= tokens.len() || !matches!(tokens[j].token, Token::LParen) {
281            i += 1;
282            continue;
283        }
284
285        j += 1;
286        skip_trivia_tokens(tokens, &mut j);
287        if j >= tokens.len() {
288            break;
289        }
290
291        if let Token::Word(word) = &tokens[j].token {
292            if word.value.eq_ignore_ascii_case("ALL") || word.value.eq_ignore_ascii_case("DISTINCT")
293            {
294                j += 1;
295                skip_trivia_tokens(tokens, &mut j);
296            }
297        }
298
299        if j >= tokens.len() {
300            break;
301        }
302
303        let Some(raw_number) = token_numeric_literal(&tokens[j].token) else {
304            i += 1;
305            continue;
306        };
307        if !numeric_literal_matches(raw_number, 0) && !numeric_literal_matches(raw_number, 1) {
308            i += 1;
309            continue;
310        }
311
312        let number_start = tokens[j].start;
313        let number_end = tokens[j].end;
314        j += 1;
315        skip_trivia_tokens(tokens, &mut j);
316        if j < tokens.len() && matches!(tokens[j].token, Token::RParen) {
317            spans.push((number_start, number_end));
318            i = j + 1;
319        } else {
320            i += 1;
321        }
322    }
323
324    spans
325}
326
327fn skip_trivia_tokens(tokens: &[LocatedToken], index: &mut usize) {
328    while *index < tokens.len() && is_trivia_token(&tokens[*index].token) {
329        *index += 1;
330    }
331}
332
333fn is_count_word(token: &Token) -> bool {
334    matches!(token, Token::Word(word) if word.value.eq_ignore_ascii_case("COUNT"))
335}
336
337fn token_numeric_literal(token: &Token) -> Option<&str> {
338    match token {
339        Token::Number(raw, _) => Some(raw.as_str()),
340        _ => None,
341    }
342}
343
344fn expr_span_offsets(sql: &str, expr: &Expr) -> Option<(usize, usize)> {
345    let span = expr.span();
346    if span.start.line == 0 || span.start.column == 0 || span.end.line == 0 || span.end.column == 0
347    {
348        return None;
349    }
350    let start = line_col_to_offset(sql, span.start.line as usize, span.start.column as usize)?;
351    let end = line_col_to_offset(sql, span.end.line as usize, span.end.column as usize)?;
352    if end < start {
353        return None;
354    }
355    Some((start, end))
356}
357
358#[derive(Clone)]
359struct LocatedToken {
360    token: Token,
361    start: usize,
362    end: usize,
363}
364
365fn tokenized(sql: &str, dialect: crate::types::Dialect) -> Option<Vec<LocatedToken>> {
366    let dialect = dialect.to_sqlparser_dialect();
367    let mut tokenizer = Tokenizer::new(dialect.as_ref(), sql);
368    let tokens = tokenizer.tokenize_with_location().ok()?;
369
370    let mut out = Vec::with_capacity(tokens.len());
371    for token in tokens {
372        let Some((start, end)) = token_with_span_offsets(sql, &token) else {
373            continue;
374        };
375        out.push(LocatedToken {
376            token: token.token,
377            start,
378            end,
379        });
380    }
381    Some(out)
382}
383
384fn tokenized_for_context(ctx: &LintContext) -> Option<Vec<LocatedToken>> {
385    let statement_start = ctx.statement_range.start;
386    let from_document = ctx.with_document_tokens(|tokens| {
387        if tokens.is_empty() {
388            return None;
389        }
390
391        Some(
392            tokens
393                .iter()
394                .filter_map(|token| {
395                    let (start, end) = token_with_span_offsets(ctx.sql, token)?;
396                    if start < ctx.statement_range.start || end > ctx.statement_range.end {
397                        return None;
398                    }
399
400                    Some(LocatedToken {
401                        token: token.token.clone(),
402                        start: start - statement_start,
403                        end: end - statement_start,
404                    })
405                })
406                .collect::<Vec<_>>(),
407        )
408    });
409
410    if let Some(tokens) = from_document {
411        return Some(tokens);
412    }
413
414    tokenized(ctx.statement_sql(), ctx.dialect())
415}
416
417fn token_with_span_offsets(sql: &str, token: &TokenWithSpan) -> Option<(usize, usize)> {
418    let start = line_col_to_offset(
419        sql,
420        token.span.start.line as usize,
421        token.span.start.column as usize,
422    )?;
423    let end = line_col_to_offset(
424        sql,
425        token.span.end.line as usize,
426        token.span.end.column as usize,
427    )?;
428    Some((start, end))
429}
430
431fn is_trivia_token(token: &Token) -> bool {
432    matches!(
433        token,
434        Token::Whitespace(Whitespace::Space | Whitespace::Tab | Whitespace::Newline)
435            | Token::Whitespace(Whitespace::SingleLineComment { .. })
436            | Token::Whitespace(Whitespace::MultiLineComment(_))
437    )
438}
439
440fn line_col_to_offset(sql: &str, line: usize, column: usize) -> Option<usize> {
441    if line == 0 || column == 0 {
442        return None;
443    }
444
445    let mut current_line = 1usize;
446    let mut current_col = 1usize;
447    for (offset, ch) in sql.char_indices() {
448        if current_line == line && current_col == column {
449            return Some(offset);
450        }
451        if ch == '\n' {
452            current_line += 1;
453            current_col = 1;
454        } else {
455            current_col += 1;
456        }
457    }
458
459    if current_line == line && current_col == column {
460        Some(sql.len())
461    } else {
462        None
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469    use crate::parser::parse_sql;
470    use crate::types::IssueAutofixApplicability;
471
472    fn check_sql(sql: &str) -> Vec<Issue> {
473        let stmts = parse_sql(sql).unwrap();
474        let rule = CountStyle::default();
475        let ctx = LintContext {
476            sql,
477            statement_range: 0..sql.len(),
478            statement_index: 0,
479        };
480        let mut issues = Vec::new();
481        for stmt in &stmts {
482            issues.extend(rule.check(stmt, &ctx));
483        }
484        issues
485    }
486
487    fn assert_single_safe_edit(
488        issue: &Issue,
489        expected_start: usize,
490        expected_end: usize,
491        expected_replacement: &str,
492    ) {
493        let span = issue.span.expect("issue span");
494        assert_eq!(span.start, expected_start);
495        assert_eq!(span.end, expected_end);
496
497        let autofix = issue.autofix.as_ref().expect("autofix metadata");
498        assert_eq!(autofix.applicability, IssueAutofixApplicability::Safe);
499        assert_eq!(autofix.edits.len(), 1);
500        assert_eq!(autofix.edits[0].span.start, expected_start);
501        assert_eq!(autofix.edits[0].span.end, expected_end);
502        assert_eq!(autofix.edits[0].replacement, expected_replacement);
503    }
504
505    #[test]
506    fn test_count_one_detected() {
507        let sql = "SELECT COUNT(1) FROM t";
508        let issues = check_sql(sql);
509        assert_eq!(issues.len(), 1);
510        assert_eq!(issues[0].code, "LINT_CV_004");
511
512        let one_start = sql.find('1').expect("count literal");
513        assert_single_safe_edit(&issues[0], one_start, one_start + 1, "*");
514    }
515
516    #[test]
517    fn test_count_leading_zero_numeric_literals_are_detected() {
518        let sql = "SELECT COUNT(01), COUNT(00) FROM t";
519        let issues = check_sql(sql);
520        assert_eq!(issues.len(), 2);
521
522        let first_start = sql.find("01").expect("first literal");
523        let second_start = sql.find("00").expect("second literal");
524        assert_single_safe_edit(&issues[0], first_start, first_start + 2, "*");
525        assert_single_safe_edit(&issues[1], second_start, second_start + 2, "*");
526    }
527
528    #[test]
529    fn test_count_star_ok() {
530        let issues = check_sql("SELECT COUNT(*) FROM t");
531        assert!(issues.is_empty());
532    }
533
534    #[test]
535    fn test_count_column_ok() {
536        let issues = check_sql("SELECT COUNT(id) FROM t");
537        assert!(issues.is_empty());
538    }
539
540    // --- Edge cases ---
541
542    #[test]
543    fn test_count_zero_detected_with_default_star_preference() {
544        let issues = check_sql("SELECT COUNT(0) FROM t");
545        assert_eq!(issues.len(), 1);
546    }
547
548    #[test]
549    fn test_count_one_in_having() {
550        let issues = check_sql("SELECT col FROM t GROUP BY col HAVING COUNT(1) > 5");
551        assert_eq!(issues.len(), 1);
552    }
553
554    #[test]
555    fn test_count_one_in_subquery() {
556        let issues =
557            check_sql("SELECT * FROM t WHERE id IN (SELECT COUNT(1) FROM t2 GROUP BY col)");
558        assert_eq!(issues.len(), 1);
559    }
560
561    #[test]
562    fn test_multiple_count_one() {
563        let issues = check_sql("SELECT COUNT(1), COUNT(1) FROM t");
564        assert_eq!(issues.len(), 2);
565    }
566
567    #[test]
568    fn test_count_distinct_ok() {
569        let issues = check_sql("SELECT COUNT(DISTINCT id) FROM t");
570        assert!(issues.is_empty());
571    }
572
573    #[test]
574    fn test_count_one_in_cte() {
575        let issues = check_sql("WITH cte AS (SELECT COUNT(1) AS cnt FROM t) SELECT * FROM cte");
576        assert_eq!(issues.len(), 1);
577    }
578
579    #[test]
580    fn test_count_one_in_qualify() {
581        let issues = check_sql("SELECT a FROM t QUALIFY COUNT(1) > 0");
582        assert_eq!(issues.len(), 1);
583    }
584
585    #[test]
586    fn test_prefer_count_one_flags_count_star() {
587        let config = LintConfig {
588            enabled: true,
589            disabled_rules: vec![],
590            rule_configs: std::collections::BTreeMap::from([(
591                "convention.count_rows".to_string(),
592                serde_json::json!({"prefer_count_1": true}),
593            )]),
594        };
595        let rule = CountStyle::from_config(&config);
596        let sql = "SELECT COUNT(*) FROM t";
597        let stmts = parse_sql(sql).unwrap();
598        let issues = rule.check(
599            &stmts[0],
600            &LintContext {
601                sql,
602                statement_range: 0..sql.len(),
603                statement_index: 0,
604            },
605        );
606        assert_eq!(issues.len(), 1);
607
608        let star_start = sql.find('*').expect("star argument");
609        assert_single_safe_edit(&issues[0], star_start, star_start + 1, "1");
610    }
611
612    #[test]
613    fn test_prefer_count_zero_flags_count_one() {
614        let config = LintConfig {
615            enabled: true,
616            disabled_rules: vec![],
617            rule_configs: std::collections::BTreeMap::from([(
618                "LINT_CV_004".to_string(),
619                serde_json::json!({"prefer_count_0": true}),
620            )]),
621        };
622        let rule = CountStyle::from_config(&config);
623        let sql = "SELECT COUNT(1) FROM t";
624        let stmts = parse_sql(sql).unwrap();
625        let issues = rule.check(
626            &stmts[0],
627            &LintContext {
628                sql,
629                statement_range: 0..sql.len(),
630                statement_index: 0,
631            },
632        );
633        assert_eq!(issues.len(), 1);
634
635        let one_start = sql.find('1').expect("count literal");
636        assert_single_safe_edit(&issues[0], one_start, one_start + 1, "0");
637    }
638}