1use rigsql_core::{Segment, SegmentType};
2
3use crate::violation::{LintViolation, SourceEdit};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum RuleGroup {
8 Capitalisation,
9 Layout,
10 Convention,
11 Aliasing,
12 Ambiguous,
13 References,
14 Structure,
15}
16
17#[derive(Debug, Clone)]
19pub enum CrawlType {
20 Segment(Vec<SegmentType>),
22 RootOnly,
24}
25
26pub struct RuleContext<'a> {
28 pub segment: &'a Segment,
30 pub parent: Option<&'a Segment>,
32 pub root: &'a Segment,
34 pub siblings: &'a [Segment],
36 pub index_in_parent: usize,
38 pub source: &'a str,
40 pub dialect: &'a str,
42}
43
44pub trait Rule: Send + Sync {
46 fn code(&self) -> &'static str;
48
49 fn name(&self) -> &'static str;
51
52 fn description(&self) -> &'static str;
54
55 fn explanation(&self) -> &'static str;
57
58 fn groups(&self) -> &[RuleGroup];
60
61 fn is_fixable(&self) -> bool;
63
64 fn crawl_type(&self) -> CrawlType;
66
67 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation>;
69
70 fn configure(&mut self, _settings: &std::collections::HashMap<String, String>) {}
73}
74
75pub fn lint(
77 root: &Segment,
78 source: &str,
79 rules: &[Box<dyn Rule>],
80 dialect: &str,
81) -> Vec<LintViolation> {
82 let mut violations = Vec::new();
83
84 for rule in rules {
85 match rule.crawl_type() {
86 CrawlType::RootOnly => {
87 let ctx = RuleContext {
88 segment: root,
89 parent: None,
90 root,
91 siblings: std::slice::from_ref(root),
92 index_in_parent: 0,
93 source,
94 dialect,
95 };
96 violations.extend(rule.eval(&ctx));
97 }
98 CrawlType::Segment(ref types) => {
99 walk_and_lint_indexed(
100 root,
101 0,
102 None,
103 root,
104 source,
105 dialect,
106 rule.as_ref(),
107 types,
108 &mut violations,
109 );
110 }
111 }
112 }
113
114 violations.sort_by_key(|v| (v.span.start, v.span.end));
115 violations
116}
117
118#[allow(clippy::too_many_arguments)]
119fn walk_and_lint_indexed(
120 segment: &Segment,
121 index_in_parent: usize,
122 parent: Option<&Segment>,
123 root: &Segment,
124 source: &str,
125 dialect: &str,
126 rule: &dyn Rule,
127 types: &[SegmentType],
128 violations: &mut Vec<LintViolation>,
129) {
130 if types.contains(&segment.segment_type()) {
131 let siblings = parent
132 .map(|p| p.children())
133 .unwrap_or(std::slice::from_ref(segment));
134
135 let ctx = RuleContext {
136 segment,
137 parent,
138 root,
139 siblings,
140 index_in_parent,
141 source,
142 dialect,
143 };
144 violations.extend(rule.eval(&ctx));
145 }
146
147 let children = segment.children();
148 for (i, child) in children.iter().enumerate() {
149 walk_and_lint_indexed(
150 child,
151 i,
152 Some(segment),
153 root,
154 source,
155 dialect,
156 rule,
157 types,
158 violations,
159 );
160 }
161}
162
163pub fn apply_fixes(source: &str, violations: &[LintViolation]) -> String {
168 let mut edits: Vec<&SourceEdit> = violations.iter().flat_map(|v| v.fixes.iter()).collect();
170
171 if edits.is_empty() {
172 return source.to_string();
173 }
174
175 edits.sort_by(|a, b| {
177 b.span
178 .start
179 .cmp(&a.span.start)
180 .then(b.span.end.cmp(&a.span.end))
181 });
182
183 edits.dedup_by(|a, b| a.span == b.span);
185
186 let mut result = source.to_string();
187 let mut last_applied_start = u32::MAX;
188
189 for edit in &edits {
190 let start = edit.span.start as usize;
191 let end = edit.span.end as usize;
192
193 if edit.span.end > last_applied_start {
195 continue;
196 }
197 if edit.span.start >= last_applied_start {
199 continue;
200 }
201
202 if start <= result.len() && end <= result.len() {
203 result.replace_range(start..end, &edit.new_text);
204 last_applied_start = edit.span.start;
205 }
206 }
207
208 result
209}