1use std::collections::HashMap;
60
61use harn_hostlib::ast::{api, Language};
62use tree_sitter::Node;
63
64pub const ROOT_CAPTURE: &str = "__match";
68
69const PLACEHOLDER_STEM: &str = "__harn_hole_";
73
74#[derive(Debug, Clone)]
76pub struct CompiledPattern {
77 pub query: String,
80 pub metavars: Vec<String>,
82}
83
84pub fn compile_pattern(snippet: &str, language: Language) -> Result<CompiledPattern, String> {
92 let sub = substitute(snippet)?;
93
94 let mut metavar_node_patterns: HashMap<String, String> = HashMap::new();
98 for (metavar, constraint) in &sub.metavar_constraints {
99 metavar_node_patterns.insert(metavar.clone(), resolve_constraint(constraint, language)?);
100 }
101
102 let mut last_err: Option<String> = None;
103
104 for (prefix, suffix) in contexts(language) {
105 let wrapped = format!("{prefix}{}{suffix}", sub.text);
106 let tree = api::parse_tree(&wrapped, language).map_err(|err| err.to_string())?;
107 let root = tree.root_node();
108 if root.has_error() {
109 last_err = Some(format!(
110 "snippet did not parse cleanly in `{}`: `{snippet}`",
111 language.name()
112 ));
113 continue;
114 }
115
116 let start = prefix.len();
121 let end = start + sub.text.len();
122 let Some(pattern_root) = root.descendant_for_byte_range(start, end.saturating_sub(1))
123 else {
124 last_err = Some(format!(
125 "could not locate snippet subtree in `{}`",
126 language.name()
127 ));
128 continue;
129 };
130
131 let bytes = wrapped.as_bytes();
132 let mut builder =
133 QueryBuilder::new(bytes, &sub.placeholder_to_metavar, &metavar_node_patterns);
134 let body = builder.build(pattern_root);
135 let predicates = builder.predicates();
136 let query = if predicates.is_empty() {
137 format!("({body} @{ROOT_CAPTURE})")
138 } else {
139 format!("({body} @{ROOT_CAPTURE} {predicates})")
140 };
141 return Ok(CompiledPattern {
142 query,
143 metavars: sub.metavar_order,
144 });
145 }
146
147 Err(last_err.unwrap_or_else(|| format!("snippet did not parse in `{}`", language.name())))
148}
149
150fn contexts(language: Language) -> Vec<(&'static str, &'static str)> {
156 let mut v = vec![("", "")];
157 let wrapper = match language {
158 Language::Rust => Some(("fn __harn_probe() { ", " }")),
159 Language::Go => Some(("package p\nfunc __harn_probe() { ", " }")),
160 Language::Java | Language::CSharp => {
161 Some(("class __HarnProbe { void __harn_probe() { ", " } }"))
162 }
163 Language::C | Language::Cpp => Some(("void __harn_probe() { ", " }")),
164 Language::Kotlin => Some(("fun __harn_probe() { ", " }")),
165 Language::Swift => Some(("func __harn_probe() { ", " }")),
166 Language::Scala => Some(("def __harn_probe() = { ", " }")),
167 _ => None,
168 };
169 v.extend(wrapper);
170 v
171}
172
173struct Substituted {
178 text: String,
180 placeholder_to_metavar: HashMap<String, String>,
182 metavar_order: Vec<String>,
184 metavar_constraints: HashMap<String, String>,
187}
188
189fn substitute(snippet: &str) -> Result<Substituted, String> {
190 let mut text = String::with_capacity(snippet.len());
191 let mut placeholder_to_metavar = HashMap::new();
192 let mut metavar_to_placeholder: HashMap<String, String> = HashMap::new();
193 let mut metavar_order: Vec<String> = Vec::new();
194 let mut metavar_constraints: HashMap<String, String> = HashMap::new();
195
196 let bytes = snippet.as_bytes();
197 let mut i = 0;
198 while i < bytes.len() {
199 if bytes[i] != b'$' {
200 let ch = snippet[i..].chars().next().unwrap();
203 text.push(ch);
204 i += ch.len_utf8();
205 continue;
206 }
207 if snippet[i..].starts_with("$$$") {
208 return Err(
209 "variadic `$$$` metavariables are not yet supported (tracked in #2833)".into(),
210 );
211 }
212 let name_start = i + 1;
214 let mut j = name_start;
215 if j < bytes.len() && is_ident_start(bytes[j]) {
216 j += 1;
217 while j < bytes.len() && is_ident_continue(bytes[j]) {
218 j += 1;
219 }
220 }
221 if j == name_start {
222 text.push('$');
224 i += 1;
225 continue;
226 }
227 let name = &snippet[name_start..j];
228 let mut consumed_end = j;
233 if j < bytes.len() && bytes[j] == b':' {
234 let kind_start = j + 1;
235 if kind_start < bytes.len() && is_ident_start(bytes[kind_start]) {
236 let mut k = kind_start + 1;
237 while k < bytes.len() && is_ident_continue(bytes[k]) {
238 k += 1;
239 }
240 let constraint = &snippet[kind_start..k];
241 match metavar_constraints.get(name) {
242 Some(existing) if existing != constraint => {
243 return Err(format!(
244 "metavariable `${name}` has conflicting type constraints \
245 `:{existing}` and `:{constraint}`"
246 ));
247 }
248 _ => {
249 metavar_constraints.insert(name.to_string(), constraint.to_string());
250 }
251 }
252 consumed_end = k;
253 }
254 }
255 let placeholder = metavar_to_placeholder
256 .entry(name.to_string())
257 .or_insert_with(|| {
258 let placeholder = format!("{PLACEHOLDER_STEM}{}", metavar_order.len());
259 metavar_order.push(name.to_string());
260 placeholder
261 })
262 .clone();
263 placeholder_to_metavar.insert(placeholder.clone(), name.to_string());
264 text.push_str(&placeholder);
265 i = consumed_end;
266 }
267
268 Ok(Substituted {
272 text,
273 placeholder_to_metavar,
274 metavar_order,
275 metavar_constraints,
276 })
277}
278
279fn resolve_constraint(constraint: &str, language: Language) -> Result<String, String> {
284 let ts = language
285 .ts_language()
286 .ok_or_else(|| format!("no grammar for `{}`", language.name()))?;
287 let candidates: Vec<&str> = match constraint {
290 "expr" | "expression" => vec!["expression"],
291 "stmt" | "statement" => vec!["statement"],
292 "ty" | "type" => vec!["type"],
293 "ident" | "identifier" => vec!["identifier"],
294 other => vec![other],
295 };
296 let valid: Vec<String> = candidates
297 .iter()
298 .filter(|kind| ts.id_for_node_kind(kind, true) != 0)
299 .map(|kind| format!("({kind})"))
300 .collect();
301 if valid.is_empty() {
302 return Err(format!(
303 "typed placeholder `:{constraint}` is not a node kind in `{}` \
304 (use an exact tree-sitter kind)",
305 language.name()
306 ));
307 }
308 Ok(if valid.len() == 1 {
309 valid.into_iter().next().unwrap()
310 } else {
311 format!("[{}]", valid.join(" "))
312 })
313}
314
315fn is_ident_start(b: u8) -> bool {
316 b.is_ascii_alphabetic() || b == b'_'
317}
318
319fn is_ident_continue(b: u8) -> bool {
320 b.is_ascii_alphanumeric() || b == b'_'
321}
322
323struct QueryBuilder<'a> {
328 src: &'a [u8],
329 placeholder_to_metavar: &'a HashMap<String, String>,
330 metavar_node_patterns: &'a HashMap<String, String>,
333 occurrences: HashMap<String, usize>,
335 eq_predicates: Vec<String>,
337 literal_count: usize,
339}
340
341impl<'a> QueryBuilder<'a> {
342 fn new(
343 src: &'a [u8],
344 placeholder_to_metavar: &'a HashMap<String, String>,
345 metavar_node_patterns: &'a HashMap<String, String>,
346 ) -> Self {
347 QueryBuilder {
348 src,
349 placeholder_to_metavar,
350 metavar_node_patterns,
351 occurrences: HashMap::new(),
352 eq_predicates: Vec::new(),
353 literal_count: 0,
354 }
355 }
356
357 fn build(&mut self, node: Node<'_>) -> String {
358 if node.child_count() == 0 {
360 let text = self.node_text(node);
361 if let Some(metavar) = self.placeholder_to_metavar.get(text) {
362 let node_pattern = self
363 .metavar_node_patterns
364 .get(metavar)
365 .map(String::as_str)
366 .unwrap_or("(_)");
367 return format!("{node_pattern} @{}", self.capture_for(metavar));
368 }
369 if node.is_named() {
370 let cap = format!("__lit_{}", self.literal_count);
374 self.literal_count += 1;
375 self.eq_predicates
376 .push(format!("(#eq? @{cap} {})", quote_literal(text)));
377 return format!("({}) @{cap}", node.kind());
378 }
379 return quote_literal(text);
380 }
381
382 let mut parts: Vec<String> = Vec::new();
383 let mut cursor = node.walk();
384 for (i, child) in node.children(&mut cursor).enumerate() {
385 let sub = self.build(child);
386 match node.field_name_for_child(i as u32) {
390 Some(field) if child.is_named() => parts.push(format!("{field}: {sub}")),
391 _ => parts.push(sub),
392 }
393 }
394 format!("({} {})", node.kind(), parts.join(" "))
395 }
396
397 fn capture_for(&mut self, metavar: &str) -> String {
401 let count = self.occurrences.entry(metavar.to_string()).or_insert(0);
402 *count += 1;
403 if *count == 1 {
404 metavar.to_string()
405 } else {
406 let helper = format!("{metavar}.{count}");
407 self.eq_predicates
408 .push(format!("(#eq? @{metavar} @{helper})"));
409 helper
410 }
411 }
412
413 fn predicates(&self) -> String {
414 self.eq_predicates.join(" ")
415 }
416
417 fn node_text(&self, node: Node<'_>) -> &'a str {
418 std::str::from_utf8(&self.src[node.start_byte()..node.end_byte()]).unwrap_or_default()
419 }
420}
421
422fn quote_literal(text: &str) -> String {
425 let mut out = String::with_capacity(text.len() + 2);
426 out.push('"');
427 for ch in text.chars() {
428 if ch == '"' || ch == '\\' {
429 out.push('\\');
430 }
431 out.push(ch);
432 }
433 out.push('"');
434 out
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440 use streaming_iterator::StreamingIterator;
441 use tree_sitter::{Query, QueryCursor};
442
443 fn run(snippet: &str, language: Language, code: &str) -> Vec<(String, Vec<String>)> {
446 let compiled = compile_pattern(snippet, language).expect("compiles");
447 let ts_language = language.ts_language().expect("grammar");
448 let query = Query::new(&ts_language, &compiled.query)
449 .unwrap_or_else(|e| panic!("query rejected: {e}\nquery: {}", compiled.query));
450 let tree = api::parse_tree(code, language).expect("parse code");
451 let names: Vec<&str> = query.capture_names().to_vec();
452 let mut cursor = QueryCursor::new();
453 let mut matches = cursor.matches(&query, tree.root_node(), code.as_bytes());
454 let mut out = Vec::new();
455 while let Some(m) = matches.next() {
456 let mut per_capture: HashMap<String, Vec<String>> = HashMap::new();
457 for cap in m.captures {
458 let name = names[cap.index as usize].to_string();
459 let text = code[cap.node.start_byte()..cap.node.end_byte()].to_string();
460 per_capture.entry(name).or_default().push(text);
461 }
462 for (name, texts) in per_capture {
463 out.push((name, texts));
464 }
465 }
466 out
467 }
468
469 fn capture<'a>(binds: &'a [(String, Vec<String>)], name: &str) -> &'a [String] {
470 binds
471 .iter()
472 .find(|(n, _)| n == name)
473 .map(|(_, v)| v.as_slice())
474 .unwrap_or(&[])
475 }
476
477 #[test]
478 fn compiles_destructuring_default_in_typescript() {
479 let snippet = "$SRC?.$KEY ?? $DEFAULT";
481 let compiled = compile_pattern(snippet, Language::TypeScript).expect("compiles");
482 assert_eq!(compiled.metavars, vec!["SRC", "KEY", "DEFAULT"]);
483 let binds = run(
485 snippet,
486 Language::TypeScript,
487 "const a = cfg?.timeout ?? 30;",
488 );
489 assert_eq!(capture(&binds, "SRC"), ["cfg".to_string()]);
490 assert_eq!(capture(&binds, "KEY"), ["timeout".to_string()]);
491 assert_eq!(capture(&binds, "DEFAULT"), ["30".to_string()]);
492 }
493
494 #[test]
495 fn compiles_optional_chain_nil_coalescing_in_harn() {
496 let snippet = "$SRC?.$KEY ?? $DEFAULT";
497 let compiled = compile_pattern(snippet, Language::Harn).expect("compiles");
498 assert_eq!(compiled.metavars, vec!["SRC", "KEY", "DEFAULT"]);
499 let binds = run(
500 snippet,
501 Language::Harn,
502 "fn main() {\n let timeout = cfg?.timeout ?? 30\n}\n",
503 );
504 assert_eq!(capture(&binds, "SRC"), ["cfg".to_string()]);
505 assert_eq!(capture(&binds, "KEY"), ["timeout".to_string()]);
506 assert_eq!(capture(&binds, "DEFAULT"), ["30".to_string()]);
507 }
508
509 #[test]
510 fn operator_is_constrained_not_just_structure() {
511 let snippet = "$SRC?.$KEY ?? $DEFAULT";
514 let binds = run(
515 snippet,
516 Language::TypeScript,
517 "const a = cfg?.timeout || 30;",
518 );
519 assert!(
520 capture(&binds, "SRC").is_empty(),
521 "|| must not match the ?? pattern"
522 );
523 }
524
525 #[test]
526 fn round_trips_the_assignment_form() {
527 let snippet = "$NAME = $SRC?.$KEY ?? $DEFAULT";
529 let compiled = compile_pattern(snippet, Language::TypeScript).expect("compiles");
530 assert_eq!(compiled.metavars, vec!["NAME", "SRC", "KEY", "DEFAULT"]);
531 let binds = run(
532 snippet,
533 Language::TypeScript,
534 "x = src?.userId ?? fallback;",
535 );
536 assert_eq!(capture(&binds, "NAME"), ["x".to_string()]);
537 assert_eq!(capture(&binds, "SRC"), ["src".to_string()]);
538 assert_eq!(capture(&binds, "KEY"), ["userId".to_string()]);
539 assert_eq!(capture(&binds, "DEFAULT"), ["fallback".to_string()]);
540 }
541
542 #[test]
543 fn lifts_metavars_in_rust() {
544 let snippet = "let $NAME = $VALUE;";
545 let binds = run(snippet, Language::Rust, "fn f() { let total = compute(); }");
546 assert_eq!(capture(&binds, "NAME"), ["total".to_string()]);
547 assert_eq!(capture(&binds, "VALUE"), ["compute()".to_string()]);
548 }
549
550 #[test]
551 fn lifts_metavars_in_python() {
552 let snippet = "$FN($ARG)";
553 let binds = run(snippet, Language::Python, "print(value)");
554 assert_eq!(capture(&binds, "FN"), ["print".to_string()]);
555 assert_eq!(capture(&binds, "ARG"), ["value".to_string()]);
556 }
557
558 #[test]
559 fn lifts_metavars_in_go() {
560 let snippet = "$FN($ARG)";
561 let binds = run(snippet, Language::Go, "package main\nfunc m() { log(err) }");
562 assert_eq!(capture(&binds, "FN"), ["log".to_string()]);
563 assert_eq!(capture(&binds, "ARG"), ["err".to_string()]);
564 }
565
566 #[test]
567 fn repeated_metavar_unifies() {
568 let snippet = "$X + $X";
570 let same = run(snippet, Language::Rust, "fn f() { let _ = a + a; }");
571 assert_eq!(capture(&same, "X"), ["a".to_string()]);
572 let different = run(snippet, Language::Rust, "fn f() { let _ = a + b; }");
573 assert!(
574 capture(&different, "X").is_empty(),
575 "unification must reject `a + b`"
576 );
577 }
578
579 #[test]
580 fn rejects_unparseable_snippet() {
581 let err = compile_pattern("$A ?? ?? $B", Language::TypeScript).unwrap_err();
582 assert!(err.contains("did not parse"), "got: {err}");
583 }
584
585 #[test]
586 fn rejects_variadic_for_now() {
587 let err = compile_pattern("foo($$$ARGS)", Language::TypeScript).unwrap_err();
588 assert!(err.contains("variadic"), "got: {err}");
589 }
590
591 #[test]
592 fn typed_placeholder_narrows_to_kind() {
593 let snippet = "$FN($ARG:identifier)";
595 let compiled = compile_pattern(snippet, Language::TypeScript).expect("compiles");
596 assert_eq!(compiled.metavars, vec!["FN", "ARG"]);
598 let hit = run(snippet, Language::TypeScript, "f(x);");
600 assert_eq!(capture(&hit, "ARG"), ["x".to_string()]);
601 let miss = run(snippet, Language::TypeScript, "f(g());");
603 assert!(
604 capture(&miss, "ARG").is_empty(),
605 "a call argument must not match `:identifier`: {miss:?}"
606 );
607 }
608
609 #[test]
610 fn typed_placeholder_expression_alias_matches_any_expression() {
611 let snippet = "$FN($ARG:expression)";
615 let ident = run(snippet, Language::TypeScript, "f(x);");
616 assert_eq!(capture(&ident, "ARG"), ["x".to_string()]);
617 let call = run(snippet, Language::TypeScript, "f(g());");
618 assert_eq!(capture(&call, "ARG"), ["g()".to_string()]);
619 }
620
621 #[test]
622 fn typed_placeholder_unknown_kind_is_an_error() {
623 let err = compile_pattern("$X:not_a_real_kind", Language::TypeScript).unwrap_err();
624 assert!(err.contains("not a node kind"), "got: {err}");
625 }
626
627 #[test]
628 fn typed_placeholder_alias_unavailable_in_grammar_errors() {
629 let err = compile_pattern("let $X = $V:expression;", Language::Rust).unwrap_err();
632 assert!(err.contains("not a node kind"), "got: {err}");
633 }
634
635 #[test]
636 fn typed_placeholder_unifies_and_constrains() {
637 let snippet = "$X:identifier + $X";
639 let same = run(snippet, Language::Rust, "fn f() { let _ = a + a; }");
640 assert_eq!(capture(&same, "X"), ["a".to_string()]);
641 let different = run(snippet, Language::Rust, "fn f() { let _ = a + b; }");
642 assert!(
643 capture(&different, "X").is_empty(),
644 "unification still holds"
645 );
646 }
647
648 #[test]
649 fn colon_without_constraint_is_left_literal() {
650 let snippet = "{$KEY: $VAL}";
653 let compiled = compile_pattern(snippet, Language::TypeScript).expect("compiles");
654 assert_eq!(compiled.metavars, vec!["KEY", "VAL"]);
655 let binds = run(snippet, Language::TypeScript, "let o = {a: 1};");
656 assert_eq!(capture(&binds, "KEY"), ["a".to_string()]);
657 assert_eq!(capture(&binds, "VAL"), ["1".to_string()]);
658 }
659
660 #[test]
661 fn literal_pattern_matches_exact_text() {
662 let snippet = "foo()";
665 let compiled = compile_pattern(snippet, Language::TypeScript).expect("compiles");
666 assert!(compiled.metavars.is_empty());
667 let hit = run(snippet, Language::TypeScript, "foo();");
669 assert!(!hit.is_empty());
670 let miss = run(snippet, Language::TypeScript, "bar();");
672 assert!(
673 miss.is_empty(),
674 "bar() must not match foo()'s literal pattern: {miss:?}"
675 );
676 }
677}