Skip to main content

rigsql_rules/
rule.rs

1use rigsql_core::{Segment, SegmentType};
2
3use crate::violation::{LintViolation, SourceEdit};
4
5/// Rule group / category.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum RuleGroup {
8    Capitalisation,
9    Layout,
10    Convention,
11    Aliasing,
12    Ambiguous,
13    References,
14    Structure,
15}
16
17/// Controls which CST nodes a rule visits.
18#[derive(Debug, Clone)]
19pub enum CrawlType {
20    /// Visit every segment of the listed types.
21    Segment(Vec<SegmentType>),
22    /// Visit the root segment only (whole-file rules).
23    RootOnly,
24}
25
26/// Context passed to a rule during evaluation.
27pub struct RuleContext<'a> {
28    /// The segment being evaluated.
29    pub segment: &'a Segment,
30    /// The parent segment (if any).
31    pub parent: Option<&'a Segment>,
32    /// The root file segment.
33    pub root: &'a Segment,
34    /// Direct children of the parent, for sibling access.
35    pub siblings: &'a [Segment],
36    /// Index of `segment` within `siblings`.
37    pub index_in_parent: usize,
38    /// Full source text.
39    pub source: &'a str,
40    /// SQL dialect name (e.g. "ansi", "postgres", "tsql").
41    pub dialect: &'a str,
42}
43
44impl<'a> RuleContext<'a> {
45    /// Get the next non-trivia sibling after the current segment.
46    pub fn next_non_trivia_sibling(&self) -> Option<&'a Segment> {
47        self.siblings[self.index_in_parent + 1..]
48            .iter()
49            .find(|s| !s.segment_type().is_trivia())
50    }
51
52    /// Get the previous non-trivia sibling before the current segment.
53    pub fn prev_non_trivia_sibling(&self) -> Option<&'a Segment> {
54        self.siblings[..self.index_in_parent]
55            .iter()
56            .rev()
57            .find(|s| !s.segment_type().is_trivia())
58    }
59}
60
61/// Trait that all lint rules must implement.
62pub trait Rule: Send + Sync {
63    /// Rule code, e.g. "LT01".
64    fn code(&self) -> &'static str;
65
66    /// Human-readable name, e.g. "layout.spacing".
67    fn name(&self) -> &'static str;
68
69    /// One-line description.
70    fn description(&self) -> &'static str;
71
72    /// Multi-sentence explanation for AI consumers.
73    fn explanation(&self) -> &'static str;
74
75    /// Rule group.
76    fn groups(&self) -> &[RuleGroup];
77
78    /// Can this rule auto-fix violations?
79    fn is_fixable(&self) -> bool;
80
81    /// Which segments should be visited.
82    fn crawl_type(&self) -> CrawlType;
83
84    /// Evaluate the rule at the given context, returning violations.
85    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation>;
86
87    /// Configure the rule with key-value settings from config.
88    /// Default implementation is a no-op.
89    fn configure(&mut self, _settings: &std::collections::HashMap<String, String>) {}
90}
91
92/// Run all rules against a parsed CST.
93pub fn lint(
94    root: &Segment,
95    source: &str,
96    rules: &[Box<dyn Rule>],
97    dialect: &str,
98) -> Vec<LintViolation> {
99    let mut violations = Vec::new();
100
101    for rule in rules {
102        match rule.crawl_type() {
103            CrawlType::RootOnly => {
104                let ctx = RuleContext {
105                    segment: root,
106                    parent: None,
107                    root,
108                    siblings: std::slice::from_ref(root),
109                    index_in_parent: 0,
110                    source,
111                    dialect,
112                };
113                violations.extend(rule.eval(&ctx));
114            }
115            CrawlType::Segment(ref types) => {
116                let walker = LintWalker {
117                    root,
118                    source,
119                    dialect,
120                    rule: rule.as_ref(),
121                    types,
122                };
123                walker.walk(root, 0, None, &mut violations);
124            }
125        }
126    }
127
128    violations.sort_by_key(|v| (v.span.start, v.span.end));
129    violations
130}
131
132/// Walks the CST and evaluates a rule at matching segments.
133struct LintWalker<'a> {
134    root: &'a Segment,
135    source: &'a str,
136    dialect: &'a str,
137    rule: &'a dyn Rule,
138    types: &'a [SegmentType],
139}
140
141impl<'a> LintWalker<'a> {
142    fn walk(
143        &self,
144        segment: &'a Segment,
145        index_in_parent: usize,
146        parent: Option<&'a Segment>,
147        violations: &mut Vec<LintViolation>,
148    ) {
149        if self.types.contains(&segment.segment_type()) {
150            let siblings = parent
151                .map(|p| p.children())
152                .unwrap_or(std::slice::from_ref(segment));
153
154            let ctx = RuleContext {
155                segment,
156                parent,
157                root: self.root,
158                siblings,
159                index_in_parent,
160                source: self.source,
161                dialect: self.dialect,
162            };
163            violations.extend(self.rule.eval(&ctx));
164        }
165
166        for (i, child) in segment.children().iter().enumerate() {
167            self.walk(child, i, Some(segment), violations);
168        }
169    }
170}
171
172/// Apply source edits to produce a fixed source string.
173///
174/// Edits are sorted by span start (descending) and applied back-to-front
175/// so that earlier offsets remain valid. Overlapping edits are skipped.
176pub fn apply_fixes(source: &str, violations: &[LintViolation]) -> String {
177    // Collect all edits from all violations
178    let mut edits: Vec<&SourceEdit> = violations.iter().flat_map(|v| v.fixes.iter()).collect();
179
180    if edits.is_empty() {
181        return source.to_string();
182    }
183
184    // Sort by span start descending, then by span end descending (apply from back)
185    edits.sort_by(|a, b| {
186        b.span
187            .start
188            .cmp(&a.span.start)
189            .then(b.span.end.cmp(&a.span.end))
190    });
191
192    // Deduplicate edits with identical spans
193    edits.dedup_by(|a, b| a.span == b.span);
194
195    let mut result = source.to_string();
196    let mut last_applied_start = u32::MAX;
197
198    for edit in &edits {
199        let start = edit.span.start as usize;
200        let end = edit.span.end as usize;
201
202        // Skip overlapping edits: any edit whose range touches the previously applied one
203        if edit.span.end > last_applied_start {
204            continue;
205        }
206        // Also skip inserts at the same offset as a previously applied edit
207        if edit.span.start >= last_applied_start {
208            continue;
209        }
210
211        if start <= result.len() && end <= result.len() {
212            result.replace_range(start..end, &edit.new_text);
213            last_applied_start = edit.span.start;
214        }
215    }
216
217    result
218}