use rigsql_core::{Segment, SegmentType, TokenKind};
use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
use crate::violation::LintViolation;
#[derive(Debug, Default)]
pub struct RuleRF02;
impl Rule for RuleRF02 {
fn code(&self) -> &'static str {
"RF02"
}
fn name(&self) -> &'static str {
"references.qualification"
}
fn description(&self) -> &'static str {
"Columns should be qualified when multiple tables are referenced."
}
fn explanation(&self) -> &'static str {
"When a query references multiple tables (via FROM and JOIN clauses), \
all column references should be qualified with a table name or alias \
(e.g., 'users.id' instead of 'id') to prevent ambiguity and improve readability."
}
fn groups(&self) -> &[RuleGroup] {
&[RuleGroup::References]
}
fn is_fixable(&self) -> bool {
false
}
fn crawl_type(&self) -> CrawlType {
CrawlType::Segment(vec![SegmentType::SelectStatement])
}
fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
let table_count = count_tables(ctx.segment);
if table_count < 2 {
return vec![];
}
let mut violations = Vec::new();
collect_unqualified_columns(ctx.segment, &mut violations, self.code(), false);
violations
}
}
fn count_tables(stmt: &Segment) -> usize {
let mut count = 0;
for child in stmt.children() {
if child.segment_type() == SegmentType::FromClause {
count += count_tables_in_clause(child);
}
}
count
}
fn count_tables_in_clause(clause: &Segment) -> usize {
let mut count = 0;
for child in clause.children() {
match child.segment_type() {
SegmentType::Identifier
| SegmentType::QuotedIdentifier
| SegmentType::AliasExpression => {
count += 1;
}
SegmentType::QualifiedIdentifier => {
count += 1;
}
SegmentType::JoinClause => {
for join_child in child.children() {
match join_child.segment_type() {
SegmentType::Identifier
| SegmentType::QuotedIdentifier
| SegmentType::AliasExpression
| SegmentType::QualifiedIdentifier => {
count += 1;
break;
}
_ => {}
}
}
}
_ => {}
}
}
count
}
const COLUMN_CONTEXTS: &[SegmentType] = &[
SegmentType::SelectClause,
SegmentType::WhereClause,
SegmentType::HavingClause,
SegmentType::OrderByClause,
SegmentType::GroupByClause,
SegmentType::OnClause,
SegmentType::OrderByExpression,
SegmentType::BinaryExpression,
];
const TABLE_SOURCE_CONTEXTS: &[SegmentType] = &[SegmentType::FromClause, SegmentType::JoinClause];
fn collect_unqualified_columns(
segment: &Segment,
violations: &mut Vec<LintViolation>,
code: &'static str,
in_table_source: bool,
) {
if segment.segment_type() == SegmentType::Subquery {
return;
}
let st = segment.segment_type();
let is_table_source = in_table_source || TABLE_SOURCE_CONTEXTS.contains(&st);
match st {
SegmentType::QualifiedIdentifier | SegmentType::ColumnRef => {
if is_table_source {
return;
}
let has_dot = segment
.children()
.iter()
.any(|c| c.segment_type() == SegmentType::Dot);
if !has_dot {
if let Some(Segment::Token(t)) = segment
.children()
.iter()
.find(|c| c.segment_type() == SegmentType::Identifier)
{
if t.token.kind == TokenKind::AtSign {
return;
}
violations.push(LintViolation::with_msg_key(
code,
format!(
"Unqualified column reference '{}' in multi-table query.",
t.token.text
),
t.token.span,
"rules.RF02.msg",
vec![("name".to_string(), t.token.text.to_string())],
));
}
}
return;
}
_ => {}
}
if COLUMN_CONTEXTS.contains(&st) {
for child in segment.children() {
if child.segment_type() == SegmentType::Identifier {
if let Segment::Token(t) = child {
if t.token.kind != TokenKind::AtSign {
violations.push(LintViolation::with_msg_key(
code,
format!(
"Unqualified column reference '{}' in multi-table query.",
t.token.text
),
t.token.span,
"rules.RF02.msg",
vec![("name".to_string(), t.token.text.to_string())],
));
}
}
} else {
collect_unqualified_columns(child, violations, code, is_table_source);
}
}
return;
}
for child in segment.children() {
collect_unqualified_columns(child, violations, code, is_table_source);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::lint_sql;
#[test]
fn test_rf02_flags_unqualified_in_multi_table() {
let violations = lint_sql(
"SELECT id FROM users JOIN orders ON users.id = orders.user_id",
RuleRF02,
);
assert!(!violations.is_empty(), "Should flag unqualified 'id'");
assert!(violations[0].message.contains("id"));
}
#[test]
fn test_rf02_accepts_qualified_in_multi_table() {
let violations = lint_sql(
"SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id",
RuleRF02,
);
assert_eq!(violations.len(), 0);
}
#[test]
fn test_rf02_accepts_single_table() {
let violations = lint_sql("SELECT id FROM users", RuleRF02);
assert_eq!(violations.len(), 0);
}
#[test]
fn test_rf02_flags_unqualified_in_where() {
let violations = lint_sql(
"SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id WHERE status = 1",
RuleRF02,
);
assert!(
!violations.is_empty(),
"Should flag unqualified 'status' in WHERE"
);
}
#[test]
fn test_rf02_ignores_qualified_table_in_from() {
let violations = lint_sql("SELECT name FROM sys.columns WHERE object_id = 1", RuleRF02);
assert_eq!(violations.len(), 0);
}
#[test]
fn test_rf02_ignores_tsql_variables() {
let violations = lint_sql(
"SELECT t1.a FROM t1 JOIN t2 ON t1.id = t2.id WHERE t1.x = @SiteName",
RuleRF02,
);
assert_eq!(violations.len(), 0);
}
}