agcodex_core/code_tools/
tree_sitter.rs

1//! Tree-sitter primary structural tool.
2
3use super::CodeTool;
4use super::ToolError;
5use super::queries::CompiledQuery;
6use super::queries::QueryLibrary;
7use super::queries::QueryType;
8use agcodex_ast::AstEngine;
9use agcodex_ast::CompressionLevel;
10use agcodex_ast::Language;
11use agcodex_ast::LanguageRegistry;
12use agcodex_ast::ParsedAst;
13use dashmap::DashMap;
14use std::path::Path;
15use std::path::PathBuf;
16use std::sync::Arc;
17use tokio::runtime::Runtime;
18use tree_sitter::Query;
19use tree_sitter::QueryCursor;
20use tree_sitter::StreamingIterator;
21use walkdir::WalkDir;
22
23#[derive(Debug, Clone)]
24pub struct TreeSitterTool {
25    engine: Arc<AstEngine>,
26    registry: Arc<LanguageRegistry>,
27    runtime: Arc<Runtime>,
28    query_engine: Arc<QueryEngine>,
29    /// New comprehensive query library
30    query_library: Arc<QueryLibrary>,
31}
32
33/// Query engine for compiling and caching tree-sitter queries
34#[derive(Debug)]
35struct QueryEngine {
36    /// Cache of compiled queries per language
37    query_cache: DashMap<(Language, String), Arc<Query>>,
38    _registry: Arc<LanguageRegistry>,
39}
40
41impl QueryEngine {
42    fn new(registry: Arc<LanguageRegistry>) -> Self {
43        Self {
44            query_cache: DashMap::new(),
45            _registry: registry,
46        }
47    }
48
49    /// Compile a tree-sitter query pattern for a specific language
50    fn compile_query(&self, language: Language, pattern: &str) -> Result<Arc<Query>, ToolError> {
51        // Check cache first
52        let cache_key = (language, pattern.to_string());
53        if let Some(query) = self.query_cache.get(&cache_key) {
54            return Ok(query.clone());
55        }
56
57        // Compile the query
58        let ts_language = language.parser();
59        let query = Query::new(&ts_language, pattern)
60            .map_err(|e| ToolError::InvalidQuery(format!("Failed to compile query: {}", e)))?;
61
62        let query = Arc::new(query);
63        self.query_cache.insert(cache_key, query.clone());
64        Ok(query)
65    }
66
67    /// Execute a query against a parsed AST
68    fn execute_query(&self, query: &Query, ast: &ParsedAst, source: &[u8]) -> Vec<TsQueryMatch> {
69        let mut cursor = QueryCursor::new();
70        let mut results = Vec::new();
71
72        // Iterate over matches manually
73        let mut query_matches = cursor.matches(query, ast.tree.root_node(), source);
74        loop {
75            query_matches.advance();
76            let Some(m) = query_matches.get() else {
77                break;
78            };
79            for capture in m.captures {
80                let node = capture.node;
81                let text = std::str::from_utf8(&source[node.byte_range()])
82                    .unwrap_or("")
83                    .to_string();
84
85                results.push(TsQueryMatch {
86                    _capture_name: query.capture_names()[capture.index as usize].to_string(),
87                    node_kind: node.kind().to_string(),
88                    text,
89                    _start_byte: node.start_byte(),
90                    _end_byte: node.end_byte(),
91                    start_position: (node.start_position().row, node.start_position().column),
92                    end_position: (node.end_position().row, node.end_position().column),
93                });
94            }
95        }
96
97        results
98    }
99}
100
101/// Result of a query execution
102#[derive(Debug, Clone)]
103struct TsQueryMatch {
104    _capture_name: String,
105    node_kind: String,
106    text: String,
107    _start_byte: usize,
108    _end_byte: usize,
109    start_position: (usize, usize),
110    end_position: (usize, usize),
111}
112
113#[derive(Debug, Clone)]
114pub struct TsQuery {
115    pub language: Option<String>,
116    pub pattern: String,
117    pub files: Vec<PathBuf>,
118    pub search_type: TsSearchType,
119}
120
121#[derive(Debug, Clone)]
122pub enum TsSearchType {
123    Pattern, // AST pattern matching
124    Symbol,  // Symbol search
125    Query,   // Tree-sitter query language
126}
127
128#[derive(Debug, Clone)]
129pub struct TsMatch {
130    pub file: String,
131    pub line: usize,
132    pub column: usize,
133    pub end_line: usize,
134    pub end_column: usize,
135    pub matched_text: String,
136    pub node_kind: String,
137    pub context: Option<String>,
138}
139
140impl TreeSitterTool {
141    pub fn new() -> Self {
142        let registry = Arc::new(LanguageRegistry::new());
143        let query_library = Arc::new(QueryLibrary::new());
144
145        // Precompile common queries for better performance
146        if let Err(e) = query_library.precompile_all() {
147            eprintln!("Warning: Failed to precompile queries: {}", e);
148        }
149
150        Self {
151            engine: Arc::new(AstEngine::new(CompressionLevel::Medium)),
152            registry: registry.clone(),
153            runtime: Arc::new(Runtime::new().expect("Failed to create tokio runtime")),
154            query_engine: Arc::new(QueryEngine::new(registry)),
155            query_library,
156        }
157    }
158
159    /// Find target files based on language or pattern
160    fn find_target_files(&self, query: &TsQuery) -> Result<Vec<PathBuf>, ToolError> {
161        let mut files = Vec::new();
162
163        // If specific files are provided, use them
164        if !query.files.is_empty() {
165            return Ok(query.files.clone());
166        }
167
168        // Otherwise, search for files in the current directory
169        let current_dir = std::env::current_dir().map_err(ToolError::Io)?;
170
171        for entry in WalkDir::new(current_dir)
172            .follow_links(true)
173            .into_iter()
174            .filter_map(Result::ok)
175            .filter(|e| e.file_type().is_file())
176        {
177            let path = entry.path();
178
179            // Try to detect language
180            if let Ok(detected_lang) = self.registry.detect_language(path) {
181                // If language filter is specified, check it
182                if let Some(ref lang_filter) = query.language {
183                    if detected_lang.name() == lang_filter {
184                        files.push(path.to_path_buf());
185                    }
186                } else {
187                    // No language filter, include all parseable files
188                    files.push(path.to_path_buf());
189                }
190            }
191        }
192
193        Ok(files)
194    }
195
196    /// Extract context around a match
197    fn extract_context(
198        &self,
199        source: &str,
200        start_line: usize,
201        end_line: usize,
202        context_lines: usize,
203    ) -> String {
204        let lines: Vec<&str> = source.lines().collect();
205        let total_lines = lines.len();
206
207        let context_start = start_line.saturating_sub(context_lines);
208        let context_end = (end_line + context_lines).min(total_lines - 1);
209
210        let mut result = String::new();
211        for i in context_start..=context_end {
212            if i < lines.len() {
213                if i == start_line {
214                    result.push_str(">>> ");
215                }
216                result.push_str(lines[i]);
217                result.push('\n');
218            }
219        }
220
221        result
222    }
223
224    /// Search within a parsed tree using pattern or query
225    async fn search_in_tree(
226        &self,
227        ast: &ParsedAst,
228        file_path: &Path,
229        query: &TsQuery,
230    ) -> Result<Vec<TsMatch>, ToolError> {
231        let source = ast.source.as_bytes();
232        let mut matches = Vec::new();
233
234        match query.search_type {
235            TsSearchType::Pattern => {
236                // Try structured query first, fall back to pattern conversion
237                let query_type = self.infer_query_type(&query.pattern);
238
239                let compiled_query = if let Some(qt) = query_type {
240                    // Use structured query from library
241                    match self.get_structured_query(ast.language, qt) {
242                        Ok(structured) => structured.query.clone(),
243                        Err(_) => {
244                            // Fall back to pattern conversion
245                            let query_pattern = self.convert_pattern_to_query(&query.pattern);
246                            self.query_engine
247                                .compile_query(ast.language, &query_pattern)?
248                        }
249                    }
250                } else {
251                    // Use pattern conversion
252                    let query_pattern = self.convert_pattern_to_query(&query.pattern);
253                    self.query_engine
254                        .compile_query(ast.language, &query_pattern)?
255                };
256
257                let query_matches = self
258                    .query_engine
259                    .execute_query(&compiled_query, ast, source);
260
261                for qm in query_matches {
262                    matches.push(TsMatch {
263                        file: file_path.display().to_string(),
264                        line: qm.start_position.0 + 1,
265                        column: qm.start_position.1,
266                        end_line: qm.end_position.0 + 1,
267                        end_column: qm.end_position.1,
268                        matched_text: qm.text.clone(),
269                        node_kind: qm.node_kind,
270                        context: Some(self.extract_context(
271                            &ast.source,
272                            qm.start_position.0,
273                            qm.end_position.0,
274                            2,
275                        )),
276                    });
277                }
278            }
279            TsSearchType::Query => {
280                // Direct tree-sitter query language
281                let compiled_query = self
282                    .query_engine
283                    .compile_query(ast.language, &query.pattern)?;
284
285                let query_matches = self
286                    .query_engine
287                    .execute_query(&compiled_query, ast, source);
288
289                for qm in query_matches {
290                    matches.push(TsMatch {
291                        file: file_path.display().to_string(),
292                        line: qm.start_position.0 + 1,
293                        column: qm.start_position.1,
294                        end_line: qm.end_position.0 + 1,
295                        end_column: qm.end_position.1,
296                        matched_text: qm.text.clone(),
297                        node_kind: qm.node_kind,
298                        context: Some(self.extract_context(
299                            &ast.source,
300                            qm.start_position.0,
301                            qm.end_position.0,
302                            2,
303                        )),
304                    });
305                }
306            }
307            TsSearchType::Symbol => {
308                // Use existing symbol search
309                let symbols = self
310                    .engine
311                    .search_symbols(&query.pattern)
312                    .await
313                    .map_err(|e| ToolError::InvalidQuery(format!("Symbol search error: {}", e)))?;
314
315                for s in symbols {
316                    if PathBuf::from(&s.location.file_path) == file_path {
317                        matches.push(TsMatch {
318                            file: file_path.display().to_string(),
319                            line: s.location.start_line,
320                            column: s.location.start_column,
321                            end_line: s.location.end_line,
322                            end_column: s.location.end_column,
323                            matched_text: s.name.clone(),
324                            node_kind: format!("{:?}", s.kind),
325                            context: Some(s.signature),
326                        });
327                    }
328                }
329            }
330        }
331
332        Ok(matches)
333    }
334
335    /// Infer query type from pattern string
336    fn infer_query_type(&self, pattern: &str) -> Option<QueryType> {
337        if pattern.starts_with("function ") || pattern.contains("function") {
338            Some(QueryType::Functions)
339        } else if pattern.starts_with("class ") || pattern.contains("class") {
340            Some(QueryType::Classes)
341        } else if pattern.starts_with("import ") || pattern.contains("import") {
342            Some(QueryType::Imports)
343        } else if pattern.starts_with("method ") || pattern.contains("method") {
344            Some(QueryType::Methods)
345        } else {
346            None
347        }
348    }
349
350    /// Get a structured query using the new query library
351    fn get_structured_query(
352        &self,
353        language: agcodex_ast::Language,
354        query_type: QueryType,
355    ) -> Result<Arc<CompiledQuery>, ToolError> {
356        self.query_library
357            .get_query(language, query_type)
358            .map_err(|e| ToolError::InvalidQuery(format!("Query library error: {}", e)))
359    }
360
361    /// Convert a simple pattern to tree-sitter query syntax (legacy support)
362    fn convert_pattern_to_query(&self, pattern: &str) -> String {
363        // This is a simplified conversion - in practice, you'd want more sophisticated parsing
364        // For now, we'll handle common patterns
365
366        if pattern.starts_with("function ") {
367            let func_name = pattern.trim_start_matches("function ").trim();
368            if func_name == "*" {
369                // Match all functions
370                "[
371                    (function_declaration) @func
372                    (function_definition) @func
373                    (method_declaration) @func
374                    (method_definition) @func
375                ]"
376                .to_string()
377            } else {
378                // Match specific function name
379                format!(
380                    "[
381                        (function_declaration name: (identifier) @name (#eq? @name \"{}\"))
382                        (function_definition name: (identifier) @name (#eq? @name \"{}\"))
383                        (method_declaration name: (identifier) @name (#eq? @name \"{}\"))
384                        (method_definition name: (identifier) @name (#eq? @name \"{}\"))
385                    ] @func",
386                    func_name, func_name, func_name, func_name
387                )
388            }
389        } else if pattern.starts_with("class ") {
390            let class_name = pattern.trim_start_matches("class ").trim();
391            if class_name == "*" {
392                "[
393                    (class_declaration) @class
394                    (class_definition) @class
395                ] @class"
396                    .to_string()
397            } else {
398                format!(
399                    "[
400                        (class_declaration name: (identifier) @name (#eq? @name \"{}\"))
401                        (class_definition name: (identifier) @name (#eq? @name \"{}\"))
402                    ] @class",
403                    class_name, class_name
404                )
405            }
406        } else if pattern.starts_with("import ") {
407            "[
408                (import_statement) @import
409                (import_declaration) @import
410                (use_declaration) @import
411            ] @import"
412                .to_string()
413        } else {
414            // Default: try to match as identifier
415            format!("(identifier) @id (#eq? @id \"{}\")", pattern)
416        }
417    }
418
419    /// Execute a structured query using the query library
420    pub async fn search_structured(
421        &self,
422        language: agcodex_ast::Language,
423        query_type: QueryType,
424        files: Vec<PathBuf>,
425    ) -> Result<Vec<TsMatch>, ToolError> {
426        let compiled_query = self.get_structured_query(language, query_type)?;
427        let mut all_matches = Vec::new();
428
429        for file_path in &files {
430            // Parse the file using AstEngine
431            let ast = self
432                .engine
433                .parse_file(file_path)
434                .await
435                .map_err(|e| ToolError::InvalidQuery(format!("Parse error: {}", e)))?;
436
437            // Skip if language doesn't match
438            if ast.language != language {
439                continue;
440            }
441
442            let source = ast.source.as_bytes();
443            let query_matches =
444                self.query_engine
445                    .execute_query(&compiled_query.query, &ast, source);
446
447            for qm in query_matches {
448                all_matches.push(TsMatch {
449                    file: file_path.display().to_string(),
450                    line: qm.start_position.0 + 1,
451                    column: qm.start_position.1,
452                    end_line: qm.end_position.0 + 1,
453                    end_column: qm.end_position.1,
454                    matched_text: qm.text.clone(),
455                    node_kind: qm.node_kind,
456                    context: Some(self.extract_context(
457                        &ast.source,
458                        qm.start_position.0,
459                        qm.end_position.0,
460                        2,
461                    )),
462                });
463            }
464        }
465
466        Ok(all_matches)
467    }
468
469    /// Get query library statistics
470    pub fn query_stats(&self) -> crate::code_tools::queries::QueryLibraryStats {
471        self.query_library.stats()
472    }
473
474    /// Check if a language supports a specific query type
475    pub fn supports_query(&self, language: agcodex_ast::Language, query_type: &QueryType) -> bool {
476        self.query_library.supports_query(language, query_type)
477    }
478
479    async fn search_async(&self, mut query: TsQuery) -> Result<Vec<TsMatch>, ToolError> {
480        // Find target files if not specified
481        if query.files.is_empty() {
482            query.files = self.find_target_files(&query)?;
483        }
484
485        let mut all_matches = Vec::new();
486
487        for file_path in &query.files {
488            // Parse the file using AstEngine
489            let ast = self
490                .engine
491                .parse_file(file_path)
492                .await
493                .map_err(|e| ToolError::InvalidQuery(format!("Parse error: {}", e)))?;
494
495            // Execute search within the tree
496            let matches = self.search_in_tree(&ast, file_path, &query).await?;
497            all_matches.extend(matches);
498        }
499
500        Ok(all_matches)
501    }
502}
503
504impl CodeTool for TreeSitterTool {
505    type Query = TsQuery;
506    type Output = Vec<TsMatch>;
507
508    fn search(&self, query: Self::Query) -> Result<Self::Output, ToolError> {
509        self.runtime.block_on(self.search_async(query))
510    }
511}
512
513impl Default for TreeSitterTool {
514    fn default() -> Self {
515        Self::new()
516    }
517}