rigsql_rules/convention/
cv05.rs1use rigsql_core::{Segment, SegmentType, TokenKind};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::{LintViolation, SourceEdit};
5
6#[derive(Debug, Default)]
10pub struct RuleCV05;
11
12impl Rule for RuleCV05 {
13 fn code(&self) -> &'static str {
14 "CV05"
15 }
16 fn name(&self) -> &'static str {
17 "convention.is_null"
18 }
19 fn description(&self) -> &'static str {
20 "Comparisons with NULL should use IS or IS NOT."
21 }
22 fn explanation(&self) -> &'static str {
23 "In SQL, NULL is not a value but the absence of one. Comparison operators \
24 (=, !=, <>) with NULL always return NULL (which is falsy). Use 'IS NULL' or \
25 'IS NOT NULL' instead of '= NULL' or '!= NULL'."
26 }
27 fn groups(&self) -> &[RuleGroup] {
28 &[RuleGroup::Convention]
29 }
30 fn is_fixable(&self) -> bool {
31 true
32 }
33
34 fn crawl_type(&self) -> CrawlType {
35 CrawlType::Segment(vec![SegmentType::BinaryExpression])
36 }
37
38 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
39 if ctx
44 .parent
45 .is_some_and(|p| p.segment_type() == SegmentType::SetClause)
46 {
47 return vec![];
48 }
49
50 let children = ctx.segment.children();
51
52 let non_trivia: Vec<_> = children
55 .iter()
56 .filter(|c| !c.segment_type().is_trivia())
57 .collect();
58
59 if non_trivia.len() < 3 {
60 return vec![];
61 }
62
63 let op = non_trivia[1];
65 let is_comparison = matches!(op.segment_type(), SegmentType::ComparisonOperator);
66 if !is_comparison {
67 return vec![];
68 }
69
70 let lhs_is_null = is_null_literal(non_trivia[0]);
72 let rhs_is_null = non_trivia.get(2).is_some_and(|s| is_null_literal(s));
73
74 if lhs_is_null || rhs_is_null {
75 let op_text = op.tokens().first().map(|t| t.text.as_str()).unwrap_or("=");
77 let is_negated = op_text == "!=" || op_text == "<>";
78
79 let op_span = op.span();
83 let null_seg = if rhs_is_null {
84 non_trivia[2]
85 } else {
86 non_trivia[0]
87 };
88 let null_span = null_seg.span();
89
90 let fix = if rhs_is_null {
91 let replace_span = rigsql_core::Span::new(op_span.start, null_span.end);
93 let replacement = if is_negated { "IS NOT NULL" } else { "IS NULL" };
94 vec![SourceEdit::replace(replace_span, replacement)]
95 } else {
96 let is_not = if is_negated { "IS NOT NULL" } else { "IS NULL" };
98 let expr = non_trivia[2];
99 let whole_span = ctx.segment.span();
100 let expr_text = ctx
101 .source
102 .get(expr.span().start as usize..expr.span().end as usize)
103 .unwrap_or("?");
104 vec![SourceEdit::replace(
105 whole_span,
106 format!("{} {}", expr_text, is_not),
107 )]
108 };
109
110 return vec![LintViolation::with_fix_and_msg_key(
111 self.code(),
112 "Use IS NULL or IS NOT NULL instead of comparison operator with NULL.",
113 ctx.segment.span(),
114 fix,
115 "rules.CV05.msg",
116 vec![],
117 )];
118 }
119
120 vec![]
121 }
122}
123
124fn is_null_literal(seg: &Segment) -> bool {
125 match seg {
126 Segment::Token(t) => {
127 t.segment_type == SegmentType::NullLiteral
128 || (t.token.kind == TokenKind::Word && t.token.text.eq_ignore_ascii_case("NULL"))
129 }
130 Segment::Node(n) => n.segment_type == SegmentType::NullLiteral,
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::test_utils::{lint_sql, lint_sql_with_dialect};
138
139 #[test]
140 fn test_cv05_flags_equals_null() {
141 let violations = lint_sql("SELECT * FROM t WHERE a = NULL", RuleCV05);
142 assert_eq!(violations.len(), 1);
143 assert_eq!(violations[0].fixes.len(), 1);
144 assert_eq!(violations[0].fixes[0].new_text, "IS NULL");
145 }
146
147 #[test]
148 fn test_cv05_accepts_is_null() {
149 let violations = lint_sql("SELECT * FROM t WHERE a IS NULL", RuleCV05);
150 assert_eq!(violations.len(), 0);
151 }
152
153 #[test]
154 fn test_cv05_accepts_update_set_null_assignment() {
155 let violations = lint_sql("UPDATE t SET a = NULL WHERE id = 1", RuleCV05);
157 assert_eq!(violations.len(), 0);
158 }
159
160 #[test]
161 fn test_cv05_accepts_tsql_update_set_null_assignment() {
162 let sql = "UPDATE dbo.TestTable\nSET TestColumn = NULL\nWHERE ID = 1;";
163 let violations = lint_sql_with_dialect(sql, RuleCV05, "tsql");
164 assert_eq!(violations.len(), 0);
165 }
166
167 #[test]
168 fn test_cv05_flags_where_in_update() {
169 let violations = lint_sql("UPDATE t SET a = 1 WHERE b = NULL", RuleCV05);
172 assert_eq!(violations.len(), 1);
173 assert_eq!(violations[0].fixes[0].new_text, "IS NULL");
174 }
175
176 #[test]
177 fn test_cv05_flags_multi_column_update_where() {
178 let violations = lint_sql("UPDATE t SET a = NULL, b = NULL WHERE c = NULL", RuleCV05);
181 assert_eq!(violations.len(), 1);
182 }
183}