rigsql_rules/layout/
lt11.rs1use rigsql_core::{Segment, SegmentType, TokenSegment};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6#[derive(Debug, Default)]
10pub struct RuleLT11;
11
12impl Rule for RuleLT11 {
13 fn code(&self) -> &'static str {
14 "LT11"
15 }
16 fn name(&self) -> &'static str {
17 "layout.set_operators"
18 }
19 fn description(&self) -> &'static str {
20 "Set operators should be surrounded by newlines."
21 }
22 fn explanation(&self) -> &'static str {
23 "Set operators such as UNION, INTERSECT, and EXCEPT combine the results of \
24 multiple queries. They should be surrounded by newlines to visually separate \
25 the individual queries and improve readability."
26 }
27 fn groups(&self) -> &[RuleGroup] {
28 &[RuleGroup::Layout]
29 }
30 fn is_fixable(&self) -> bool {
31 false
32 }
33
34 fn crawl_type(&self) -> CrawlType {
35 CrawlType::RootOnly
36 }
37
38 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
39 let mut violations = Vec::new();
40 let tokens = collect_leaf_tokens(ctx.segment);
41
42 for (i, t) in tokens.iter().enumerate() {
43 if !t.token.text.eq_ignore_ascii_case("UNION")
44 && !t.token.text.eq_ignore_ascii_case("INTERSECT")
45 && !t.token.text.eq_ignore_ascii_case("EXCEPT")
46 {
47 continue;
48 }
49
50 let op_span = t.token.span;
51
52 let has_newline_before = check_adjacent_newline(&tokens, i, Direction::Before);
53
54 let mut end_idx = i;
56 let mut j = i + 1;
57 while j < tokens.len() {
58 if tokens[j].segment_type.is_trivia() {
59 j += 1;
60 } else {
61 if tokens[j].token.text.eq_ignore_ascii_case("ALL") {
62 end_idx = j;
63 }
64 break;
65 }
66 }
67
68 let has_newline_after = check_adjacent_newline(&tokens, end_idx, Direction::After);
69
70 if !has_newline_before {
71 violations.push(LintViolation::new(
72 self.code(),
73 format!(
74 "Expected newline before '{}'.",
75 t.token.text.to_ascii_uppercase()
76 ),
77 op_span,
78 ));
79 }
80
81 if !has_newline_after {
82 violations.push(LintViolation::new(
83 self.code(),
84 format!(
85 "Expected newline after '{}'.",
86 t.token.text.to_ascii_uppercase()
87 ),
88 op_span,
89 ));
90 }
91 }
92
93 violations
94 }
95}
96
97enum Direction {
98 Before,
99 After,
100}
101
102fn check_adjacent_newline(tokens: &[TokenSegment], idx: usize, dir: Direction) -> bool {
103 let mut j = match dir {
104 Direction::Before => idx.wrapping_sub(1),
105 Direction::After => idx + 1,
106 };
107 loop {
108 if j >= tokens.len() {
109 return false;
110 }
111 if tokens[j].segment_type == SegmentType::Newline {
112 return true;
113 }
114 if tokens[j].segment_type != SegmentType::Whitespace {
115 return false;
116 }
117 j = match dir {
118 Direction::Before => j.wrapping_sub(1),
119 Direction::After => j + 1,
120 };
121 }
122}
123
124fn collect_leaf_tokens(segment: &Segment) -> Vec<TokenSegment> {
126 let mut out = Vec::new();
127 collect_leaf_tokens_inner(segment, &mut out);
128 out
129}
130
131fn collect_leaf_tokens_inner(segment: &Segment, out: &mut Vec<TokenSegment>) {
132 match segment {
133 Segment::Token(t) => out.push(t.clone()),
134 Segment::Node(n) => {
135 for child in &n.children {
136 collect_leaf_tokens_inner(child, out);
137 }
138 }
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use crate::test_utils::lint_sql;
146
147 #[test]
148 fn test_lt11_flags_inline_union() {
149 let violations = lint_sql("SELECT 1 UNION SELECT 2", RuleLT11);
150 assert!(!violations.is_empty());
151 assert!(violations.iter().all(|v| v.rule_code == "LT11"));
152 }
153
154 #[test]
155 fn test_lt11_accepts_newlines() {
156 let violations = lint_sql("SELECT 1\nUNION\nSELECT 2", RuleLT11);
157 assert_eq!(violations.len(), 0);
158 }
159}