Skip to main content

rigsql_rules/capitalisation/
cp03.rs

1use rigsql_core::{Segment, SegmentType, TokenKind};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::{LintViolation, SourceEdit};
5
6/// CP03: Function names must be consistently capitalised.
7///
8/// By default, expects lower case function names.
9#[derive(Debug, Default)]
10pub struct RuleCP03;
11
12impl Rule for RuleCP03 {
13    fn code(&self) -> &'static str {
14        "CP03"
15    }
16    fn name(&self) -> &'static str {
17        "capitalisation.functions"
18    }
19    fn description(&self) -> &'static str {
20        "Function names must be consistently capitalised."
21    }
22    fn explanation(&self) -> &'static str {
23        "Function names like COUNT, SUM, COALESCE should be consistently capitalised. \
24         Whether upper or lower depends on your team's convention."
25    }
26    fn groups(&self) -> &[RuleGroup] {
27        &[RuleGroup::Capitalisation]
28    }
29    fn is_fixable(&self) -> bool {
30        true
31    }
32
33    fn crawl_type(&self) -> CrawlType {
34        CrawlType::Segment(vec![SegmentType::FunctionCall])
35    }
36
37    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
38        // FunctionCall's first child should be the function name (Identifier)
39        let children = ctx.segment.children();
40        if children.is_empty() {
41            return vec![];
42        }
43
44        // Walk to find the function name token
45        let name_seg = Self::find_function_name(children);
46        let Some(Segment::Token(t)) = name_seg else {
47            return vec![];
48        };
49        if t.token.kind != TokenKind::Word {
50            return vec![];
51        }
52
53        // Check: function names should be consistent (default: lower)
54        let text = t.token.text.as_str();
55        // Skip if it's all upper or all lower (both are acceptable in many configs)
56        // Default: we don't enforce function name case (many projects use either)
57        // Only flag mixed case
58        let is_all_upper = text
59            .chars()
60            .all(|c| !c.is_ascii_alphabetic() || c.is_ascii_uppercase());
61        let is_all_lower = text
62            .chars()
63            .all(|c| !c.is_ascii_alphabetic() || c.is_ascii_lowercase());
64        if is_all_upper || is_all_lower {
65            return vec![];
66        }
67
68        vec![LintViolation::with_fix(
69            self.code(),
70            format!(
71                "Function name '{}' has inconsistent capitalisation. Use all upper or all lower case.",
72                text
73            ),
74            t.token.span,
75            vec![SourceEdit::replace(t.token.span, text.to_ascii_uppercase())],
76        )]
77    }
78}
79
80impl RuleCP03 {
81    fn find_function_name(children: &[Segment]) -> Option<&Segment> {
82        for child in children {
83            match child.segment_type() {
84                SegmentType::Identifier => return Some(child),
85                SegmentType::ColumnRef => {
86                    // qualified function: schema.func — get last identifier
87                    let inner = child.children();
88                    return inner
89                        .iter()
90                        .rev()
91                        .find(|s| s.segment_type() == SegmentType::Identifier);
92                }
93                _ if child.segment_type().is_trivia() => continue,
94                _ => break,
95            }
96        }
97        None
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use crate::test_utils::lint_sql;
105
106    #[test]
107    fn test_cp03_flags_mixed_case() {
108        let violations = lint_sql("SELECT Count(*) FROM t", RuleCP03);
109        assert_eq!(violations.len(), 1);
110    }
111
112    #[test]
113    fn test_cp03_accepts_all_upper() {
114        let violations = lint_sql("SELECT COUNT(*) FROM t", RuleCP03);
115        assert_eq!(violations.len(), 0);
116    }
117
118    #[test]
119    fn test_cp03_accepts_all_lower() {
120        let violations = lint_sql("SELECT count(*) FROM t", RuleCP03);
121        assert_eq!(violations.len(), 0);
122    }
123}