Skip to main content

rigsql_rules/references/
rf02.rs

1use rigsql_core::{Segment, SegmentType, TokenKind};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6/// RF02: Column references should be qualified when multiple tables are present.
7///
8/// When a SELECT statement references multiple tables (via FROM + JOINs),
9/// all column references should be qualified with a table alias or name
10/// to avoid ambiguity.
11#[derive(Debug, Default)]
12pub struct RuleRF02;
13
14impl Rule for RuleRF02 {
15    fn code(&self) -> &'static str {
16        "RF02"
17    }
18    fn name(&self) -> &'static str {
19        "references.qualification"
20    }
21    fn description(&self) -> &'static str {
22        "Columns should be qualified when multiple tables are referenced."
23    }
24    fn explanation(&self) -> &'static str {
25        "When a query references multiple tables (via FROM and JOIN clauses), \
26         all column references should be qualified with a table name or alias \
27         (e.g., 'users.id' instead of 'id') to prevent ambiguity and improve readability."
28    }
29    fn groups(&self) -> &[RuleGroup] {
30        &[RuleGroup::References]
31    }
32    fn is_fixable(&self) -> bool {
33        false
34    }
35
36    fn crawl_type(&self) -> CrawlType {
37        CrawlType::Segment(vec![SegmentType::SelectStatement])
38    }
39
40    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
41        let table_count = count_tables(ctx.segment);
42
43        if table_count < 2 {
44            return vec![];
45        }
46
47        // Find unqualified column references across all relevant clauses
48        let mut violations = Vec::new();
49        collect_unqualified_columns(ctx.segment, &mut violations, self.code(), false);
50        violations
51    }
52}
53
54/// Count tables referenced in FROM and JOIN clauses.
55fn count_tables(stmt: &Segment) -> usize {
56    let mut count = 0;
57    for child in stmt.children() {
58        if child.segment_type() == SegmentType::FromClause {
59            count += count_tables_in_clause(child);
60        }
61    }
62    count
63}
64
65fn count_tables_in_clause(clause: &Segment) -> usize {
66    let mut count = 0;
67    for child in clause.children() {
68        match child.segment_type() {
69            SegmentType::Identifier
70            | SegmentType::QuotedIdentifier
71            | SegmentType::AliasExpression => {
72                count += 1;
73            }
74            SegmentType::QualifiedIdentifier => {
75                // e.g., schema.table — counts as one table
76                count += 1;
77            }
78            SegmentType::JoinClause => {
79                for join_child in child.children() {
80                    match join_child.segment_type() {
81                        SegmentType::Identifier
82                        | SegmentType::QuotedIdentifier
83                        | SegmentType::AliasExpression
84                        | SegmentType::QualifiedIdentifier => {
85                            count += 1;
86                            break;
87                        }
88                        _ => {}
89                    }
90                }
91            }
92            _ => {}
93        }
94    }
95    count
96}
97
98/// Contexts where bare Identifiers are likely column references.
99const COLUMN_CONTEXTS: &[SegmentType] = &[
100    SegmentType::SelectClause,
101    SegmentType::WhereClause,
102    SegmentType::HavingClause,
103    SegmentType::OrderByClause,
104    SegmentType::GroupByClause,
105    SegmentType::OnClause,
106    SegmentType::OrderByExpression,
107    SegmentType::BinaryExpression,
108];
109
110/// Segment types that represent table sources, not column references.
111const TABLE_SOURCE_CONTEXTS: &[SegmentType] = &[SegmentType::FromClause, SegmentType::JoinClause];
112
113/// Recursively find unqualified column references in column-relevant clauses.
114fn collect_unqualified_columns(
115    segment: &Segment,
116    violations: &mut Vec<LintViolation>,
117    code: &'static str,
118    in_table_source: bool,
119) {
120    // Skip subqueries to avoid cross-scope analysis
121    if segment.segment_type() == SegmentType::Subquery {
122        return;
123    }
124
125    let st = segment.segment_type();
126    let is_table_source = in_table_source || TABLE_SOURCE_CONTEXTS.contains(&st);
127
128    // QualifiedIdentifier / ColumnRef in table sources are table names, skip them
129    match st {
130        SegmentType::QualifiedIdentifier | SegmentType::ColumnRef => {
131            if is_table_source {
132                return;
133            }
134            // In column context: qualified refs are fine, only unqualified are violations
135            let has_dot = segment
136                .children()
137                .iter()
138                .any(|c| c.segment_type() == SegmentType::Dot);
139            if !has_dot {
140                // Unqualified column ref
141                if let Some(Segment::Token(t)) = segment
142                    .children()
143                    .iter()
144                    .find(|c| c.segment_type() == SegmentType::Identifier)
145                {
146                    // Skip TSQL variables (@var)
147                    if t.token.kind == TokenKind::AtSign {
148                        return;
149                    }
150                    violations.push(LintViolation::new(
151                        code,
152                        format!(
153                            "Unqualified column reference '{}' in multi-table query.",
154                            t.token.text
155                        ),
156                        t.token.span,
157                    ));
158                }
159            }
160            return;
161        }
162        _ => {}
163    }
164
165    // In column-relevant contexts, bare Identifiers are likely column references
166    if COLUMN_CONTEXTS.contains(&st) {
167        for child in segment.children() {
168            if child.segment_type() == SegmentType::Identifier {
169                if let Segment::Token(t) = child {
170                    // Skip TSQL variables (@var) — they're not column references
171                    if t.token.kind != TokenKind::AtSign {
172                        violations.push(LintViolation::new(
173                            code,
174                            format!(
175                                "Unqualified column reference '{}' in multi-table query.",
176                                t.token.text
177                            ),
178                            t.token.span,
179                        ));
180                    }
181                }
182            } else {
183                collect_unqualified_columns(child, violations, code, is_table_source);
184            }
185        }
186        return;
187    }
188
189    for child in segment.children() {
190        collect_unqualified_columns(child, violations, code, is_table_source);
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use crate::test_utils::lint_sql;
198
199    #[test]
200    fn test_rf02_flags_unqualified_in_multi_table() {
201        let violations = lint_sql(
202            "SELECT id FROM users JOIN orders ON users.id = orders.user_id",
203            RuleRF02,
204        );
205        assert!(!violations.is_empty(), "Should flag unqualified 'id'");
206        assert!(violations[0].message.contains("id"));
207    }
208
209    #[test]
210    fn test_rf02_accepts_qualified_in_multi_table() {
211        let violations = lint_sql(
212            "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id",
213            RuleRF02,
214        );
215        assert_eq!(violations.len(), 0);
216    }
217
218    #[test]
219    fn test_rf02_accepts_single_table() {
220        let violations = lint_sql("SELECT id FROM users", RuleRF02);
221        assert_eq!(violations.len(), 0);
222    }
223
224    #[test]
225    fn test_rf02_flags_unqualified_in_where() {
226        let violations = lint_sql(
227            "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id WHERE status = 1",
228            RuleRF02,
229        );
230        assert!(
231            !violations.is_empty(),
232            "Should flag unqualified 'status' in WHERE"
233        );
234    }
235
236    #[test]
237    fn test_rf02_ignores_qualified_table_in_from() {
238        // sys.columns is a table name, not a column ref
239        let violations = lint_sql("SELECT name FROM sys.columns WHERE object_id = 1", RuleRF02);
240        assert_eq!(violations.len(), 0);
241    }
242
243    #[test]
244    fn test_rf02_ignores_tsql_variables() {
245        // @SiteName is a TSQL variable, not a column reference
246        let violations = lint_sql(
247            "SELECT t1.a FROM t1 JOIN t2 ON t1.id = t2.id WHERE t1.x = @SiteName",
248            RuleRF02,
249        );
250        assert_eq!(violations.len(), 0);
251    }
252}