codex_patcher/sg/
matcher.rs1use crate::cache;
2use crate::sg::errors::AstGrepError;
3use crate::sg::lang::rust;
4use ast_grep_core::tree_sitter::StrDoc;
5use ast_grep_core::{AstGrep, NodeMatch};
6use ast_grep_language::SupportLang;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub struct PatternMatch {
12 pub byte_start: usize,
14 pub byte_end: usize,
15 pub text: String,
17 pub captures: HashMap<String, String>,
20}
21
22impl PatternMatch {
23 pub fn find_capture_span(&self, name: &str) -> Option<(usize, usize)> {
28 let capture_text = self.captures.get(name)?;
29 let offset = self.text.find(capture_text)?;
31 let start = self.byte_start + offset;
32 let end = start + capture_text.len();
33 Some((start, end))
34 }
35}
36
37pub struct PatternMatcher {
54 source: String,
55 sg: AstGrep<StrDoc<SupportLang>>,
56}
57
58impl PatternMatcher {
59 pub fn new(source: &str) -> Self {
61 let sg = AstGrep::new(source, rust());
62 Self {
63 source: source.to_string(),
64 sg,
65 }
66 }
67
68 pub fn find_all(&self, pattern: &str) -> Result<Vec<PatternMatch>, AstGrepError> {
70 let pat = cache::get_or_compile_pattern(pattern, rust());
71 let root = self.sg.root();
72 let matches: Vec<_> = root.find_all(&pat).collect();
73
74 let results = matches
75 .into_iter()
76 .map(|m| self.node_match_to_pattern_match(m))
77 .collect();
78
79 Ok(results)
80 }
81
82 pub fn find_unique(&self, pattern: &str) -> Result<PatternMatch, AstGrepError> {
84 let matches = self.find_all(pattern)?;
85
86 match matches.len() {
87 0 => Err(AstGrepError::NoMatch),
88 1 => Ok(matches.into_iter().next().expect("len checked == 1")),
89 n => Err(AstGrepError::AmbiguousMatch { count: n }),
90 }
91 }
92
93 pub fn has_match(&self, pattern: &str) -> bool {
95 let pat = cache::get_or_compile_pattern(pattern, rust());
96 self.sg.root().find(&pat).is_some()
97 }
98
99 pub fn find_in_range(
101 &self,
102 pattern: &str,
103 start: usize,
104 end: usize,
105 ) -> Result<Vec<PatternMatch>, AstGrepError> {
106 let matches = self.find_all(pattern)?;
107
108 let filtered: Vec<_> = matches
109 .into_iter()
110 .filter(|m| m.byte_start >= start && m.byte_end <= end)
111 .collect();
112
113 Ok(filtered)
114 }
115
116 pub fn find_in_function(
118 &self,
119 pattern: &str,
120 function_name: &str,
121 ) -> Result<Vec<PatternMatch>, AstGrepError> {
122 let func_pattern = format!("fn {function_name}($$$PARAMS) {{ $$$BODY }}");
124 let func_matches = self.find_all(&func_pattern)?;
125
126 let mut results = Vec::new();
127
128 for func_match in func_matches {
129 let inner_matches =
131 self.find_in_range(pattern, func_match.byte_start, func_match.byte_end)?;
132 results.extend(inner_matches);
133 }
134
135 Ok(results)
136 }
137
138 pub fn source(&self) -> &str {
140 &self.source
141 }
142
143 pub fn find_by_kind_with_field(
160 &self,
161 kind: &str,
162 field_filter: Option<(&str, &str)>,
163 ) -> Result<Vec<PatternMatch>, AstGrepError> {
164 let root = self.sg.root();
165 let mut results = Vec::new();
166
167 for node in root.dfs() {
169 if node.kind() != kind {
170 continue;
171 }
172
173 if let Some((field_name, pattern)) = field_filter {
175 let pat = cache::get_or_compile_pattern(pattern, rust());
176 let field_node = node.field(field_name);
177
178 if let Some(field) = field_node {
179 if field.find(&pat).is_none() {
180 continue;
181 }
182 } else {
183 continue;
184 }
185 }
186
187 let range = node.range();
189 let byte_start = range.start;
190 let byte_end = range.end;
191 let text = self.source[byte_start..byte_end].to_string();
192
193 results.push(PatternMatch {
194 byte_start,
195 byte_end,
196 text,
197 captures: HashMap::new(), });
199 }
200
201 Ok(results)
202 }
203
204 pub fn find_match_arms(&self, pattern: &str) -> Result<Vec<PatternMatch>, AstGrepError> {
218 self.find_by_kind_with_field("match_arm", Some(("pattern", pattern)))
219 }
220
221 fn node_match_to_pattern_match(&self, m: NodeMatch<StrDoc<SupportLang>>) -> PatternMatch {
222 let node = m.get_node();
223 let range = node.range();
224 let byte_start = range.start;
225 let byte_end = range.end;
226 let text = self.source[byte_start..byte_end].to_string();
227
228 let env = m.get_env().clone();
230 let captures: HashMap<String, String> = env.into();
231
232 PatternMatch {
233 byte_start,
234 byte_end,
235 text,
236 captures,
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn find_function_by_pattern() {
247 let source = r#"
248fn helper() -> i32 { 42 }
249
250fn main() {
251 let x = helper();
252 println!("{}", x);
253}
254"#;
255 let matcher = PatternMatcher::new(source);
256 let matches = matcher.find_all("fn main() { $$$BODY }").unwrap();
257
258 assert_eq!(matches.len(), 1);
259 assert!(matches[0].text.contains("fn main()"));
260 assert!(matches[0].captures.contains_key("BODY"));
261 }
262
263 #[test]
264 fn find_struct_fields() {
265 let source = r#"
266struct Config {
267 name: String,
268 value: i32,
269}
270"#;
271 let matcher = PatternMatcher::new(source);
272 let m = matcher.find_unique("struct Config { $$$FIELDS }").unwrap();
273
274 assert!(m.captures.contains_key("FIELDS"));
275 }
276
277 #[test]
278 fn find_method_calls() {
279 let source = r#"
280fn test() {
281 let a = foo.clone();
282 let b = bar.clone();
283 let c = baz.to_string();
284}
285"#;
286 let matcher = PatternMatcher::new(source);
287 let matches = matcher.find_all("$EXPR.clone()").unwrap();
288
289 assert_eq!(matches.len(), 2);
290 }
291
292 #[test]
293 fn find_enum_variants() {
294 let source = r#"
295match exporter {
296 OtelExporter::Statsig => do_statsig(),
297 OtelExporter::None => do_nothing(),
298 _ => other(),
299}
300"#;
301 let matcher = PatternMatcher::new(source);
302 let matches = matcher.find_all("OtelExporter::$VARIANT").unwrap();
303
304 assert_eq!(matches.len(), 2);
305 }
306
307 #[test]
308 fn find_unique_success() {
309 let source = "fn main() { println!(\"hello\"); }";
310 let matcher = PatternMatcher::new(source);
311 let m = matcher.find_unique("fn main() { $$$BODY }").unwrap();
312
313 assert!(m.text.contains("fn main()"));
314 }
315
316 #[test]
317 fn find_unique_no_match() {
318 let source = "fn main() {}";
319 let matcher = PatternMatcher::new(source);
320 let result = matcher.find_unique("fn nonexistent() { $$$BODY }");
321
322 assert!(matches!(result, Err(AstGrepError::NoMatch)));
323 }
324
325 #[test]
326 fn find_unique_ambiguous() {
327 let source = r#"
328fn foo() {}
329fn bar() {}
330"#;
331 let matcher = PatternMatcher::new(source);
332 let result = matcher.find_unique("fn $NAME() {}");
333
334 assert!(matches!(
335 result,
336 Err(AstGrepError::AmbiguousMatch { count: 2 })
337 ));
338 }
339
340 #[test]
341 fn find_in_function_context() {
342 let source = r#"
343fn outer() {
344 let x = foo.clone();
345}
346
347fn inner() {
348 let y = bar.clone();
349}
350"#;
351 let matcher = PatternMatcher::new(source);
352 let matches = matcher.find_in_function("$EXPR.clone()", "inner").unwrap();
353
354 assert_eq!(matches.len(), 1);
355 assert!(matches[0].captures.contains_key("EXPR"));
356 }
357
358 #[test]
359 fn byte_spans_accurate() {
360 let source = "fn foo() { let x = 1; }";
361 let matcher = PatternMatcher::new(source);
362 let m = matcher.find_unique("fn $NAME() { $$$BODY }").unwrap();
363
364 let extracted = &source[m.byte_start..m.byte_end];
366 assert_eq!(extracted, source);
367 }
368
369 #[test]
370 fn find_match_arms_by_pattern() {
371 let source = r#"
372match exporter {
373 OtelExporter::Statsig => {
374 OtelExporter::OtlpHttp { endpoint: url }
375 }
376 OtelExporter::None => None,
377 _ => exporter.clone(),
378}
379"#;
380 let matcher = PatternMatcher::new(source);
381
382 let arms = matcher.find_match_arms("OtelExporter::Statsig").unwrap();
384 assert_eq!(arms.len(), 1);
385 assert!(arms[0].text.contains("OtelExporter::Statsig"));
386 assert!(arms[0].text.contains("OtlpHttp"));
387
388 let none_arms = matcher.find_match_arms("OtelExporter::None").unwrap();
390 assert_eq!(none_arms.len(), 1);
391
392 let variant_arms = matcher.find_match_arms("OtelExporter::$VARIANT").unwrap();
394 assert_eq!(variant_arms.len(), 2); }
396
397 #[test]
398 fn find_by_kind_generic() {
399 let source = r#"
400struct Foo { x: i32 }
401struct Bar { y: String }
402"#;
403 let matcher = PatternMatcher::new(source);
404
405 let structs = matcher
407 .find_by_kind_with_field("struct_item", None)
408 .unwrap();
409 assert_eq!(structs.len(), 2);
410
411 let foo = matcher
414 .find_by_kind_with_field("struct_item", Some(("name", "$NAME")))
415 .unwrap();
416 assert_eq!(foo.len(), 2);
418
419 let foo_structs: Vec<_> = foo.iter().filter(|m| m.text.contains("Foo")).collect();
421 assert_eq!(foo_structs.len(), 1);
422 }
423}