use std::collections::BTreeMap;
use harn_hostlib::ast::{api, Language};
use streaming_iterator::StreamingIterator;
use tree_sitter::{Query, QueryCursor};
use crate::error::RulesError;
use crate::model::{AtomicMatcher, Rule};
use crate::pattern::{compile_pattern, ROOT_CAPTURE};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Span {
pub start_byte: usize,
pub end_byte: usize,
pub start_row: usize,
pub start_col: usize,
pub end_row: usize,
pub end_col: usize,
}
impl Span {
fn of(node: tree_sitter::Node<'_>) -> Self {
let start = node.start_position();
let end = node.end_position();
Span {
start_byte: node.start_byte(),
end_byte: node.end_byte(),
start_row: start.row,
start_col: start.column,
end_row: end.row,
end_col: end.column,
}
}
}
#[derive(Debug, Clone)]
pub struct Binding {
pub text: String,
pub span: Span,
}
#[derive(Debug, Clone)]
pub struct RuleMatch {
pub rule_id: String,
pub span: Span,
pub text: String,
pub bindings: BTreeMap<String, Binding>,
}
pub struct CompiledRule {
rule_id: String,
language: Language,
matcher: CompiledMatcher,
}
enum CompiledMatcher {
Query { query: Query, metavars: Vec<String> },
Regex(regex::Regex),
}
impl CompiledRule {
pub fn compile(rule: &Rule) -> Result<Self, RulesError> {
let language =
Language::from_name(&rule.language).ok_or_else(|| RulesError::UnknownLanguage {
rule: rule.id.clone(),
language: rule.language.clone(),
})?;
let matcher = match rule
.rule
.resolve()
.map_err(|message| RulesError::PatternCompile {
rule: rule.id.clone(),
message,
})? {
AtomicMatcher::Pattern(snippet) => {
let ts_language =
language
.ts_language()
.ok_or_else(|| RulesError::GrammarUnavailable {
rule: rule.id.clone(),
language: language.name().to_string(),
})?;
let compiled = compile_pattern(&snippet, language).map_err(|message| {
RulesError::PatternCompile {
rule: rule.id.clone(),
message,
}
})?;
let query = Query::new(&ts_language, &compiled.query).map_err(|err| {
RulesError::QueryRejected {
rule: rule.id.clone(),
message: err.to_string(),
query: compiled.query.clone(),
}
})?;
CompiledMatcher::Query {
query,
metavars: compiled.metavars,
}
}
AtomicMatcher::Kind(kind) => {
let ts_language =
language
.ts_language()
.ok_or_else(|| RulesError::GrammarUnavailable {
rule: rule.id.clone(),
language: language.name().to_string(),
})?;
let query_text = format!("({kind}) @{ROOT_CAPTURE}");
let query = Query::new(&ts_language, &query_text).map_err(|err| {
RulesError::QueryRejected {
rule: rule.id.clone(),
message: err.to_string(),
query: query_text.clone(),
}
})?;
CompiledMatcher::Query {
query,
metavars: Vec::new(),
}
}
AtomicMatcher::Regex(pattern) => {
let regex =
regex::Regex::new(&pattern).map_err(|err| RulesError::PatternCompile {
rule: rule.id.clone(),
message: format!("invalid regex `{pattern}`: {err}"),
})?;
CompiledMatcher::Regex(regex)
}
};
Ok(CompiledRule {
rule_id: rule.id.clone(),
language,
matcher,
})
}
pub fn language(&self) -> Language {
self.language
}
pub fn run(&self, source: &str) -> Result<Vec<RuleMatch>, RulesError> {
match &self.matcher {
CompiledMatcher::Query { query, metavars } => self.run_query(query, metavars, source),
CompiledMatcher::Regex(regex) => Ok(self.run_regex(regex, source)),
}
}
fn run_query(
&self,
query: &Query,
metavars: &[String],
source: &str,
) -> Result<Vec<RuleMatch>, RulesError> {
let tree =
api::parse_tree(source, self.language).map_err(|err| RulesError::SourceParse {
rule: self.rule_id.clone(),
message: err.to_string(),
})?;
let names: Vec<&str> = query.capture_names().to_vec();
let bytes = source.as_bytes();
let mut cursor = QueryCursor::new();
let mut it = cursor.matches(query, tree.root_node(), bytes);
let mut matches = Vec::new();
while let Some(m) = it.next() {
let mut root: Option<Span> = None;
let mut root_text = String::new();
let mut bindings: BTreeMap<String, Binding> = BTreeMap::new();
for cap in m.captures {
let name = names[cap.index as usize];
let span = Span::of(cap.node);
let text = source[cap.node.start_byte()..cap.node.end_byte()].to_string();
if name == ROOT_CAPTURE {
root = Some(span);
root_text = text;
} else if metavars.iter().any(|m| m == name) {
bindings
.entry(name.to_string())
.or_insert(Binding { text, span });
}
}
if let Some(span) = root {
matches.push(RuleMatch {
rule_id: self.rule_id.clone(),
span,
text: root_text,
bindings,
});
}
}
matches.sort_by_key(|m| (m.span.start_byte, m.span.end_byte));
Ok(matches)
}
fn run_regex(&self, regex: ®ex::Regex, source: &str) -> Vec<RuleMatch> {
let mut matches = Vec::new();
for m in regex.find_iter(source) {
let span = byte_span(source, m.start(), m.end());
matches.push(RuleMatch {
rule_id: self.rule_id.clone(),
span,
text: m.as_str().to_string(),
bindings: BTreeMap::new(),
});
}
matches
}
}
fn byte_span(source: &str, start: usize, end: usize) -> Span {
let (start_row, start_col) = row_col(source, start);
let (end_row, end_col) = row_col(source, end);
Span {
start_byte: start,
end_byte: end,
start_row,
start_col,
end_row,
end_col,
}
}
fn row_col(source: &str, byte: usize) -> (usize, usize) {
let mut row = 0;
let mut col = 0;
for (i, ch) in source.char_indices() {
if i >= byte {
break;
}
if ch == '\n' {
row += 1;
col = 0;
} else {
col += 1;
}
}
(row, col)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::Rule;
fn rule(toml: &str) -> CompiledRule {
let parsed = Rule::from_toml_str(toml).expect("rule parses");
CompiledRule::compile(&parsed).expect("rule compiles")
}
#[test]
fn pattern_rule_binds_metavars() {
let compiled = rule(
r#"
id = "destructure-default"
language = "typescript"
fix = "{ $KEY: $SRC }"
[rule]
pattern = "$SRC?.$KEY ?? $DEFAULT"
"#,
);
let matches = compiled
.run("const a = cfg?.timeout ?? 30;\nconst b = opts?.retries ?? 3;\n")
.unwrap();
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].bindings["SRC"].text, "cfg");
assert_eq!(matches[0].bindings["KEY"].text, "timeout");
assert_eq!(matches[0].bindings["DEFAULT"].text, "30");
assert_eq!(matches[1].bindings["SRC"].text, "opts");
assert_eq!(matches[0].text, "cfg?.timeout ?? 30");
assert_eq!(matches[0].span.start_row, 0);
assert_eq!(matches[1].span.start_row, 1);
}
#[test]
fn kind_rule_matches_node_kind() {
let compiled = rule(
r#"
id = "find-calls"
language = "python"
[rule]
kind = "call"
"#,
);
let matches = compiled.run("print(x)\nlog(y)\n").unwrap();
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].text, "print(x)");
assert!(matches[0].bindings.is_empty());
}
#[test]
fn regex_rule_matches_text() {
let compiled = rule(
r#"
id = "todo"
language = "rust"
message = "Found a TODO"
[rule]
regex = "TODO\\(\\w+\\)"
"#,
);
let matches = compiled
.run("fn f() {\n // TODO(ken) fix\n // todo lower\n}\n")
.unwrap();
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].text, "TODO(ken)");
assert_eq!(matches[0].span.start_row, 1);
}
#[test]
fn unknown_language_is_an_error() {
let parsed = Rule::from_toml_str(
r#"
id = "x"
language = "cobol"
[rule]
kind = "foo"
"#,
)
.unwrap();
assert!(matches!(
CompiledRule::compile(&parsed),
Err(RulesError::UnknownLanguage { .. })
));
}
#[test]
fn invalid_pattern_surfaces_compile_error() {
let parsed = Rule::from_toml_str(
r#"
id = "x"
language = "typescript"
[rule]
pattern = "foo($$$ARGS)"
"#,
)
.unwrap();
assert!(matches!(
CompiledRule::compile(&parsed),
Err(RulesError::PatternCompile { .. })
));
}
}