use std::collections::HashMap;
use harn_hostlib::ast::{api, Language};
use tree_sitter::Node;
pub const ROOT_CAPTURE: &str = "__match";
const PLACEHOLDER_STEM: &str = "__harn_hole_";
#[derive(Debug, Clone)]
pub struct CompiledPattern {
pub query: String,
pub metavars: Vec<String>,
}
pub fn compile_pattern(snippet: &str, language: Language) -> Result<CompiledPattern, String> {
let sub = substitute(snippet)?;
let mut last_err: Option<String> = None;
for (prefix, suffix) in contexts(language) {
let wrapped = format!("{prefix}{}{suffix}", sub.text);
let tree = api::parse_tree(&wrapped, language).map_err(|err| err.to_string())?;
let root = tree.root_node();
if root.has_error() {
last_err = Some(format!(
"snippet did not parse cleanly in `{}`: `{snippet}`",
language.name()
));
continue;
}
let start = prefix.len();
let end = start + sub.text.len();
let Some(pattern_root) = root.descendant_for_byte_range(start, end.saturating_sub(1))
else {
last_err = Some(format!(
"could not locate snippet subtree in `{}`",
language.name()
));
continue;
};
let bytes = wrapped.as_bytes();
let mut builder = QueryBuilder::new(bytes, &sub.placeholder_to_metavar);
let body = builder.build(pattern_root);
let predicates = builder.predicates();
let query = if predicates.is_empty() {
format!("({body} @{ROOT_CAPTURE})")
} else {
format!("({body} @{ROOT_CAPTURE} {predicates})")
};
return Ok(CompiledPattern {
query,
metavars: sub.metavar_order,
});
}
Err(last_err.unwrap_or_else(|| format!("snippet did not parse in `{}`", language.name())))
}
fn contexts(language: Language) -> Vec<(&'static str, &'static str)> {
let mut v = vec![("", "")];
let wrapper = match language {
Language::Rust => Some(("fn __harn_probe() { ", " }")),
Language::Go => Some(("package p\nfunc __harn_probe() { ", " }")),
Language::Java | Language::CSharp => {
Some(("class __HarnProbe { void __harn_probe() { ", " } }"))
}
Language::C | Language::Cpp => Some(("void __harn_probe() { ", " }")),
Language::Kotlin => Some(("fun __harn_probe() { ", " }")),
Language::Swift => Some(("func __harn_probe() { ", " }")),
Language::Scala => Some(("def __harn_probe() = { ", " }")),
_ => None,
};
v.extend(wrapper);
v
}
struct Substituted {
text: String,
placeholder_to_metavar: HashMap<String, String>,
metavar_order: Vec<String>,
}
fn substitute(snippet: &str) -> Result<Substituted, String> {
let mut text = String::with_capacity(snippet.len());
let mut placeholder_to_metavar = HashMap::new();
let mut metavar_to_placeholder: HashMap<String, String> = HashMap::new();
let mut metavar_order: Vec<String> = Vec::new();
let bytes = snippet.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] != b'$' {
let ch = snippet[i..].chars().next().unwrap();
text.push(ch);
i += ch.len_utf8();
continue;
}
if snippet[i..].starts_with("$$$") {
return Err(
"variadic `$$$` metavariables are not yet supported (tracked in #2833)".into(),
);
}
let name_start = i + 1;
let mut j = name_start;
if j < bytes.len() && is_ident_start(bytes[j]) {
j += 1;
while j < bytes.len() && is_ident_continue(bytes[j]) {
j += 1;
}
}
if j == name_start {
text.push('$');
i += 1;
continue;
}
let name = &snippet[name_start..j];
let placeholder = metavar_to_placeholder
.entry(name.to_string())
.or_insert_with(|| {
let placeholder = format!("{PLACEHOLDER_STEM}{}", metavar_order.len());
metavar_order.push(name.to_string());
placeholder
})
.clone();
placeholder_to_metavar.insert(placeholder.clone(), name.to_string());
text.push_str(&placeholder);
i = j;
}
Ok(Substituted {
text,
placeholder_to_metavar,
metavar_order,
})
}
fn is_ident_start(b: u8) -> bool {
b.is_ascii_alphabetic() || b == b'_'
}
fn is_ident_continue(b: u8) -> bool {
b.is_ascii_alphanumeric() || b == b'_'
}
struct QueryBuilder<'a> {
src: &'a [u8],
placeholder_to_metavar: &'a HashMap<String, String>,
occurrences: HashMap<String, usize>,
eq_predicates: Vec<String>,
literal_count: usize,
}
impl<'a> QueryBuilder<'a> {
fn new(src: &'a [u8], placeholder_to_metavar: &'a HashMap<String, String>) -> Self {
QueryBuilder {
src,
placeholder_to_metavar,
occurrences: HashMap::new(),
eq_predicates: Vec::new(),
literal_count: 0,
}
}
fn build(&mut self, node: Node<'_>) -> String {
if node.child_count() == 0 {
let text = self.node_text(node);
if let Some(metavar) = self.placeholder_to_metavar.get(text) {
return format!("(_) @{}", self.capture_for(metavar));
}
if node.is_named() {
let cap = format!("__lit_{}", self.literal_count);
self.literal_count += 1;
self.eq_predicates
.push(format!("(#eq? @{cap} {})", quote_literal(text)));
return format!("({}) @{cap}", node.kind());
}
return quote_literal(text);
}
let mut parts: Vec<String> = Vec::new();
let mut cursor = node.walk();
for (i, child) in node.children(&mut cursor).enumerate() {
let sub = self.build(child);
match node.field_name_for_child(i as u32) {
Some(field) if child.is_named() => parts.push(format!("{field}: {sub}")),
_ => parts.push(sub),
}
}
format!("({} {})", node.kind(), parts.join(" "))
}
fn capture_for(&mut self, metavar: &str) -> String {
let count = self.occurrences.entry(metavar.to_string()).or_insert(0);
*count += 1;
if *count == 1 {
metavar.to_string()
} else {
let helper = format!("{metavar}.{count}");
self.eq_predicates
.push(format!("(#eq? @{metavar} @{helper})"));
helper
}
}
fn predicates(&self) -> String {
self.eq_predicates.join(" ")
}
fn node_text(&self, node: Node<'_>) -> &'a str {
std::str::from_utf8(&self.src[node.start_byte()..node.end_byte()]).unwrap_or_default()
}
}
fn quote_literal(text: &str) -> String {
let mut out = String::with_capacity(text.len() + 2);
out.push('"');
for ch in text.chars() {
if ch == '"' || ch == '\\' {
out.push('\\');
}
out.push(ch);
}
out.push('"');
out
}
#[cfg(test)]
mod tests {
use super::*;
use streaming_iterator::StreamingIterator;
use tree_sitter::{Query, QueryCursor};
fn run(snippet: &str, language: Language, code: &str) -> Vec<(String, Vec<String>)> {
let compiled = compile_pattern(snippet, language).expect("compiles");
let ts_language = language.ts_language().expect("grammar");
let query = Query::new(&ts_language, &compiled.query)
.unwrap_or_else(|e| panic!("query rejected: {e}\nquery: {}", compiled.query));
let tree = api::parse_tree(code, language).expect("parse code");
let names: Vec<&str> = query.capture_names().to_vec();
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, tree.root_node(), code.as_bytes());
let mut out = Vec::new();
while let Some(m) = matches.next() {
let mut per_capture: HashMap<String, Vec<String>> = HashMap::new();
for cap in m.captures {
let name = names[cap.index as usize].to_string();
let text = code[cap.node.start_byte()..cap.node.end_byte()].to_string();
per_capture.entry(name).or_default().push(text);
}
for (name, texts) in per_capture {
out.push((name, texts));
}
}
out
}
fn capture<'a>(binds: &'a [(String, Vec<String>)], name: &str) -> &'a [String] {
binds
.iter()
.find(|(n, _)| n == name)
.map(|(_, v)| v.as_slice())
.unwrap_or(&[])
}
#[test]
fn compiles_destructuring_default_in_typescript() {
let snippet = "$SRC?.$KEY ?? $DEFAULT";
let compiled = compile_pattern(snippet, Language::TypeScript).expect("compiles");
assert_eq!(compiled.metavars, vec!["SRC", "KEY", "DEFAULT"]);
let binds = run(
snippet,
Language::TypeScript,
"const a = cfg?.timeout ?? 30;",
);
assert_eq!(capture(&binds, "SRC"), ["cfg".to_string()]);
assert_eq!(capture(&binds, "KEY"), ["timeout".to_string()]);
assert_eq!(capture(&binds, "DEFAULT"), ["30".to_string()]);
}
#[test]
fn operator_is_constrained_not_just_structure() {
let snippet = "$SRC?.$KEY ?? $DEFAULT";
let binds = run(
snippet,
Language::TypeScript,
"const a = cfg?.timeout || 30;",
);
assert!(
capture(&binds, "SRC").is_empty(),
"|| must not match the ?? pattern"
);
}
#[test]
fn round_trips_the_assignment_form() {
let snippet = "$NAME = $SRC?.$KEY ?? $DEFAULT";
let compiled = compile_pattern(snippet, Language::TypeScript).expect("compiles");
assert_eq!(compiled.metavars, vec!["NAME", "SRC", "KEY", "DEFAULT"]);
let binds = run(
snippet,
Language::TypeScript,
"x = src?.userId ?? fallback;",
);
assert_eq!(capture(&binds, "NAME"), ["x".to_string()]);
assert_eq!(capture(&binds, "SRC"), ["src".to_string()]);
assert_eq!(capture(&binds, "KEY"), ["userId".to_string()]);
assert_eq!(capture(&binds, "DEFAULT"), ["fallback".to_string()]);
}
#[test]
fn lifts_metavars_in_rust() {
let snippet = "let $NAME = $VALUE;";
let binds = run(snippet, Language::Rust, "fn f() { let total = compute(); }");
assert_eq!(capture(&binds, "NAME"), ["total".to_string()]);
assert_eq!(capture(&binds, "VALUE"), ["compute()".to_string()]);
}
#[test]
fn lifts_metavars_in_python() {
let snippet = "$FN($ARG)";
let binds = run(snippet, Language::Python, "print(value)");
assert_eq!(capture(&binds, "FN"), ["print".to_string()]);
assert_eq!(capture(&binds, "ARG"), ["value".to_string()]);
}
#[test]
fn lifts_metavars_in_go() {
let snippet = "$FN($ARG)";
let binds = run(snippet, Language::Go, "package main\nfunc m() { log(err) }");
assert_eq!(capture(&binds, "FN"), ["log".to_string()]);
assert_eq!(capture(&binds, "ARG"), ["err".to_string()]);
}
#[test]
fn repeated_metavar_unifies() {
let snippet = "$X + $X";
let same = run(snippet, Language::Rust, "fn f() { let _ = a + a; }");
assert_eq!(capture(&same, "X"), ["a".to_string()]);
let different = run(snippet, Language::Rust, "fn f() { let _ = a + b; }");
assert!(
capture(&different, "X").is_empty(),
"unification must reject `a + b`"
);
}
#[test]
fn rejects_unparseable_snippet() {
let err = compile_pattern("$A ?? ?? $B", Language::TypeScript).unwrap_err();
assert!(err.contains("did not parse"), "got: {err}");
}
#[test]
fn rejects_variadic_for_now() {
let err = compile_pattern("foo($$$ARGS)", Language::TypeScript).unwrap_err();
assert!(err.contains("variadic"), "got: {err}");
}
#[test]
fn literal_pattern_matches_exact_text() {
let snippet = "foo()";
let compiled = compile_pattern(snippet, Language::TypeScript).expect("compiles");
assert!(compiled.metavars.is_empty());
let hit = run(snippet, Language::TypeScript, "foo();");
assert!(!hit.is_empty());
let miss = run(snippet, Language::TypeScript, "bar();");
assert!(
miss.is_empty(),
"bar() must not match foo()'s literal pattern: {miss:?}"
);
}
}