1use rigsql_core::{Segment, SegmentType};
2
3use crate::violation::{LintViolation, SourceEdit};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum RuleGroup {
8 Capitalisation,
9 Layout,
10 Convention,
11 Aliasing,
12 Ambiguous,
13 References,
14 Structure,
15}
16
17#[derive(Debug, Clone)]
19pub enum CrawlType {
20 Segment(Vec<SegmentType>),
22 RootOnly,
24}
25
26pub struct RuleContext<'a> {
28 pub segment: &'a Segment,
30 pub parent: Option<&'a Segment>,
32 pub root: &'a Segment,
34 pub siblings: &'a [Segment],
36 pub index_in_parent: usize,
38 pub source: &'a str,
40 pub dialect: &'a str,
42}
43
44impl<'a> RuleContext<'a> {
45 pub fn next_non_trivia_sibling(&self) -> Option<&'a Segment> {
47 self.siblings[self.index_in_parent + 1..]
48 .iter()
49 .find(|s| !s.segment_type().is_trivia())
50 }
51
52 pub fn prev_non_trivia_sibling(&self) -> Option<&'a Segment> {
54 self.siblings[..self.index_in_parent]
55 .iter()
56 .rev()
57 .find(|s| !s.segment_type().is_trivia())
58 }
59}
60
61pub trait Rule: Send + Sync {
63 fn code(&self) -> &'static str;
65
66 fn name(&self) -> &'static str;
68
69 fn description(&self) -> &'static str;
71
72 fn explanation(&self) -> &'static str;
74
75 fn groups(&self) -> &[RuleGroup];
77
78 fn is_fixable(&self) -> bool;
80
81 fn crawl_type(&self) -> CrawlType;
83
84 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation>;
86
87 fn configure(&mut self, _settings: &std::collections::HashMap<String, String>) {}
90}
91
92pub fn lint(
94 root: &Segment,
95 source: &str,
96 rules: &[Box<dyn Rule>],
97 dialect: &str,
98) -> Vec<LintViolation> {
99 let mut violations = Vec::new();
100
101 for rule in rules {
102 match rule.crawl_type() {
103 CrawlType::RootOnly => {
104 let ctx = RuleContext {
105 segment: root,
106 parent: None,
107 root,
108 siblings: std::slice::from_ref(root),
109 index_in_parent: 0,
110 source,
111 dialect,
112 };
113 violations.extend(rule.eval(&ctx));
114 }
115 CrawlType::Segment(ref types) => {
116 let walker = LintWalker {
117 root,
118 source,
119 dialect,
120 rule: rule.as_ref(),
121 types,
122 };
123 walker.walk(root, 0, None, &mut violations);
124 }
125 }
126 }
127
128 violations.sort_by_key(|v| (v.span.start, v.span.end));
129 violations
130}
131
132struct LintWalker<'a> {
134 root: &'a Segment,
135 source: &'a str,
136 dialect: &'a str,
137 rule: &'a dyn Rule,
138 types: &'a [SegmentType],
139}
140
141impl<'a> LintWalker<'a> {
142 fn walk(
143 &self,
144 segment: &'a Segment,
145 index_in_parent: usize,
146 parent: Option<&'a Segment>,
147 violations: &mut Vec<LintViolation>,
148 ) {
149 if self.types.contains(&segment.segment_type()) {
150 let siblings = parent
151 .map(|p| p.children())
152 .unwrap_or(std::slice::from_ref(segment));
153
154 let ctx = RuleContext {
155 segment,
156 parent,
157 root: self.root,
158 siblings,
159 index_in_parent,
160 source: self.source,
161 dialect: self.dialect,
162 };
163 violations.extend(self.rule.eval(&ctx));
164 }
165
166 for (i, child) in segment.children().iter().enumerate() {
167 self.walk(child, i, Some(segment), violations);
168 }
169 }
170}
171
172pub fn apply_fixes(source: &str, violations: &[LintViolation]) -> String {
177 let mut edits: Vec<&SourceEdit> = violations.iter().flat_map(|v| v.fixes.iter()).collect();
179
180 if edits.is_empty() {
181 return source.to_string();
182 }
183
184 edits.sort_by(|a, b| {
186 b.span
187 .start
188 .cmp(&a.span.start)
189 .then(b.span.end.cmp(&a.span.end))
190 });
191
192 edits.dedup_by(|a, b| a.span == b.span);
194
195 let mut result = source.to_string();
196 let mut last_applied_start = u32::MAX;
197
198 for edit in &edits {
199 let start = edit.span.start as usize;
200 let end = edit.span.end as usize;
201
202 if edit.span.end > last_applied_start {
204 continue;
205 }
206 if edit.span.start >= last_applied_start {
208 continue;
209 }
210
211 if start <= result.len() && end <= result.len() {
212 result.replace_range(start..end, &edit.new_text);
213 last_applied_start = edit.span.start;
214 }
215 }
216
217 result
218}