Skip to main content

codex_patcher/ts/
query.rs

1use crate::ts::errors::TreeSitterError;
2use crate::ts::parser::ParsedSource;
3use ast_grep_language::{LanguageExt, SupportLang};
4use std::collections::HashMap;
5use tree_sitter::{Query, QueryCursor, StreamingIterator};
6
7/// A match from a tree-sitter query with captured nodes.
8#[derive(Debug, Clone)]
9pub struct QueryMatch {
10    /// The full match byte range
11    pub byte_start: usize,
12    pub byte_end: usize,
13    /// Named captures: capture_name -> (byte_start, byte_end, text)
14    pub captures: HashMap<String, CapturedNode>,
15}
16
17#[derive(Debug, Clone)]
18pub struct CapturedNode {
19    pub byte_start: usize,
20    pub byte_end: usize,
21    pub text: String,
22    pub kind: String,
23}
24
25/// Engine for executing tree-sitter queries against parsed Rust source.
26pub struct QueryEngine {
27    query: Query,
28    capture_names: Vec<String>,
29}
30
31impl QueryEngine {
32    /// Create a new query engine from a tree-sitter query string.
33    ///
34    /// # Query Syntax
35    ///
36    /// Tree-sitter queries use S-expression syntax:
37    /// ```text
38    /// (function_item
39    ///   name: (identifier) @func_name
40    ///   body: (block) @body)
41    /// ```
42    ///
43    /// Captures are prefixed with `@` and can be referenced by name.
44    pub fn new(query_str: &str) -> Result<Self, TreeSitterError> {
45        let language = SupportLang::Rust.get_ts_language();
46        let query =
47            Query::new(&language, query_str).map_err(|e| TreeSitterError::InvalidQuery {
48                message: e.to_string(),
49            })?;
50
51        let capture_names = query
52            .capture_names()
53            .iter()
54            .map(|s| s.to_string())
55            .collect();
56
57        Ok(Self {
58            query,
59            capture_names,
60        })
61    }
62
63    /// Execute the query against parsed source and return all matches.
64    pub fn find_all<'a>(&self, parsed: &'a ParsedSource<'a>) -> Vec<QueryMatch> {
65        let mut cursor = QueryCursor::new();
66        let mut matches = cursor.matches(&self.query, parsed.root_node(), parsed.source.as_bytes());
67
68        let mut results = Vec::new();
69
70        // tree-sitter 0.25+ uses StreamingIterator
71        while let Some(m) = matches.next() {
72            let mut captures = HashMap::new();
73            let mut overall_start = usize::MAX;
74            let mut overall_end = 0usize;
75
76            for capture in m.captures {
77                let node = capture.node;
78                let name = &self.capture_names[capture.index as usize];
79                let text = parsed.node_text(node).to_string();
80
81                overall_start = overall_start.min(node.start_byte());
82                overall_end = overall_end.max(node.end_byte());
83
84                captures.insert(
85                    name.clone(),
86                    CapturedNode {
87                        byte_start: node.start_byte(),
88                        byte_end: node.end_byte(),
89                        text,
90                        kind: node.kind().to_string(),
91                    },
92                );
93            }
94
95            if overall_start != usize::MAX {
96                results.push(QueryMatch {
97                    byte_start: overall_start,
98                    byte_end: overall_end,
99                    captures,
100                });
101            }
102        }
103
104        results
105    }
106
107    /// Execute the query and expect exactly one match.
108    pub fn find_unique<'a>(
109        &self,
110        parsed: &'a ParsedSource<'a>,
111    ) -> Result<QueryMatch, TreeSitterError> {
112        let matches = self.find_all(parsed);
113
114        match matches.len() {
115            0 => Err(TreeSitterError::NoMatch),
116            1 => Ok(matches.into_iter().next().unwrap()),
117            n => Err(TreeSitterError::AmbiguousMatch { count: n }),
118        }
119    }
120
121    /// Get capture names defined in the query.
122    pub fn capture_names(&self) -> &[String] {
123        &self.capture_names
124    }
125}
126
127/// Common tree-sitter queries for Rust constructs.
128pub mod queries {
129    /// Query for a function by name.
130    pub fn function_by_name(name: &str) -> String {
131        format!(
132            r#"(function_item
133                name: (identifier) @name
134                (#eq? @name "{name}")
135            ) @function"#
136        )
137    }
138
139    /// Query for a function in an impl block.
140    pub fn method_by_name(type_name: &str, method_name: &str) -> String {
141        format!(
142            r#"(impl_item
143                type: (_) @type
144                (#match? @type "{type_name}")
145                body: (declaration_list
146                    (function_item
147                        name: (identifier) @method_name
148                        (#eq? @method_name "{method_name}")
149                    ) @method
150                )
151            )"#
152        )
153    }
154
155    /// Query for a struct by name.
156    pub fn struct_by_name(name: &str) -> String {
157        format!(
158            r#"(struct_item
159                name: (type_identifier) @name
160                (#eq? @name "{name}")
161            ) @struct"#
162        )
163    }
164
165    /// Query for an enum by name.
166    pub fn enum_by_name(name: &str) -> String {
167        format!(
168            r#"(enum_item
169                name: (type_identifier) @name
170                (#eq? @name "{name}")
171            ) @enum"#
172        )
173    }
174
175    /// Query for a const item by name.
176    pub fn const_by_name(name: &str) -> String {
177        format!(
178            r#"(const_item
179                name: (identifier) @name
180                (#eq? @name "{name}")
181            ) @const"#
182        )
183    }
184
185    /// Query for a static item by name.
186    pub fn static_by_name(name: &str) -> String {
187        format!(
188            r#"(static_item
189                name: (identifier) @name
190                (#eq? @name "{name}")
191            ) @static"#
192        )
193    }
194
195    /// Query for an impl block by type name.
196    pub fn impl_by_type(type_name: &str) -> String {
197        format!(
198            r#"(impl_item
199                type: (type_identifier) @type
200                (#eq? @type "{type_name}")
201            ) @impl"#
202        )
203    }
204
205    /// Query for an impl block with a trait.
206    pub fn impl_trait_for_type(trait_name: &str, type_name: &str) -> String {
207        format!(
208            r#"(impl_item
209                trait: (type_identifier) @trait
210                (#eq? @trait "{trait_name}")
211                type: (type_identifier) @type
212                (#eq? @type "{type_name}")
213            ) @impl"#
214        )
215    }
216
217    /// Query for use statements matching a path pattern.
218    pub fn use_declaration(path_pattern: &str) -> String {
219        format!(
220            r#"(use_declaration
221                argument: (_) @path
222                (#match? @path "{path_pattern}")
223            ) @use"#
224        )
225    }
226
227    /// Query for all functions in file.
228    pub const ALL_FUNCTIONS: &str = r#"(function_item
229        name: (identifier) @name
230    ) @function"#;
231
232    /// Query for all structs in file.
233    pub const ALL_STRUCTS: &str = r#"(struct_item
234        name: (type_identifier) @name
235    ) @struct"#;
236
237    /// Query for all impl blocks in file.
238    pub const ALL_IMPLS: &str = r#"(impl_item
239        type: (_) @type
240    ) @impl"#;
241
242    /// Query for const items matching a name pattern.
243    pub fn const_matching(pattern: &str) -> String {
244        format!(
245            r#"(const_item
246                name: (identifier) @name
247                (#match? @name "{pattern}")
248            ) @const"#
249        )
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use crate::ts::parser::RustParser;
257
258    #[test]
259    fn find_function_by_name() {
260        let mut parser = RustParser::new().unwrap();
261        let source = r#"
262fn helper() {}
263
264fn main() {
265    helper();
266}
267
268fn other() {}
269"#;
270        let parsed = parser.parse_with_source(source).unwrap();
271        let engine = QueryEngine::new(&queries::function_by_name("main")).unwrap();
272
273        let matches = engine.find_all(&parsed);
274        assert_eq!(matches.len(), 1);
275
276        let m = &matches[0];
277        assert!(m.captures.contains_key("name"));
278        assert_eq!(m.captures["name"].text, "main");
279    }
280
281    #[test]
282    fn find_struct_by_name() {
283        let mut parser = RustParser::new().unwrap();
284        let source = r#"
285struct Foo {
286    x: i32,
287}
288
289struct Bar;
290"#;
291        let parsed = parser.parse_with_source(source).unwrap();
292        let engine = QueryEngine::new(&queries::struct_by_name("Foo")).unwrap();
293
294        let m = engine.find_unique(&parsed).unwrap();
295        assert_eq!(m.captures["name"].text, "Foo");
296    }
297
298    #[test]
299    fn find_const_by_pattern() {
300        let mut parser = RustParser::new().unwrap();
301        let source = r#"
302const STATSIG_API_KEY: &str = "secret";
303const STATSIG_ENDPOINT: &str = "https://example.com";
304const OTHER_CONST: i32 = 42;
305"#;
306        let parsed = parser.parse_with_source(source).unwrap();
307        let engine = QueryEngine::new(&queries::const_matching("^STATSIG_")).unwrap();
308
309        let matches = engine.find_all(&parsed);
310        assert_eq!(matches.len(), 2);
311    }
312
313    #[test]
314    fn ambiguous_match_error() {
315        let mut parser = RustParser::new().unwrap();
316        let source = r#"
317fn test() {}
318fn test() {}
319"#;
320        let parsed = parser.parse_with_source(source).unwrap();
321        let engine = QueryEngine::new(&queries::function_by_name("test")).unwrap();
322
323        let result = engine.find_unique(&parsed);
324        assert!(matches!(
325            result,
326            Err(TreeSitterError::AmbiguousMatch { count: 2 })
327        ));
328    }
329
330    #[test]
331    fn no_match_error() {
332        let mut parser = RustParser::new().unwrap();
333        let source = "fn main() {}";
334        let parsed = parser.parse_with_source(source).unwrap();
335        let engine = QueryEngine::new(&queries::function_by_name("nonexistent")).unwrap();
336
337        let result = engine.find_unique(&parsed);
338        assert!(matches!(result, Err(TreeSitterError::NoMatch)));
339    }
340}