codex_patcher/ts/
query.rs1use crate::ts::errors::TreeSitterError;
2use crate::ts::parser::ParsedSource;
3use ast_grep_language::{LanguageExt, SupportLang};
4use std::collections::HashMap;
5use tree_sitter::{Query, QueryCursor, StreamingIterator};
6
7#[derive(Debug, Clone)]
9pub struct QueryMatch {
10 pub byte_start: usize,
12 pub byte_end: usize,
13 pub captures: HashMap<String, CapturedNode>,
15}
16
17#[derive(Debug, Clone)]
18pub struct CapturedNode {
19 pub byte_start: usize,
20 pub byte_end: usize,
21 pub text: String,
22 pub kind: String,
23}
24
25pub struct QueryEngine {
27 query: Query,
28 capture_names: Vec<String>,
29}
30
31impl QueryEngine {
32 pub fn new(query_str: &str) -> Result<Self, TreeSitterError> {
45 let language = SupportLang::Rust.get_ts_language();
46 let query =
47 Query::new(&language, query_str).map_err(|e| TreeSitterError::InvalidQuery {
48 message: e.to_string(),
49 })?;
50
51 let capture_names = query
52 .capture_names()
53 .iter()
54 .map(|s| s.to_string())
55 .collect();
56
57 Ok(Self {
58 query,
59 capture_names,
60 })
61 }
62
63 pub fn find_all<'a>(&self, parsed: &'a ParsedSource<'a>) -> Vec<QueryMatch> {
65 let mut cursor = QueryCursor::new();
66 let mut matches = cursor.matches(&self.query, parsed.root_node(), parsed.source.as_bytes());
67
68 let mut results = Vec::new();
69
70 while let Some(m) = matches.next() {
72 let mut captures = HashMap::new();
73 let mut overall_start = usize::MAX;
74 let mut overall_end = 0usize;
75
76 for capture in m.captures {
77 let node = capture.node;
78 let name = &self.capture_names[capture.index as usize];
79 let text = parsed.node_text(node).to_string();
80
81 overall_start = overall_start.min(node.start_byte());
82 overall_end = overall_end.max(node.end_byte());
83
84 captures.insert(
85 name.clone(),
86 CapturedNode {
87 byte_start: node.start_byte(),
88 byte_end: node.end_byte(),
89 text,
90 kind: node.kind().to_string(),
91 },
92 );
93 }
94
95 if overall_start != usize::MAX {
96 results.push(QueryMatch {
97 byte_start: overall_start,
98 byte_end: overall_end,
99 captures,
100 });
101 }
102 }
103
104 results
105 }
106
107 pub fn find_unique<'a>(
109 &self,
110 parsed: &'a ParsedSource<'a>,
111 ) -> Result<QueryMatch, TreeSitterError> {
112 let matches = self.find_all(parsed);
113
114 match matches.len() {
115 0 => Err(TreeSitterError::NoMatch),
116 1 => Ok(matches.into_iter().next().unwrap()),
117 n => Err(TreeSitterError::AmbiguousMatch { count: n }),
118 }
119 }
120
121 pub fn capture_names(&self) -> &[String] {
123 &self.capture_names
124 }
125}
126
127pub mod queries {
129 pub fn function_by_name(name: &str) -> String {
131 format!(
132 r#"(function_item
133 name: (identifier) @name
134 (#eq? @name "{name}")
135 ) @function"#
136 )
137 }
138
139 pub fn method_by_name(type_name: &str, method_name: &str) -> String {
141 format!(
142 r#"(impl_item
143 type: (_) @type
144 (#match? @type "{type_name}")
145 body: (declaration_list
146 (function_item
147 name: (identifier) @method_name
148 (#eq? @method_name "{method_name}")
149 ) @method
150 )
151 )"#
152 )
153 }
154
155 pub fn struct_by_name(name: &str) -> String {
157 format!(
158 r#"(struct_item
159 name: (type_identifier) @name
160 (#eq? @name "{name}")
161 ) @struct"#
162 )
163 }
164
165 pub fn enum_by_name(name: &str) -> String {
167 format!(
168 r#"(enum_item
169 name: (type_identifier) @name
170 (#eq? @name "{name}")
171 ) @enum"#
172 )
173 }
174
175 pub fn const_by_name(name: &str) -> String {
177 format!(
178 r#"(const_item
179 name: (identifier) @name
180 (#eq? @name "{name}")
181 ) @const"#
182 )
183 }
184
185 pub fn static_by_name(name: &str) -> String {
187 format!(
188 r#"(static_item
189 name: (identifier) @name
190 (#eq? @name "{name}")
191 ) @static"#
192 )
193 }
194
195 pub fn impl_by_type(type_name: &str) -> String {
197 format!(
198 r#"(impl_item
199 type: (type_identifier) @type
200 (#eq? @type "{type_name}")
201 ) @impl"#
202 )
203 }
204
205 pub fn impl_trait_for_type(trait_name: &str, type_name: &str) -> String {
207 format!(
208 r#"(impl_item
209 trait: (type_identifier) @trait
210 (#eq? @trait "{trait_name}")
211 type: (type_identifier) @type
212 (#eq? @type "{type_name}")
213 ) @impl"#
214 )
215 }
216
217 pub fn use_declaration(path_pattern: &str) -> String {
219 format!(
220 r#"(use_declaration
221 argument: (_) @path
222 (#match? @path "{path_pattern}")
223 ) @use"#
224 )
225 }
226
227 pub const ALL_FUNCTIONS: &str = r#"(function_item
229 name: (identifier) @name
230 ) @function"#;
231
232 pub const ALL_STRUCTS: &str = r#"(struct_item
234 name: (type_identifier) @name
235 ) @struct"#;
236
237 pub const ALL_IMPLS: &str = r#"(impl_item
239 type: (_) @type
240 ) @impl"#;
241
242 pub fn const_matching(pattern: &str) -> String {
244 format!(
245 r#"(const_item
246 name: (identifier) @name
247 (#match? @name "{pattern}")
248 ) @const"#
249 )
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use crate::ts::parser::RustParser;
257
258 #[test]
259 fn find_function_by_name() {
260 let mut parser = RustParser::new().unwrap();
261 let source = r#"
262fn helper() {}
263
264fn main() {
265 helper();
266}
267
268fn other() {}
269"#;
270 let parsed = parser.parse_with_source(source).unwrap();
271 let engine = QueryEngine::new(&queries::function_by_name("main")).unwrap();
272
273 let matches = engine.find_all(&parsed);
274 assert_eq!(matches.len(), 1);
275
276 let m = &matches[0];
277 assert!(m.captures.contains_key("name"));
278 assert_eq!(m.captures["name"].text, "main");
279 }
280
281 #[test]
282 fn find_struct_by_name() {
283 let mut parser = RustParser::new().unwrap();
284 let source = r#"
285struct Foo {
286 x: i32,
287}
288
289struct Bar;
290"#;
291 let parsed = parser.parse_with_source(source).unwrap();
292 let engine = QueryEngine::new(&queries::struct_by_name("Foo")).unwrap();
293
294 let m = engine.find_unique(&parsed).unwrap();
295 assert_eq!(m.captures["name"].text, "Foo");
296 }
297
298 #[test]
299 fn find_const_by_pattern() {
300 let mut parser = RustParser::new().unwrap();
301 let source = r#"
302const STATSIG_API_KEY: &str = "secret";
303const STATSIG_ENDPOINT: &str = "https://example.com";
304const OTHER_CONST: i32 = 42;
305"#;
306 let parsed = parser.parse_with_source(source).unwrap();
307 let engine = QueryEngine::new(&queries::const_matching("^STATSIG_")).unwrap();
308
309 let matches = engine.find_all(&parsed);
310 assert_eq!(matches.len(), 2);
311 }
312
313 #[test]
314 fn ambiguous_match_error() {
315 let mut parser = RustParser::new().unwrap();
316 let source = r#"
317fn test() {}
318fn test() {}
319"#;
320 let parsed = parser.parse_with_source(source).unwrap();
321 let engine = QueryEngine::new(&queries::function_by_name("test")).unwrap();
322
323 let result = engine.find_unique(&parsed);
324 assert!(matches!(
325 result,
326 Err(TreeSitterError::AmbiguousMatch { count: 2 })
327 ));
328 }
329
330 #[test]
331 fn no_match_error() {
332 let mut parser = RustParser::new().unwrap();
333 let source = "fn main() {}";
334 let parsed = parser.parse_with_source(source).unwrap();
335 let engine = QueryEngine::new(&queries::function_by_name("nonexistent")).unwrap();
336
337 let result = engine.find_unique(&parsed);
338 assert!(matches!(result, Err(TreeSitterError::NoMatch)));
339 }
340}