Skip to main content

rigsql_rules/
rule.rs

1use rigsql_core::{Segment, SegmentType};
2
3use crate::violation::{LintViolation, SourceEdit};
4
5/// Rule group / category.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum RuleGroup {
8    Capitalisation,
9    Layout,
10    Convention,
11    Aliasing,
12    Ambiguous,
13    References,
14    Structure,
15}
16
17/// Controls which CST nodes a rule visits.
18#[derive(Debug, Clone)]
19pub enum CrawlType {
20    /// Visit every segment of the listed types.
21    Segment(Vec<SegmentType>),
22    /// Visit the root segment only (whole-file rules).
23    RootOnly,
24}
25
26/// Context passed to a rule during evaluation.
27pub struct RuleContext<'a> {
28    /// The segment being evaluated.
29    pub segment: &'a Segment,
30    /// The parent segment (if any).
31    pub parent: Option<&'a Segment>,
32    /// The root file segment.
33    pub root: &'a Segment,
34    /// Direct children of the parent, for sibling access.
35    pub siblings: &'a [Segment],
36    /// Index of `segment` within `siblings`.
37    pub index_in_parent: usize,
38    /// Full source text.
39    pub source: &'a str,
40    /// SQL dialect name (e.g. "ansi", "postgres", "tsql").
41    pub dialect: &'a str,
42}
43
44/// Trait that all lint rules must implement.
45pub trait Rule: Send + Sync {
46    /// Rule code, e.g. "LT01".
47    fn code(&self) -> &'static str;
48
49    /// Human-readable name, e.g. "layout.spacing".
50    fn name(&self) -> &'static str;
51
52    /// One-line description.
53    fn description(&self) -> &'static str;
54
55    /// Multi-sentence explanation for AI consumers.
56    fn explanation(&self) -> &'static str;
57
58    /// Rule group.
59    fn groups(&self) -> &[RuleGroup];
60
61    /// Can this rule auto-fix violations?
62    fn is_fixable(&self) -> bool;
63
64    /// Which segments should be visited.
65    fn crawl_type(&self) -> CrawlType;
66
67    /// Evaluate the rule at the given context, returning violations.
68    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation>;
69
70    /// Configure the rule with key-value settings from config.
71    /// Default implementation is a no-op.
72    fn configure(&mut self, _settings: &std::collections::HashMap<String, String>) {}
73}
74
75/// Run all rules against a parsed CST.
76pub fn lint(
77    root: &Segment,
78    source: &str,
79    rules: &[Box<dyn Rule>],
80    dialect: &str,
81) -> Vec<LintViolation> {
82    let mut violations = Vec::new();
83
84    for rule in rules {
85        match rule.crawl_type() {
86            CrawlType::RootOnly => {
87                let ctx = RuleContext {
88                    segment: root,
89                    parent: None,
90                    root,
91                    siblings: std::slice::from_ref(root),
92                    index_in_parent: 0,
93                    source,
94                    dialect,
95                };
96                violations.extend(rule.eval(&ctx));
97            }
98            CrawlType::Segment(ref types) => {
99                walk_and_lint_indexed(
100                    root,
101                    0,
102                    None,
103                    root,
104                    source,
105                    dialect,
106                    rule.as_ref(),
107                    types,
108                    &mut violations,
109                );
110            }
111        }
112    }
113
114    violations.sort_by_key(|v| (v.span.start, v.span.end));
115    violations
116}
117
118#[allow(clippy::too_many_arguments)]
119fn walk_and_lint_indexed(
120    segment: &Segment,
121    index_in_parent: usize,
122    parent: Option<&Segment>,
123    root: &Segment,
124    source: &str,
125    dialect: &str,
126    rule: &dyn Rule,
127    types: &[SegmentType],
128    violations: &mut Vec<LintViolation>,
129) {
130    if types.contains(&segment.segment_type()) {
131        let siblings = parent
132            .map(|p| p.children())
133            .unwrap_or(std::slice::from_ref(segment));
134
135        let ctx = RuleContext {
136            segment,
137            parent,
138            root,
139            siblings,
140            index_in_parent,
141            source,
142            dialect,
143        };
144        violations.extend(rule.eval(&ctx));
145    }
146
147    let children = segment.children();
148    for (i, child) in children.iter().enumerate() {
149        walk_and_lint_indexed(
150            child,
151            i,
152            Some(segment),
153            root,
154            source,
155            dialect,
156            rule,
157            types,
158            violations,
159        );
160    }
161}
162
163/// Apply source edits to produce a fixed source string.
164///
165/// Edits are sorted by span start (descending) and applied back-to-front
166/// so that earlier offsets remain valid. Overlapping edits are skipped.
167pub fn apply_fixes(source: &str, violations: &[LintViolation]) -> String {
168    // Collect all edits from all violations
169    let mut edits: Vec<&SourceEdit> = violations.iter().flat_map(|v| v.fixes.iter()).collect();
170
171    if edits.is_empty() {
172        return source.to_string();
173    }
174
175    // Sort by span start descending, then by span end descending (apply from back)
176    edits.sort_by(|a, b| {
177        b.span
178            .start
179            .cmp(&a.span.start)
180            .then(b.span.end.cmp(&a.span.end))
181    });
182
183    // Deduplicate edits with identical spans
184    edits.dedup_by(|a, b| a.span == b.span);
185
186    let mut result = source.to_string();
187    let mut last_applied_start = u32::MAX;
188
189    for edit in &edits {
190        let start = edit.span.start as usize;
191        let end = edit.span.end as usize;
192
193        // Skip overlapping edits: any edit whose range touches the previously applied one
194        if edit.span.end > last_applied_start {
195            continue;
196        }
197        // Also skip inserts at the same offset as a previously applied edit
198        if edit.span.start >= last_applied_start {
199            continue;
200        }
201
202        if start <= result.len() && end <= result.len() {
203            result.replace_range(start..end, &edit.new_text);
204            last_applied_start = edit.span.start;
205        }
206    }
207
208    result
209}