use rigsql_core::{Segment, SegmentType};
use crate::violation::{LintViolation, SourceEdit};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RuleGroup {
Capitalisation,
Layout,
Convention,
Aliasing,
Ambiguous,
References,
Structure,
}
#[derive(Debug, Clone)]
pub enum CrawlType {
Segment(Vec<SegmentType>),
RootOnly,
}
pub struct RuleContext<'a> {
pub segment: &'a Segment,
pub parent: Option<&'a Segment>,
pub root: &'a Segment,
pub siblings: &'a [Segment],
pub index_in_parent: usize,
pub source: &'a str,
pub dialect: &'a str,
}
impl<'a> RuleContext<'a> {
pub fn next_non_trivia_sibling(&self) -> Option<&'a Segment> {
self.siblings[self.index_in_parent + 1..]
.iter()
.find(|s| !s.segment_type().is_trivia())
}
pub fn prev_non_trivia_sibling(&self) -> Option<&'a Segment> {
self.siblings[..self.index_in_parent]
.iter()
.rev()
.find(|s| !s.segment_type().is_trivia())
}
}
pub trait Rule: Send + Sync {
fn code(&self) -> &'static str;
fn name(&self) -> &'static str;
fn description(&self) -> &'static str;
fn explanation(&self) -> &'static str;
fn groups(&self) -> &[RuleGroup];
fn is_fixable(&self) -> bool;
fn crawl_type(&self) -> CrawlType;
fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation>;
fn configure(&mut self, _settings: &std::collections::HashMap<String, String>) {}
}
pub fn lint(
root: &Segment,
source: &str,
rules: &[Box<dyn Rule>],
dialect: &str,
) -> Vec<LintViolation> {
let mut violations = Vec::new();
for rule in rules {
match rule.crawl_type() {
CrawlType::RootOnly => {
let ctx = RuleContext {
segment: root,
parent: None,
root,
siblings: std::slice::from_ref(root),
index_in_parent: 0,
source,
dialect,
};
violations.extend(rule.eval(&ctx));
}
CrawlType::Segment(ref types) => {
let walker = LintWalker {
root,
source,
dialect,
rule: rule.as_ref(),
types,
};
walker.walk(root, 0, None, &mut violations);
}
}
}
violations.sort_by_key(|v| (v.span.start, v.span.end));
violations
}
struct LintWalker<'a> {
root: &'a Segment,
source: &'a str,
dialect: &'a str,
rule: &'a dyn Rule,
types: &'a [SegmentType],
}
impl<'a> LintWalker<'a> {
fn walk(
&self,
segment: &'a Segment,
index_in_parent: usize,
parent: Option<&'a Segment>,
violations: &mut Vec<LintViolation>,
) {
if self.types.contains(&segment.segment_type()) {
let siblings = parent
.map(|p| p.children())
.unwrap_or(std::slice::from_ref(segment));
let ctx = RuleContext {
segment,
parent,
root: self.root,
siblings,
index_in_parent,
source: self.source,
dialect: self.dialect,
};
violations.extend(self.rule.eval(&ctx));
}
for (i, child) in segment.children().iter().enumerate() {
self.walk(child, i, Some(segment), violations);
}
}
}
pub fn apply_fixes(source: &str, violations: &[LintViolation]) -> String {
let mut edits: Vec<&SourceEdit> = violations.iter().flat_map(|v| v.fixes.iter()).collect();
if edits.is_empty() {
return source.to_string();
}
edits.sort_by(|a, b| {
b.span
.start
.cmp(&a.span.start)
.then(b.span.end.cmp(&a.span.end))
});
edits.dedup_by(|a, b| a.span == b.span);
let mut result = source.to_string();
let mut last_applied_start = u32::MAX;
for edit in &edits {
let start = edit.span.start as usize;
let end = edit.span.end as usize;
if edit.span.end > last_applied_start {
continue;
}
if edit.span.start >= last_applied_start {
continue;
}
if start <= result.len() && end <= result.len() {
result.replace_range(start..end, &edit.new_text);
last_applied_start = edit.span.start;
}
}
result
}