1use rigsql_core::{Segment, SegmentType};
2
3use crate::violation::{LintViolation, SourceEdit};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum RuleGroup {
8 Capitalisation,
9 Layout,
10 Convention,
11 Aliasing,
12 Ambiguous,
13 References,
14 Structure,
15}
16
17#[derive(Debug, Clone)]
19pub enum CrawlType {
20 Segment(Vec<SegmentType>),
22 RootOnly,
24}
25
26pub struct RuleContext<'a> {
28 pub segment: &'a Segment,
30 pub parent: Option<&'a Segment>,
32 pub root: &'a Segment,
34 pub siblings: &'a [Segment],
36 pub index_in_parent: usize,
38 pub source: &'a str,
40}
41
42pub trait Rule: Send + Sync {
44 fn code(&self) -> &'static str;
46
47 fn name(&self) -> &'static str;
49
50 fn description(&self) -> &'static str;
52
53 fn explanation(&self) -> &'static str;
55
56 fn groups(&self) -> &[RuleGroup];
58
59 fn is_fixable(&self) -> bool;
61
62 fn crawl_type(&self) -> CrawlType;
64
65 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation>;
67
68 fn configure(&mut self, _settings: &std::collections::HashMap<String, String>) {}
71}
72
73pub fn lint(root: &Segment, source: &str, rules: &[Box<dyn Rule>]) -> Vec<LintViolation> {
75 let mut violations = Vec::new();
76
77 for rule in rules {
78 match rule.crawl_type() {
79 CrawlType::RootOnly => {
80 let ctx = RuleContext {
81 segment: root,
82 parent: None,
83 root,
84 siblings: std::slice::from_ref(root),
85 index_in_parent: 0,
86 source,
87 };
88 violations.extend(rule.eval(&ctx));
89 }
90 CrawlType::Segment(ref types) => {
91 walk_and_lint_indexed(
92 root,
93 0,
94 None,
95 root,
96 source,
97 rule.as_ref(),
98 types,
99 &mut violations,
100 );
101 }
102 }
103 }
104
105 violations.sort_by_key(|v| (v.span.start, v.span.end));
106 violations
107}
108
109#[allow(clippy::too_many_arguments)]
110fn walk_and_lint_indexed(
111 segment: &Segment,
112 index_in_parent: usize,
113 parent: Option<&Segment>,
114 root: &Segment,
115 source: &str,
116 rule: &dyn Rule,
117 types: &[SegmentType],
118 violations: &mut Vec<LintViolation>,
119) {
120 if types.contains(&segment.segment_type()) {
121 let siblings = parent
122 .map(|p| p.children())
123 .unwrap_or(std::slice::from_ref(segment));
124
125 let ctx = RuleContext {
126 segment,
127 parent,
128 root,
129 siblings,
130 index_in_parent,
131 source,
132 };
133 violations.extend(rule.eval(&ctx));
134 }
135
136 let children = segment.children();
137 for (i, child) in children.iter().enumerate() {
138 walk_and_lint_indexed(
139 child,
140 i,
141 Some(segment),
142 root,
143 source,
144 rule,
145 types,
146 violations,
147 );
148 }
149}
150
151pub fn apply_fixes(source: &str, violations: &[LintViolation]) -> String {
156 let mut edits: Vec<&SourceEdit> = violations.iter().flat_map(|v| v.fixes.iter()).collect();
158
159 if edits.is_empty() {
160 return source.to_string();
161 }
162
163 edits.sort_by(|a, b| {
165 b.span
166 .start
167 .cmp(&a.span.start)
168 .then(b.span.end.cmp(&a.span.end))
169 });
170
171 edits.dedup_by(|a, b| a.span == b.span);
173
174 let mut result = source.to_string();
175 let mut last_applied_start = u32::MAX;
176
177 for edit in &edits {
178 let start = edit.span.start as usize;
179 let end = edit.span.end as usize;
180
181 if edit.span.end > last_applied_start {
183 continue;
184 }
185 if edit.span.start >= last_applied_start {
187 continue;
188 }
189
190 if start <= result.len() && end <= result.len() {
191 result.replace_range(start..end, &edit.new_text);
192 last_applied_start = edit.span.start;
193 }
194 }
195
196 result
197}