Skip to main content

recast_core/
structural.rs

1//! Tree-sitter-backed structural rewrites (feature `structural`).
2//!
3//! Pattern syntax is tree-sitter's S-expression Query language.
4//! Captures (`@name`) feed the rewrite template, which can reference
5//! them as `$name`. The capture named `@root` (or, if absent, the
6//! outermost match node) defines the byte range that gets replaced.
7
8use std::path::Path;
9
10use rayon::prelude::*;
11use tree_sitter::{Language as TsLanguage, Node, Parser, Query, QueryCursor, StreamingIterator};
12
13use crate::error::{Error, Result};
14use crate::plan::{
15    FileChange, Plan, PlanOptions, PlanOutcome, check_match_counts, read_text_or_skip_binary,
16};
17use crate::rewrite::{label_for_path, unified_diff};
18use crate::search::{
19    SearchFile, SearchMatch, SearchOptions, SearchPlan, collect, scan, truncate_snippet,
20};
21use crate::walker::walk_paths;
22
23const METAVAR_PREFIX: &str = "__RECAST_VAR_";
24const ELLIPSIS_PREFIX: &str = "__RECAST_ELLIPSIS_";
25const METAVAR_SUFFIX: &str = "__";
26
27/// Language registry for structural rewrites. Variants are gated by
28/// the matching `lang-*` cargo feature; build with `--features
29/// lang-all` to enable every grammar shipped today.
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31#[non_exhaustive]
32pub enum Language {
33    #[cfg(feature = "lang-rust")]
34    Rust,
35    #[cfg(feature = "lang-ts")]
36    TypeScript,
37    #[cfg(feature = "lang-ts")]
38    Tsx,
39    #[cfg(feature = "lang-js")]
40    JavaScript,
41    #[cfg(feature = "lang-python")]
42    Python,
43    #[cfg(feature = "lang-bash")]
44    Bash,
45    #[cfg(feature = "lang-go")]
46    Go,
47    #[cfg(feature = "lang-json")]
48    Json,
49    #[cfg(feature = "lang-md")]
50    Markdown,
51}
52
53impl Language {
54    /// Resolve a CLI-friendly name (case-insensitive) to a language.
55    /// Returns [`Error::UnknownLanguage`] for names that aren't
56    /// recognized or whose `lang-*` feature wasn't compiled in.
57    pub fn from_name(name: &str) -> Result<Self> {
58        match name.to_ascii_lowercase().as_str() {
59            #[cfg(feature = "lang-rust")]
60            "rust" | "rs" => Ok(Language::Rust),
61            #[cfg(feature = "lang-ts")]
62            "typescript" | "ts" => Ok(Language::TypeScript),
63            #[cfg(feature = "lang-ts")]
64            "tsx" => Ok(Language::Tsx),
65            #[cfg(feature = "lang-js")]
66            "javascript" | "js" | "jsx" => Ok(Language::JavaScript),
67            #[cfg(feature = "lang-python")]
68            "python" | "py" => Ok(Language::Python),
69            #[cfg(feature = "lang-bash")]
70            "bash" | "sh" | "shell" => Ok(Language::Bash),
71            #[cfg(feature = "lang-go")]
72            "go" | "golang" => Ok(Language::Go),
73            #[cfg(feature = "lang-json")]
74            "json" => Ok(Language::Json),
75            #[cfg(feature = "lang-md")]
76            "markdown" | "md" => Ok(Language::Markdown),
77            _ => Err(Error::UnknownLanguage(name.to_owned())),
78        }
79    }
80
81    /// Infer the grammar from a file extension. Returns `None` for
82    /// extensions without a compiled grammar — the syntax-regression
83    /// guard is skipped for those files, so a `--no-default-features`
84    /// build that drops every `lang-*` feature keeps working unchecked.
85    pub fn from_path(path: &Path) -> Option<Self> {
86        match path.extension()?.to_str()? {
87            #[cfg(feature = "lang-rust")]
88            "rs" => Some(Language::Rust),
89            #[cfg(feature = "lang-ts")]
90            "ts" => Some(Language::TypeScript),
91            #[cfg(feature = "lang-ts")]
92            "tsx" => Some(Language::Tsx),
93            #[cfg(feature = "lang-js")]
94            "js" | "mjs" | "cjs" | "jsx" => Some(Language::JavaScript),
95            #[cfg(feature = "lang-python")]
96            "py" | "pyi" => Some(Language::Python),
97            #[cfg(feature = "lang-bash")]
98            "sh" | "bash" => Some(Language::Bash),
99            #[cfg(feature = "lang-go")]
100            "go" => Some(Language::Go),
101            #[cfg(feature = "lang-json")]
102            "json" => Some(Language::Json),
103            #[cfg(feature = "lang-md")]
104            "md" | "markdown" => Some(Language::Markdown),
105            _ => None,
106        }
107    }
108
109    /// Stable lowercase name for diagnostics (matches the canonical
110    /// `from_name` alias).
111    pub(crate) fn name(self) -> &'static str {
112        match self {
113            #[cfg(feature = "lang-rust")]
114            Language::Rust => "rust",
115            #[cfg(feature = "lang-ts")]
116            Language::TypeScript => "typescript",
117            #[cfg(feature = "lang-ts")]
118            Language::Tsx => "tsx",
119            #[cfg(feature = "lang-js")]
120            Language::JavaScript => "javascript",
121            #[cfg(feature = "lang-python")]
122            Language::Python => "python",
123            #[cfg(feature = "lang-bash")]
124            Language::Bash => "bash",
125            #[cfg(feature = "lang-go")]
126            Language::Go => "go",
127            #[cfg(feature = "lang-json")]
128            Language::Json => "json",
129            #[cfg(feature = "lang-md")]
130            Language::Markdown => "markdown",
131        }
132    }
133
134    fn ts_language(self) -> TsLanguage {
135        match self {
136            #[cfg(feature = "lang-rust")]
137            Language::Rust => tree_sitter_rust::LANGUAGE.into(),
138            #[cfg(feature = "lang-ts")]
139            Language::TypeScript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
140            #[cfg(feature = "lang-ts")]
141            Language::Tsx => tree_sitter_typescript::LANGUAGE_TSX.into(),
142            #[cfg(feature = "lang-js")]
143            Language::JavaScript => tree_sitter_javascript::LANGUAGE.into(),
144            #[cfg(feature = "lang-python")]
145            Language::Python => tree_sitter_python::LANGUAGE.into(),
146            #[cfg(feature = "lang-bash")]
147            Language::Bash => tree_sitter_bash::LANGUAGE.into(),
148            #[cfg(feature = "lang-go")]
149            Language::Go => tree_sitter_go::LANGUAGE.into(),
150            #[cfg(feature = "lang-json")]
151            Language::Json => tree_sitter_json::LANGUAGE.into(),
152            #[cfg(feature = "lang-md")]
153            Language::Markdown => tree_sitter_md::LANGUAGE.into(),
154        }
155    }
156}
157
158/// Parse `src` with `lang`'s grammar and count `ERROR` + `MISSING`
159/// nodes. A parse that yields no tree (grammar load failure) counts
160/// as zero — the guard compares pre/post deltas, so an unparsable
161/// language degrades to "no regression detected" rather than a false
162/// positive.
163pub(crate) fn count_error_nodes(lang: Language, src: &str) -> usize {
164    let mut parser = Parser::new();
165    if parser.set_language(&lang.ts_language()).is_err() {
166        return 0;
167    }
168    let Some(tree) = parser.parse(src, None) else {
169        return 0;
170    };
171    let mut count = 0usize;
172    let mut stack = vec![tree.root_node()];
173    while let Some(node) = stack.pop() {
174        if node.is_error() || node.is_missing() {
175            count += 1;
176        }
177        let mut c = node.walk();
178        for child in node.children(&mut c) {
179            stack.push(child);
180        }
181    }
182    count
183}
184
185/// Reject a rewrite whose post-image introduces *new* syntax errors.
186/// Files whose extension maps to no compiled grammar pass through
187/// unchecked. The comparison is a count delta — a file that was
188/// already unparsable stays acceptable as long as the rewrite doesn't
189/// make it worse. Shared by the regex/script planner ([`crate::plan`])
190/// and the structural planner.
191pub(crate) fn guard_syntax(path: &Path, before: &str, after: &str) -> Result<()> {
192    let Some(lang) = Language::from_path(path) else {
193        return Ok(());
194    };
195    let new_errors = count_error_nodes(lang, after).saturating_sub(count_error_nodes(lang, before));
196    if new_errors > 0 {
197        return Err(Error::SyntaxRegression {
198            path: path.to_path_buf(),
199            lang: lang.name(),
200            new_errors,
201        });
202    }
203    Ok(())
204}
205
206/// Result of [`structural_rewrite`]: the new source text plus the
207/// number of disjoint matches that were rewritten.
208#[derive(Debug, Clone)]
209pub struct StructuralOutcome {
210    pub text: String,
211    pub matches: usize,
212}
213
214/// One slice of the parsed rewrite template. Literals are pre-joined
215/// strings between captures; captures resolve to a known capture index
216/// in the compiled query.
217enum TemplatePart {
218    Literal(String),
219    Capture { index: usize, name: String },
220}
221
222/// One match in [`CompiledStructural::apply`]: byte range the rewrite
223/// occupies in the source plus the rendered replacement text. Sorted by
224/// `start` before splicing; overlapping later hits are skipped.
225struct Hit {
226    start: usize,
227    end: usize,
228    replacement: String,
229}
230
231/// A compiled structural-rewrite job: language, query, capture index
232/// table, and the rewrite template pre-resolved to a sequence of
233/// literal/capture parts. Built once per invocation and applied to every
234/// candidate file — that's the whole point of pulling parsing out of
235/// the per-file loop.
236struct CompiledStructural {
237    ts_lang: TsLanguage,
238    query: Query,
239    root_capture_idx: Option<usize>,
240    template_parts: Vec<TemplatePart>,
241    include_leading_attrs: bool,
242}
243
244impl CompiledStructural {
245    fn compile(
246        lang: Language,
247        query_src: &str,
248        template: &str,
249        include_leading_attrs: bool,
250    ) -> Result<Self> {
251        let ts_lang = lang.ts_language();
252        // Probe the language by configuring a throwaway parser. Catches
253        // ABI mismatch up front so the per-thread workers can rely on
254        // `set_language` succeeding without surfacing late errors.
255        let mut probe = Parser::new();
256        probe.set_language(&ts_lang).map_err(|e| Error::StructuralQuery(e.to_string()))?;
257
258        let query = Query::new(&ts_lang, query_src)
259            .map_err(|e| Error::StructuralQuery(format_query_error(query_src, &e)))?;
260        let capture_names: Vec<&str> = query.capture_names().to_vec();
261        let root_capture_idx = capture_names.iter().position(|n| *n == "root");
262        let template_parts = parse_template(template, &capture_names)?;
263
264        Ok(Self { ts_lang, query, root_capture_idx, template_parts, include_leading_attrs })
265    }
266
267    fn new_parser(&self) -> Parser {
268        let mut parser = Parser::new();
269        // Language ABI was validated in `compile`, so this call is
270        // infallible in practice. If it somehow does fail, the parser
271        // stays in its unset state and the next `parse()` returns None,
272        // surfacing as Error::StructuralParse — no panic, defined
273        // behavior.
274        let _ = parser.set_language(&self.ts_lang);
275        parser
276    }
277
278    fn apply(
279        &self,
280        parser: &mut Parser,
281        cursor: &mut QueryCursor,
282        source: &str,
283    ) -> Result<StructuralOutcome> {
284        let tree = parser.parse(source, None).ok_or(Error::StructuralParse)?;
285        let bytes = source.as_bytes();
286
287        let mut hits: Vec<Hit> = Vec::new();
288        let mut iter = cursor.matches(&self.query, tree.root_node(), bytes);
289        while let Some(m) = iter.next() {
290            let primary = match self.root_capture_idx {
291                Some(idx) => {
292                    m.captures.iter().find(|c| c.index as usize == idx).ok_or_else(|| {
293                        Error::StructuralQuery(format!(
294                            "match did not bind primary capture index {idx}"
295                        ))
296                    })?
297                }
298                None => outermost_capture(m.captures)
299                    .ok_or_else(|| Error::StructuralQuery("match bound no captures".into()))?,
300            };
301            let replacement = self.render(source, m.captures)?;
302            let start = if self.include_leading_attrs {
303                extend_start_over_attrs(primary.node, source)
304            } else {
305                primary.node.start_byte()
306            };
307            hits.push(Hit { start, end: primary.node.end_byte(), replacement });
308        }
309        hits.sort_by_key(|h| h.start);
310
311        // Reserve source.len() plus the per-hit (replacement - range) delta
312        // so the splice loop doesn't realloc when replacements grow the text.
313        let extra: usize =
314            hits.iter().map(|h| h.replacement.len().saturating_sub(h.end - h.start)).sum();
315        let mut out = String::with_capacity(source.len() + extra);
316        let mut cursor_byte = 0usize;
317        let mut applied = 0usize;
318        for hit in &hits {
319            if hit.start < cursor_byte {
320                continue;
321            }
322            out.push_str(&source[cursor_byte..hit.start]);
323            out.push_str(&hit.replacement);
324            cursor_byte = hit.end;
325            applied += 1;
326        }
327        out.push_str(&source[cursor_byte..]);
328        Ok(StructuralOutcome { text: out, matches: applied })
329    }
330
331    pub(crate) fn search(
332        &self,
333        parser: &mut Parser,
334        cursor: &mut QueryCursor,
335        source: &str,
336    ) -> Result<Vec<SearchMatch>> {
337        let tree = parser.parse(source, None).ok_or(Error::StructuralParse)?;
338        let bytes = source.as_bytes();
339        let capture_names = self.query.capture_names();
340
341        let mut hits: Vec<SearchMatch> = Vec::new();
342        let mut iter = cursor.matches(&self.query, tree.root_node(), bytes);
343        while let Some(m) = iter.next() {
344            let primary = match self.root_capture_idx {
345                Some(idx) => {
346                    m.captures.iter().find(|c| c.index as usize == idx).ok_or_else(|| {
347                        Error::StructuralQuery(format!(
348                            "match did not bind primary capture index {idx}"
349                        ))
350                    })?
351                }
352                None => outermost_capture(m.captures)
353                    .ok_or_else(|| Error::StructuralQuery("match bound no captures".into()))?,
354            };
355            let pos = primary.node.start_position();
356            let capture_name =
357                capture_names.get(primary.index as usize).copied().map(ToOwned::to_owned);
358            let snippet =
359                truncate_snippet(&source[primary.node.start_byte()..primary.node.end_byte()]);
360            hits.push(SearchMatch {
361                line: pos.row + 1,
362                column: pos.column + 1,
363                snippet,
364                capture: capture_name,
365            });
366        }
367        hits.sort_by_key(|h| (h.line, h.column));
368        Ok(hits)
369    }
370
371    fn render(&self, source: &str, captures: &[tree_sitter::QueryCapture<'_>]) -> Result<String> {
372        let mut out = String::with_capacity(self.template_size_hint());
373        for part in &self.template_parts {
374            match part {
375                TemplatePart::Literal(s) => out.push_str(s),
376                TemplatePart::Capture { index, name } => {
377                    let cap =
378                        captures.iter().find(|c| c.index as usize == *index).ok_or_else(|| {
379                            Error::StructuralTemplate(format!(
380                                "capture `${name}` did not bind in this match"
381                            ))
382                        })?;
383                    out.push_str(&source[cap.node.start_byte()..cap.node.end_byte()]);
384                }
385            }
386        }
387        Ok(out)
388    }
389
390    fn template_size_hint(&self) -> usize {
391        self.template_parts
392            .iter()
393            .map(|p| match p {
394                TemplatePart::Literal(s) => s.len(),
395                TemplatePart::Capture { .. } => 16,
396            })
397            .sum()
398    }
399}
400
401fn parse_template(template: &str, capture_names: &[&str]) -> Result<Vec<TemplatePart>> {
402    use crate::template_scan::{scan_braced_name, scan_meta_name, utf8_char_len};
403
404    let mut parts: Vec<TemplatePart> = Vec::new();
405    let mut literal = String::new();
406    let bytes = template.as_bytes();
407    let mut i = 0;
408    while i < bytes.len() {
409        let b = bytes[i];
410        if b == b'$' && i + 1 < bytes.len() {
411            let next = bytes[i + 1];
412            if next == b'$' {
413                literal.push('$');
414                i += 2;
415                continue;
416            }
417            if next == b'{' {
418                let (name_start, name_end, after) =
419                    scan_braced_name(template, i).ok_or_else(|| {
420                        Error::StructuralTemplate("unterminated `${...}` in template".into())
421                    })?;
422                let name = &template[name_start..name_end];
423                push_capture(&mut parts, &mut literal, capture_names, name, true)?;
424                i = after;
425                continue;
426            }
427            if let Some((name_start, name_end, after)) = scan_meta_name(bytes, i) {
428                let name = &template[name_start..name_end];
429                push_capture(&mut parts, &mut literal, capture_names, name, false)?;
430                i = after;
431                continue;
432            }
433        }
434        let ch_len = utf8_char_len(b);
435        literal.push_str(&template[i..i + ch_len]);
436        i += ch_len;
437    }
438    flush_literal(&mut literal, &mut parts);
439    Ok(parts)
440}
441
442fn push_capture(
443    parts: &mut Vec<TemplatePart>,
444    literal: &mut String,
445    capture_names: &[&str],
446    name: &str,
447    braced: bool,
448) -> Result<()> {
449    let cap_idx = capture_names.iter().position(|n| *n == name).ok_or_else(|| {
450        if braced {
451            Error::StructuralTemplate(format!("no capture named `${{{name}}}` in query"))
452        } else {
453            Error::StructuralTemplate(format!("no capture named `${name}` in query"))
454        }
455    })?;
456    flush_literal(literal, parts);
457    parts.push(TemplatePart::Capture { index: cap_idx, name: name.to_owned() });
458    Ok(())
459}
460
461fn flush_literal(literal: &mut String, parts: &mut Vec<TemplatePart>) {
462    if !literal.is_empty() {
463        parts.push(TemplatePart::Literal(std::mem::take(literal)));
464    }
465}
466
467/// Run a tree-sitter Query against `source`, substitute captures into
468/// `template` per match, and splice the resulting text into the source
469/// at each match's replacement range. Overlapping match ranges are
470/// resolved greedy-first: the first match by start offset wins, later
471/// overlaps are skipped.
472pub fn structural_rewrite(
473    lang: Language,
474    source: &str,
475    query_src: &str,
476    template: &str,
477) -> Result<StructuralOutcome> {
478    structural_rewrite_attrs(lang, source, query_src, template, false)
479}
480
481/// Variant of [`structural_rewrite`] that, when `include_leading_attrs`
482/// is set, extends each replacement range backward over the contiguous
483/// run of preceding `attribute_item` / doc-comment siblings — so
484/// deleting a function also removes its `#[test]` / `///` lines instead
485/// of orphaning them. A blank-line gap or a non-attr/non-doc sibling
486/// ends the run. Node kinds are tree-sitter-rust's; languages without
487/// those kinds simply never extend (no-op).
488pub(crate) fn structural_rewrite_attrs(
489    lang: Language,
490    source: &str,
491    query_src: &str,
492    template: &str,
493    include_leading_attrs: bool,
494) -> Result<StructuralOutcome> {
495    let compiled = CompiledStructural::compile(lang, query_src, template, include_leading_attrs)?;
496    let mut parser = compiled.new_parser();
497    let mut cursor = QueryCursor::new();
498    compiled.apply(&mut parser, &mut cursor, source)
499}
500
501/// Walk backward from `node` over the contiguous run of preceding
502/// `attribute_item` / doc-comment siblings, returning the start byte of
503/// the earliest one in the run (or `node.start_byte()` if none). A
504/// blank line (two or more newlines in the inter-sibling gap) breaks the
505/// run, preserving intentional separation.
506fn extend_start_over_attrs(node: Node, source: &str) -> usize {
507    let bytes = source.as_bytes();
508    let mut start = node.start_byte();
509    let mut anchor = node;
510    while let Some(prev) = anchor.prev_sibling() {
511        if !is_swallowable_sibling(&prev, bytes) {
512            break;
513        }
514        let gap = &source[prev.end_byte()..anchor.start_byte()];
515        if gap.matches('\n').count() >= 2 {
516            break;
517        }
518        start = prev.start_byte();
519        anchor = prev;
520    }
521    start
522}
523
524/// True for an `attribute_item` or a doc-style comment (`///`, `//!`,
525/// `/**`, `/*!`). Plain `//` / `/* */` comments are left in place.
526fn is_swallowable_sibling(node: &Node, source: &[u8]) -> bool {
527    match node.kind() {
528        "attribute_item" => true,
529        "line_comment" | "block_comment" => {
530            let text = &source[node.start_byte()..node.end_byte()];
531            text.starts_with(b"///")
532                || text.starts_with(b"//!")
533                || text.starts_with(b"/**")
534                || text.starts_with(b"/*!")
535        }
536        _ => false,
537    }
538}
539
540/// Multi-file structural pipeline. Walks `roots`, applies
541/// [`structural_rewrite`] per file, and folds the results into a
542/// [`Plan`] that callers can pipe into [`crate::apply_changes`]. Honors
543/// `walk_options`, `max_files`, `max_bytes`, and the `at_least` /
544/// `at_most` match-count guard from `opts`. The convergence check and
545/// scripted-callback variants don't apply here — structural rewrites
546/// aren't re-probed against their own output.
547///
548/// The compiled query, capture-index table, and parsed rewrite template
549/// are built once and shared read-only across the per-file workers; only
550/// the tree-sitter `Parser` and `QueryCursor` are per-thread.
551pub fn plan_structural_rewrite<P: AsRef<Path>>(
552    lang: Language,
553    query: &str,
554    template: &str,
555    roots: &[P],
556    opts: &PlanOptions,
557    include_leading_attrs: bool,
558) -> Result<Plan> {
559    let files = walk_paths(roots, &opts.walk_options)?;
560    if files.len() > opts.max_files {
561        return Err(Error::TooManyFiles { count: files.len(), limit: opts.max_files });
562    }
563    let files_scanned = files.len();
564
565    let compiled = CompiledStructural::compile(lang, query, template, include_leading_attrs)?;
566
567    let results: Vec<Result<Option<FileChange>>> = files
568        .par_iter()
569        .map_init(
570            || (compiled.new_parser(), QueryCursor::new()),
571            |(parser, cursor), path| plan_one(&compiled, parser, cursor, path, opts),
572        )
573        .collect();
574
575    let mut changes: Vec<FileChange> = Vec::with_capacity(files_scanned);
576    for r in results {
577        if let Some(change) = r? {
578            changes.push(change);
579        }
580    }
581
582    let total_matches: usize = changes.iter().map(|c| c.matches).sum();
583    if total_matches == 0 {
584        return Ok(Plan {
585            changes: Vec::new(),
586            total_matches: 0,
587            files_scanned,
588            outcome: PlanOutcome::AlreadyApplied,
589        });
590    }
591    check_match_counts(total_matches, opts.at_least, opts.at_most)?;
592    Ok(Plan { changes, total_matches, files_scanned, outcome: PlanOutcome::Changes })
593}
594
595fn plan_one(
596    compiled: &CompiledStructural,
597    parser: &mut Parser,
598    cursor: &mut QueryCursor,
599    path: &Path,
600    opts: &PlanOptions,
601) -> Result<Option<FileChange>> {
602    let (before, permissions) = match read_text_or_skip_binary(path, opts.max_bytes)? {
603        Some(pair) => pair,
604        None => return Ok(None),
605    };
606    let outcome = compiled.apply(parser, cursor, &before)?;
607    if outcome.text == before {
608        return Ok(None);
609    }
610    if !opts.allow_syntax_errors {
611        guard_syntax(path, &before, &outcome.text)?;
612    }
613    let label = label_for_path(path);
614    let diff = unified_diff(&label, &before, &outcome.text);
615    Ok(Some(FileChange {
616        path: path.to_path_buf(),
617        matches: outcome.matches,
618        after: outcome.text,
619        diff,
620        permissions: Some(permissions),
621    }))
622}
623
624/// Friendlier counterpart to [`structural_rewrite`]: `pattern_source`
625/// is written in the target language with `$NAME` placeholders. The
626/// pattern is compiled to a tree-sitter Query under the hood; the
627/// rewrite template uses the same `$NAME` / `${NAME}` substitution as
628/// the raw API.
629///
630/// Example for Rust:
631///
632/// ```text
633/// pattern:  "fn $NAME() {}"
634/// template: "fn ${NAME}_v2() {}"
635/// ```
636///
637/// Metavariables match a single AST node at the position where the
638/// `$NAME` placeholder appeared in the parsed pattern (`(_)` wildcard
639/// in the underlying query). Capture names are the placeholder name
640/// minus the leading `$`.
641pub fn structural_rewrite_friendly(
642    lang: Language,
643    source: &str,
644    pattern_source: &str,
645    template: &str,
646) -> Result<StructuralOutcome> {
647    let query = compile_friendly_query(lang, pattern_source)?;
648    structural_rewrite(lang, source, &query, template)
649}
650
651/// Run a tree-sitter query against `source` and return match locations without rewriting.
652pub fn structural_search(
653    lang: Language,
654    source: &str,
655    query_src: &str,
656) -> Result<Vec<SearchMatch>> {
657    let compiled = CompiledStructural::compile(lang, query_src, "", false)?;
658    let mut parser = compiled.new_parser();
659    let mut cursor = QueryCursor::new();
660    compiled.search(&mut parser, &mut cursor, source)
661}
662
663/// Multi-file structural search. Walk `roots`, run the query per file, fold into `SearchPlan`.
664pub fn plan_structural_search<P: AsRef<std::path::Path>>(
665    lang: Language,
666    query_src: &str,
667    roots: &[P],
668    opts: &SearchOptions,
669) -> Result<SearchPlan> {
670    let files = scan(roots, opts)?;
671    let files_scanned = files.len();
672    let compiled = CompiledStructural::compile(lang, query_src, "", false)?;
673
674    let results: Vec<Result<Option<SearchFile>>> = files
675        .par_iter()
676        .map_init(
677            || (compiled.new_parser(), QueryCursor::new()),
678            |(parser, cursor), path| {
679                let (source, _) = match read_text_or_skip_binary(path, opts.max_bytes)? {
680                    Some(pair) => pair,
681                    None => return Ok(None),
682                };
683                let matches = compiled.search(parser, cursor, &source)?;
684                if matches.is_empty() {
685                    return Ok(None);
686                }
687                Ok(Some(SearchFile { path: path.to_path_buf(), matches }))
688            },
689        )
690        .collect();
691
692    let found = collect(results)?;
693    let total_matches: usize = found.iter().map(|f| f.matches.len()).sum();
694    check_match_counts(total_matches, opts.at_least, opts.at_most)?;
695    Ok(SearchPlan { files: found, total_matches, files_scanned })
696}
697
698/// Compile a friendly pattern (target-language source with `$NAME`
699/// placeholders) into a tree-sitter Query string. Exposed for callers
700/// that want to inspect or further manipulate the query.
701pub fn compile_friendly_query(lang: Language, pattern: &str) -> Result<String> {
702    let substituted = substitute_metavars(pattern);
703    let ts_lang = lang.ts_language();
704    let mut parser = Parser::new();
705    parser.set_language(&ts_lang).map_err(|e| Error::StructuralQuery(e.to_string()))?;
706    let tree = parser.parse(&substituted, None).ok_or_else(|| {
707        Error::StructuralQuery(format!(
708            "could not parse `--ast` pattern with the requested grammar; check that the pattern is valid {} syntax with `$NAME` / `$$$NAME` metavars in node positions",
709            ts_lang.name().unwrap_or("source")
710        ))
711    })?;
712    let root = tree.root_node();
713    if root.has_error() {
714        let snippet = pattern.lines().next().unwrap_or(pattern);
715        return Err(Error::StructuralQuery(format!(
716            "`--ast` pattern is not valid {} source after metavar substitution: `{snippet}`. \
717             Metavars (`$NAME`, `$$$NAME`) can only appear where an identifier-like token is \
718             legal in the target language.",
719            ts_lang.name().unwrap_or("source")
720        )));
721    }
722    // Tree-sitter wraps top-level items in a `source_file` (or similar)
723    // container; unwrap so the user-visible pattern matches the actual
724    // item, not the whole file.
725    let effective = if root.kind() == "source_file" && root.named_child_count() >= 1 {
726        root.named_child(0).ok_or_else(|| Error::StructuralQuery("empty pattern".into()))?
727    } else {
728        root
729    };
730
731    let mut buf = String::new();
732    let mut predicates: Vec<String> = Vec::new();
733    let mut lit_counter: usize = 0;
734    emit_node(&mut buf, &mut predicates, &mut lit_counter, effective, substituted.as_bytes());
735    let trimmed = buf.trim_start();
736    if predicates.is_empty() {
737        Ok(format!("{trimmed} @root"))
738    } else {
739        Ok(format!("({trimmed} @root {})", predicates.join(" ")))
740    }
741}
742
743fn substitute_metavars(pattern: &str) -> String {
744    use crate::template_scan::{scan_ellipsis_name, scan_meta_name, utf8_char_len};
745
746    let mut out = String::with_capacity(pattern.len());
747    let bytes = pattern.as_bytes();
748    let mut i = 0;
749    while i < bytes.len() {
750        let b = bytes[i];
751        if b == b'$' {
752            // $$$NAME — ellipsis metavar (variable-shape subtree)
753            if let Some((name_start, name_end, after)) = scan_ellipsis_name(bytes, i) {
754                out.push_str(ELLIPSIS_PREFIX);
755                out.push_str(&pattern[name_start..name_end]);
756                out.push_str(METAVAR_SUFFIX);
757                i = after;
758                continue;
759            }
760            if let Some((name_start, name_end, after)) = scan_meta_name(bytes, i) {
761                out.push_str(METAVAR_PREFIX);
762                out.push_str(&pattern[name_start..name_end]);
763                out.push_str(METAVAR_SUFFIX);
764                i = after;
765                continue;
766            }
767        }
768        let ch_len = utf8_char_len(b);
769        out.push_str(&pattern[i..i + ch_len]);
770        i += ch_len;
771    }
772    out
773}
774
775fn emit_node(
776    buf: &mut String,
777    predicates: &mut Vec<String>,
778    lit_counter: &mut usize,
779    node: Node<'_>,
780    src: &[u8],
781) {
782    use std::fmt::Write as _;
783
784    // Iterative: user `--ast` pattern depth is unbounded — recursion
785    // would give it a stack-overflow vector.
786    enum Frame<'tree> {
787        Open { node: Node<'tree>, field: Option<&'static str> },
788        Close,
789    }
790
791    let mut stack: Vec<Frame<'_>> = vec![Frame::Open { node, field: None }];
792    while let Some(frame) = stack.pop() {
793        match frame {
794            Frame::Close => buf.push(')'),
795            Frame::Open { node, field } => {
796                if !node.is_named() {
797                    continue;
798                }
799                if let Some(name) = field {
800                    buf.push(' ');
801                    buf.push_str(name);
802                    buf.push(':');
803                }
804                if let Some(ellipsis) = subtree_ellipsis_capture(node, src) {
805                    buf.push_str(" (_) @");
806                    buf.push_str(&ellipsis);
807                    continue;
808                }
809                if let Some(meta) = metavar_at(node, src) {
810                    buf.push_str(" (_) @");
811                    buf.push_str(&meta);
812                    continue;
813                }
814                // Terminal named leaves (identifier, integer_literal, etc.)
815                // are constrained to exact text via `#eq?` predicates so a
816                // literal in the pattern doesn't match every same-kind
817                // sibling in the source.
818                if node.named_child_count() == 0
819                    && let Ok(text) = node.utf8_text(src)
820                {
821                    let n = *lit_counter;
822                    *lit_counter += 1;
823                    let _ = write!(buf, " ({}) @__lit{n}", node.kind());
824                    let mut pred = String::new();
825                    let _ = write!(pred, "(#eq? @__lit{n} \"{}\")", escape_query_string(text));
826                    predicates.push(pred);
827                    continue;
828                }
829                buf.push_str(" (");
830                buf.push_str(node.kind());
831                stack.push(Frame::Close);
832                // Push children in reverse so the LIFO stack visits them
833                // in source order. `named_child(i)` indexes the same set
834                // `named_children()` iterates, so the named-child index
835                // doubles as the argument to `field_name_for_named_child`.
836                let count = node.named_child_count();
837                for i in (0..count).rev() {
838                    if let Some(child) = node.named_child(i) {
839                        let field = node.field_name_for_named_child(i as u32);
840                        stack.push(Frame::Open { node: child, field });
841                    }
842                }
843            }
844        }
845    }
846}
847
848/// Pick the outermost-by-byte-range capture in a match: smallest start
849/// byte wins; ties break to the largest end byte; final tiebreak is the
850/// lowest capture index for stability across queries that differ only
851/// in capture declaration order. Used when a query lacks an explicit
852/// `@root` so the apply phase still picks a deterministic primary.
853fn outermost_capture<'a, 'tree>(
854    captures: &'a [tree_sitter::QueryCapture<'tree>],
855) -> Option<&'a tree_sitter::QueryCapture<'tree>> {
856    captures.iter().min_by(|a, b| {
857        a.node
858            .start_byte()
859            .cmp(&b.node.start_byte())
860            .then_with(|| b.node.end_byte().cmp(&a.node.end_byte()))
861            .then_with(|| a.index.cmp(&b.index))
862    })
863}
864
865/// Render a tree-sitter `QueryError` with the offending fragment and a
866/// caret pointing at the byte offset, so callers see something useful
867/// instead of the raw `QueryError { row, column, offset, kind, message }`
868/// Debug output.
869fn format_query_error(query_src: &str, err: &tree_sitter::QueryError) -> String {
870    let kind = match err.kind {
871        tree_sitter::QueryErrorKind::Syntax => "syntax",
872        tree_sitter::QueryErrorKind::NodeType => "unknown node type",
873        tree_sitter::QueryErrorKind::Field => "unknown field",
874        tree_sitter::QueryErrorKind::Capture => "unknown capture",
875        tree_sitter::QueryErrorKind::Predicate => "bad predicate",
876        tree_sitter::QueryErrorKind::Structure => "structural mismatch",
877        tree_sitter::QueryErrorKind::Language => "language mismatch",
878    };
879    let line = query_src.lines().nth(err.row).unwrap_or("");
880    let caret_col = err.column.min(line.len());
881    let caret = format!("{}^", " ".repeat(caret_col));
882    let msg = err.message.trim();
883    format!(
884        "tree-sitter query {kind} error at line {row}, column {col}: {msg}\n  | {line}\n  | {caret}",
885        row = err.row + 1,
886        col = err.column + 1,
887    )
888}
889
890fn escape_query_string(s: &str) -> String {
891    let mut out = String::with_capacity(s.len());
892    for c in s.chars() {
893        match c {
894            '"' => out.push_str("\\\""),
895            '\\' => out.push_str("\\\\"),
896            '\n' => out.push_str("\\n"),
897            other => out.push(other),
898        }
899    }
900    out
901}
902
903fn metavar_at(node: Node<'_>, src: &[u8]) -> Option<String> {
904    if node.named_child_count() != 0 {
905        return None;
906    }
907    let text = node.utf8_text(src).ok()?;
908    let stripped = text.strip_prefix(METAVAR_PREFIX)?.strip_suffix(METAVAR_SUFFIX)?;
909    if stripped.is_empty() {
910        return None;
911    }
912    Some(stripped.to_owned())
913}
914
915/// Walk the subtree rooted at `node` and, if it contains exactly one
916/// ellipsis identifier (`$$$NAME` → `__RECAST_ELLIPSIS_NAME__`) and no
917/// other named leaves carrying meaningful content (no literals, no
918/// single-node metavars), return the ellipsis name. Such a subtree
919/// collapses to a single `(_) @NAME` wildcard in the generated query
920/// so the parent field can match any shape.
921fn subtree_ellipsis_capture(node: Node<'_>, src: &[u8]) -> Option<String> {
922    let mut ellipsis: Option<String> = None;
923    let mut other_leaves = 0usize;
924    let mut stack = vec![node];
925    while let Some(n) = stack.pop() {
926        if !n.is_named() {
927            continue;
928        }
929        if n.named_child_count() == 0 {
930            let text = n.utf8_text(src).ok()?;
931            if let Some(stripped) =
932                text.strip_prefix(ELLIPSIS_PREFIX).and_then(|s| s.strip_suffix(METAVAR_SUFFIX))
933                && !stripped.is_empty()
934            {
935                if ellipsis.is_some() {
936                    return None;
937                }
938                ellipsis = Some(stripped.to_owned());
939                continue;
940            }
941            other_leaves += 1;
942            continue;
943        }
944        let mut c = n.walk();
945        for child in n.named_children(&mut c) {
946            stack.push(child);
947        }
948    }
949    if other_leaves == 0 { ellipsis } else { None }
950}
951
952#[cfg(test)]
953#[path = "structural_tests.rs"]
954mod tests;