Skip to main content

rigsql_rules/ambiguous/
am06.rs

1use rigsql_core::{Segment, SegmentType};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6/// AM06: Inconsistent column references in GROUP BY/ORDER BY.
7///
8/// GROUP BY and ORDER BY clauses should not mix positional (numeric) references
9/// with explicit (named) references. Use one style consistently.
10#[derive(Debug, Default)]
11pub struct RuleAM06;
12
13impl Rule for RuleAM06 {
14    fn code(&self) -> &'static str {
15        "AM06"
16    }
17    fn name(&self) -> &'static str {
18        "ambiguous.column_references"
19    }
20    fn description(&self) -> &'static str {
21        "Inconsistent column references in GROUP BY/ORDER BY."
22    }
23    fn explanation(&self) -> &'static str {
24        "GROUP BY and ORDER BY clauses should use a consistent style for column references: \
25         either all positional (e.g., GROUP BY 1, 2) or all explicit column names \
26         (e.g., GROUP BY foo, bar). Mixing styles like GROUP BY foo, 2 is ambiguous \
27         and hard to maintain."
28    }
29    fn groups(&self) -> &[RuleGroup] {
30        &[RuleGroup::Ambiguous]
31    }
32    fn is_fixable(&self) -> bool {
33        false
34    }
35
36    fn crawl_type(&self) -> CrawlType {
37        CrawlType::Segment(vec![SegmentType::GroupByClause, SegmentType::OrderByClause])
38    }
39
40    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
41        let mut positional = Vec::new();
42        let mut named = Vec::new();
43
44        collect_ref_styles(ctx.segment, &mut positional, &mut named);
45
46        // Only flag if there's a mix of styles
47        if !positional.is_empty() && !named.is_empty() {
48            let clause_name = match ctx.segment.segment_type() {
49                SegmentType::GroupByClause => "GROUP BY",
50                SegmentType::OrderByClause => "ORDER BY",
51                _ => "Clause",
52            };
53
54            // Flag the minority style references.
55            // If there are more positional than named, flag named ones and vice versa.
56            let (targets, style) = if positional.len() >= named.len() {
57                (&named, "explicit")
58            } else {
59                (&positional, "positional")
60            };
61
62            return targets
63                .iter()
64                .map(|span| {
65                    LintViolation::with_msg_key(
66                        self.code(),
67                        format!(
68                            "Mixed positional and explicit references in {}. Found {} reference.",
69                            clause_name, style
70                        ),
71                        *span,
72                        "rules.AM06.msg",
73                        vec![
74                            ("clause".to_string(), clause_name.to_string()),
75                            ("style".to_string(), style.to_string()),
76                        ],
77                    )
78                })
79                .collect();
80        }
81
82        vec![]
83    }
84}
85
86/// Classify references in a GROUP BY or ORDER BY clause as positional (numeric)
87/// or named (identifier/expression).
88fn collect_ref_styles(
89    segment: &Segment,
90    positional: &mut Vec<rigsql_core::Span>,
91    named: &mut Vec<rigsql_core::Span>,
92) {
93    for child in segment.children() {
94        let st = child.segment_type();
95        match st {
96            // Skip keywords (GROUP, BY, ORDER, ASC, DESC), trivia, commas
97            SegmentType::Keyword
98            | SegmentType::Whitespace
99            | SegmentType::Newline
100            | SegmentType::Comma
101            | SegmentType::LineComment
102            | SegmentType::BlockComment => {}
103
104            // A bare NumberLiteral is a positional reference
105            SegmentType::NumericLiteral => {
106                positional.push(child.span());
107            }
108
109            // OrderByExpression wraps an expression + optional ASC/DESC
110            SegmentType::OrderByExpression => {
111                collect_ref_styles(child, positional, named);
112            }
113
114            // An expression node: check if it contains only a NumberLiteral
115            SegmentType::Expression => {
116                if is_single_number_literal(child) {
117                    positional.push(child.span());
118                } else {
119                    named.push(child.span());
120                }
121            }
122
123            // Identifiers, ColumnRef, QualifiedIdentifier, FunctionCall, etc. → named
124            _ => {
125                if !child.children().is_empty() {
126                    // Node type — check if it's a wrapper around a single number
127                    if is_single_number_literal(child) {
128                        positional.push(child.span());
129                    } else {
130                        named.push(child.span());
131                    }
132                } else {
133                    // Leaf token that's not a keyword/trivia → named reference
134                    named.push(child.span());
135                }
136            }
137        }
138    }
139}
140
141/// Check if a segment is (or contains only) a single NumberLiteral.
142fn is_single_number_literal(segment: &Segment) -> bool {
143    match segment {
144        Segment::Token(t) => t.segment_type == SegmentType::NumericLiteral,
145        Segment::Node(n) => {
146            let mut non_trivia = n.children.iter().filter(|c| !c.segment_type().is_trivia());
147            match (non_trivia.next(), non_trivia.next()) {
148                (Some(only), None) => is_single_number_literal(only),
149                _ => false,
150            }
151        }
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::test_utils::lint_sql;
159
160    #[test]
161    fn test_am06_flags_mixed_group_by() {
162        // Mixing named 'foo' with positional '2'
163        let violations = lint_sql("SELECT foo, bar, SUM(baz) FROM t GROUP BY foo, 2", RuleAM06);
164        assert!(!violations.is_empty(), "Should flag mixed GROUP BY styles");
165    }
166
167    #[test]
168    fn test_am06_accepts_all_explicit_group_by() {
169        let violations = lint_sql(
170            "SELECT foo, bar, SUM(baz) FROM t GROUP BY foo, bar",
171            RuleAM06,
172        );
173        assert_eq!(violations.len(), 0);
174    }
175
176    #[test]
177    fn test_am06_accepts_all_positional_group_by() {
178        let violations = lint_sql("SELECT foo, bar, SUM(baz) FROM t GROUP BY 1, 2", RuleAM06);
179        assert_eq!(violations.len(), 0);
180    }
181
182    #[test]
183    fn test_am06_flags_mixed_order_by() {
184        let violations = lint_sql("SELECT a, b FROM t ORDER BY a, 2", RuleAM06);
185        assert!(!violations.is_empty(), "Should flag mixed ORDER BY styles");
186    }
187
188    #[test]
189    fn test_am06_accepts_all_explicit_order_by() {
190        let violations = lint_sql("SELECT a, b FROM t ORDER BY a, b", RuleAM06);
191        assert_eq!(violations.len(), 0);
192    }
193
194    #[test]
195    fn test_am06_accepts_all_positional_order_by() {
196        let violations = lint_sql("SELECT a, b FROM t ORDER BY 1, 2", RuleAM06);
197        assert_eq!(violations.len(), 0);
198    }
199}