rigsql_rules/references/
rf02.rs1use rigsql_core::{Segment, SegmentType, TokenKind};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6#[derive(Debug, Default)]
12pub struct RuleRF02;
13
14impl Rule for RuleRF02 {
15 fn code(&self) -> &'static str {
16 "RF02"
17 }
18 fn name(&self) -> &'static str {
19 "references.qualification"
20 }
21 fn description(&self) -> &'static str {
22 "Columns should be qualified when multiple tables are referenced."
23 }
24 fn explanation(&self) -> &'static str {
25 "When a query references multiple tables (via FROM and JOIN clauses), \
26 all column references should be qualified with a table name or alias \
27 (e.g., 'users.id' instead of 'id') to prevent ambiguity and improve readability."
28 }
29 fn groups(&self) -> &[RuleGroup] {
30 &[RuleGroup::References]
31 }
32 fn is_fixable(&self) -> bool {
33 false
34 }
35
36 fn crawl_type(&self) -> CrawlType {
37 CrawlType::Segment(vec![SegmentType::SelectStatement])
38 }
39
40 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
41 let table_count = count_tables(ctx.segment);
42
43 if table_count < 2 {
44 return vec![];
45 }
46
47 let mut violations = Vec::new();
49 collect_unqualified_columns(ctx.segment, &mut violations, self.code(), false);
50 violations
51 }
52}
53
54fn count_tables(stmt: &Segment) -> usize {
56 let mut count = 0;
57 for child in stmt.children() {
58 if child.segment_type() == SegmentType::FromClause {
59 count += count_tables_in_clause(child);
60 }
61 }
62 count
63}
64
65fn count_tables_in_clause(clause: &Segment) -> usize {
66 let mut count = 0;
67 for child in clause.children() {
68 match child.segment_type() {
69 SegmentType::Identifier
70 | SegmentType::QuotedIdentifier
71 | SegmentType::AliasExpression => {
72 count += 1;
73 }
74 SegmentType::QualifiedIdentifier => {
75 count += 1;
77 }
78 SegmentType::JoinClause => {
79 for join_child in child.children() {
80 match join_child.segment_type() {
81 SegmentType::Identifier
82 | SegmentType::QuotedIdentifier
83 | SegmentType::AliasExpression
84 | SegmentType::QualifiedIdentifier => {
85 count += 1;
86 break;
87 }
88 _ => {}
89 }
90 }
91 }
92 _ => {}
93 }
94 }
95 count
96}
97
98const COLUMN_CONTEXTS: &[SegmentType] = &[
100 SegmentType::SelectClause,
101 SegmentType::WhereClause,
102 SegmentType::HavingClause,
103 SegmentType::OrderByClause,
104 SegmentType::GroupByClause,
105 SegmentType::OnClause,
106 SegmentType::OrderByExpression,
107 SegmentType::BinaryExpression,
108];
109
110const TABLE_SOURCE_CONTEXTS: &[SegmentType] = &[SegmentType::FromClause, SegmentType::JoinClause];
112
113fn collect_unqualified_columns(
115 segment: &Segment,
116 violations: &mut Vec<LintViolation>,
117 code: &'static str,
118 in_table_source: bool,
119) {
120 if segment.segment_type() == SegmentType::Subquery {
122 return;
123 }
124
125 let st = segment.segment_type();
126 let is_table_source = in_table_source || TABLE_SOURCE_CONTEXTS.contains(&st);
127
128 match st {
130 SegmentType::QualifiedIdentifier | SegmentType::ColumnRef => {
131 if is_table_source {
132 return;
133 }
134 let has_dot = segment
136 .children()
137 .iter()
138 .any(|c| c.segment_type() == SegmentType::Dot);
139 if !has_dot {
140 if let Some(Segment::Token(t)) = segment
142 .children()
143 .iter()
144 .find(|c| c.segment_type() == SegmentType::Identifier)
145 {
146 if t.token.kind == TokenKind::AtSign {
148 return;
149 }
150 violations.push(LintViolation::new(
151 code,
152 format!(
153 "Unqualified column reference '{}' in multi-table query.",
154 t.token.text
155 ),
156 t.token.span,
157 ));
158 }
159 }
160 return;
161 }
162 _ => {}
163 }
164
165 if COLUMN_CONTEXTS.contains(&st) {
167 for child in segment.children() {
168 if child.segment_type() == SegmentType::Identifier {
169 if let Segment::Token(t) = child {
170 if t.token.kind != TokenKind::AtSign {
172 violations.push(LintViolation::new(
173 code,
174 format!(
175 "Unqualified column reference '{}' in multi-table query.",
176 t.token.text
177 ),
178 t.token.span,
179 ));
180 }
181 }
182 } else {
183 collect_unqualified_columns(child, violations, code, is_table_source);
184 }
185 }
186 return;
187 }
188
189 for child in segment.children() {
190 collect_unqualified_columns(child, violations, code, is_table_source);
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use crate::test_utils::lint_sql;
198
199 #[test]
200 fn test_rf02_flags_unqualified_in_multi_table() {
201 let violations = lint_sql(
202 "SELECT id FROM users JOIN orders ON users.id = orders.user_id",
203 RuleRF02,
204 );
205 assert!(!violations.is_empty(), "Should flag unqualified 'id'");
206 assert!(violations[0].message.contains("id"));
207 }
208
209 #[test]
210 fn test_rf02_accepts_qualified_in_multi_table() {
211 let violations = lint_sql(
212 "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id",
213 RuleRF02,
214 );
215 assert_eq!(violations.len(), 0);
216 }
217
218 #[test]
219 fn test_rf02_accepts_single_table() {
220 let violations = lint_sql("SELECT id FROM users", RuleRF02);
221 assert_eq!(violations.len(), 0);
222 }
223
224 #[test]
225 fn test_rf02_flags_unqualified_in_where() {
226 let violations = lint_sql(
227 "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id WHERE status = 1",
228 RuleRF02,
229 );
230 assert!(
231 !violations.is_empty(),
232 "Should flag unqualified 'status' in WHERE"
233 );
234 }
235
236 #[test]
237 fn test_rf02_ignores_qualified_table_in_from() {
238 let violations = lint_sql("SELECT name FROM sys.columns WHERE object_id = 1", RuleRF02);
240 assert_eq!(violations.len(), 0);
241 }
242
243 #[test]
244 fn test_rf02_ignores_tsql_variables() {
245 let violations = lint_sql(
247 "SELECT t1.a FROM t1 JOIN t2 ON t1.id = t2.id WHERE t1.x = @SiteName",
248 RuleRF02,
249 );
250 assert_eq!(violations.len(), 0);
251 }
252}