use std::collections::BTreeMap;
use harn_hostlib::ast::Language;
use crate::constraint::CompiledConstraint;
use crate::error::RulesError;
use crate::evaluator::CompiledRuleTree;
use crate::fix::{interpolate, splice, AppliedEdit};
use crate::model::{Applicability, Rule, Safety, Severity};
use crate::transform::CompiledTransform;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Span {
pub start_byte: usize,
pub end_byte: usize,
pub start_row: usize,
pub start_col: usize,
pub end_row: usize,
pub end_col: usize,
}
impl Span {
pub(crate) fn of(node: tree_sitter::Node<'_>) -> Self {
let start = node.start_position();
let end = node.end_position();
Span {
start_byte: node.start_byte(),
end_byte: node.end_byte(),
start_row: start.row,
start_col: start.column,
end_row: end.row,
end_col: end.column,
}
}
}
#[derive(Debug, Clone)]
pub struct Binding {
pub text: String,
pub span: Span,
}
#[derive(Debug, Clone)]
pub struct RuleMatch {
pub rule_id: String,
pub span: Span,
pub text: String,
pub bindings: BTreeMap<String, Binding>,
}
#[derive(Debug, Clone)]
pub struct CodemodResult {
pub rewritten: String,
pub edits: Vec<AppliedEdit>,
pub changed: bool,
pub safety: Safety,
pub applicability: Applicability,
pub idempotent: bool,
}
pub struct CompiledRule {
rule_id: String,
language: Language,
execution: Execution,
constraints: Vec<CompiledConstraint>,
transforms: Vec<(String, CompiledTransform)>,
fix: Option<String>,
safety: Safety,
message: String,
severity: Severity,
}
#[derive(Debug, Clone)]
pub struct Diagnostic {
pub rule_id: String,
pub message: String,
pub severity: Severity,
pub span: Span,
pub applicability: Applicability,
pub fix: Option<String>,
}
enum Execution {
SourceRegex(regex::Regex),
Tree(Box<CompiledRuleTree>),
}
impl CompiledRule {
pub fn compile(rule: &Rule) -> Result<Self, RulesError> {
let language =
Language::from_name(&rule.language).ok_or_else(|| RulesError::UnknownLanguage {
rule: rule.id.clone(),
language: rule.language.clone(),
})?;
let execution = if rule.rule.is_pure_regex() {
let pattern = rule.rule.regex.as_ref().expect("pure regex");
Execution::SourceRegex(regex::Regex::new(pattern).map_err(|err| {
RulesError::PatternCompile {
rule: rule.id.clone(),
message: format!("invalid regex `{pattern}`: {err}"),
}
})?)
} else {
Execution::Tree(Box::new(CompiledRuleTree::compile(
&rule.id,
language,
&rule.rule,
&rule.utils,
)?))
};
let constraints = rule
.where_constraints
.iter()
.map(|c| CompiledConstraint::compile(&rule.id, language, c))
.collect::<Result<Vec<_>, _>>()?;
let transforms = rule
.transform
.iter()
.map(|(name, t)| {
CompiledTransform::compile(&rule.id, name, t).map(|c| (name.clone(), c))
})
.collect::<Result<Vec<_>, _>>()?;
Ok(CompiledRule {
rule_id: rule.id.clone(),
language,
execution,
constraints,
transforms,
fix: rule.fix.clone(),
safety: rule.safety,
message: rule.message.clone(),
severity: rule.severity,
})
}
pub fn language(&self) -> Language {
self.language
}
pub fn safety(&self) -> Safety {
self.safety
}
pub fn applicability(&self) -> Applicability {
self.safety.applicability()
}
pub fn id(&self) -> &str {
&self.rule_id
}
pub fn run(&self, source: &str) -> Result<Vec<RuleMatch>, RulesError> {
let mut matches = match &self.execution {
Execution::SourceRegex(regex) => self.run_regex(regex, source),
Execution::Tree(tree) => tree
.find(&self.rule_id, self.language, source)?
.into_iter()
.map(|m| RuleMatch {
rule_id: self.rule_id.clone(),
span: m.span,
text: m.text,
bindings: m.bindings,
})
.collect(),
};
if !self.constraints.is_empty() {
matches.retain(|m| self.satisfies_constraints(m));
}
Ok(matches)
}
fn satisfies_constraints(&self, m: &RuleMatch) -> bool {
self.constraints.iter().all(|c| {
m.bindings
.get(&c.metavar)
.is_some_and(|b| c.evaluate(&b.text))
})
}
pub fn apply(&self, source: &str) -> Result<CodemodResult, RulesError> {
let (rewritten, edits) = self.rewrite(source)?;
let changed = rewritten != source;
let (twice, _) = self.rewrite(&rewritten)?;
let idempotent = twice == rewritten;
Ok(CodemodResult {
rewritten,
edits,
changed,
safety: self.safety,
applicability: self.applicability(),
idempotent,
})
}
pub fn auto_apply(&self, source: &str) -> Result<CodemodResult, RulesError> {
if !self.safety.is_auto_applicable() {
return Err(RulesError::NotAutoApplicable {
rule: self.rule_id.clone(),
safety: format!("{:?}", self.safety),
});
}
self.apply(source)
}
pub fn apply_checked(&self, source: &str) -> Result<CodemodResult, RulesError> {
let result = self.apply(source)?;
if !result.idempotent {
return Err(RulesError::NotIdempotent {
rule: self.rule_id.clone(),
});
}
Ok(result)
}
pub fn diagnostics(&self, source: &str) -> Result<Vec<Diagnostic>, RulesError> {
let applicability = self.applicability();
let matches = self.run(source)?;
Ok(matches
.iter()
.map(|m| Diagnostic {
rule_id: self.rule_id.clone(),
message: self.message.clone(),
severity: self.severity,
span: m.span,
applicability,
fix: self.fix.as_ref().map(|template| {
let vars = self.metavars_for(m);
interpolate(template, &vars)
}),
})
.collect())
}
fn rewrite(&self, source: &str) -> Result<(String, Vec<AppliedEdit>), RulesError> {
let template = self
.fix
.as_ref()
.ok_or_else(|| RulesError::PatternCompile {
rule: self.rule_id.clone(),
message: "apply requires a `fix` template; this rule has none".into(),
})?;
let matches = self.run(source)?;
let edits: Vec<AppliedEdit> = matches
.iter()
.map(|m| {
let vars = self.metavars_for(m);
AppliedEdit {
span: m.span,
before: m.text.clone(),
replacement: interpolate(template, &vars),
}
})
.collect();
Ok((splice(source, &edits), edits))
}
fn metavars_for(&self, m: &RuleMatch) -> BTreeMap<String, String> {
let mut vars: BTreeMap<String, String> = m
.bindings
.iter()
.map(|(name, binding)| (name.clone(), binding.text.clone()))
.collect();
for (name, transform) in &self.transforms {
let input = m
.bindings
.get(&transform.source)
.map(|b| b.text.as_str())
.unwrap_or("");
vars.insert(name.clone(), transform.apply(input));
}
vars
}
fn run_regex(&self, regex: ®ex::Regex, source: &str) -> Vec<RuleMatch> {
let mut matches = Vec::new();
for m in regex.find_iter(source) {
let span = byte_span(source, m.start(), m.end());
matches.push(RuleMatch {
rule_id: self.rule_id.clone(),
span,
text: m.as_str().to_string(),
bindings: BTreeMap::new(),
});
}
matches
}
}
fn byte_span(source: &str, start: usize, end: usize) -> Span {
let (start_row, start_col) = row_col(source, start);
let (end_row, end_col) = row_col(source, end);
Span {
start_byte: start,
end_byte: end,
start_row,
start_col,
end_row,
end_col,
}
}
fn row_col(source: &str, byte: usize) -> (usize, usize) {
let mut row = 0;
let mut col = 0;
for (i, ch) in source.char_indices() {
if i >= byte {
break;
}
if ch == '\n' {
row += 1;
col = 0;
} else {
col += 1;
}
}
(row, col)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::Rule;
fn rule(toml: &str) -> CompiledRule {
let parsed = Rule::from_toml_str(toml).expect("rule parses");
CompiledRule::compile(&parsed).expect("rule compiles")
}
#[test]
fn pattern_rule_binds_metavars() {
let compiled = rule(
r#"
id = "destructure-default"
language = "typescript"
fix = "{ $KEY: $SRC }"
[rule]
pattern = "$SRC?.$KEY ?? $DEFAULT"
"#,
);
let matches = compiled
.run("const a = cfg?.timeout ?? 30;\nconst b = opts?.retries ?? 3;\n")
.unwrap();
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].bindings["SRC"].text, "cfg");
assert_eq!(matches[0].bindings["KEY"].text, "timeout");
assert_eq!(matches[0].bindings["DEFAULT"].text, "30");
assert_eq!(matches[1].bindings["SRC"].text, "opts");
assert_eq!(matches[0].text, "cfg?.timeout ?? 30");
assert_eq!(matches[0].span.start_row, 0);
assert_eq!(matches[1].span.start_row, 1);
}
#[test]
fn kind_rule_matches_node_kind() {
let compiled = rule(
r#"
id = "find-calls"
language = "python"
[rule]
kind = "call"
"#,
);
let matches = compiled.run("print(x)\nlog(y)\n").unwrap();
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].text, "print(x)");
assert!(matches[0].bindings.is_empty());
}
#[test]
fn regex_rule_matches_text() {
let compiled = rule(
r#"
id = "todo"
language = "rust"
message = "Found a TODO"
[rule]
regex = "TODO\\(\\w+\\)"
"#,
);
let matches = compiled
.run("fn f() {\n // TODO(ken) fix\n // todo lower\n}\n")
.unwrap();
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].text, "TODO(ken)");
assert_eq!(matches[0].span.start_row, 1);
}
#[test]
fn unknown_language_is_an_error() {
let parsed = Rule::from_toml_str(
r#"
id = "x"
language = "cobol"
[rule]
kind = "foo"
"#,
)
.unwrap();
assert!(matches!(
CompiledRule::compile(&parsed),
Err(RulesError::UnknownLanguage { .. })
));
}
#[test]
fn invalid_pattern_surfaces_compile_error() {
let parsed = Rule::from_toml_str(
r#"
id = "x"
language = "typescript"
[rule]
pattern = "foo($$$ARGS)"
"#,
)
.unwrap();
assert!(matches!(
CompiledRule::compile(&parsed),
Err(RulesError::PatternCompile { .. })
));
}
}