Skip to main content

rigsql_rules/rigsql/
rg02.rs

1use rigsql_core::{Segment, SegmentType};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6/// RG02: Consistent use of NULL in UNION.
7///
8/// Bare NULL literals in UNION SELECT items should have an explicit type cast
9/// for clarity and consistency.
10#[derive(Debug, Default)]
11pub struct RuleRG02;
12
13impl Rule for RuleRG02 {
14    fn code(&self) -> &'static str {
15        "RG02"
16    }
17    fn name(&self) -> &'static str {
18        "rigsql.union_null"
19    }
20    fn description(&self) -> &'static str {
21        "Consistent use of NULL in UNION."
22    }
23    fn explanation(&self) -> &'static str {
24        "When using NULL in a UNION query, the type of the NULL column is ambiguous. \
25         Use an explicit CAST (e.g. CAST(NULL AS INTEGER)) to make the intent clear \
26         and avoid potential type-mismatch issues across UNION branches."
27    }
28    fn groups(&self) -> &[RuleGroup] {
29        &[RuleGroup::Convention]
30    }
31    fn is_fixable(&self) -> bool {
32        false
33    }
34
35    fn crawl_type(&self) -> CrawlType {
36        CrawlType::RootOnly
37    }
38
39    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
40        let mut violations = Vec::new();
41        find_union_nulls(ctx.root, false, &mut violations);
42        violations
43    }
44}
45
46/// Recursively walk the tree looking for SelectStatements that are part of a
47/// UNION. When we find one, scan its SelectClause items for bare NullLiterals.
48fn find_union_nulls(segment: &Segment, in_union: bool, violations: &mut Vec<LintViolation>) {
49    let children = segment.children();
50
51    let has_union = children.iter().any(|c| {
52        if let Segment::Token(t) = c {
53            t.token.text.eq_ignore_ascii_case("UNION")
54        } else {
55            false
56        }
57    });
58
59    let union_context = in_union || has_union;
60
61    for child in children {
62        if union_context && child.segment_type() == SegmentType::SelectStatement {
63            check_select_for_bare_null(child, violations);
64        }
65
66        if child.segment_type() == SegmentType::SelectClause && union_context {
67            check_clause_for_bare_null(child, violations);
68        }
69
70        find_union_nulls(child, union_context, violations);
71    }
72}
73
74fn check_select_for_bare_null(select: &Segment, violations: &mut Vec<LintViolation>) {
75    select.walk(&mut |seg| {
76        if seg.segment_type() == SegmentType::SelectClause {
77            check_clause_for_bare_null(seg, violations);
78        }
79    });
80}
81
82fn check_clause_for_bare_null(clause: &Segment, violations: &mut Vec<LintViolation>) {
83    for child in clause.children() {
84        find_bare_nulls(child, violations);
85    }
86}
87
88/// Walk looking for NullLiterals that are NOT inside a CastExpression.
89fn find_bare_nulls(segment: &Segment, violations: &mut Vec<LintViolation>) {
90    if segment.segment_type() == SegmentType::CastExpression {
91        return;
92    }
93
94    if segment.segment_type() == SegmentType::NullLiteral {
95        violations.push(LintViolation::new(
96            "RG02",
97            "Bare NULL in UNION. Consider using an explicit CAST.",
98            segment.span(),
99        ));
100        return;
101    }
102
103    for child in segment.children() {
104        find_bare_nulls(child, violations);
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::test_utils::lint_sql;
112
113    #[test]
114    fn test_rg02_accepts_non_union() {
115        let violations = lint_sql("SELECT NULL FROM t", RuleRG02);
116        assert_eq!(violations.len(), 0);
117    }
118}