Skip to main content

rigsql_rules/ambiguous/
am06.rs

1use rigsql_core::{Segment, SegmentType};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6/// AM06: Inconsistent column references in GROUP BY/ORDER BY.
7///
8/// GROUP BY and ORDER BY clauses should not mix positional (numeric) references
9/// with explicit (named) references. Use one style consistently.
10#[derive(Debug, Default)]
11pub struct RuleAM06;
12
13impl Rule for RuleAM06 {
14    fn code(&self) -> &'static str {
15        "AM06"
16    }
17    fn name(&self) -> &'static str {
18        "ambiguous.column_references"
19    }
20    fn description(&self) -> &'static str {
21        "Inconsistent column references in GROUP BY/ORDER BY."
22    }
23    fn explanation(&self) -> &'static str {
24        "GROUP BY and ORDER BY clauses should use a consistent style for column references: \
25         either all positional (e.g., GROUP BY 1, 2) or all explicit column names \
26         (e.g., GROUP BY foo, bar). Mixing styles like GROUP BY foo, 2 is ambiguous \
27         and hard to maintain."
28    }
29    fn groups(&self) -> &[RuleGroup] {
30        &[RuleGroup::Ambiguous]
31    }
32    fn is_fixable(&self) -> bool {
33        false
34    }
35
36    fn crawl_type(&self) -> CrawlType {
37        CrawlType::Segment(vec![SegmentType::GroupByClause, SegmentType::OrderByClause])
38    }
39
40    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
41        let mut positional = Vec::new();
42        let mut named = Vec::new();
43
44        collect_ref_styles(ctx.segment, &mut positional, &mut named);
45
46        // Only flag if there's a mix of styles
47        if !positional.is_empty() && !named.is_empty() {
48            let clause_name = match ctx.segment.segment_type() {
49                SegmentType::GroupByClause => "GROUP BY",
50                SegmentType::OrderByClause => "ORDER BY",
51                _ => "Clause",
52            };
53
54            // Flag the minority style references.
55            // If there are more positional than named, flag named ones and vice versa.
56            let (targets, style) = if positional.len() >= named.len() {
57                (&named, "explicit")
58            } else {
59                (&positional, "positional")
60            };
61
62            return targets
63                .iter()
64                .map(|span| {
65                    LintViolation::new(
66                        self.code(),
67                        format!(
68                            "Mixed positional and explicit references in {}. Found {} reference.",
69                            clause_name, style
70                        ),
71                        *span,
72                    )
73                })
74                .collect();
75        }
76
77        vec![]
78    }
79}
80
81/// Classify references in a GROUP BY or ORDER BY clause as positional (numeric)
82/// or named (identifier/expression).
83fn collect_ref_styles(
84    segment: &Segment,
85    positional: &mut Vec<rigsql_core::Span>,
86    named: &mut Vec<rigsql_core::Span>,
87) {
88    for child in segment.children() {
89        let st = child.segment_type();
90        match st {
91            // Skip keywords (GROUP, BY, ORDER, ASC, DESC), trivia, commas
92            SegmentType::Keyword
93            | SegmentType::Whitespace
94            | SegmentType::Newline
95            | SegmentType::Comma
96            | SegmentType::LineComment
97            | SegmentType::BlockComment => {}
98
99            // A bare NumberLiteral is a positional reference
100            SegmentType::NumericLiteral => {
101                positional.push(child.span());
102            }
103
104            // OrderByExpression wraps an expression + optional ASC/DESC
105            SegmentType::OrderByExpression => {
106                collect_ref_styles(child, positional, named);
107            }
108
109            // An expression node: check if it contains only a NumberLiteral
110            SegmentType::Expression => {
111                if is_single_number_literal(child) {
112                    positional.push(child.span());
113                } else {
114                    named.push(child.span());
115                }
116            }
117
118            // Identifiers, ColumnRef, QualifiedIdentifier, FunctionCall, etc. → named
119            _ => {
120                if !child.children().is_empty() {
121                    // Node type — check if it's a wrapper around a single number
122                    if is_single_number_literal(child) {
123                        positional.push(child.span());
124                    } else {
125                        named.push(child.span());
126                    }
127                } else {
128                    // Leaf token that's not a keyword/trivia → named reference
129                    named.push(child.span());
130                }
131            }
132        }
133    }
134}
135
136/// Check if a segment is (or contains only) a single NumberLiteral.
137fn is_single_number_literal(segment: &Segment) -> bool {
138    match segment {
139        Segment::Token(t) => t.segment_type == SegmentType::NumericLiteral,
140        Segment::Node(n) => {
141            let mut non_trivia = n.children.iter().filter(|c| !c.segment_type().is_trivia());
142            match (non_trivia.next(), non_trivia.next()) {
143                (Some(only), None) => is_single_number_literal(only),
144                _ => false,
145            }
146        }
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use crate::test_utils::lint_sql;
154
155    #[test]
156    fn test_am06_flags_mixed_group_by() {
157        // Mixing named 'foo' with positional '2'
158        let violations = lint_sql("SELECT foo, bar, SUM(baz) FROM t GROUP BY foo, 2", RuleAM06);
159        assert!(!violations.is_empty(), "Should flag mixed GROUP BY styles");
160    }
161
162    #[test]
163    fn test_am06_accepts_all_explicit_group_by() {
164        let violations = lint_sql(
165            "SELECT foo, bar, SUM(baz) FROM t GROUP BY foo, bar",
166            RuleAM06,
167        );
168        assert_eq!(violations.len(), 0);
169    }
170
171    #[test]
172    fn test_am06_accepts_all_positional_group_by() {
173        let violations = lint_sql("SELECT foo, bar, SUM(baz) FROM t GROUP BY 1, 2", RuleAM06);
174        assert_eq!(violations.len(), 0);
175    }
176
177    #[test]
178    fn test_am06_flags_mixed_order_by() {
179        let violations = lint_sql("SELECT a, b FROM t ORDER BY a, 2", RuleAM06);
180        assert!(!violations.is_empty(), "Should flag mixed ORDER BY styles");
181    }
182
183    #[test]
184    fn test_am06_accepts_all_explicit_order_by() {
185        let violations = lint_sql("SELECT a, b FROM t ORDER BY a, b", RuleAM06);
186        assert_eq!(violations.len(), 0);
187    }
188
189    #[test]
190    fn test_am06_accepts_all_positional_order_by() {
191        let violations = lint_sql("SELECT a, b FROM t ORDER BY 1, 2", RuleAM06);
192        assert_eq!(violations.len(), 0);
193    }
194}