rigsql_rules/layout/
lt11.rs1use rigsql_core::{Segment, SegmentType, TokenSegment};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6#[derive(Debug, Default)]
10pub struct RuleLT11;
11
12impl Rule for RuleLT11 {
13 fn code(&self) -> &'static str {
14 "LT11"
15 }
16 fn name(&self) -> &'static str {
17 "layout.set_operators"
18 }
19 fn description(&self) -> &'static str {
20 "Set operators should be surrounded by newlines."
21 }
22 fn explanation(&self) -> &'static str {
23 "Set operators such as UNION, INTERSECT, and EXCEPT combine the results of \
24 multiple queries. They should be surrounded by newlines to visually separate \
25 the individual queries and improve readability."
26 }
27 fn groups(&self) -> &[RuleGroup] {
28 &[RuleGroup::Layout]
29 }
30 fn is_fixable(&self) -> bool {
31 false
32 }
33
34 fn crawl_type(&self) -> CrawlType {
35 CrawlType::RootOnly
36 }
37
38 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
39 let mut violations = Vec::new();
40 let tokens = collect_leaf_tokens(ctx.segment);
41
42 for (i, t) in tokens.iter().enumerate() {
43 if !t.token.text.eq_ignore_ascii_case("UNION")
44 && !t.token.text.eq_ignore_ascii_case("INTERSECT")
45 && !t.token.text.eq_ignore_ascii_case("EXCEPT")
46 {
47 continue;
48 }
49
50 let op_span = t.token.span;
51
52 let has_newline_before = check_adjacent_newline(&tokens, i, Direction::Before);
53
54 let mut end_idx = i;
56 let mut j = i + 1;
57 while j < tokens.len() {
58 if tokens[j].segment_type.is_trivia() {
59 j += 1;
60 } else {
61 if tokens[j].token.text.eq_ignore_ascii_case("ALL") {
62 end_idx = j;
63 }
64 break;
65 }
66 }
67
68 let has_newline_after = check_adjacent_newline(&tokens, end_idx, Direction::After);
69
70 if !has_newline_before {
71 let operator_text = t.token.text.to_ascii_uppercase();
72 violations.push(LintViolation::with_msg_key(
73 self.code(),
74 format!("Expected newline before '{}'.", operator_text),
75 op_span,
76 "rules.LT11.msg.before",
77 vec![("operator".to_string(), operator_text)],
78 ));
79 }
80
81 if !has_newline_after {
82 let operator_text = t.token.text.to_ascii_uppercase();
83 violations.push(LintViolation::with_msg_key(
84 self.code(),
85 format!("Expected newline after '{}'.", operator_text),
86 op_span,
87 "rules.LT11.msg.after",
88 vec![("operator".to_string(), operator_text)],
89 ));
90 }
91 }
92
93 violations
94 }
95}
96
97enum Direction {
98 Before,
99 After,
100}
101
102fn check_adjacent_newline(tokens: &[TokenSegment], idx: usize, dir: Direction) -> bool {
103 let mut j = match dir {
104 Direction::Before => idx.wrapping_sub(1),
105 Direction::After => idx + 1,
106 };
107 loop {
108 if j >= tokens.len() {
109 return false;
110 }
111 if tokens[j].segment_type == SegmentType::Newline {
112 return true;
113 }
114 if tokens[j].segment_type != SegmentType::Whitespace
116 && tokens[j].segment_type != SegmentType::LineComment
117 && tokens[j].segment_type != SegmentType::BlockComment
118 {
119 return false;
120 }
121 j = match dir {
122 Direction::Before => j.wrapping_sub(1),
123 Direction::After => j + 1,
124 };
125 }
126}
127
128fn collect_leaf_tokens(segment: &Segment) -> Vec<TokenSegment> {
130 let mut out = Vec::new();
131 collect_leaf_tokens_inner(segment, &mut out);
132 out
133}
134
135fn collect_leaf_tokens_inner(segment: &Segment, out: &mut Vec<TokenSegment>) {
136 match segment {
137 Segment::Token(t) => out.push(t.clone()),
138 Segment::Node(n) => {
139 for child in &n.children {
140 collect_leaf_tokens_inner(child, out);
141 }
142 }
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use crate::test_utils::lint_sql;
150
151 #[test]
152 fn test_lt11_flags_inline_union() {
153 let violations = lint_sql("SELECT 1 UNION SELECT 2", RuleLT11);
154 assert!(!violations.is_empty());
155 assert!(violations.iter().all(|v| v.rule_code == "LT11"));
156 }
157
158 #[test]
159 fn test_lt11_accepts_newlines() {
160 let violations = lint_sql("SELECT 1\nUNION\nSELECT 2", RuleLT11);
161 assert_eq!(violations.len(), 0);
162 }
163
164 #[test]
165 fn test_lt11_accepts_union_with_trailing_comment() {
166 let violations = lint_sql("SELECT 1\nUNION -- noqa: AM02\nSELECT 2", RuleLT11);
168 assert_eq!(violations.len(), 0);
169 }
170
171 #[test]
172 fn test_lt11_accepts_union_with_leading_comment() {
173 let violations = lint_sql("SELECT 1\n-- comment\nUNION\nSELECT 2", RuleLT11);
175 assert_eq!(violations.len(), 0);
176 }
177}