1use std::collections::HashMap;
42
43use harn_hostlib::ast::{api, Language};
44use tree_sitter::Node;
45
46pub const ROOT_CAPTURE: &str = "__match";
50
51const PLACEHOLDER_STEM: &str = "__harn_hole_";
55
56#[derive(Debug, Clone)]
58pub struct CompiledPattern {
59 pub query: String,
62 pub metavars: Vec<String>,
64}
65
66pub fn compile_pattern(snippet: &str, language: Language) -> Result<CompiledPattern, String> {
74 let sub = substitute(snippet)?;
75 let mut last_err: Option<String> = None;
76
77 for (prefix, suffix) in contexts(language) {
78 let wrapped = format!("{prefix}{}{suffix}", sub.text);
79 let tree = api::parse_tree(&wrapped, language).map_err(|err| err.to_string())?;
80 let root = tree.root_node();
81 if root.has_error() {
82 last_err = Some(format!(
83 "snippet did not parse cleanly in `{}`: `{snippet}`",
84 language.name()
85 ));
86 continue;
87 }
88
89 let start = prefix.len();
94 let end = start + sub.text.len();
95 let Some(pattern_root) = root.descendant_for_byte_range(start, end.saturating_sub(1))
96 else {
97 last_err = Some(format!(
98 "could not locate snippet subtree in `{}`",
99 language.name()
100 ));
101 continue;
102 };
103
104 let bytes = wrapped.as_bytes();
105 let mut builder = QueryBuilder::new(bytes, &sub.placeholder_to_metavar);
106 let body = builder.build(pattern_root);
107 let predicates = builder.predicates();
108 let query = if predicates.is_empty() {
109 format!("({body} @{ROOT_CAPTURE})")
110 } else {
111 format!("({body} @{ROOT_CAPTURE} {predicates})")
112 };
113 return Ok(CompiledPattern {
114 query,
115 metavars: sub.metavar_order,
116 });
117 }
118
119 Err(last_err.unwrap_or_else(|| format!("snippet did not parse in `{}`", language.name())))
120}
121
122fn contexts(language: Language) -> Vec<(&'static str, &'static str)> {
128 let mut v = vec![("", "")];
129 let wrapper = match language {
130 Language::Rust => Some(("fn __harn_probe() { ", " }")),
131 Language::Go => Some(("package p\nfunc __harn_probe() { ", " }")),
132 Language::Java | Language::CSharp => {
133 Some(("class __HarnProbe { void __harn_probe() { ", " } }"))
134 }
135 Language::C | Language::Cpp => Some(("void __harn_probe() { ", " }")),
136 Language::Kotlin => Some(("fun __harn_probe() { ", " }")),
137 Language::Swift => Some(("func __harn_probe() { ", " }")),
138 Language::Scala => Some(("def __harn_probe() = { ", " }")),
139 _ => None,
140 };
141 v.extend(wrapper);
142 v
143}
144
145struct Substituted {
150 text: String,
152 placeholder_to_metavar: HashMap<String, String>,
154 metavar_order: Vec<String>,
156}
157
158fn substitute(snippet: &str) -> Result<Substituted, String> {
159 let mut text = String::with_capacity(snippet.len());
160 let mut placeholder_to_metavar = HashMap::new();
161 let mut metavar_to_placeholder: HashMap<String, String> = HashMap::new();
162 let mut metavar_order: Vec<String> = Vec::new();
163
164 let bytes = snippet.as_bytes();
165 let mut i = 0;
166 while i < bytes.len() {
167 if bytes[i] != b'$' {
168 let ch = snippet[i..].chars().next().unwrap();
171 text.push(ch);
172 i += ch.len_utf8();
173 continue;
174 }
175 if snippet[i..].starts_with("$$$") {
176 return Err(
177 "variadic `$$$` metavariables are not yet supported (tracked in #2833)".into(),
178 );
179 }
180 let name_start = i + 1;
182 let mut j = name_start;
183 if j < bytes.len() && is_ident_start(bytes[j]) {
184 j += 1;
185 while j < bytes.len() && is_ident_continue(bytes[j]) {
186 j += 1;
187 }
188 }
189 if j == name_start {
190 text.push('$');
192 i += 1;
193 continue;
194 }
195 let name = &snippet[name_start..j];
196 let placeholder = metavar_to_placeholder
197 .entry(name.to_string())
198 .or_insert_with(|| {
199 let placeholder = format!("{PLACEHOLDER_STEM}{}", metavar_order.len());
200 metavar_order.push(name.to_string());
201 placeholder
202 })
203 .clone();
204 placeholder_to_metavar.insert(placeholder.clone(), name.to_string());
205 text.push_str(&placeholder);
206 i = j;
207 }
208
209 Ok(Substituted {
213 text,
214 placeholder_to_metavar,
215 metavar_order,
216 })
217}
218
219fn is_ident_start(b: u8) -> bool {
220 b.is_ascii_alphabetic() || b == b'_'
221}
222
223fn is_ident_continue(b: u8) -> bool {
224 b.is_ascii_alphanumeric() || b == b'_'
225}
226
227struct QueryBuilder<'a> {
232 src: &'a [u8],
233 placeholder_to_metavar: &'a HashMap<String, String>,
234 occurrences: HashMap<String, usize>,
236 eq_predicates: Vec<String>,
238 literal_count: usize,
240}
241
242impl<'a> QueryBuilder<'a> {
243 fn new(src: &'a [u8], placeholder_to_metavar: &'a HashMap<String, String>) -> Self {
244 QueryBuilder {
245 src,
246 placeholder_to_metavar,
247 occurrences: HashMap::new(),
248 eq_predicates: Vec::new(),
249 literal_count: 0,
250 }
251 }
252
253 fn build(&mut self, node: Node<'_>) -> String {
254 if node.child_count() == 0 {
256 let text = self.node_text(node);
257 if let Some(metavar) = self.placeholder_to_metavar.get(text) {
258 return format!("(_) @{}", self.capture_for(metavar));
259 }
260 if node.is_named() {
261 let cap = format!("__lit_{}", self.literal_count);
265 self.literal_count += 1;
266 self.eq_predicates
267 .push(format!("(#eq? @{cap} {})", quote_literal(text)));
268 return format!("({}) @{cap}", node.kind());
269 }
270 return quote_literal(text);
271 }
272
273 let mut parts: Vec<String> = Vec::new();
274 let mut cursor = node.walk();
275 for (i, child) in node.children(&mut cursor).enumerate() {
276 let sub = self.build(child);
277 match node.field_name_for_child(i as u32) {
281 Some(field) if child.is_named() => parts.push(format!("{field}: {sub}")),
282 _ => parts.push(sub),
283 }
284 }
285 format!("({} {})", node.kind(), parts.join(" "))
286 }
287
288 fn capture_for(&mut self, metavar: &str) -> String {
292 let count = self.occurrences.entry(metavar.to_string()).or_insert(0);
293 *count += 1;
294 if *count == 1 {
295 metavar.to_string()
296 } else {
297 let helper = format!("{metavar}.{count}");
298 self.eq_predicates
299 .push(format!("(#eq? @{metavar} @{helper})"));
300 helper
301 }
302 }
303
304 fn predicates(&self) -> String {
305 self.eq_predicates.join(" ")
306 }
307
308 fn node_text(&self, node: Node<'_>) -> &'a str {
309 std::str::from_utf8(&self.src[node.start_byte()..node.end_byte()]).unwrap_or_default()
310 }
311}
312
313fn quote_literal(text: &str) -> String {
316 let mut out = String::with_capacity(text.len() + 2);
317 out.push('"');
318 for ch in text.chars() {
319 if ch == '"' || ch == '\\' {
320 out.push('\\');
321 }
322 out.push(ch);
323 }
324 out.push('"');
325 out
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use streaming_iterator::StreamingIterator;
332 use tree_sitter::{Query, QueryCursor};
333
334 fn run(snippet: &str, language: Language, code: &str) -> Vec<(String, Vec<String>)> {
337 let compiled = compile_pattern(snippet, language).expect("compiles");
338 let ts_language = language.ts_language().expect("grammar");
339 let query = Query::new(&ts_language, &compiled.query)
340 .unwrap_or_else(|e| panic!("query rejected: {e}\nquery: {}", compiled.query));
341 let tree = api::parse_tree(code, language).expect("parse code");
342 let names: Vec<&str> = query.capture_names().to_vec();
343 let mut cursor = QueryCursor::new();
344 let mut matches = cursor.matches(&query, tree.root_node(), code.as_bytes());
345 let mut out = Vec::new();
346 while let Some(m) = matches.next() {
347 let mut per_capture: HashMap<String, Vec<String>> = HashMap::new();
348 for cap in m.captures {
349 let name = names[cap.index as usize].to_string();
350 let text = code[cap.node.start_byte()..cap.node.end_byte()].to_string();
351 per_capture.entry(name).or_default().push(text);
352 }
353 for (name, texts) in per_capture {
354 out.push((name, texts));
355 }
356 }
357 out
358 }
359
360 fn capture<'a>(binds: &'a [(String, Vec<String>)], name: &str) -> &'a [String] {
361 binds
362 .iter()
363 .find(|(n, _)| n == name)
364 .map(|(_, v)| v.as_slice())
365 .unwrap_or(&[])
366 }
367
368 #[test]
369 fn compiles_destructuring_default_in_typescript() {
370 let snippet = "$SRC?.$KEY ?? $DEFAULT";
372 let compiled = compile_pattern(snippet, Language::TypeScript).expect("compiles");
373 assert_eq!(compiled.metavars, vec!["SRC", "KEY", "DEFAULT"]);
374 let binds = run(
376 snippet,
377 Language::TypeScript,
378 "const a = cfg?.timeout ?? 30;",
379 );
380 assert_eq!(capture(&binds, "SRC"), ["cfg".to_string()]);
381 assert_eq!(capture(&binds, "KEY"), ["timeout".to_string()]);
382 assert_eq!(capture(&binds, "DEFAULT"), ["30".to_string()]);
383 }
384
385 #[test]
386 fn operator_is_constrained_not_just_structure() {
387 let snippet = "$SRC?.$KEY ?? $DEFAULT";
390 let binds = run(
391 snippet,
392 Language::TypeScript,
393 "const a = cfg?.timeout || 30;",
394 );
395 assert!(
396 capture(&binds, "SRC").is_empty(),
397 "|| must not match the ?? pattern"
398 );
399 }
400
401 #[test]
402 fn round_trips_the_assignment_form() {
403 let snippet = "$NAME = $SRC?.$KEY ?? $DEFAULT";
405 let compiled = compile_pattern(snippet, Language::TypeScript).expect("compiles");
406 assert_eq!(compiled.metavars, vec!["NAME", "SRC", "KEY", "DEFAULT"]);
407 let binds = run(
408 snippet,
409 Language::TypeScript,
410 "x = src?.userId ?? fallback;",
411 );
412 assert_eq!(capture(&binds, "NAME"), ["x".to_string()]);
413 assert_eq!(capture(&binds, "SRC"), ["src".to_string()]);
414 assert_eq!(capture(&binds, "KEY"), ["userId".to_string()]);
415 assert_eq!(capture(&binds, "DEFAULT"), ["fallback".to_string()]);
416 }
417
418 #[test]
419 fn lifts_metavars_in_rust() {
420 let snippet = "let $NAME = $VALUE;";
421 let binds = run(snippet, Language::Rust, "fn f() { let total = compute(); }");
422 assert_eq!(capture(&binds, "NAME"), ["total".to_string()]);
423 assert_eq!(capture(&binds, "VALUE"), ["compute()".to_string()]);
424 }
425
426 #[test]
427 fn lifts_metavars_in_python() {
428 let snippet = "$FN($ARG)";
429 let binds = run(snippet, Language::Python, "print(value)");
430 assert_eq!(capture(&binds, "FN"), ["print".to_string()]);
431 assert_eq!(capture(&binds, "ARG"), ["value".to_string()]);
432 }
433
434 #[test]
435 fn lifts_metavars_in_go() {
436 let snippet = "$FN($ARG)";
437 let binds = run(snippet, Language::Go, "package main\nfunc m() { log(err) }");
438 assert_eq!(capture(&binds, "FN"), ["log".to_string()]);
439 assert_eq!(capture(&binds, "ARG"), ["err".to_string()]);
440 }
441
442 #[test]
443 fn repeated_metavar_unifies() {
444 let snippet = "$X + $X";
446 let same = run(snippet, Language::Rust, "fn f() { let _ = a + a; }");
447 assert_eq!(capture(&same, "X"), ["a".to_string()]);
448 let different = run(snippet, Language::Rust, "fn f() { let _ = a + b; }");
449 assert!(
450 capture(&different, "X").is_empty(),
451 "unification must reject `a + b`"
452 );
453 }
454
455 #[test]
456 fn rejects_unparseable_snippet() {
457 let err = compile_pattern("$A ?? ?? $B", Language::TypeScript).unwrap_err();
458 assert!(err.contains("did not parse"), "got: {err}");
459 }
460
461 #[test]
462 fn rejects_variadic_for_now() {
463 let err = compile_pattern("foo($$$ARGS)", Language::TypeScript).unwrap_err();
464 assert!(err.contains("variadic"), "got: {err}");
465 }
466
467 #[test]
468 fn literal_pattern_matches_exact_text() {
469 let snippet = "foo()";
472 let compiled = compile_pattern(snippet, Language::TypeScript).expect("compiles");
473 assert!(compiled.metavars.is_empty());
474 let hit = run(snippet, Language::TypeScript, "foo();");
476 assert!(!hit.is_empty());
477 let miss = run(snippet, Language::TypeScript, "bar();");
479 assert!(
480 miss.is_empty(),
481 "bar() must not match foo()'s literal pattern: {miss:?}"
482 );
483 }
484}