1use crate::ts::errors::TreeSitterError;
2use crate::ts::parser::RustParser;
3use crate::ts::query::{queries, QueryEngine, QueryMatch};
4use std::path::Path;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum StructuralTarget {
9 Function { name: String },
11
12 Method {
14 type_name: String,
15 method_name: String,
16 },
17
18 Struct { name: String },
20
21 Enum { name: String },
23
24 Const { name: String },
26
27 ConstMatching { pattern: String },
29
30 Static { name: String },
32
33 Impl { type_name: String },
35
36 ImplTrait {
38 trait_name: String,
39 type_name: String,
40 },
41
42 Use { path_pattern: String },
44
45 Custom { query: String },
47}
48
49impl StructuralTarget {
50 pub fn to_query(&self) -> String {
52 match self {
53 StructuralTarget::Function { name } => queries::function_by_name(name),
54 StructuralTarget::Method {
55 type_name,
56 method_name,
57 } => queries::method_by_name(type_name, method_name),
58 StructuralTarget::Struct { name } => queries::struct_by_name(name),
59 StructuralTarget::Enum { name } => queries::enum_by_name(name),
60 StructuralTarget::Const { name } => queries::const_by_name(name),
61 StructuralTarget::ConstMatching { pattern } => queries::const_matching(pattern),
62 StructuralTarget::Static { name } => queries::static_by_name(name),
63 StructuralTarget::Impl { type_name } => queries::impl_by_type(type_name),
64 StructuralTarget::ImplTrait {
65 trait_name,
66 type_name,
67 } => queries::impl_trait_for_type(trait_name, type_name),
68 StructuralTarget::Use { path_pattern } => queries::use_declaration(path_pattern),
69 StructuralTarget::Custom { query } => query.clone(),
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct LocatorResult {
77 pub byte_start: usize,
79 pub byte_end: usize,
80 pub text: String,
82 pub captures: std::collections::HashMap<String, CaptureInfo>,
84}
85
86#[derive(Debug, Clone)]
87pub struct CaptureInfo {
88 pub byte_start: usize,
89 pub byte_end: usize,
90 pub text: String,
91}
92
93impl From<QueryMatch> for LocatorResult {
94 fn from(m: QueryMatch) -> Self {
95 LocatorResult {
96 byte_start: m.byte_start,
97 byte_end: m.byte_end,
98 text: String::new(), captures: m
100 .captures
101 .into_iter()
102 .map(|(k, v)| {
103 (
104 k,
105 CaptureInfo {
106 byte_start: v.byte_start,
107 byte_end: v.byte_end,
108 text: v.text,
109 },
110 )
111 })
112 .collect(),
113 }
114 }
115}
116
117pub struct StructuralLocator {
119 parser: RustParser,
120}
121
122impl StructuralLocator {
123 pub fn new() -> Result<Self, TreeSitterError> {
125 Ok(Self {
126 parser: RustParser::new()?,
127 })
128 }
129
130 pub fn locate(
132 &mut self,
133 source: &str,
134 target: &StructuralTarget,
135 ) -> Result<LocatorResult, TreeSitterError> {
136 let parsed = self.parser.parse_with_source(source)?;
137 let query_str = target.to_query();
138 let engine = QueryEngine::new(&query_str)?;
139
140 let m = engine.find_unique(&parsed)?;
141 let mut result = LocatorResult::from(m);
142 result.text = source[result.byte_start..result.byte_end].to_string();
143
144 Ok(result)
145 }
146
147 pub fn locate_all(
149 &mut self,
150 source: &str,
151 target: &StructuralTarget,
152 ) -> Result<Vec<LocatorResult>, TreeSitterError> {
153 let parsed = self.parser.parse_with_source(source)?;
154 let query_str = target.to_query();
155 let engine = QueryEngine::new(&query_str)?;
156
157 let matches = engine.find_all(&parsed);
158 let results = matches
159 .into_iter()
160 .map(|m| {
161 let mut result = LocatorResult::from(m);
162 result.text = source[result.byte_start..result.byte_end].to_string();
163 result
164 })
165 .collect();
166
167 Ok(results)
168 }
169
170 pub fn locate_in_file(
172 &mut self,
173 path: &Path,
174 target: &StructuralTarget,
175 ) -> Result<LocatorResult, TreeSitterError> {
176 let source = std::fs::read_to_string(path).map_err(|e| TreeSitterError::Io {
177 path: path.to_path_buf(),
178 source: e,
179 })?;
180 self.locate(&source, target)
181 }
182
183 pub fn has_errors(&mut self, source: &str) -> Result<bool, TreeSitterError> {
185 let parsed = self.parser.parse_with_source(source)?;
186 Ok(parsed.has_errors())
187 }
188
189 pub fn parser_mut(&mut self) -> &mut RustParser {
191 &mut self.parser
192 }
193}
194
195impl Default for StructuralLocator {
196 fn default() -> Self {
197 Self::new().expect("failed to create default StructuralLocator")
198 }
199}
200
201pub mod pooled {
206 use super::*;
207 use crate::pool;
208
209 pub fn locate(
211 source: &str,
212 target: &StructuralTarget,
213 ) -> Result<LocatorResult, TreeSitterError> {
214 pool::with_parser(|parser| {
215 let parsed = parser.parse_with_source(source)?;
216 let query_str = target.to_query();
217 let engine = QueryEngine::new(&query_str)?;
218
219 let m = engine.find_unique(&parsed)?;
220 let mut result = LocatorResult::from(m);
221 result.text = source[result.byte_start..result.byte_end].to_string();
222
223 Ok(result)
224 })?
225 }
226
227 pub fn locate_all(
229 source: &str,
230 target: &StructuralTarget,
231 ) -> Result<Vec<LocatorResult>, TreeSitterError> {
232 pool::with_parser(|parser| {
233 let parsed = parser.parse_with_source(source)?;
234 let query_str = target.to_query();
235 let engine = QueryEngine::new(&query_str)?;
236
237 let matches = engine.find_all(&parsed);
238 let results = matches
239 .into_iter()
240 .map(|m| {
241 let mut result = LocatorResult::from(m);
242 result.text = source[result.byte_start..result.byte_end].to_string();
243 result
244 })
245 .collect();
246
247 Ok(results)
248 })?
249 }
250
251 pub fn find_function(source: &str, name: &str) -> Result<LocatorResult, TreeSitterError> {
253 locate(
254 source,
255 &StructuralTarget::Function {
256 name: name.to_string(),
257 },
258 )
259 }
260}
261
262impl StructuralLocator {
264 pub fn find_function(
266 &mut self,
267 source: &str,
268 name: &str,
269 ) -> Result<LocatorResult, TreeSitterError> {
270 self.locate(
271 source,
272 &StructuralTarget::Function {
273 name: name.to_string(),
274 },
275 )
276 }
277
278 pub fn find_struct(
280 &mut self,
281 source: &str,
282 name: &str,
283 ) -> Result<LocatorResult, TreeSitterError> {
284 self.locate(
285 source,
286 &StructuralTarget::Struct {
287 name: name.to_string(),
288 },
289 )
290 }
291
292 pub fn find_const(
294 &mut self,
295 source: &str,
296 name: &str,
297 ) -> Result<LocatorResult, TreeSitterError> {
298 self.locate(
299 source,
300 &StructuralTarget::Const {
301 name: name.to_string(),
302 },
303 )
304 }
305
306 pub fn find_consts_matching(
308 &mut self,
309 source: &str,
310 pattern: &str,
311 ) -> Result<Vec<LocatorResult>, TreeSitterError> {
312 self.locate_all(
313 source,
314 &StructuralTarget::ConstMatching {
315 pattern: pattern.to_string(),
316 },
317 )
318 }
319
320 pub fn find_impl(
322 &mut self,
323 source: &str,
324 type_name: &str,
325 ) -> Result<LocatorResult, TreeSitterError> {
326 self.locate(
327 source,
328 &StructuralTarget::Impl {
329 type_name: type_name.to_string(),
330 },
331 )
332 }
333
334 pub fn find_method(
336 &mut self,
337 source: &str,
338 type_name: &str,
339 method_name: &str,
340 ) -> Result<LocatorResult, TreeSitterError> {
341 self.locate(
342 source,
343 &StructuralTarget::Method {
344 type_name: type_name.to_string(),
345 method_name: method_name.to_string(),
346 },
347 )
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn locate_function() {
357 let mut locator = StructuralLocator::new().unwrap();
358 let source = r#"
359fn helper() -> i32 {
360 42
361}
362
363fn main() {
364 let x = helper();
365 println!("{}", x);
366}
367"#;
368
369 let result = locator.find_function(source, "main").unwrap();
370 assert!(result.text.contains("fn main()"));
371 assert!(result.text.contains("println!"));
372 }
373
374 #[test]
375 fn locate_struct() {
376 let mut locator = StructuralLocator::new().unwrap();
377 let source = r#"
378/// A configuration struct
379#[derive(Debug)]
380struct Config {
381 name: String,
382 value: i32,
383}
384"#;
385
386 let result = locator.find_struct(source, "Config").unwrap();
387 assert!(result.text.contains("struct Config"));
388 assert!(result.text.contains("name: String"));
389 }
390
391 #[test]
392 fn locate_consts_by_pattern() {
393 let mut locator = StructuralLocator::new().unwrap();
394 let source = r#"
395const STATSIG_API_KEY: &str = "key123";
396const STATSIG_ENDPOINT: &str = "https://api.statsig.com";
397const OTEL_ENABLED: bool = true;
398"#;
399
400 let results = locator.find_consts_matching(source, "^STATSIG_").unwrap();
401 assert_eq!(results.len(), 2);
402
403 let names: Vec<_> = results
404 .iter()
405 .map(|r| r.captures["name"].text.as_str())
406 .collect();
407 assert!(names.contains(&"STATSIG_API_KEY"));
408 assert!(names.contains(&"STATSIG_ENDPOINT"));
409 }
410
411 #[test]
412 fn locate_impl_block() {
413 let mut locator = StructuralLocator::new().unwrap();
414 let source = r#"
415struct Foo;
416
417impl Foo {
418 fn new() -> Self {
419 Foo
420 }
421
422 fn method(&self) -> i32 {
423 42
424 }
425}
426"#;
427
428 let result = locator.find_impl(source, "Foo").unwrap();
429 assert!(result.text.contains("impl Foo"));
430 assert!(result.text.contains("fn new()"));
431 assert!(result.text.contains("fn method(&self)"));
432 }
433
434 #[test]
435 fn byte_span_accuracy() {
436 let mut locator = StructuralLocator::new().unwrap();
437 let source = "fn foo() {}\nfn bar() {}";
438
439 let result = locator.find_function(source, "bar").unwrap();
440
441 let extracted = &source[result.byte_start..result.byte_end];
443 assert_eq!(extracted, "fn bar() {}");
444 }
445}