1use super::CodeTool;
4use super::ToolError;
5use super::queries::CompiledQuery;
6use super::queries::QueryLibrary;
7use super::queries::QueryType;
8use agcodex_ast::AstEngine;
9use agcodex_ast::CompressionLevel;
10use agcodex_ast::Language;
11use agcodex_ast::LanguageRegistry;
12use agcodex_ast::ParsedAst;
13use dashmap::DashMap;
14use std::path::Path;
15use std::path::PathBuf;
16use std::sync::Arc;
17use tokio::runtime::Runtime;
18use tree_sitter::Query;
19use tree_sitter::QueryCursor;
20use tree_sitter::StreamingIterator;
21use walkdir::WalkDir;
22
23#[derive(Debug, Clone)]
24pub struct TreeSitterTool {
25 engine: Arc<AstEngine>,
26 registry: Arc<LanguageRegistry>,
27 runtime: Arc<Runtime>,
28 query_engine: Arc<QueryEngine>,
29 query_library: Arc<QueryLibrary>,
31}
32
33#[derive(Debug)]
35struct QueryEngine {
36 query_cache: DashMap<(Language, String), Arc<Query>>,
38 _registry: Arc<LanguageRegistry>,
39}
40
41impl QueryEngine {
42 fn new(registry: Arc<LanguageRegistry>) -> Self {
43 Self {
44 query_cache: DashMap::new(),
45 _registry: registry,
46 }
47 }
48
49 fn compile_query(&self, language: Language, pattern: &str) -> Result<Arc<Query>, ToolError> {
51 let cache_key = (language, pattern.to_string());
53 if let Some(query) = self.query_cache.get(&cache_key) {
54 return Ok(query.clone());
55 }
56
57 let ts_language = language.parser();
59 let query = Query::new(&ts_language, pattern)
60 .map_err(|e| ToolError::InvalidQuery(format!("Failed to compile query: {}", e)))?;
61
62 let query = Arc::new(query);
63 self.query_cache.insert(cache_key, query.clone());
64 Ok(query)
65 }
66
67 fn execute_query(&self, query: &Query, ast: &ParsedAst, source: &[u8]) -> Vec<TsQueryMatch> {
69 let mut cursor = QueryCursor::new();
70 let mut results = Vec::new();
71
72 let mut query_matches = cursor.matches(query, ast.tree.root_node(), source);
74 loop {
75 query_matches.advance();
76 let Some(m) = query_matches.get() else {
77 break;
78 };
79 for capture in m.captures {
80 let node = capture.node;
81 let text = std::str::from_utf8(&source[node.byte_range()])
82 .unwrap_or("")
83 .to_string();
84
85 results.push(TsQueryMatch {
86 _capture_name: query.capture_names()[capture.index as usize].to_string(),
87 node_kind: node.kind().to_string(),
88 text,
89 _start_byte: node.start_byte(),
90 _end_byte: node.end_byte(),
91 start_position: (node.start_position().row, node.start_position().column),
92 end_position: (node.end_position().row, node.end_position().column),
93 });
94 }
95 }
96
97 results
98 }
99}
100
101#[derive(Debug, Clone)]
103struct TsQueryMatch {
104 _capture_name: String,
105 node_kind: String,
106 text: String,
107 _start_byte: usize,
108 _end_byte: usize,
109 start_position: (usize, usize),
110 end_position: (usize, usize),
111}
112
113#[derive(Debug, Clone)]
114pub struct TsQuery {
115 pub language: Option<String>,
116 pub pattern: String,
117 pub files: Vec<PathBuf>,
118 pub search_type: TsSearchType,
119}
120
121#[derive(Debug, Clone)]
122pub enum TsSearchType {
123 Pattern, Symbol, Query, }
127
128#[derive(Debug, Clone)]
129pub struct TsMatch {
130 pub file: String,
131 pub line: usize,
132 pub column: usize,
133 pub end_line: usize,
134 pub end_column: usize,
135 pub matched_text: String,
136 pub node_kind: String,
137 pub context: Option<String>,
138}
139
140impl TreeSitterTool {
141 pub fn new() -> Self {
142 let registry = Arc::new(LanguageRegistry::new());
143 let query_library = Arc::new(QueryLibrary::new());
144
145 if let Err(e) = query_library.precompile_all() {
147 eprintln!("Warning: Failed to precompile queries: {}", e);
148 }
149
150 Self {
151 engine: Arc::new(AstEngine::new(CompressionLevel::Medium)),
152 registry: registry.clone(),
153 runtime: Arc::new(Runtime::new().expect("Failed to create tokio runtime")),
154 query_engine: Arc::new(QueryEngine::new(registry)),
155 query_library,
156 }
157 }
158
159 fn find_target_files(&self, query: &TsQuery) -> Result<Vec<PathBuf>, ToolError> {
161 let mut files = Vec::new();
162
163 if !query.files.is_empty() {
165 return Ok(query.files.clone());
166 }
167
168 let current_dir = std::env::current_dir().map_err(ToolError::Io)?;
170
171 for entry in WalkDir::new(current_dir)
172 .follow_links(true)
173 .into_iter()
174 .filter_map(Result::ok)
175 .filter(|e| e.file_type().is_file())
176 {
177 let path = entry.path();
178
179 if let Ok(detected_lang) = self.registry.detect_language(path) {
181 if let Some(ref lang_filter) = query.language {
183 if detected_lang.name() == lang_filter {
184 files.push(path.to_path_buf());
185 }
186 } else {
187 files.push(path.to_path_buf());
189 }
190 }
191 }
192
193 Ok(files)
194 }
195
196 fn extract_context(
198 &self,
199 source: &str,
200 start_line: usize,
201 end_line: usize,
202 context_lines: usize,
203 ) -> String {
204 let lines: Vec<&str> = source.lines().collect();
205 let total_lines = lines.len();
206
207 let context_start = start_line.saturating_sub(context_lines);
208 let context_end = (end_line + context_lines).min(total_lines - 1);
209
210 let mut result = String::new();
211 for i in context_start..=context_end {
212 if i < lines.len() {
213 if i == start_line {
214 result.push_str(">>> ");
215 }
216 result.push_str(lines[i]);
217 result.push('\n');
218 }
219 }
220
221 result
222 }
223
224 async fn search_in_tree(
226 &self,
227 ast: &ParsedAst,
228 file_path: &Path,
229 query: &TsQuery,
230 ) -> Result<Vec<TsMatch>, ToolError> {
231 let source = ast.source.as_bytes();
232 let mut matches = Vec::new();
233
234 match query.search_type {
235 TsSearchType::Pattern => {
236 let query_type = self.infer_query_type(&query.pattern);
238
239 let compiled_query = if let Some(qt) = query_type {
240 match self.get_structured_query(ast.language, qt) {
242 Ok(structured) => structured.query.clone(),
243 Err(_) => {
244 let query_pattern = self.convert_pattern_to_query(&query.pattern);
246 self.query_engine
247 .compile_query(ast.language, &query_pattern)?
248 }
249 }
250 } else {
251 let query_pattern = self.convert_pattern_to_query(&query.pattern);
253 self.query_engine
254 .compile_query(ast.language, &query_pattern)?
255 };
256
257 let query_matches = self
258 .query_engine
259 .execute_query(&compiled_query, ast, source);
260
261 for qm in query_matches {
262 matches.push(TsMatch {
263 file: file_path.display().to_string(),
264 line: qm.start_position.0 + 1,
265 column: qm.start_position.1,
266 end_line: qm.end_position.0 + 1,
267 end_column: qm.end_position.1,
268 matched_text: qm.text.clone(),
269 node_kind: qm.node_kind,
270 context: Some(self.extract_context(
271 &ast.source,
272 qm.start_position.0,
273 qm.end_position.0,
274 2,
275 )),
276 });
277 }
278 }
279 TsSearchType::Query => {
280 let compiled_query = self
282 .query_engine
283 .compile_query(ast.language, &query.pattern)?;
284
285 let query_matches = self
286 .query_engine
287 .execute_query(&compiled_query, ast, source);
288
289 for qm in query_matches {
290 matches.push(TsMatch {
291 file: file_path.display().to_string(),
292 line: qm.start_position.0 + 1,
293 column: qm.start_position.1,
294 end_line: qm.end_position.0 + 1,
295 end_column: qm.end_position.1,
296 matched_text: qm.text.clone(),
297 node_kind: qm.node_kind,
298 context: Some(self.extract_context(
299 &ast.source,
300 qm.start_position.0,
301 qm.end_position.0,
302 2,
303 )),
304 });
305 }
306 }
307 TsSearchType::Symbol => {
308 let symbols = self
310 .engine
311 .search_symbols(&query.pattern)
312 .await
313 .map_err(|e| ToolError::InvalidQuery(format!("Symbol search error: {}", e)))?;
314
315 for s in symbols {
316 if PathBuf::from(&s.location.file_path) == file_path {
317 matches.push(TsMatch {
318 file: file_path.display().to_string(),
319 line: s.location.start_line,
320 column: s.location.start_column,
321 end_line: s.location.end_line,
322 end_column: s.location.end_column,
323 matched_text: s.name.clone(),
324 node_kind: format!("{:?}", s.kind),
325 context: Some(s.signature),
326 });
327 }
328 }
329 }
330 }
331
332 Ok(matches)
333 }
334
335 fn infer_query_type(&self, pattern: &str) -> Option<QueryType> {
337 if pattern.starts_with("function ") || pattern.contains("function") {
338 Some(QueryType::Functions)
339 } else if pattern.starts_with("class ") || pattern.contains("class") {
340 Some(QueryType::Classes)
341 } else if pattern.starts_with("import ") || pattern.contains("import") {
342 Some(QueryType::Imports)
343 } else if pattern.starts_with("method ") || pattern.contains("method") {
344 Some(QueryType::Methods)
345 } else {
346 None
347 }
348 }
349
350 fn get_structured_query(
352 &self,
353 language: agcodex_ast::Language,
354 query_type: QueryType,
355 ) -> Result<Arc<CompiledQuery>, ToolError> {
356 self.query_library
357 .get_query(language, query_type)
358 .map_err(|e| ToolError::InvalidQuery(format!("Query library error: {}", e)))
359 }
360
361 fn convert_pattern_to_query(&self, pattern: &str) -> String {
363 if pattern.starts_with("function ") {
367 let func_name = pattern.trim_start_matches("function ").trim();
368 if func_name == "*" {
369 "[
371 (function_declaration) @func
372 (function_definition) @func
373 (method_declaration) @func
374 (method_definition) @func
375 ]"
376 .to_string()
377 } else {
378 format!(
380 "[
381 (function_declaration name: (identifier) @name (#eq? @name \"{}\"))
382 (function_definition name: (identifier) @name (#eq? @name \"{}\"))
383 (method_declaration name: (identifier) @name (#eq? @name \"{}\"))
384 (method_definition name: (identifier) @name (#eq? @name \"{}\"))
385 ] @func",
386 func_name, func_name, func_name, func_name
387 )
388 }
389 } else if pattern.starts_with("class ") {
390 let class_name = pattern.trim_start_matches("class ").trim();
391 if class_name == "*" {
392 "[
393 (class_declaration) @class
394 (class_definition) @class
395 ] @class"
396 .to_string()
397 } else {
398 format!(
399 "[
400 (class_declaration name: (identifier) @name (#eq? @name \"{}\"))
401 (class_definition name: (identifier) @name (#eq? @name \"{}\"))
402 ] @class",
403 class_name, class_name
404 )
405 }
406 } else if pattern.starts_with("import ") {
407 "[
408 (import_statement) @import
409 (import_declaration) @import
410 (use_declaration) @import
411 ] @import"
412 .to_string()
413 } else {
414 format!("(identifier) @id (#eq? @id \"{}\")", pattern)
416 }
417 }
418
419 pub async fn search_structured(
421 &self,
422 language: agcodex_ast::Language,
423 query_type: QueryType,
424 files: Vec<PathBuf>,
425 ) -> Result<Vec<TsMatch>, ToolError> {
426 let compiled_query = self.get_structured_query(language, query_type)?;
427 let mut all_matches = Vec::new();
428
429 for file_path in &files {
430 let ast = self
432 .engine
433 .parse_file(file_path)
434 .await
435 .map_err(|e| ToolError::InvalidQuery(format!("Parse error: {}", e)))?;
436
437 if ast.language != language {
439 continue;
440 }
441
442 let source = ast.source.as_bytes();
443 let query_matches =
444 self.query_engine
445 .execute_query(&compiled_query.query, &ast, source);
446
447 for qm in query_matches {
448 all_matches.push(TsMatch {
449 file: file_path.display().to_string(),
450 line: qm.start_position.0 + 1,
451 column: qm.start_position.1,
452 end_line: qm.end_position.0 + 1,
453 end_column: qm.end_position.1,
454 matched_text: qm.text.clone(),
455 node_kind: qm.node_kind,
456 context: Some(self.extract_context(
457 &ast.source,
458 qm.start_position.0,
459 qm.end_position.0,
460 2,
461 )),
462 });
463 }
464 }
465
466 Ok(all_matches)
467 }
468
469 pub fn query_stats(&self) -> crate::code_tools::queries::QueryLibraryStats {
471 self.query_library.stats()
472 }
473
474 pub fn supports_query(&self, language: agcodex_ast::Language, query_type: &QueryType) -> bool {
476 self.query_library.supports_query(language, query_type)
477 }
478
479 async fn search_async(&self, mut query: TsQuery) -> Result<Vec<TsMatch>, ToolError> {
480 if query.files.is_empty() {
482 query.files = self.find_target_files(&query)?;
483 }
484
485 let mut all_matches = Vec::new();
486
487 for file_path in &query.files {
488 let ast = self
490 .engine
491 .parse_file(file_path)
492 .await
493 .map_err(|e| ToolError::InvalidQuery(format!("Parse error: {}", e)))?;
494
495 let matches = self.search_in_tree(&ast, file_path, &query).await?;
497 all_matches.extend(matches);
498 }
499
500 Ok(all_matches)
501 }
502}
503
504impl CodeTool for TreeSitterTool {
505 type Query = TsQuery;
506 type Output = Vec<TsMatch>;
507
508 fn search(&self, query: Self::Query) -> Result<Self::Output, ToolError> {
509 self.runtime.block_on(self.search_async(query))
510 }
511}
512
513impl Default for TreeSitterTool {
514 fn default() -> Self {
515 Self::new()
516 }
517}