Skip to main content

harn_rules/
pattern.rs

1//! The pattern compiler: a code snippet with metavariable holes → a
2//! tree-sitter query.
3//!
4//! This is the atomic-tier `pattern` form. The idea (from ast-grep) is to
5//! let rule authors write a *snippet of real code* with `$VAR` holes
6//! instead of hand-authoring a tree-sitter S-expression query:
7//!
8//! ```text
9//!   $SRC?.$KEY ?? $DEFAULT
10//! ```
11//!
12//! compiles to
13//!
14//! ```text
15//!   ((binary_expression
16//!      left: (member_expression object: (_) @SRC (optional_chain) property: (_) @KEY)
17//!      "??"
18//!      right: (_) @DEFAULT) @__match)
19//! ```
20//!
21//! ## How it works
22//!
23//! 1. Each `$VAR` is replaced with a unique placeholder identifier so the
24//!    snippet parses as ordinary code in the target grammar.
25//! 2. We parse the substituted snippet — bare, then in a per-language
26//!    wrapper context (e.g. a function body) when the fragment is not a
27//!    valid compilation unit — and locate the snippet's own subtree by its
28//!    byte range in the parsed source.
29//! 3. We walk that subtree and mirror it into a query: every named child is
30//!    emitted with its field name, every anonymous token (operators,
31//!    keywords, punctuation) is emitted as a quoted literal so the structure
32//!    is matched precisely, and every placeholder becomes a `(_) @VAR`
33//!    wildcard capture.
34//! 4. Repeated metavariables unify: the second and later occurrences get
35//!    helper captures plus an `(#eq? …)` predicate so `$X … $X` only matches
36//!    when both holes carry identical text.
37//!
38//! Variadic `$$$` holes are not yet supported (tracked for the relational
39//! tier, #2833); they compile to a clear error.
40
41use std::collections::HashMap;
42
43use harn_hostlib::ast::{api, Language};
44use tree_sitter::Node;
45
46/// The capture name bound to the whole matched pattern, used for range
47/// extraction. Chosen to not collide with a user metavar (which are
48/// uppercase by convention and never start with `__`).
49pub const ROOT_CAPTURE: &str = "__match";
50
51/// Placeholder identifier stem substituted for each `$VAR`. Lowercase +
52/// `__` prefix keeps it a valid identifier across grammars and unlikely to
53/// collide with real snippet text.
54const PLACEHOLDER_STEM: &str = "__harn_hole_";
55
56/// A snippet pattern compiled to a tree-sitter query string.
57#[derive(Debug, Clone)]
58pub struct CompiledPattern {
59    /// The generated S-expression query. Always binds the pattern root to
60    /// `@__match` ([`ROOT_CAPTURE`]).
61    pub query: String,
62    /// Metavar names in first-appearance order (without the leading `$`).
63    pub metavars: Vec<String>,
64}
65
66/// Compile a `pattern` snippet for `language` into a tree-sitter query.
67///
68/// A snippet is often a *fragment* (`a + a`, `foo(bar)`) that is not a
69/// valid compilation unit on its own. We therefore try the snippet bare
70/// first (works for expression-statement languages like TS/JS/Python),
71/// then in a small set of per-language wrapper contexts (e.g. a function
72/// body for Rust/Go), and locate the snippet's own subtree by byte range.
73pub fn compile_pattern(snippet: &str, language: Language) -> Result<CompiledPattern, String> {
74    let sub = substitute(snippet)?;
75    let mut last_err: Option<String> = None;
76
77    for (prefix, suffix) in contexts(language) {
78        let wrapped = format!("{prefix}{}{suffix}", sub.text);
79        let tree = api::parse_tree(&wrapped, language).map_err(|err| err.to_string())?;
80        let root = tree.root_node();
81        if root.has_error() {
82            last_err = Some(format!(
83                "snippet did not parse cleanly in `{}`: `{snippet}`",
84                language.name()
85            ));
86            continue;
87        }
88
89        // The snippet occupies `[start, end)` inside the wrapped source; the
90        // deepest node spanning that range is its own subtree (no need to
91        // descend wrappers — and no risk of over-descending a single-child
92        // node like a unary expression).
93        let start = prefix.len();
94        let end = start + sub.text.len();
95        let Some(pattern_root) = root.descendant_for_byte_range(start, end.saturating_sub(1))
96        else {
97            last_err = Some(format!(
98                "could not locate snippet subtree in `{}`",
99                language.name()
100            ));
101            continue;
102        };
103
104        let bytes = wrapped.as_bytes();
105        let mut builder = QueryBuilder::new(bytes, &sub.placeholder_to_metavar);
106        let body = builder.build(pattern_root);
107        let predicates = builder.predicates();
108        let query = if predicates.is_empty() {
109            format!("({body} @{ROOT_CAPTURE})")
110        } else {
111            format!("({body} @{ROOT_CAPTURE} {predicates})")
112        };
113        return Ok(CompiledPattern {
114            query,
115            metavars: sub.metavar_order,
116        });
117    }
118
119    Err(last_err.unwrap_or_else(|| format!("snippet did not parse in `{}`", language.name())))
120}
121
122/// Candidate parse contexts for a snippet, tried in order. The bare context
123/// (`""`, `""`) comes first; item-required languages add a wrapper that
124/// makes an expression/statement fragment parse. Languages whose top level
125/// already accepts expression statements (TS/JS/Python/Ruby/…) only need
126/// the bare context.
127fn contexts(language: Language) -> Vec<(&'static str, &'static str)> {
128    let mut v = vec![("", "")];
129    let wrapper = match language {
130        Language::Rust => Some(("fn __harn_probe() { ", " }")),
131        Language::Go => Some(("package p\nfunc __harn_probe() { ", " }")),
132        Language::Java | Language::CSharp => {
133            Some(("class __HarnProbe { void __harn_probe() { ", " } }"))
134        }
135        Language::C | Language::Cpp => Some(("void __harn_probe() { ", " }")),
136        Language::Kotlin => Some(("fun __harn_probe() { ", " }")),
137        Language::Swift => Some(("func __harn_probe() { ", " }")),
138        Language::Scala => Some(("def __harn_probe() = { ", " }")),
139        _ => None,
140    };
141    v.extend(wrapper);
142    v
143}
144
145// ---------------------------------------------------------------------------
146// Step 1: metavar substitution
147// ---------------------------------------------------------------------------
148
149struct Substituted {
150    /// Snippet with `$VAR` replaced by placeholder identifiers.
151    text: String,
152    /// placeholder identifier → metavar name.
153    placeholder_to_metavar: HashMap<String, String>,
154    /// Metavar names in first-appearance order.
155    metavar_order: Vec<String>,
156}
157
158fn substitute(snippet: &str) -> Result<Substituted, String> {
159    let mut text = String::with_capacity(snippet.len());
160    let mut placeholder_to_metavar = HashMap::new();
161    let mut metavar_to_placeholder: HashMap<String, String> = HashMap::new();
162    let mut metavar_order: Vec<String> = Vec::new();
163
164    let bytes = snippet.as_bytes();
165    let mut i = 0;
166    while i < bytes.len() {
167        if bytes[i] != b'$' {
168            // Copy this UTF-8 scalar verbatim. Indexing the &str at byte
169            // boundaries is safe because we only special-case ASCII `$`.
170            let ch = snippet[i..].chars().next().unwrap();
171            text.push(ch);
172            i += ch.len_utf8();
173            continue;
174        }
175        if snippet[i..].starts_with("$$$") {
176            return Err(
177                "variadic `$$$` metavariables are not yet supported (tracked in #2833)".into(),
178            );
179        }
180        // Parse `$NAME` where NAME is `[A-Za-z_][A-Za-z0-9_]*`.
181        let name_start = i + 1;
182        let mut j = name_start;
183        if j < bytes.len() && is_ident_start(bytes[j]) {
184            j += 1;
185            while j < bytes.len() && is_ident_continue(bytes[j]) {
186                j += 1;
187            }
188        }
189        if j == name_start {
190            // A lone `$` that is not a metavar — keep it literal.
191            text.push('$');
192            i += 1;
193            continue;
194        }
195        let name = &snippet[name_start..j];
196        let placeholder = metavar_to_placeholder
197            .entry(name.to_string())
198            .or_insert_with(|| {
199                let placeholder = format!("{PLACEHOLDER_STEM}{}", metavar_order.len());
200                metavar_order.push(name.to_string());
201                placeholder
202            })
203            .clone();
204        placeholder_to_metavar.insert(placeholder.clone(), name.to_string());
205        text.push_str(&placeholder);
206        i = j;
207    }
208
209    // A pattern with no metavars is a valid *literal* pattern (it matches a
210    // fixed structure), so we do not require one.
211
212    Ok(Substituted {
213        text,
214        placeholder_to_metavar,
215        metavar_order,
216    })
217}
218
219fn is_ident_start(b: u8) -> bool {
220    b.is_ascii_alphabetic() || b == b'_'
221}
222
223fn is_ident_continue(b: u8) -> bool {
224    b.is_ascii_alphanumeric() || b == b'_'
225}
226
227// ---------------------------------------------------------------------------
228// Step 2: walk the located subtree into a query
229// ---------------------------------------------------------------------------
230
231struct QueryBuilder<'a> {
232    src: &'a [u8],
233    placeholder_to_metavar: &'a HashMap<String, String>,
234    /// occurrence count per metavar, to mint unification helper captures.
235    occurrences: HashMap<String, usize>,
236    /// `(#eq? …)` predicates for repeated metavars and literal leaves.
237    eq_predicates: Vec<String>,
238    /// counter for literal-leaf text-constraint captures.
239    literal_count: usize,
240}
241
242impl<'a> QueryBuilder<'a> {
243    fn new(src: &'a [u8], placeholder_to_metavar: &'a HashMap<String, String>) -> Self {
244        QueryBuilder {
245            src,
246            placeholder_to_metavar,
247            occurrences: HashMap::new(),
248            eq_predicates: Vec::new(),
249            literal_count: 0,
250        }
251    }
252
253    fn build(&mut self, node: Node<'_>) -> String {
254        // A placeholder leaf is a metavar hole.
255        if node.child_count() == 0 {
256            let text = self.node_text(node);
257            if let Some(metavar) = self.placeholder_to_metavar.get(text) {
258                return format!("(_) @{}", self.capture_for(metavar));
259            }
260            if node.is_named() {
261                // A literal named leaf (a specific identifier / literal in
262                // the snippet): constrain it to its exact text so `foo()`
263                // matches calls to `foo`, not any call.
264                let cap = format!("__lit_{}", self.literal_count);
265                self.literal_count += 1;
266                self.eq_predicates
267                    .push(format!("(#eq? @{cap} {})", quote_literal(text)));
268                return format!("({}) @{cap}", node.kind());
269            }
270            return quote_literal(text);
271        }
272
273        let mut parts: Vec<String> = Vec::new();
274        let mut cursor = node.walk();
275        for (i, child) in node.children(&mut cursor).enumerate() {
276            let sub = self.build(child);
277            // Field names only attach to named children; an anonymous token
278            // in a field slot is matched positionally as a literal, which
279            // tree-sitter accepts where `field: "literal"` may not.
280            match node.field_name_for_child(i as u32) {
281                Some(field) if child.is_named() => parts.push(format!("{field}: {sub}")),
282                _ => parts.push(sub),
283            }
284        }
285        format!("({} {})", node.kind(), parts.join(" "))
286    }
287
288    /// Mint the capture name for this occurrence of `metavar`. The first
289    /// occurrence is `@NAME`; later ones are `@NAME.k` plus an `(#eq? …)`
290    /// predicate tying them to the first (metavar unification).
291    fn capture_for(&mut self, metavar: &str) -> String {
292        let count = self.occurrences.entry(metavar.to_string()).or_insert(0);
293        *count += 1;
294        if *count == 1 {
295            metavar.to_string()
296        } else {
297            let helper = format!("{metavar}.{count}");
298            self.eq_predicates
299                .push(format!("(#eq? @{metavar} @{helper})"));
300            helper
301        }
302    }
303
304    fn predicates(&self) -> String {
305        self.eq_predicates.join(" ")
306    }
307
308    fn node_text(&self, node: Node<'_>) -> &'a str {
309        std::str::from_utf8(&self.src[node.start_byte()..node.end_byte()]).unwrap_or_default()
310    }
311}
312
313/// Quote an anonymous token as a tree-sitter query literal, escaping `"`
314/// and `\`.
315fn quote_literal(text: &str) -> String {
316    let mut out = String::with_capacity(text.len() + 2);
317    out.push('"');
318    for ch in text.chars() {
319        if ch == '"' || ch == '\\' {
320            out.push('\\');
321        }
322        out.push(ch);
323    }
324    out.push('"');
325    out
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use streaming_iterator::StreamingIterator;
332    use tree_sitter::{Query, QueryCursor};
333
334    /// Compile `snippet`, run the query against `code`, and return the
335    /// captured text for each requested metavar from the first match.
336    fn run(snippet: &str, language: Language, code: &str) -> Vec<(String, Vec<String>)> {
337        let compiled = compile_pattern(snippet, language).expect("compiles");
338        let ts_language = language.ts_language().expect("grammar");
339        let query = Query::new(&ts_language, &compiled.query)
340            .unwrap_or_else(|e| panic!("query rejected: {e}\nquery: {}", compiled.query));
341        let tree = api::parse_tree(code, language).expect("parse code");
342        let names: Vec<&str> = query.capture_names().to_vec();
343        let mut cursor = QueryCursor::new();
344        let mut matches = cursor.matches(&query, tree.root_node(), code.as_bytes());
345        let mut out = Vec::new();
346        while let Some(m) = matches.next() {
347            let mut per_capture: HashMap<String, Vec<String>> = HashMap::new();
348            for cap in m.captures {
349                let name = names[cap.index as usize].to_string();
350                let text = code[cap.node.start_byte()..cap.node.end_byte()].to_string();
351                per_capture.entry(name).or_default().push(text);
352            }
353            for (name, texts) in per_capture {
354                out.push((name, texts));
355            }
356        }
357        out
358    }
359
360    fn capture<'a>(binds: &'a [(String, Vec<String>)], name: &str) -> &'a [String] {
361        binds
362            .iter()
363            .find(|(n, _)| n == name)
364            .map(|(_, v)| v.as_slice())
365            .unwrap_or(&[])
366    }
367
368    #[test]
369    fn compiles_destructuring_default_in_typescript() {
370        // The #2824 / burin-code#1629 codemod shape.
371        let snippet = "$SRC?.$KEY ?? $DEFAULT";
372        let compiled = compile_pattern(snippet, Language::TypeScript).expect("compiles");
373        assert_eq!(compiled.metavars, vec!["SRC", "KEY", "DEFAULT"]);
374        // It captures the optional-chain object/property and the fallback.
375        let binds = run(
376            snippet,
377            Language::TypeScript,
378            "const a = cfg?.timeout ?? 30;",
379        );
380        assert_eq!(capture(&binds, "SRC"), ["cfg".to_string()]);
381        assert_eq!(capture(&binds, "KEY"), ["timeout".to_string()]);
382        assert_eq!(capture(&binds, "DEFAULT"), ["30".to_string()]);
383    }
384
385    #[test]
386    fn operator_is_constrained_not_just_structure() {
387        // The `??` literal in the query must reject a `||` with the same
388        // structural shape — otherwise the codemod would be unsound.
389        let snippet = "$SRC?.$KEY ?? $DEFAULT";
390        let binds = run(
391            snippet,
392            Language::TypeScript,
393            "const a = cfg?.timeout || 30;",
394        );
395        assert!(
396            capture(&binds, "SRC").is_empty(),
397            "|| must not match the ?? pattern"
398        );
399    }
400
401    #[test]
402    fn round_trips_the_assignment_form() {
403        // The literal acceptance pattern: `$NAME = $SRC?.$KEY ?? $DEFAULT`.
404        let snippet = "$NAME = $SRC?.$KEY ?? $DEFAULT";
405        let compiled = compile_pattern(snippet, Language::TypeScript).expect("compiles");
406        assert_eq!(compiled.metavars, vec!["NAME", "SRC", "KEY", "DEFAULT"]);
407        let binds = run(
408            snippet,
409            Language::TypeScript,
410            "x = src?.userId ?? fallback;",
411        );
412        assert_eq!(capture(&binds, "NAME"), ["x".to_string()]);
413        assert_eq!(capture(&binds, "SRC"), ["src".to_string()]);
414        assert_eq!(capture(&binds, "KEY"), ["userId".to_string()]);
415        assert_eq!(capture(&binds, "DEFAULT"), ["fallback".to_string()]);
416    }
417
418    #[test]
419    fn lifts_metavars_in_rust() {
420        let snippet = "let $NAME = $VALUE;";
421        let binds = run(snippet, Language::Rust, "fn f() { let total = compute(); }");
422        assert_eq!(capture(&binds, "NAME"), ["total".to_string()]);
423        assert_eq!(capture(&binds, "VALUE"), ["compute()".to_string()]);
424    }
425
426    #[test]
427    fn lifts_metavars_in_python() {
428        let snippet = "$FN($ARG)";
429        let binds = run(snippet, Language::Python, "print(value)");
430        assert_eq!(capture(&binds, "FN"), ["print".to_string()]);
431        assert_eq!(capture(&binds, "ARG"), ["value".to_string()]);
432    }
433
434    #[test]
435    fn lifts_metavars_in_go() {
436        let snippet = "$FN($ARG)";
437        let binds = run(snippet, Language::Go, "package main\nfunc m() { log(err) }");
438        assert_eq!(capture(&binds, "FN"), ["log".to_string()]);
439        assert_eq!(capture(&binds, "ARG"), ["err".to_string()]);
440    }
441
442    #[test]
443    fn repeated_metavar_unifies() {
444        // `$X + $X` must match `a + a` but not `a + b`.
445        let snippet = "$X + $X";
446        let same = run(snippet, Language::Rust, "fn f() { let _ = a + a; }");
447        assert_eq!(capture(&same, "X"), ["a".to_string()]);
448        let different = run(snippet, Language::Rust, "fn f() { let _ = a + b; }");
449        assert!(
450            capture(&different, "X").is_empty(),
451            "unification must reject `a + b`"
452        );
453    }
454
455    #[test]
456    fn rejects_unparseable_snippet() {
457        let err = compile_pattern("$A ?? ?? $B", Language::TypeScript).unwrap_err();
458        assert!(err.contains("did not parse"), "got: {err}");
459    }
460
461    #[test]
462    fn rejects_variadic_for_now() {
463        let err = compile_pattern("foo($$$ARGS)", Language::TypeScript).unwrap_err();
464        assert!(err.contains("variadic"), "got: {err}");
465    }
466
467    #[test]
468    fn literal_pattern_matches_exact_text() {
469        // A metavar-free pattern is a literal pattern: `foo()` matches calls
470        // to `foo`, not to other functions.
471        let snippet = "foo()";
472        let compiled = compile_pattern(snippet, Language::TypeScript).expect("compiles");
473        assert!(compiled.metavars.is_empty());
474        // It matches `foo()` …
475        let hit = run(snippet, Language::TypeScript, "foo();");
476        assert!(!hit.is_empty());
477        // … but not `bar()` (the literal identifier is constrained).
478        let miss = run(snippet, Language::TypeScript, "bar();");
479        assert!(
480            miss.is_empty(),
481            "bar() must not match foo()'s literal pattern: {miss:?}"
482        );
483    }
484}