rigsql_rules/rigsql/
rg02.rs1use rigsql_core::{Segment, SegmentType};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6#[derive(Debug, Default)]
11pub struct RuleRG02;
12
13impl Rule for RuleRG02 {
14 fn code(&self) -> &'static str {
15 "RG02"
16 }
17 fn name(&self) -> &'static str {
18 "rigsql.union_null"
19 }
20 fn description(&self) -> &'static str {
21 "Consistent use of NULL in UNION."
22 }
23 fn explanation(&self) -> &'static str {
24 "When using NULL in a UNION query, the type of the NULL column is ambiguous. \
25 Use an explicit CAST (e.g. CAST(NULL AS INTEGER)) to make the intent clear \
26 and avoid potential type-mismatch issues across UNION branches."
27 }
28 fn groups(&self) -> &[RuleGroup] {
29 &[RuleGroup::Convention]
30 }
31 fn is_fixable(&self) -> bool {
32 false
33 }
34
35 fn crawl_type(&self) -> CrawlType {
36 CrawlType::RootOnly
37 }
38
39 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
40 let mut violations = Vec::new();
41 find_union_nulls(ctx.root, false, &mut violations);
42 violations
43 }
44}
45
46fn find_union_nulls(segment: &Segment, in_union: bool, violations: &mut Vec<LintViolation>) {
49 let children = segment.children();
50
51 let has_union = children.iter().any(|c| {
52 if let Segment::Token(t) = c {
53 t.token.text.eq_ignore_ascii_case("UNION")
54 } else {
55 false
56 }
57 });
58
59 let union_context = in_union || has_union;
60
61 for child in children {
62 if union_context && child.segment_type() == SegmentType::SelectStatement {
63 check_select_for_bare_null(child, violations);
64 }
65
66 if child.segment_type() == SegmentType::SelectClause && union_context {
67 check_clause_for_bare_null(child, violations);
68 }
69
70 find_union_nulls(child, union_context, violations);
71 }
72}
73
74fn check_select_for_bare_null(select: &Segment, violations: &mut Vec<LintViolation>) {
75 select.walk(&mut |seg| {
76 if seg.segment_type() == SegmentType::SelectClause {
77 check_clause_for_bare_null(seg, violations);
78 }
79 });
80}
81
82fn check_clause_for_bare_null(clause: &Segment, violations: &mut Vec<LintViolation>) {
83 for child in clause.children() {
84 find_bare_nulls(child, violations);
85 }
86}
87
88fn find_bare_nulls(segment: &Segment, violations: &mut Vec<LintViolation>) {
90 if segment.segment_type() == SegmentType::CastExpression {
91 return;
92 }
93
94 if segment.segment_type() == SegmentType::NullLiteral {
95 violations.push(LintViolation::new(
96 "RG02",
97 "Bare NULL in UNION. Consider using an explicit CAST.",
98 segment.span(),
99 ));
100 return;
101 }
102
103 for child in segment.children() {
104 find_bare_nulls(child, violations);
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::test_utils::lint_sql;
112
113 #[test]
114 fn test_rg02_accepts_non_union() {
115 let violations = lint_sql("SELECT NULL FROM t", RuleRG02);
116 assert_eq!(violations.len(), 0);
117 }
118}