use rigsql_core::{Segment, SegmentType};
use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
use crate::violation::LintViolation;
#[derive(Debug, Default)]
pub struct RuleAL03;
impl Rule for RuleAL03 {
fn code(&self) -> &'static str {
"AL03"
}
fn name(&self) -> &'static str {
"aliasing.expression"
}
fn description(&self) -> &'static str {
"Column expression without alias. Use explicit alias."
}
fn explanation(&self) -> &'static str {
"Complex expressions in SELECT should have an explicit alias using AS. \
An unlabeled expression like 'SELECT a + b FROM t' is harder to work with \
than 'SELECT a + b AS total FROM t'. This makes result sets self-documenting."
}
fn groups(&self) -> &[RuleGroup] {
&[RuleGroup::Aliasing]
}
fn is_fixable(&self) -> bool {
false
}
fn crawl_type(&self) -> CrawlType {
CrawlType::Segment(vec![SegmentType::SelectClause])
}
fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
let children = ctx.segment.children();
let mut violations = Vec::new();
for child in children {
let st = child.segment_type();
if st.is_trivia() || st == SegmentType::Keyword || st == SegmentType::Comma {
continue;
}
if is_complex_expression(child) && !is_wrapped_in_alias(child, ctx) {
violations.push(LintViolation::with_msg_key(
self.code(),
"Column expression should have an explicit alias.",
child.span(),
"rules.AL03.msg",
vec![],
));
}
}
violations
}
}
fn is_complex_expression(seg: &Segment) -> bool {
matches!(
seg.segment_type(),
SegmentType::BinaryExpression
| SegmentType::FunctionCall
| SegmentType::CaseExpression
| SegmentType::CastExpression
| SegmentType::ParenExpression
| SegmentType::UnaryExpression
)
}
fn is_wrapped_in_alias(seg: &Segment, _ctx: &RuleContext) -> bool {
seg.segment_type() == SegmentType::AliasExpression
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::lint_sql;
#[test]
fn test_al03_flags_function_without_alias() {
let violations = lint_sql("SELECT COUNT(*) FROM t", RuleAL03);
assert_eq!(violations.len(), 1);
}
#[test]
fn test_al03_accepts_function_with_alias() {
let violations = lint_sql("SELECT COUNT(*) AS cnt FROM t", RuleAL03);
assert_eq!(violations.len(), 0);
}
#[test]
fn test_al03_accepts_simple_column() {
let violations = lint_sql("SELECT a FROM t", RuleAL03);
assert_eq!(violations.len(), 0);
}
}