Skip to main content

codex_patcher/ts/
locator.rs

1use crate::ts::errors::TreeSitterError;
2use crate::ts::parser::RustParser;
3use crate::ts::query::{queries, QueryEngine, QueryMatch};
4use std::path::Path;
5
6/// High-level structural target for locating Rust code constructs.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum StructuralTarget {
9    /// A function by name (top-level or in module)
10    Function { name: String },
11
12    /// A method in an impl block
13    Method {
14        type_name: String,
15        method_name: String,
16    },
17
18    /// A struct by name
19    Struct { name: String },
20
21    /// An enum by name
22    Enum { name: String },
23
24    /// A const item by name
25    Const { name: String },
26
27    /// Const items matching a regex pattern
28    ConstMatching { pattern: String },
29
30    /// A static item by name
31    Static { name: String },
32
33    /// An impl block for a type
34    Impl { type_name: String },
35
36    /// An impl block for a trait on a type
37    ImplTrait {
38        trait_name: String,
39        type_name: String,
40    },
41
42    /// A use declaration matching a path pattern
43    Use { path_pattern: String },
44
45    /// Custom tree-sitter query
46    Custom { query: String },
47}
48
49impl StructuralTarget {
50    /// Convert to a tree-sitter query string.
51    pub fn to_query(&self) -> String {
52        match self {
53            StructuralTarget::Function { name } => queries::function_by_name(name),
54            StructuralTarget::Method {
55                type_name,
56                method_name,
57            } => queries::method_by_name(type_name, method_name),
58            StructuralTarget::Struct { name } => queries::struct_by_name(name),
59            StructuralTarget::Enum { name } => queries::enum_by_name(name),
60            StructuralTarget::Const { name } => queries::const_by_name(name),
61            StructuralTarget::ConstMatching { pattern } => queries::const_matching(pattern),
62            StructuralTarget::Static { name } => queries::static_by_name(name),
63            StructuralTarget::Impl { type_name } => queries::impl_by_type(type_name),
64            StructuralTarget::ImplTrait {
65                trait_name,
66                type_name,
67            } => queries::impl_trait_for_type(trait_name, type_name),
68            StructuralTarget::Use { path_pattern } => queries::use_declaration(path_pattern),
69            StructuralTarget::Custom { query } => query.clone(),
70        }
71    }
72}
73
74/// Result of locating a structural target.
75#[derive(Debug, Clone)]
76pub struct LocatorResult {
77    /// Byte range of the entire matched construct
78    pub byte_start: usize,
79    pub byte_end: usize,
80    /// The matched text
81    pub text: String,
82    /// Named captures from the query
83    pub captures: std::collections::HashMap<String, CaptureInfo>,
84}
85
86#[derive(Debug, Clone)]
87pub struct CaptureInfo {
88    pub byte_start: usize,
89    pub byte_end: usize,
90    pub text: String,
91}
92
93impl From<QueryMatch> for LocatorResult {
94    fn from(m: QueryMatch) -> Self {
95        LocatorResult {
96            byte_start: m.byte_start,
97            byte_end: m.byte_end,
98            text: String::new(), // Will be filled in by locator
99            captures: m
100                .captures
101                .into_iter()
102                .map(|(k, v)| {
103                    (
104                        k,
105                        CaptureInfo {
106                            byte_start: v.byte_start,
107                            byte_end: v.byte_end,
108                            text: v.text,
109                        },
110                    )
111                })
112                .collect(),
113        }
114    }
115}
116
117/// Structural code locator using tree-sitter queries.
118pub struct StructuralLocator {
119    parser: RustParser,
120}
121
122impl StructuralLocator {
123    /// Create a new structural locator.
124    pub fn new() -> Result<Self, TreeSitterError> {
125        Ok(Self {
126            parser: RustParser::new()?,
127        })
128    }
129
130    /// Locate a structural target in source code, expecting exactly one match.
131    pub fn locate(
132        &mut self,
133        source: &str,
134        target: &StructuralTarget,
135    ) -> Result<LocatorResult, TreeSitterError> {
136        let parsed = self.parser.parse_with_source(source)?;
137        let query_str = target.to_query();
138        let engine = QueryEngine::new(&query_str)?;
139
140        let m = engine.find_unique(&parsed)?;
141        let mut result = LocatorResult::from(m);
142        result.text = source[result.byte_start..result.byte_end].to_string();
143
144        Ok(result)
145    }
146
147    /// Locate all matches for a structural target.
148    pub fn locate_all(
149        &mut self,
150        source: &str,
151        target: &StructuralTarget,
152    ) -> Result<Vec<LocatorResult>, TreeSitterError> {
153        let parsed = self.parser.parse_with_source(source)?;
154        let query_str = target.to_query();
155        let engine = QueryEngine::new(&query_str)?;
156
157        let matches = engine.find_all(&parsed);
158        let results = matches
159            .into_iter()
160            .map(|m| {
161                let mut result = LocatorResult::from(m);
162                result.text = source[result.byte_start..result.byte_end].to_string();
163                result
164            })
165            .collect();
166
167        Ok(results)
168    }
169
170    /// Locate a target in a file.
171    pub fn locate_in_file(
172        &mut self,
173        path: &Path,
174        target: &StructuralTarget,
175    ) -> Result<LocatorResult, TreeSitterError> {
176        let source = std::fs::read_to_string(path).map_err(|e| TreeSitterError::Io {
177            path: path.to_path_buf(),
178            source: e,
179        })?;
180        self.locate(&source, target)
181    }
182
183    /// Check if source has syntax errors.
184    pub fn has_errors(&mut self, source: &str) -> Result<bool, TreeSitterError> {
185        let parsed = self.parser.parse_with_source(source)?;
186        Ok(parsed.has_errors())
187    }
188
189    /// Get the underlying parser for direct tree-sitter access.
190    pub fn parser_mut(&mut self) -> &mut RustParser {
191        &mut self.parser
192    }
193}
194
195impl Default for StructuralLocator {
196    fn default() -> Self {
197        Self::new().expect("failed to create default StructuralLocator")
198    }
199}
200
201/// Pooled location functions that reuse parsers from thread-local pool.
202///
203/// These functions provide significant performance improvements for multi-patch
204/// workloads by avoiding redundant parser allocation and initialization.
205pub mod pooled {
206    use super::*;
207    use crate::pool;
208
209    /// Locate a structural target using pooled parser.
210    pub fn locate(
211        source: &str,
212        target: &StructuralTarget,
213    ) -> Result<LocatorResult, TreeSitterError> {
214        pool::with_parser(|parser| {
215            let parsed = parser.parse_with_source(source)?;
216            let query_str = target.to_query();
217            let engine = QueryEngine::new(&query_str)?;
218
219            let m = engine.find_unique(&parsed)?;
220            let mut result = LocatorResult::from(m);
221            result.text = source[result.byte_start..result.byte_end].to_string();
222
223            Ok(result)
224        })?
225    }
226
227    /// Locate all matches using pooled parser.
228    pub fn locate_all(
229        source: &str,
230        target: &StructuralTarget,
231    ) -> Result<Vec<LocatorResult>, TreeSitterError> {
232        pool::with_parser(|parser| {
233            let parsed = parser.parse_with_source(source)?;
234            let query_str = target.to_query();
235            let engine = QueryEngine::new(&query_str)?;
236
237            let matches = engine.find_all(&parsed);
238            let results = matches
239                .into_iter()
240                .map(|m| {
241                    let mut result = LocatorResult::from(m);
242                    result.text = source[result.byte_start..result.byte_end].to_string();
243                    result
244                })
245                .collect();
246
247            Ok(results)
248        })?
249    }
250
251    /// Find a function by name using pooled parser.
252    pub fn find_function(source: &str, name: &str) -> Result<LocatorResult, TreeSitterError> {
253        locate(
254            source,
255            &StructuralTarget::Function {
256                name: name.to_string(),
257            },
258        )
259    }
260}
261
262/// Convenience functions for common operations.
263impl StructuralLocator {
264    /// Find a function by name.
265    pub fn find_function(
266        &mut self,
267        source: &str,
268        name: &str,
269    ) -> Result<LocatorResult, TreeSitterError> {
270        self.locate(
271            source,
272            &StructuralTarget::Function {
273                name: name.to_string(),
274            },
275        )
276    }
277
278    /// Find a struct by name.
279    pub fn find_struct(
280        &mut self,
281        source: &str,
282        name: &str,
283    ) -> Result<LocatorResult, TreeSitterError> {
284        self.locate(
285            source,
286            &StructuralTarget::Struct {
287                name: name.to_string(),
288            },
289        )
290    }
291
292    /// Find a const by name.
293    pub fn find_const(
294        &mut self,
295        source: &str,
296        name: &str,
297    ) -> Result<LocatorResult, TreeSitterError> {
298        self.locate(
299            source,
300            &StructuralTarget::Const {
301                name: name.to_string(),
302            },
303        )
304    }
305
306    /// Find all consts matching a pattern.
307    pub fn find_consts_matching(
308        &mut self,
309        source: &str,
310        pattern: &str,
311    ) -> Result<Vec<LocatorResult>, TreeSitterError> {
312        self.locate_all(
313            source,
314            &StructuralTarget::ConstMatching {
315                pattern: pattern.to_string(),
316            },
317        )
318    }
319
320    /// Find an impl block for a type.
321    pub fn find_impl(
322        &mut self,
323        source: &str,
324        type_name: &str,
325    ) -> Result<LocatorResult, TreeSitterError> {
326        self.locate(
327            source,
328            &StructuralTarget::Impl {
329                type_name: type_name.to_string(),
330            },
331        )
332    }
333
334    /// Find a method in an impl block.
335    pub fn find_method(
336        &mut self,
337        source: &str,
338        type_name: &str,
339        method_name: &str,
340    ) -> Result<LocatorResult, TreeSitterError> {
341        self.locate(
342            source,
343            &StructuralTarget::Method {
344                type_name: type_name.to_string(),
345                method_name: method_name.to_string(),
346            },
347        )
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn locate_function() {
357        let mut locator = StructuralLocator::new().unwrap();
358        let source = r#"
359fn helper() -> i32 {
360    42
361}
362
363fn main() {
364    let x = helper();
365    println!("{}", x);
366}
367"#;
368
369        let result = locator.find_function(source, "main").unwrap();
370        assert!(result.text.contains("fn main()"));
371        assert!(result.text.contains("println!"));
372    }
373
374    #[test]
375    fn locate_struct() {
376        let mut locator = StructuralLocator::new().unwrap();
377        let source = r#"
378/// A configuration struct
379#[derive(Debug)]
380struct Config {
381    name: String,
382    value: i32,
383}
384"#;
385
386        let result = locator.find_struct(source, "Config").unwrap();
387        assert!(result.text.contains("struct Config"));
388        assert!(result.text.contains("name: String"));
389    }
390
391    #[test]
392    fn locate_consts_by_pattern() {
393        let mut locator = StructuralLocator::new().unwrap();
394        let source = r#"
395const STATSIG_API_KEY: &str = "key123";
396const STATSIG_ENDPOINT: &str = "https://api.statsig.com";
397const OTEL_ENABLED: bool = true;
398"#;
399
400        let results = locator.find_consts_matching(source, "^STATSIG_").unwrap();
401        assert_eq!(results.len(), 2);
402
403        let names: Vec<_> = results
404            .iter()
405            .map(|r| r.captures["name"].text.as_str())
406            .collect();
407        assert!(names.contains(&"STATSIG_API_KEY"));
408        assert!(names.contains(&"STATSIG_ENDPOINT"));
409    }
410
411    #[test]
412    fn locate_impl_block() {
413        let mut locator = StructuralLocator::new().unwrap();
414        let source = r#"
415struct Foo;
416
417impl Foo {
418    fn new() -> Self {
419        Foo
420    }
421
422    fn method(&self) -> i32 {
423        42
424    }
425}
426"#;
427
428        let result = locator.find_impl(source, "Foo").unwrap();
429        assert!(result.text.contains("impl Foo"));
430        assert!(result.text.contains("fn new()"));
431        assert!(result.text.contains("fn method(&self)"));
432    }
433
434    #[test]
435    fn byte_span_accuracy() {
436        let mut locator = StructuralLocator::new().unwrap();
437        let source = "fn foo() {}\nfn bar() {}";
438
439        let result = locator.find_function(source, "bar").unwrap();
440
441        // Verify the byte span extracts exactly the function
442        let extracted = &source[result.byte_start..result.byte_end];
443        assert_eq!(extracted, "fn bar() {}");
444    }
445}