1use crate::linter::config::LintConfig;
7use crate::linter::rule::{LintContext, LintRule};
8use crate::linter::visit;
9use crate::types::{issue_codes, Issue, IssueAutofixApplicability, IssuePatchEdit};
10use sqlparser::ast::{Spanned, *};
11use sqlparser::tokenizer::{Token, TokenWithSpan, Tokenizer, Whitespace};
12
13#[derive(Clone, Copy, Debug, Eq, PartialEq)]
14enum CountPreference {
15 Star,
16 One,
17 Zero,
18}
19
20impl CountPreference {
21 fn from_config(config: &LintConfig) -> Self {
22 let prefer_one = config
23 .rule_option_bool(issue_codes::LINT_CV_004, "prefer_count_1")
24 .unwrap_or(false);
25 let prefer_zero = config
26 .rule_option_bool(issue_codes::LINT_CV_004, "prefer_count_0")
27 .unwrap_or(false);
28
29 if prefer_one {
30 Self::One
31 } else if prefer_zero {
32 Self::Zero
33 } else {
34 Self::Star
35 }
36 }
37
38 fn message(self) -> &'static str {
39 match self {
40 Self::Star => "Use COUNT(*) for row counts.",
41 Self::One => "Use COUNT(1) for row counts.",
42 Self::Zero => "Use COUNT(0) for row counts.",
43 }
44 }
45
46 fn violates(self, kind: CountArgKind) -> bool {
47 match self {
48 Self::Star => matches!(kind, CountArgKind::One | CountArgKind::Zero),
49 Self::One => matches!(kind, CountArgKind::Star | CountArgKind::Zero),
50 Self::Zero => matches!(kind, CountArgKind::Star | CountArgKind::One),
51 }
52 }
53
54 fn replacement(self) -> &'static str {
55 match self {
56 Self::Star => "*",
57 Self::One => "1",
58 Self::Zero => "0",
59 }
60 }
61}
62
63#[derive(Clone, Copy, Debug, Eq, PartialEq)]
64enum CountArgKind {
65 Star,
66 One,
67 Zero,
68 Other,
69}
70
71pub struct CountStyle {
72 preference: CountPreference,
73}
74
75impl CountStyle {
76 pub fn from_config(config: &LintConfig) -> Self {
77 Self {
78 preference: CountPreference::from_config(config),
79 }
80 }
81}
82
83impl Default for CountStyle {
84 fn default() -> Self {
85 Self {
86 preference: CountPreference::Star,
87 }
88 }
89}
90
91impl LintRule for CountStyle {
92 fn code(&self) -> &'static str {
93 issue_codes::LINT_CV_004
94 }
95
96 fn name(&self) -> &'static str {
97 "COUNT style"
98 }
99
100 fn description(&self) -> &'static str {
101 "Use consistent syntax to express \"count number of rows\"."
102 }
103
104 fn check(&self, stmt: &Statement, ctx: &LintContext) -> Vec<Issue> {
105 let tokens =
106 tokenized_for_context(ctx).or_else(|| tokenized(ctx.statement_sql(), ctx.dialect()));
107 let wildcard_spans = tokens
108 .as_deref()
109 .map(collect_count_wildcard_spans)
110 .unwrap_or_default();
111 let numeric_spans = tokens
112 .as_deref()
113 .map(collect_count_numeric_spans)
114 .unwrap_or_default();
115 let mut wildcard_index = 0usize;
116 let mut numeric_index = 0usize;
117
118 let mut issues = Vec::new();
119 visit::visit_expressions(stmt, &mut |expr| {
120 let Expr::Function(func) = expr else {
121 return;
122 };
123 if !func.name.to_string().eq_ignore_ascii_case("COUNT") {
124 return;
125 }
126
127 let kind = count_argument_kind(&func.args);
128 let argument_span = match kind {
129 CountArgKind::Star => {
130 let span = wildcard_spans.get(wildcard_index).copied();
131 wildcard_index = wildcard_index.saturating_add(1);
132 span
133 }
134 CountArgKind::One | CountArgKind::Zero => {
135 let span = numeric_spans.get(numeric_index).copied();
136 numeric_index = numeric_index.saturating_add(1);
137 span.or_else(|| count_numeric_argument_span(ctx, func))
138 }
139 CountArgKind::Other => None,
140 };
141
142 if self.preference.violates(kind) {
143 let mut issue = Issue::info(issue_codes::LINT_CV_004, self.preference.message())
144 .with_statement(ctx.statement_index);
145 if let Some((start, end)) = argument_span {
146 let span = ctx.span_from_statement_offset(start, end);
147 issue = issue.with_span(span).with_autofix_edits(
148 IssueAutofixApplicability::Safe,
149 vec![IssuePatchEdit::new(span, self.preference.replacement())],
150 );
151 }
152 issues.push(issue);
153 }
154 });
155 issues
156 }
157}
158
159fn count_argument_kind(args: &FunctionArguments) -> CountArgKind {
160 let arg_list = match args {
161 FunctionArguments::List(list) => list,
162 _ => return CountArgKind::Other,
163 };
164
165 if arg_list.args.len() != 1 {
166 return CountArgKind::Other;
167 }
168
169 match &arg_list.args[0] {
170 FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => CountArgKind::Star,
171 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(ValueWithSpan {
172 value: Value::Number(n, _),
173 ..
174 }))) if numeric_literal_matches(n, 1) => CountArgKind::One,
175 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(ValueWithSpan {
176 value: Value::Number(n, _),
177 ..
178 }))) if numeric_literal_matches(n, 0) => CountArgKind::Zero,
179 _ => CountArgKind::Other,
180 }
181}
182
183fn numeric_literal_matches(raw: &str, expected: u8) -> bool {
184 raw.trim()
185 .parse::<u64>()
186 .ok()
187 .is_some_and(|value| value == expected as u64)
188}
189
190fn count_numeric_argument_span(ctx: &LintContext, func: &Function) -> Option<(usize, usize)> {
191 let FunctionArguments::List(arg_list) = &func.args else {
192 return None;
193 };
194 if arg_list.args.len() != 1 {
195 return None;
196 }
197
198 let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) = &arg_list.args[0] else {
199 return None;
200 };
201
202 if let Some((start, end)) = expr_span_offsets(ctx.statement_sql(), expr) {
203 return Some((start, end));
204 }
205
206 let (start, end) = expr_span_offsets(ctx.sql, expr)?;
207 if start < ctx.statement_range.start || end > ctx.statement_range.end {
208 return None;
209 }
210
211 Some((
212 start - ctx.statement_range.start,
213 end - ctx.statement_range.start,
214 ))
215}
216
217fn collect_count_wildcard_spans(tokens: &[LocatedToken]) -> Vec<(usize, usize)> {
218 let mut spans = Vec::new();
219 let mut i = 0usize;
220
221 while i < tokens.len() {
222 if !is_count_word(&tokens[i].token) {
223 i += 1;
224 continue;
225 }
226
227 let mut j = i + 1;
228 skip_trivia_tokens(tokens, &mut j);
229 if j >= tokens.len() || !matches!(tokens[j].token, Token::LParen) {
230 i += 1;
231 continue;
232 }
233
234 j += 1;
235 skip_trivia_tokens(tokens, &mut j);
236 if j >= tokens.len() {
237 break;
238 }
239
240 if let Token::Word(word) = &tokens[j].token {
241 if word.value.eq_ignore_ascii_case("ALL") || word.value.eq_ignore_ascii_case("DISTINCT")
242 {
243 j += 1;
244 skip_trivia_tokens(tokens, &mut j);
245 }
246 }
247
248 if j >= tokens.len() || !matches!(tokens[j].token, Token::Mul) {
249 i += 1;
250 continue;
251 }
252
253 let star_start = tokens[j].start;
254 let star_end = tokens[j].end;
255 j += 1;
256 skip_trivia_tokens(tokens, &mut j);
257 if j < tokens.len() && matches!(tokens[j].token, Token::RParen) {
258 spans.push((star_start, star_end));
259 i = j + 1;
260 } else {
261 i += 1;
262 }
263 }
264
265 spans
266}
267
268fn collect_count_numeric_spans(tokens: &[LocatedToken]) -> Vec<(usize, usize)> {
269 let mut spans = Vec::new();
270 let mut i = 0usize;
271
272 while i < tokens.len() {
273 if !is_count_word(&tokens[i].token) {
274 i += 1;
275 continue;
276 }
277
278 let mut j = i + 1;
279 skip_trivia_tokens(tokens, &mut j);
280 if j >= tokens.len() || !matches!(tokens[j].token, Token::LParen) {
281 i += 1;
282 continue;
283 }
284
285 j += 1;
286 skip_trivia_tokens(tokens, &mut j);
287 if j >= tokens.len() {
288 break;
289 }
290
291 if let Token::Word(word) = &tokens[j].token {
292 if word.value.eq_ignore_ascii_case("ALL") || word.value.eq_ignore_ascii_case("DISTINCT")
293 {
294 j += 1;
295 skip_trivia_tokens(tokens, &mut j);
296 }
297 }
298
299 if j >= tokens.len() {
300 break;
301 }
302
303 let Some(raw_number) = token_numeric_literal(&tokens[j].token) else {
304 i += 1;
305 continue;
306 };
307 if !numeric_literal_matches(raw_number, 0) && !numeric_literal_matches(raw_number, 1) {
308 i += 1;
309 continue;
310 }
311
312 let number_start = tokens[j].start;
313 let number_end = tokens[j].end;
314 j += 1;
315 skip_trivia_tokens(tokens, &mut j);
316 if j < tokens.len() && matches!(tokens[j].token, Token::RParen) {
317 spans.push((number_start, number_end));
318 i = j + 1;
319 } else {
320 i += 1;
321 }
322 }
323
324 spans
325}
326
327fn skip_trivia_tokens(tokens: &[LocatedToken], index: &mut usize) {
328 while *index < tokens.len() && is_trivia_token(&tokens[*index].token) {
329 *index += 1;
330 }
331}
332
333fn is_count_word(token: &Token) -> bool {
334 matches!(token, Token::Word(word) if word.value.eq_ignore_ascii_case("COUNT"))
335}
336
337fn token_numeric_literal(token: &Token) -> Option<&str> {
338 match token {
339 Token::Number(raw, _) => Some(raw.as_str()),
340 _ => None,
341 }
342}
343
344fn expr_span_offsets(sql: &str, expr: &Expr) -> Option<(usize, usize)> {
345 let span = expr.span();
346 if span.start.line == 0 || span.start.column == 0 || span.end.line == 0 || span.end.column == 0
347 {
348 return None;
349 }
350 let start = line_col_to_offset(sql, span.start.line as usize, span.start.column as usize)?;
351 let end = line_col_to_offset(sql, span.end.line as usize, span.end.column as usize)?;
352 if end < start {
353 return None;
354 }
355 Some((start, end))
356}
357
358#[derive(Clone)]
359struct LocatedToken {
360 token: Token,
361 start: usize,
362 end: usize,
363}
364
365fn tokenized(sql: &str, dialect: crate::types::Dialect) -> Option<Vec<LocatedToken>> {
366 let dialect = dialect.to_sqlparser_dialect();
367 let mut tokenizer = Tokenizer::new(dialect.as_ref(), sql);
368 let tokens = tokenizer.tokenize_with_location().ok()?;
369
370 let mut out = Vec::with_capacity(tokens.len());
371 for token in tokens {
372 let Some((start, end)) = token_with_span_offsets(sql, &token) else {
373 continue;
374 };
375 out.push(LocatedToken {
376 token: token.token,
377 start,
378 end,
379 });
380 }
381 Some(out)
382}
383
384fn tokenized_for_context(ctx: &LintContext) -> Option<Vec<LocatedToken>> {
385 let statement_start = ctx.statement_range.start;
386 let from_document = ctx.with_document_tokens(|tokens| {
387 if tokens.is_empty() {
388 return None;
389 }
390
391 Some(
392 tokens
393 .iter()
394 .filter_map(|token| {
395 let (start, end) = token_with_span_offsets(ctx.sql, token)?;
396 if start < ctx.statement_range.start || end > ctx.statement_range.end {
397 return None;
398 }
399
400 Some(LocatedToken {
401 token: token.token.clone(),
402 start: start - statement_start,
403 end: end - statement_start,
404 })
405 })
406 .collect::<Vec<_>>(),
407 )
408 });
409
410 if let Some(tokens) = from_document {
411 return Some(tokens);
412 }
413
414 tokenized(ctx.statement_sql(), ctx.dialect())
415}
416
417fn token_with_span_offsets(sql: &str, token: &TokenWithSpan) -> Option<(usize, usize)> {
418 let start = line_col_to_offset(
419 sql,
420 token.span.start.line as usize,
421 token.span.start.column as usize,
422 )?;
423 let end = line_col_to_offset(
424 sql,
425 token.span.end.line as usize,
426 token.span.end.column as usize,
427 )?;
428 Some((start, end))
429}
430
431fn is_trivia_token(token: &Token) -> bool {
432 matches!(
433 token,
434 Token::Whitespace(Whitespace::Space | Whitespace::Tab | Whitespace::Newline)
435 | Token::Whitespace(Whitespace::SingleLineComment { .. })
436 | Token::Whitespace(Whitespace::MultiLineComment(_))
437 )
438}
439
440fn line_col_to_offset(sql: &str, line: usize, column: usize) -> Option<usize> {
441 if line == 0 || column == 0 {
442 return None;
443 }
444
445 let mut current_line = 1usize;
446 let mut current_col = 1usize;
447 for (offset, ch) in sql.char_indices() {
448 if current_line == line && current_col == column {
449 return Some(offset);
450 }
451 if ch == '\n' {
452 current_line += 1;
453 current_col = 1;
454 } else {
455 current_col += 1;
456 }
457 }
458
459 if current_line == line && current_col == column {
460 Some(sql.len())
461 } else {
462 None
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469 use crate::parser::parse_sql;
470 use crate::types::IssueAutofixApplicability;
471
472 fn check_sql(sql: &str) -> Vec<Issue> {
473 let stmts = parse_sql(sql).unwrap();
474 let rule = CountStyle::default();
475 let ctx = LintContext {
476 sql,
477 statement_range: 0..sql.len(),
478 statement_index: 0,
479 };
480 let mut issues = Vec::new();
481 for stmt in &stmts {
482 issues.extend(rule.check(stmt, &ctx));
483 }
484 issues
485 }
486
487 fn assert_single_safe_edit(
488 issue: &Issue,
489 expected_start: usize,
490 expected_end: usize,
491 expected_replacement: &str,
492 ) {
493 let span = issue.span.expect("issue span");
494 assert_eq!(span.start, expected_start);
495 assert_eq!(span.end, expected_end);
496
497 let autofix = issue.autofix.as_ref().expect("autofix metadata");
498 assert_eq!(autofix.applicability, IssueAutofixApplicability::Safe);
499 assert_eq!(autofix.edits.len(), 1);
500 assert_eq!(autofix.edits[0].span.start, expected_start);
501 assert_eq!(autofix.edits[0].span.end, expected_end);
502 assert_eq!(autofix.edits[0].replacement, expected_replacement);
503 }
504
505 #[test]
506 fn test_count_one_detected() {
507 let sql = "SELECT COUNT(1) FROM t";
508 let issues = check_sql(sql);
509 assert_eq!(issues.len(), 1);
510 assert_eq!(issues[0].code, "LINT_CV_004");
511
512 let one_start = sql.find('1').expect("count literal");
513 assert_single_safe_edit(&issues[0], one_start, one_start + 1, "*");
514 }
515
516 #[test]
517 fn test_count_leading_zero_numeric_literals_are_detected() {
518 let sql = "SELECT COUNT(01), COUNT(00) FROM t";
519 let issues = check_sql(sql);
520 assert_eq!(issues.len(), 2);
521
522 let first_start = sql.find("01").expect("first literal");
523 let second_start = sql.find("00").expect("second literal");
524 assert_single_safe_edit(&issues[0], first_start, first_start + 2, "*");
525 assert_single_safe_edit(&issues[1], second_start, second_start + 2, "*");
526 }
527
528 #[test]
529 fn test_count_star_ok() {
530 let issues = check_sql("SELECT COUNT(*) FROM t");
531 assert!(issues.is_empty());
532 }
533
534 #[test]
535 fn test_count_column_ok() {
536 let issues = check_sql("SELECT COUNT(id) FROM t");
537 assert!(issues.is_empty());
538 }
539
540 #[test]
543 fn test_count_zero_detected_with_default_star_preference() {
544 let issues = check_sql("SELECT COUNT(0) FROM t");
545 assert_eq!(issues.len(), 1);
546 }
547
548 #[test]
549 fn test_count_one_in_having() {
550 let issues = check_sql("SELECT col FROM t GROUP BY col HAVING COUNT(1) > 5");
551 assert_eq!(issues.len(), 1);
552 }
553
554 #[test]
555 fn test_count_one_in_subquery() {
556 let issues =
557 check_sql("SELECT * FROM t WHERE id IN (SELECT COUNT(1) FROM t2 GROUP BY col)");
558 assert_eq!(issues.len(), 1);
559 }
560
561 #[test]
562 fn test_multiple_count_one() {
563 let issues = check_sql("SELECT COUNT(1), COUNT(1) FROM t");
564 assert_eq!(issues.len(), 2);
565 }
566
567 #[test]
568 fn test_count_distinct_ok() {
569 let issues = check_sql("SELECT COUNT(DISTINCT id) FROM t");
570 assert!(issues.is_empty());
571 }
572
573 #[test]
574 fn test_count_one_in_cte() {
575 let issues = check_sql("WITH cte AS (SELECT COUNT(1) AS cnt FROM t) SELECT * FROM cte");
576 assert_eq!(issues.len(), 1);
577 }
578
579 #[test]
580 fn test_count_one_in_qualify() {
581 let issues = check_sql("SELECT a FROM t QUALIFY COUNT(1) > 0");
582 assert_eq!(issues.len(), 1);
583 }
584
585 #[test]
586 fn test_prefer_count_one_flags_count_star() {
587 let config = LintConfig {
588 enabled: true,
589 disabled_rules: vec![],
590 rule_configs: std::collections::BTreeMap::from([(
591 "convention.count_rows".to_string(),
592 serde_json::json!({"prefer_count_1": true}),
593 )]),
594 };
595 let rule = CountStyle::from_config(&config);
596 let sql = "SELECT COUNT(*) FROM t";
597 let stmts = parse_sql(sql).unwrap();
598 let issues = rule.check(
599 &stmts[0],
600 &LintContext {
601 sql,
602 statement_range: 0..sql.len(),
603 statement_index: 0,
604 },
605 );
606 assert_eq!(issues.len(), 1);
607
608 let star_start = sql.find('*').expect("star argument");
609 assert_single_safe_edit(&issues[0], star_start, star_start + 1, "1");
610 }
611
612 #[test]
613 fn test_prefer_count_zero_flags_count_one() {
614 let config = LintConfig {
615 enabled: true,
616 disabled_rules: vec![],
617 rule_configs: std::collections::BTreeMap::from([(
618 "LINT_CV_004".to_string(),
619 serde_json::json!({"prefer_count_0": true}),
620 )]),
621 };
622 let rule = CountStyle::from_config(&config);
623 let sql = "SELECT COUNT(1) FROM t";
624 let stmts = parse_sql(sql).unwrap();
625 let issues = rule.check(
626 &stmts[0],
627 &LintContext {
628 sql,
629 statement_range: 0..sql.len(),
630 statement_index: 0,
631 },
632 );
633 assert_eq!(issues.len(), 1);
634
635 let one_start = sql.find('1').expect("count literal");
636 assert_single_safe_edit(&issues[0], one_start, one_start + 1, "0");
637 }
638}