Skip to main content

rigsql_rules/references/
rf02.rs

1use rigsql_core::{Segment, SegmentType, TokenKind};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6/// RF02: Column references should be qualified when multiple tables are present.
7///
8/// When a SELECT statement references multiple tables (via FROM + JOINs),
9/// all column references should be qualified with a table alias or name
10/// to avoid ambiguity.
11#[derive(Debug, Default)]
12pub struct RuleRF02;
13
14impl Rule for RuleRF02 {
15    fn code(&self) -> &'static str {
16        "RF02"
17    }
18    fn name(&self) -> &'static str {
19        "references.qualification"
20    }
21    fn description(&self) -> &'static str {
22        "Columns should be qualified when multiple tables are referenced."
23    }
24    fn explanation(&self) -> &'static str {
25        "When a query references multiple tables (via FROM and JOIN clauses), \
26         all column references should be qualified with a table name or alias \
27         (e.g., 'users.id' instead of 'id') to prevent ambiguity and improve readability."
28    }
29    fn groups(&self) -> &[RuleGroup] {
30        &[RuleGroup::References]
31    }
32    fn is_fixable(&self) -> bool {
33        false
34    }
35
36    fn crawl_type(&self) -> CrawlType {
37        CrawlType::Segment(vec![SegmentType::SelectStatement])
38    }
39
40    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
41        let table_count = count_tables(ctx.segment);
42
43        if table_count < 2 {
44            return vec![];
45        }
46
47        // Find unqualified column references across all relevant clauses
48        let mut violations = Vec::new();
49        collect_unqualified_columns(ctx.segment, &mut violations, self.code(), false);
50        violations
51    }
52}
53
54/// Count tables referenced in FROM and JOIN clauses.
55fn count_tables(stmt: &Segment) -> usize {
56    let mut count = 0;
57    for child in stmt.children() {
58        if child.segment_type() == SegmentType::FromClause {
59            count += count_tables_in_clause(child);
60        }
61    }
62    count
63}
64
65fn count_tables_in_clause(clause: &Segment) -> usize {
66    let mut count = 0;
67    for child in clause.children() {
68        match child.segment_type() {
69            SegmentType::Identifier
70            | SegmentType::QuotedIdentifier
71            | SegmentType::AliasExpression => {
72                count += 1;
73            }
74            SegmentType::QualifiedIdentifier => {
75                // e.g., schema.table — counts as one table
76                count += 1;
77            }
78            SegmentType::JoinClause => {
79                for join_child in child.children() {
80                    match join_child.segment_type() {
81                        SegmentType::Identifier
82                        | SegmentType::QuotedIdentifier
83                        | SegmentType::AliasExpression
84                        | SegmentType::QualifiedIdentifier => {
85                            count += 1;
86                            break;
87                        }
88                        _ => {}
89                    }
90                }
91            }
92            _ => {}
93        }
94    }
95    count
96}
97
98/// Contexts where bare Identifiers are likely column references.
99const COLUMN_CONTEXTS: &[SegmentType] = &[
100    SegmentType::SelectClause,
101    SegmentType::WhereClause,
102    SegmentType::HavingClause,
103    SegmentType::OrderByClause,
104    SegmentType::GroupByClause,
105    SegmentType::OnClause,
106    SegmentType::OrderByExpression,
107    SegmentType::BinaryExpression,
108];
109
110/// Segment types that represent table sources, not column references.
111const TABLE_SOURCE_CONTEXTS: &[SegmentType] = &[SegmentType::FromClause, SegmentType::JoinClause];
112
113/// Recursively find unqualified column references in column-relevant clauses.
114fn collect_unqualified_columns(
115    segment: &Segment,
116    violations: &mut Vec<LintViolation>,
117    code: &'static str,
118    in_table_source: bool,
119) {
120    // Skip subqueries to avoid cross-scope analysis
121    if segment.segment_type() == SegmentType::Subquery {
122        return;
123    }
124
125    let st = segment.segment_type();
126    let is_table_source = in_table_source || TABLE_SOURCE_CONTEXTS.contains(&st);
127
128    // QualifiedIdentifier / ColumnRef in table sources are table names, skip them
129    match st {
130        SegmentType::QualifiedIdentifier | SegmentType::ColumnRef => {
131            if is_table_source {
132                return;
133            }
134            // In column context: qualified refs are fine, only unqualified are violations
135            let has_dot = segment
136                .children()
137                .iter()
138                .any(|c| c.segment_type() == SegmentType::Dot);
139            if !has_dot {
140                // Unqualified column ref
141                if let Some(Segment::Token(t)) = segment
142                    .children()
143                    .iter()
144                    .find(|c| c.segment_type() == SegmentType::Identifier)
145                {
146                    // Skip TSQL variables (@var)
147                    if t.token.kind == TokenKind::AtSign {
148                        return;
149                    }
150                    violations.push(LintViolation::with_msg_key(
151                        code,
152                        format!(
153                            "Unqualified column reference '{}' in multi-table query.",
154                            t.token.text
155                        ),
156                        t.token.span,
157                        "rules.RF02.msg",
158                        vec![("name".to_string(), t.token.text.to_string())],
159                    ));
160                }
161            }
162            return;
163        }
164        _ => {}
165    }
166
167    // In column-relevant contexts, bare Identifiers are likely column references
168    if COLUMN_CONTEXTS.contains(&st) {
169        for child in segment.children() {
170            if child.segment_type() == SegmentType::Identifier {
171                if let Segment::Token(t) = child {
172                    // Skip TSQL variables (@var) — they're not column references
173                    if t.token.kind != TokenKind::AtSign {
174                        violations.push(LintViolation::with_msg_key(
175                            code,
176                            format!(
177                                "Unqualified column reference '{}' in multi-table query.",
178                                t.token.text
179                            ),
180                            t.token.span,
181                            "rules.RF02.msg",
182                            vec![("name".to_string(), t.token.text.to_string())],
183                        ));
184                    }
185                }
186            } else {
187                collect_unqualified_columns(child, violations, code, is_table_source);
188            }
189        }
190        return;
191    }
192
193    for child in segment.children() {
194        collect_unqualified_columns(child, violations, code, is_table_source);
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use crate::test_utils::lint_sql;
202
203    #[test]
204    fn test_rf02_flags_unqualified_in_multi_table() {
205        let violations = lint_sql(
206            "SELECT id FROM users JOIN orders ON users.id = orders.user_id",
207            RuleRF02,
208        );
209        assert!(!violations.is_empty(), "Should flag unqualified 'id'");
210        assert!(violations[0].message.contains("id"));
211    }
212
213    #[test]
214    fn test_rf02_accepts_qualified_in_multi_table() {
215        let violations = lint_sql(
216            "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id",
217            RuleRF02,
218        );
219        assert_eq!(violations.len(), 0);
220    }
221
222    #[test]
223    fn test_rf02_accepts_single_table() {
224        let violations = lint_sql("SELECT id FROM users", RuleRF02);
225        assert_eq!(violations.len(), 0);
226    }
227
228    #[test]
229    fn test_rf02_flags_unqualified_in_where() {
230        let violations = lint_sql(
231            "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id WHERE status = 1",
232            RuleRF02,
233        );
234        assert!(
235            !violations.is_empty(),
236            "Should flag unqualified 'status' in WHERE"
237        );
238    }
239
240    #[test]
241    fn test_rf02_ignores_qualified_table_in_from() {
242        // sys.columns is a table name, not a column ref
243        let violations = lint_sql("SELECT name FROM sys.columns WHERE object_id = 1", RuleRF02);
244        assert_eq!(violations.len(), 0);
245    }
246
247    #[test]
248    fn test_rf02_ignores_tsql_variables() {
249        // @SiteName is a TSQL variable, not a column reference
250        let violations = lint_sql(
251            "SELECT t1.a FROM t1 JOIN t2 ON t1.id = t2.id WHERE t1.x = @SiteName",
252            RuleRF02,
253        );
254        assert_eq!(violations.len(), 0);
255    }
256}