Skip to main content

codex_patcher/sg/
matcher.rs

1use crate::cache;
2use crate::sg::errors::AstGrepError;
3use crate::sg::lang::rust;
4use ast_grep_core::tree_sitter::StrDoc;
5use ast_grep_core::{AstGrep, NodeMatch};
6use ast_grep_language::SupportLang;
7use std::collections::HashMap;
8
9/// A match from an ast-grep pattern with captured metavariables.
10#[derive(Debug, Clone)]
11pub struct PatternMatch {
12    /// Byte range of the entire match
13    pub byte_start: usize,
14    pub byte_end: usize,
15    /// The matched text
16    pub text: String,
17    /// Captured metavariables: name -> text
18    /// Note: For byte spans of captures, use find_capture_span()
19    pub captures: HashMap<String, String>,
20}
21
22impl PatternMatch {
23    /// Find the byte span of a capture within the matched text.
24    ///
25    /// This is an approximation that finds the first occurrence of the
26    /// captured text within the matched region.
27    pub fn find_capture_span(&self, name: &str) -> Option<(usize, usize)> {
28        let capture_text = self.captures.get(name)?;
29        // Find the capture text within the match
30        let offset = self.text.find(capture_text)?;
31        let start = self.byte_start + offset;
32        let end = start + capture_text.len();
33        Some((start, end))
34    }
35}
36
37/// Pattern matcher using ast-grep's metavariable syntax.
38///
39/// # Metavariable Syntax
40///
41/// - `$NAME` - Matches a single node and captures it
42/// - `$$$NAME` - Matches zero or more nodes (variadic)
43/// - `$_` - Matches any single node (anonymous)
44///
45/// # Example Patterns
46///
47/// ```text
48/// fn $NAME($$$PARAMS) { $$$BODY }     // Match function definition
49/// struct $NAME { $$$FIELDS }           // Match struct definition
50/// $EXPR.clone()                        // Match .clone() calls
51/// OtelExporter::$VARIANT               // Match enum variants
52/// ```
53pub struct PatternMatcher {
54    source: String,
55    sg: AstGrep<StrDoc<SupportLang>>,
56}
57
58impl PatternMatcher {
59    /// Create a new pattern matcher for the given source code.
60    pub fn new(source: &str) -> Self {
61        let sg = AstGrep::new(source, rust());
62        Self {
63            source: source.to_string(),
64            sg,
65        }
66    }
67
68    /// Find all matches for a pattern.
69    pub fn find_all(&self, pattern: &str) -> Result<Vec<PatternMatch>, AstGrepError> {
70        let pat = cache::get_or_compile_pattern(pattern, rust());
71        let root = self.sg.root();
72        let matches: Vec<_> = root.find_all(&pat).collect();
73
74        let results = matches
75            .into_iter()
76            .map(|m| self.node_match_to_pattern_match(m))
77            .collect();
78
79        Ok(results)
80    }
81
82    /// Find exactly one match for a pattern.
83    pub fn find_unique(&self, pattern: &str) -> Result<PatternMatch, AstGrepError> {
84        let matches = self.find_all(pattern)?;
85
86        match matches.len() {
87            0 => Err(AstGrepError::NoMatch),
88            1 => Ok(matches.into_iter().next().expect("len checked == 1")),
89            n => Err(AstGrepError::AmbiguousMatch { count: n }),
90        }
91    }
92
93    /// Check if a pattern has any matches.
94    pub fn has_match(&self, pattern: &str) -> bool {
95        let pat = cache::get_or_compile_pattern(pattern, rust());
96        self.sg.root().find(&pat).is_some()
97    }
98
99    /// Find matches within a specific byte range (for context constraints).
100    pub fn find_in_range(
101        &self,
102        pattern: &str,
103        start: usize,
104        end: usize,
105    ) -> Result<Vec<PatternMatch>, AstGrepError> {
106        let matches = self.find_all(pattern)?;
107
108        let filtered: Vec<_> = matches
109            .into_iter()
110            .filter(|m| m.byte_start >= start && m.byte_end <= end)
111            .collect();
112
113        Ok(filtered)
114    }
115
116    /// Find matches that are inside a function with the given name.
117    pub fn find_in_function(
118        &self,
119        pattern: &str,
120        function_name: &str,
121    ) -> Result<Vec<PatternMatch>, AstGrepError> {
122        // First, find the function
123        let func_pattern = format!("fn {function_name}($$$PARAMS) {{ $$$BODY }}");
124        let func_matches = self.find_all(&func_pattern)?;
125
126        let mut results = Vec::new();
127
128        for func_match in func_matches {
129            // Find pattern matches within the function body
130            let inner_matches =
131                self.find_in_range(pattern, func_match.byte_start, func_match.byte_end)?;
132            results.extend(inner_matches);
133        }
134
135        Ok(results)
136    }
137
138    /// Get the source code.
139    pub fn source(&self) -> &str {
140        &self.source
141    }
142
143    /// Find all nodes of a specific kind, optionally filtering by a pattern on a field.
144    ///
145    /// This is useful for constructs that aren't valid standalone Rust syntax,
146    /// like match arms (`PAT => BODY`). Since match arms can't be parsed in
147    /// isolation, we find them by kind and optionally filter by matching a
148    /// pattern against a specific field.
149    ///
150    /// # Example
151    ///
152    /// ```ignore
153    /// // Find all match arms where the pattern is OtelExporter::Statsig
154    /// let arms = matcher.find_by_kind_with_field(
155    ///     "match_arm",
156    ///     Some(("pattern", "OtelExporter::Statsig")),
157    /// )?;
158    /// ```
159    pub fn find_by_kind_with_field(
160        &self,
161        kind: &str,
162        field_filter: Option<(&str, &str)>,
163    ) -> Result<Vec<PatternMatch>, AstGrepError> {
164        let root = self.sg.root();
165        let mut results = Vec::new();
166
167        // Use depth-first traversal to find all nodes
168        for node in root.dfs() {
169            if node.kind() != kind {
170                continue;
171            }
172
173            // If we have a field filter, check it
174            if let Some((field_name, pattern)) = field_filter {
175                let pat = cache::get_or_compile_pattern(pattern, rust());
176                let field_node = node.field(field_name);
177
178                if let Some(field) = field_node {
179                    if field.find(&pat).is_none() {
180                        continue;
181                    }
182                } else {
183                    continue;
184                }
185            }
186
187            // Node matches - convert to PatternMatch
188            let range = node.range();
189            let byte_start = range.start;
190            let byte_end = range.end;
191            let text = self.source[byte_start..byte_end].to_string();
192
193            results.push(PatternMatch {
194                byte_start,
195                byte_end,
196                text,
197                captures: HashMap::new(), // No captures for kind-based matching
198            });
199        }
200
201        Ok(results)
202    }
203
204    /// Find match arms by their pattern.
205    ///
206    /// Convenience method for finding match arms since they can't be matched
207    /// directly with patterns (not valid standalone Rust syntax).
208    ///
209    /// # Example
210    ///
211    /// ```ignore
212    /// let arms = matcher.find_match_arms("OtelExporter::Statsig")?;
213    /// for arm in arms {
214    ///     println!("Found arm: {}", arm.text);
215    /// }
216    /// ```
217    pub fn find_match_arms(&self, pattern: &str) -> Result<Vec<PatternMatch>, AstGrepError> {
218        self.find_by_kind_with_field("match_arm", Some(("pattern", pattern)))
219    }
220
221    fn node_match_to_pattern_match(&self, m: NodeMatch<StrDoc<SupportLang>>) -> PatternMatch {
222        let node = m.get_node();
223        let range = node.range();
224        let byte_start = range.start;
225        let byte_end = range.end;
226        let text = self.source[byte_start..byte_end].to_string();
227
228        // Convert MetaVarEnv to HashMap<String, String>
229        let env = m.get_env().clone();
230        let captures: HashMap<String, String> = env.into();
231
232        PatternMatch {
233            byte_start,
234            byte_end,
235            text,
236            captures,
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn find_function_by_pattern() {
247        let source = r#"
248fn helper() -> i32 { 42 }
249
250fn main() {
251    let x = helper();
252    println!("{}", x);
253}
254"#;
255        let matcher = PatternMatcher::new(source);
256        let matches = matcher.find_all("fn main() { $$$BODY }").unwrap();
257
258        assert_eq!(matches.len(), 1);
259        assert!(matches[0].text.contains("fn main()"));
260        assert!(matches[0].captures.contains_key("BODY"));
261    }
262
263    #[test]
264    fn find_struct_fields() {
265        let source = r#"
266struct Config {
267    name: String,
268    value: i32,
269}
270"#;
271        let matcher = PatternMatcher::new(source);
272        let m = matcher.find_unique("struct Config { $$$FIELDS }").unwrap();
273
274        assert!(m.captures.contains_key("FIELDS"));
275    }
276
277    #[test]
278    fn find_method_calls() {
279        let source = r#"
280fn test() {
281    let a = foo.clone();
282    let b = bar.clone();
283    let c = baz.to_string();
284}
285"#;
286        let matcher = PatternMatcher::new(source);
287        let matches = matcher.find_all("$EXPR.clone()").unwrap();
288
289        assert_eq!(matches.len(), 2);
290    }
291
292    #[test]
293    fn find_enum_variants() {
294        let source = r#"
295match exporter {
296    OtelExporter::Statsig => do_statsig(),
297    OtelExporter::None => do_nothing(),
298    _ => other(),
299}
300"#;
301        let matcher = PatternMatcher::new(source);
302        let matches = matcher.find_all("OtelExporter::$VARIANT").unwrap();
303
304        assert_eq!(matches.len(), 2);
305    }
306
307    #[test]
308    fn find_unique_success() {
309        let source = "fn main() { println!(\"hello\"); }";
310        let matcher = PatternMatcher::new(source);
311        let m = matcher.find_unique("fn main() { $$$BODY }").unwrap();
312
313        assert!(m.text.contains("fn main()"));
314    }
315
316    #[test]
317    fn find_unique_no_match() {
318        let source = "fn main() {}";
319        let matcher = PatternMatcher::new(source);
320        let result = matcher.find_unique("fn nonexistent() { $$$BODY }");
321
322        assert!(matches!(result, Err(AstGrepError::NoMatch)));
323    }
324
325    #[test]
326    fn find_unique_ambiguous() {
327        let source = r#"
328fn foo() {}
329fn bar() {}
330"#;
331        let matcher = PatternMatcher::new(source);
332        let result = matcher.find_unique("fn $NAME() {}");
333
334        assert!(matches!(
335            result,
336            Err(AstGrepError::AmbiguousMatch { count: 2 })
337        ));
338    }
339
340    #[test]
341    fn find_in_function_context() {
342        let source = r#"
343fn outer() {
344    let x = foo.clone();
345}
346
347fn inner() {
348    let y = bar.clone();
349}
350"#;
351        let matcher = PatternMatcher::new(source);
352        let matches = matcher.find_in_function("$EXPR.clone()", "inner").unwrap();
353
354        assert_eq!(matches.len(), 1);
355        assert!(matches[0].captures.contains_key("EXPR"));
356    }
357
358    #[test]
359    fn byte_spans_accurate() {
360        let source = "fn foo() { let x = 1; }";
361        let matcher = PatternMatcher::new(source);
362        let m = matcher.find_unique("fn $NAME() { $$$BODY }").unwrap();
363
364        // Verify we can extract exact text using byte spans
365        let extracted = &source[m.byte_start..m.byte_end];
366        assert_eq!(extracted, source);
367    }
368
369    #[test]
370    fn find_match_arms_by_pattern() {
371        let source = r#"
372match exporter {
373    OtelExporter::Statsig => {
374        OtelExporter::OtlpHttp { endpoint: url }
375    }
376    OtelExporter::None => None,
377    _ => exporter.clone(),
378}
379"#;
380        let matcher = PatternMatcher::new(source);
381
382        // Find the Statsig match arm
383        let arms = matcher.find_match_arms("OtelExporter::Statsig").unwrap();
384        assert_eq!(arms.len(), 1);
385        assert!(arms[0].text.contains("OtelExporter::Statsig"));
386        assert!(arms[0].text.contains("OtlpHttp"));
387
388        // Find the None match arm
389        let none_arms = matcher.find_match_arms("OtelExporter::None").unwrap();
390        assert_eq!(none_arms.len(), 1);
391
392        // Find by variant pattern
393        let variant_arms = matcher.find_match_arms("OtelExporter::$VARIANT").unwrap();
394        assert_eq!(variant_arms.len(), 2); // Statsig and None
395    }
396
397    #[test]
398    fn find_by_kind_generic() {
399        let source = r#"
400struct Foo { x: i32 }
401struct Bar { y: String }
402"#;
403        let matcher = PatternMatcher::new(source);
404
405        // Find all struct items
406        let structs = matcher
407            .find_by_kind_with_field("struct_item", None)
408            .unwrap();
409        assert_eq!(structs.len(), 2);
410
411        // Find struct with specific name - use metavar to match the type_identifier
412        // Note: The "name" field of struct_item is a type_identifier node
413        let foo = matcher
414            .find_by_kind_with_field("struct_item", Some(("name", "$NAME")))
415            .unwrap();
416        // $NAME matches any identifier, so both structs match
417        assert_eq!(foo.len(), 2);
418
419        // To filter by a specific name, check the text of the match
420        let foo_structs: Vec<_> = foo.iter().filter(|m| m.text.contains("Foo")).collect();
421        assert_eq!(foo_structs.len(), 1);
422    }
423}