rigsql_rules/references/
rf02.rs1use rigsql_core::{Segment, SegmentType, TokenKind};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6#[derive(Debug, Default)]
12pub struct RuleRF02;
13
14impl Rule for RuleRF02 {
15 fn code(&self) -> &'static str {
16 "RF02"
17 }
18 fn name(&self) -> &'static str {
19 "references.qualification"
20 }
21 fn description(&self) -> &'static str {
22 "Columns should be qualified when multiple tables are referenced."
23 }
24 fn explanation(&self) -> &'static str {
25 "When a query references multiple tables (via FROM and JOIN clauses), \
26 all column references should be qualified with a table name or alias \
27 (e.g., 'users.id' instead of 'id') to prevent ambiguity and improve readability."
28 }
29 fn groups(&self) -> &[RuleGroup] {
30 &[RuleGroup::References]
31 }
32 fn is_fixable(&self) -> bool {
33 false
34 }
35
36 fn crawl_type(&self) -> CrawlType {
37 CrawlType::Segment(vec![SegmentType::SelectStatement])
38 }
39
40 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
41 let table_count = count_tables(ctx.segment);
42
43 if table_count < 2 {
44 return vec![];
45 }
46
47 let mut violations = Vec::new();
49 collect_unqualified_columns(ctx.segment, &mut violations, self.code(), false);
50 violations
51 }
52}
53
54fn count_tables(stmt: &Segment) -> usize {
56 let mut count = 0;
57 for child in stmt.children() {
58 if child.segment_type() == SegmentType::FromClause {
59 count += count_tables_in_clause(child);
60 }
61 }
62 count
63}
64
65fn count_tables_in_clause(clause: &Segment) -> usize {
66 let mut count = 0;
67 for child in clause.children() {
68 match child.segment_type() {
69 SegmentType::Identifier
70 | SegmentType::QuotedIdentifier
71 | SegmentType::AliasExpression => {
72 count += 1;
73 }
74 SegmentType::QualifiedIdentifier => {
75 count += 1;
77 }
78 SegmentType::JoinClause => {
79 for join_child in child.children() {
80 match join_child.segment_type() {
81 SegmentType::Identifier
82 | SegmentType::QuotedIdentifier
83 | SegmentType::AliasExpression
84 | SegmentType::QualifiedIdentifier => {
85 count += 1;
86 break;
87 }
88 _ => {}
89 }
90 }
91 }
92 _ => {}
93 }
94 }
95 count
96}
97
98const COLUMN_CONTEXTS: &[SegmentType] = &[
100 SegmentType::SelectClause,
101 SegmentType::WhereClause,
102 SegmentType::HavingClause,
103 SegmentType::OrderByClause,
104 SegmentType::GroupByClause,
105 SegmentType::OnClause,
106 SegmentType::OrderByExpression,
107 SegmentType::BinaryExpression,
108];
109
110const TABLE_SOURCE_CONTEXTS: &[SegmentType] = &[SegmentType::FromClause, SegmentType::JoinClause];
112
113fn collect_unqualified_columns(
115 segment: &Segment,
116 violations: &mut Vec<LintViolation>,
117 code: &'static str,
118 in_table_source: bool,
119) {
120 if segment.segment_type() == SegmentType::Subquery {
122 return;
123 }
124
125 let st = segment.segment_type();
126 let is_table_source = in_table_source || TABLE_SOURCE_CONTEXTS.contains(&st);
127
128 match st {
130 SegmentType::QualifiedIdentifier | SegmentType::ColumnRef => {
131 if is_table_source {
132 return;
133 }
134 let has_dot = segment
136 .children()
137 .iter()
138 .any(|c| c.segment_type() == SegmentType::Dot);
139 if !has_dot {
140 if let Some(Segment::Token(t)) = segment
142 .children()
143 .iter()
144 .find(|c| c.segment_type() == SegmentType::Identifier)
145 {
146 if t.token.kind == TokenKind::AtSign {
148 return;
149 }
150 violations.push(LintViolation::with_msg_key(
151 code,
152 format!(
153 "Unqualified column reference '{}' in multi-table query.",
154 t.token.text
155 ),
156 t.token.span,
157 "rules.RF02.msg",
158 vec![("name".to_string(), t.token.text.to_string())],
159 ));
160 }
161 }
162 return;
163 }
164 _ => {}
165 }
166
167 if COLUMN_CONTEXTS.contains(&st) {
169 for child in segment.children() {
170 if child.segment_type() == SegmentType::Identifier {
171 if let Segment::Token(t) = child {
172 if t.token.kind != TokenKind::AtSign {
174 violations.push(LintViolation::with_msg_key(
175 code,
176 format!(
177 "Unqualified column reference '{}' in multi-table query.",
178 t.token.text
179 ),
180 t.token.span,
181 "rules.RF02.msg",
182 vec![("name".to_string(), t.token.text.to_string())],
183 ));
184 }
185 }
186 } else {
187 collect_unqualified_columns(child, violations, code, is_table_source);
188 }
189 }
190 return;
191 }
192
193 for child in segment.children() {
194 collect_unqualified_columns(child, violations, code, is_table_source);
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use crate::test_utils::lint_sql;
202
203 #[test]
204 fn test_rf02_flags_unqualified_in_multi_table() {
205 let violations = lint_sql(
206 "SELECT id FROM users JOIN orders ON users.id = orders.user_id",
207 RuleRF02,
208 );
209 assert!(!violations.is_empty(), "Should flag unqualified 'id'");
210 assert!(violations[0].message.contains("id"));
211 }
212
213 #[test]
214 fn test_rf02_accepts_qualified_in_multi_table() {
215 let violations = lint_sql(
216 "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id",
217 RuleRF02,
218 );
219 assert_eq!(violations.len(), 0);
220 }
221
222 #[test]
223 fn test_rf02_accepts_single_table() {
224 let violations = lint_sql("SELECT id FROM users", RuleRF02);
225 assert_eq!(violations.len(), 0);
226 }
227
228 #[test]
229 fn test_rf02_flags_unqualified_in_where() {
230 let violations = lint_sql(
231 "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id WHERE status = 1",
232 RuleRF02,
233 );
234 assert!(
235 !violations.is_empty(),
236 "Should flag unqualified 'status' in WHERE"
237 );
238 }
239
240 #[test]
241 fn test_rf02_ignores_qualified_table_in_from() {
242 let violations = lint_sql("SELECT name FROM sys.columns WHERE object_id = 1", RuleRF02);
244 assert_eq!(violations.len(), 0);
245 }
246
247 #[test]
248 fn test_rf02_ignores_tsql_variables() {
249 let violations = lint_sql(
251 "SELECT t1.a FROM t1 JOIN t2 ON t1.id = t2.id WHERE t1.x = @SiteName",
252 RuleRF02,
253 );
254 assert_eq!(violations.len(), 0);
255 }
256}