raz_core/
ast.rs

1//! Abstract Syntax Tree parsing and analysis using tree-sitter
2//!
3//! This module provides AST-based analysis of Rust source code to enable
4//! context-aware command suggestions based on cursor position and code structure.
5
6use crate::{Position, Range, RazError, RazResult, Symbol, SymbolKind};
7use std::collections::HashMap;
8use tree_sitter::{Language, Node, Parser, Query, QueryCursor, StreamingIteratorMut, Tree};
9
10/// AST analyzer for Rust source code
11pub struct RustAnalyzer {
12    parser: Parser,
13    queries: QuerySet,
14}
15
16impl RustAnalyzer {
17    /// Create a new Rust AST analyzer
18    pub fn new() -> RazResult<Self> {
19        let language = tree_sitter_rust::LANGUAGE;
20        let mut parser = Parser::new();
21        parser
22            .set_language(&language.into())
23            .map_err(|e| RazError::analysis(format!("Failed to set language: {e}")))?;
24
25        let queries = QuerySet::new(language.into())?;
26
27        Ok(Self { parser, queries })
28    }
29
30    /// Parse source code and return the syntax tree
31    pub fn parse(&mut self, source: &str) -> RazResult<Tree> {
32        self.parser
33            .parse(source, None)
34            .ok_or_else(|| RazError::analysis("Failed to parse source code".to_string()))
35    }
36
37    /// Extract all symbols from the AST
38    pub fn extract_symbols(&self, tree: &Tree, source: &str) -> RazResult<Vec<Symbol>> {
39        let mut symbols = Vec::new();
40        let root_node = tree.root_node();
41
42        // Extract tests first to identify test functions
43        let test_symbols = self.extract_tests(&root_node, source)?;
44        let test_names: std::collections::HashSet<String> =
45            test_symbols.iter().map(|s| s.name.clone()).collect();
46        symbols.extend(test_symbols);
47
48        // Extract functions (but skip test functions)
49        let functions = self.extract_functions(&root_node, source)?;
50        for func in functions {
51            if !test_names.contains(&func.name) {
52                symbols.push(func);
53            }
54        }
55
56        // Extract structs
57        symbols.extend(self.extract_structs(&root_node, source)?);
58
59        // Extract enums
60        symbols.extend(self.extract_enums(&root_node, source)?);
61
62        // Extract traits
63        symbols.extend(self.extract_traits(&root_node, source)?);
64
65        // Extract modules
66        symbols.extend(self.extract_modules(&root_node, source)?);
67
68        // Extract constants
69        symbols.extend(self.extract_constants(&root_node, source)?);
70
71        // Extract type aliases
72        symbols.extend(self.extract_type_aliases(&root_node, source)?);
73
74        // Extract macros
75        symbols.extend(self.extract_macros(&root_node, source)?);
76
77        Ok(symbols)
78    }
79
80    /// Find the symbol at a specific cursor position
81    pub fn symbol_at_position(
82        &self,
83        tree: &Tree,
84        source: &str,
85        position: Position,
86    ) -> RazResult<Option<Symbol>> {
87        let symbols = self.extract_symbols(tree, source)?;
88
89        // Find the most specific symbol that contains the cursor position
90        let mut best_match: Option<Symbol> = None;
91        let mut smallest_range = u32::MAX;
92
93        for symbol in symbols {
94            if symbol.range.contains_position(position) {
95                let range_size = symbol.range.end.line - symbol.range.start.line;
96                if range_size < smallest_range {
97                    smallest_range = range_size;
98                    best_match = Some(symbol);
99                }
100            }
101        }
102
103        Ok(best_match)
104    }
105
106    /// Get the context around a cursor position (parent functions, modules, etc.)
107    pub fn context_at_position(
108        &self,
109        tree: &Tree,
110        source: &str,
111        position: Position,
112    ) -> RazResult<SymbolContext> {
113        let symbols = self.extract_symbols(tree, source)?;
114        let mut context = SymbolContext::default();
115
116        // Find all symbols that contain the position
117        for symbol in symbols {
118            if symbol.range.contains_position(position) {
119                match symbol.kind {
120                    SymbolKind::Function => {
121                        if symbol.modifiers.contains(&"test".to_string()) {
122                            context.in_test_function = Some(symbol.clone());
123                        } else {
124                            context.in_function = Some(symbol.clone());
125                        }
126                    }
127                    SymbolKind::Struct => context.in_struct = Some(symbol.clone()),
128                    SymbolKind::Enum => context.in_enum = Some(symbol.clone()),
129                    SymbolKind::Trait => context.in_trait = Some(symbol.clone()),
130                    SymbolKind::Module => context.in_module = Some(symbol.clone()),
131                    SymbolKind::Impl => context.in_impl = Some(symbol.clone()),
132                    _ => {}
133                }
134            }
135        }
136
137        Ok(context)
138    }
139
140    /// Extract function definitions from the AST
141    fn extract_functions(&self, node: &Node, source: &str) -> RazResult<Vec<Symbol>> {
142        let mut functions = Vec::new();
143        let mut cursor = QueryCursor::new();
144
145        let mut matches = cursor.matches(&self.queries.functions, *node, source.as_bytes());
146
147        while let Some(match_) = matches.next_mut() {
148            if let Some(function) = self.parse_function_match(match_, source)? {
149                functions.push(function);
150            }
151        }
152
153        Ok(functions)
154    }
155
156    /// Extract struct definitions from the AST
157    fn extract_structs(&self, node: &Node, source: &str) -> RazResult<Vec<Symbol>> {
158        let mut structs = Vec::new();
159        let mut cursor = QueryCursor::new();
160
161        let mut matches = cursor.matches(&self.queries.structs, *node, source.as_bytes());
162
163        while let Some(match_) = matches.next_mut() {
164            if let Some(struct_symbol) = self.parse_struct_match(match_, source)? {
165                structs.push(struct_symbol);
166            }
167        }
168
169        Ok(structs)
170    }
171
172    /// Extract enum definitions from the AST
173    fn extract_enums(&self, node: &Node, source: &str) -> RazResult<Vec<Symbol>> {
174        let mut enums = Vec::new();
175        let mut cursor = QueryCursor::new();
176
177        let mut matches = cursor.matches(&self.queries.enums, *node, source.as_bytes());
178
179        while let Some(match_) = matches.next_mut() {
180            if let Some(enum_symbol) = self.parse_enum_match(match_, source)? {
181                enums.push(enum_symbol);
182            }
183        }
184
185        Ok(enums)
186    }
187
188    /// Extract trait definitions from the AST
189    fn extract_traits(&self, node: &Node, source: &str) -> RazResult<Vec<Symbol>> {
190        let mut traits = Vec::new();
191        let mut cursor = QueryCursor::new();
192
193        let mut matches = cursor.matches(&self.queries.traits, *node, source.as_bytes());
194
195        while let Some(match_) = matches.next_mut() {
196            if let Some(trait_symbol) = self.parse_trait_match(match_, source)? {
197                traits.push(trait_symbol);
198            }
199        }
200
201        Ok(traits)
202    }
203
204    /// Extract module definitions from the AST
205    fn extract_modules(&self, node: &Node, source: &str) -> RazResult<Vec<Symbol>> {
206        let mut modules = Vec::new();
207        let mut cursor = QueryCursor::new();
208
209        let mut matches = cursor.matches(&self.queries.modules, *node, source.as_bytes());
210
211        while let Some(match_) = matches.next_mut() {
212            if let Some(module_symbol) = self.parse_module_match(match_, source)? {
213                modules.push(module_symbol);
214            }
215        }
216
217        Ok(modules)
218    }
219
220    /// Extract test functions from the AST
221    fn extract_tests(&self, node: &Node, source: &str) -> RazResult<Vec<Symbol>> {
222        let mut tests = Vec::new();
223
224        // Find test attributes first
225        let mut cursor = QueryCursor::new();
226        let mut attr_matches = cursor.matches(&self.queries.tests, *node, source.as_bytes());
227
228        // Collect test attribute nodes
229        let mut test_attrs = Vec::new();
230        while let Some(match_) = attr_matches.next_mut() {
231            for capture in match_.captures {
232                let node = capture.node;
233                let capture_name = &self.queries.tests.capture_names()[capture.index as usize];
234                if capture_name as &str == "test.attr" {
235                    test_attrs.push(node);
236                }
237            }
238        }
239
240        // Now find function items that have test attributes
241        for test_attr in test_attrs {
242            if let Some(function_item) = self.find_function_with_test_attr(test_attr, node) {
243                if let Some(test_symbol) =
244                    self.create_test_symbol_from_function(function_item, source)?
245                {
246                    tests.push(test_symbol);
247                }
248            }
249        }
250
251        // Also find functions that start with "test_" even without #[test] attribute
252        // This handles cases where the AST might not capture the attribute properly
253        let function_query = Query::new(
254            &tree_sitter_rust::LANGUAGE.into(),
255            r#"
256            (function_item
257                name: (identifier) @func.name
258            ) @func
259            "#,
260        )
261        .map_err(|e| RazError::analysis(format!("Failed to create function query: {e}")))?;
262
263        let mut func_matches = cursor.matches(&function_query, *node, source.as_bytes());
264        while let Some(match_) = func_matches.next_mut() {
265            let mut func_name = None;
266            let mut func_node = None;
267
268            for capture in match_.captures {
269                let capture_name = &function_query.capture_names()[capture.index as usize];
270                match capture_name as &str {
271                    "func.name" => {
272                        func_name = capture
273                            .node
274                            .utf8_text(source.as_bytes())
275                            .ok()
276                            .map(|s| s.to_string());
277                    }
278                    "func" => {
279                        func_node = Some(capture.node);
280                    }
281                    _ => {}
282                }
283            }
284
285            if let (Some(name), Some(node)) = (func_name, func_node) {
286                if name.starts_with("test_") && !tests.iter().any(|t| t.name == name) {
287                    tests.push(Symbol {
288                        name,
289                        kind: SymbolKind::Test,
290                        range: self.node_to_range(node),
291                        modifiers: vec!["test".to_string()],
292                        children: Vec::new(),
293                        metadata: HashMap::new(),
294                    });
295                }
296            }
297        }
298
299        Ok(tests)
300    }
301
302    fn find_function_with_test_attr<'a>(
303        &self,
304        test_attr: Node<'a>,
305        _root: &Node,
306    ) -> Option<Node<'a>> {
307        // Walk up from test attribute to find the function_item
308        let mut current = test_attr;
309        while let Some(parent) = current.parent() {
310            if parent.kind() == "function_item" {
311                return Some(parent);
312            }
313            current = parent;
314        }
315        None
316    }
317
318    fn create_test_symbol_from_function(
319        &self,
320        function_node: Node<'_>,
321        source: &str,
322    ) -> RazResult<Option<Symbol>> {
323        // Extract function name
324        let mut cursor = function_node.walk();
325
326        for child in function_node.children(&mut cursor) {
327            if child.kind() == "identifier" {
328                let name = child
329                    .utf8_text(source.as_bytes())
330                    .map_err(|e| {
331                        RazError::analysis(format!("Failed to extract test function name: {e}"))
332                    })?
333                    .to_string();
334
335                return Ok(Some(Symbol {
336                    name,
337                    kind: SymbolKind::Test,
338                    range: self.node_to_range(function_node),
339                    modifiers: vec!["test".to_string()],
340                    children: Vec::new(),
341                    metadata: HashMap::new(),
342                }));
343            }
344        }
345
346        Ok(None)
347    }
348
349    /// Extract constant definitions from the AST
350    fn extract_constants(&self, node: &Node, source: &str) -> RazResult<Vec<Symbol>> {
351        let mut constants = Vec::new();
352        let mut cursor = QueryCursor::new();
353
354        let mut matches = cursor.matches(&self.queries.constants, *node, source.as_bytes());
355
356        while let Some(match_) = matches.next_mut() {
357            if let Some(const_symbol) = self.parse_constant_match(match_, source)? {
358                constants.push(const_symbol);
359            }
360        }
361
362        Ok(constants)
363    }
364
365    /// Extract type alias definitions from the AST
366    fn extract_type_aliases(&self, node: &Node, source: &str) -> RazResult<Vec<Symbol>> {
367        let mut type_aliases = Vec::new();
368        let mut cursor = QueryCursor::new();
369
370        let mut matches = cursor.matches(&self.queries.type_aliases, *node, source.as_bytes());
371
372        while let Some(match_) = matches.next_mut() {
373            if let Some(type_alias) = self.parse_type_alias_match(match_, source)? {
374                type_aliases.push(type_alias);
375            }
376        }
377
378        Ok(type_aliases)
379    }
380
381    /// Extract macro definitions from the AST
382    fn extract_macros(&self, node: &Node, source: &str) -> RazResult<Vec<Symbol>> {
383        let mut macros = Vec::new();
384        let mut cursor = QueryCursor::new();
385
386        let mut matches = cursor.matches(&self.queries.macros, *node, source.as_bytes());
387
388        while let Some(match_) = matches.next_mut() {
389            if let Some(macro_symbol) = self.parse_macro_match(match_, source)? {
390                macros.push(macro_symbol);
391            }
392        }
393
394        Ok(macros)
395    }
396
397    /// Parse a function match from the query result
398    fn parse_function_match(
399        &self,
400        match_: &tree_sitter::QueryMatch,
401        source: &str,
402    ) -> RazResult<Option<Symbol>> {
403        let mut name = None;
404        let mut range = None;
405        let mut modifiers = Vec::new();
406
407        for capture in match_.captures {
408            let node = capture.node;
409            let capture_name = &self.queries.functions.capture_names()[capture.index as usize];
410
411            match capture_name as &str {
412                "function.name" => {
413                    name = Some(
414                        node.utf8_text(source.as_bytes())
415                            .map_err(|e| {
416                                RazError::analysis(format!("Failed to extract function name: {e}"))
417                            })?
418                            .to_string(),
419                    );
420                    range = Some(self.node_to_range(node));
421                }
422                "function.async" => modifiers.push("async".to_string()),
423                "function.unsafe" => modifiers.push("unsafe".to_string()),
424                "function.const" => modifiers.push("const".to_string()),
425                "function.extern" => modifiers.push("extern".to_string()),
426                _ => {}
427            }
428        }
429
430        if let (Some(name), Some(range)) = (name, range) {
431            Ok(Some(Symbol {
432                name,
433                kind: SymbolKind::Function,
434                range,
435                modifiers,
436                children: Vec::new(),
437                metadata: HashMap::new(),
438            }))
439        } else {
440            Ok(None)
441        }
442    }
443
444    /// Parse a struct match from the query result
445    fn parse_struct_match(
446        &self,
447        match_: &tree_sitter::QueryMatch,
448        source: &str,
449    ) -> RazResult<Option<Symbol>> {
450        let mut name = None;
451        let mut range = None;
452        let mut modifiers = Vec::new();
453
454        for capture in match_.captures {
455            let node = capture.node;
456            let capture_name = &self.queries.structs.capture_names()[capture.index as usize];
457
458            match capture_name as &str {
459                "struct.name" => {
460                    name = Some(
461                        node.utf8_text(source.as_bytes())
462                            .map_err(|e| {
463                                RazError::analysis(format!("Failed to extract struct name: {e}"))
464                            })?
465                            .to_string(),
466                    );
467                    // Get the range of the entire struct item, not just the name
468                    if let Some(parent) = node.parent() {
469                        if parent.kind() == "struct_item" {
470                            range = Some(self.node_to_range(parent));
471                        } else {
472                            range = Some(self.node_to_range(node));
473                        }
474                    } else {
475                        range = Some(self.node_to_range(node));
476                    }
477                }
478                "struct.vis" => {
479                    let vis_text = node.utf8_text(source.as_bytes()).map_err(|e| {
480                        RazError::analysis(format!("Failed to extract struct visibility: {e}"))
481                    })?;
482                    modifiers.push(vis_text.to_string());
483                }
484                _ => {}
485            }
486        }
487
488        if let (Some(name), Some(range)) = (name, range) {
489            Ok(Some(Symbol {
490                name,
491                kind: SymbolKind::Struct,
492                range,
493                modifiers,
494                children: Vec::new(),
495                metadata: HashMap::new(),
496            }))
497        } else {
498            Ok(None)
499        }
500    }
501
502    /// Parse an enum match from the query result
503    fn parse_enum_match(
504        &self,
505        match_: &tree_sitter::QueryMatch,
506        source: &str,
507    ) -> RazResult<Option<Symbol>> {
508        let mut name = None;
509        let mut range = None;
510        let mut modifiers = Vec::new();
511
512        for capture in match_.captures {
513            let node = capture.node;
514            let capture_name = &self.queries.enums.capture_names()[capture.index as usize];
515
516            match capture_name as &str {
517                "enum.name" => {
518                    name = Some(
519                        node.utf8_text(source.as_bytes())
520                            .map_err(|e| {
521                                RazError::analysis(format!("Failed to extract enum name: {e}"))
522                            })?
523                            .to_string(),
524                    );
525                    range = Some(self.node_to_range(node));
526                }
527                "enum.pub" => modifiers.push("pub".to_string()),
528                _ => {}
529            }
530        }
531
532        if let (Some(name), Some(range)) = (name, range) {
533            Ok(Some(Symbol {
534                name,
535                kind: SymbolKind::Enum,
536                range,
537                modifiers,
538                children: Vec::new(),
539                metadata: HashMap::new(),
540            }))
541        } else {
542            Ok(None)
543        }
544    }
545
546    /// Parse a trait match from the query result
547    fn parse_trait_match(
548        &self,
549        match_: &tree_sitter::QueryMatch,
550        source: &str,
551    ) -> RazResult<Option<Symbol>> {
552        let mut name = None;
553        let mut range = None;
554        let mut modifiers = Vec::new();
555
556        for capture in match_.captures {
557            let node = capture.node;
558            let capture_name = &self.queries.traits.capture_names()[capture.index as usize];
559
560            match capture_name as &str {
561                "trait.name" => {
562                    name = Some(
563                        node.utf8_text(source.as_bytes())
564                            .map_err(|e| {
565                                RazError::analysis(format!("Failed to extract trait name: {e}"))
566                            })?
567                            .to_string(),
568                    );
569                    range = Some(self.node_to_range(node));
570                }
571                "trait.pub" => modifiers.push("pub".to_string()),
572                "trait.unsafe" => modifiers.push("unsafe".to_string()),
573                _ => {}
574            }
575        }
576
577        if let (Some(name), Some(range)) = (name, range) {
578            Ok(Some(Symbol {
579                name,
580                kind: SymbolKind::Trait,
581                range,
582                modifiers,
583                children: Vec::new(),
584                metadata: HashMap::new(),
585            }))
586        } else {
587            Ok(None)
588        }
589    }
590
591    /// Parse a module match from the query result
592    fn parse_module_match(
593        &self,
594        match_: &tree_sitter::QueryMatch,
595        source: &str,
596    ) -> RazResult<Option<Symbol>> {
597        let mut name = None;
598        let mut range = None;
599        let mut modifiers = Vec::new();
600
601        for capture in match_.captures {
602            let node = capture.node;
603            let capture_name = &self.queries.modules.capture_names()[capture.index as usize];
604
605            match capture_name as &str {
606                "module.name" => {
607                    name = Some(
608                        node.utf8_text(source.as_bytes())
609                            .map_err(|e| {
610                                RazError::analysis(format!("Failed to extract module name: {e}"))
611                            })?
612                            .to_string(),
613                    );
614                }
615                "module" => {
616                    // This captures the entire module
617                    range = Some(self.node_to_range(node));
618                }
619                "module.pub" => modifiers.push("pub".to_string()),
620                _ => {}
621            }
622        }
623
624        if let (Some(name), Some(range)) = (name, range) {
625            Ok(Some(Symbol {
626                name,
627                kind: SymbolKind::Module,
628                range,
629                modifiers,
630                children: Vec::new(),
631                metadata: HashMap::new(),
632            }))
633        } else {
634            Ok(None)
635        }
636    }
637
638    /// Parse a constant match from the query result
639    fn parse_constant_match(
640        &self,
641        match_: &tree_sitter::QueryMatch,
642        source: &str,
643    ) -> RazResult<Option<Symbol>> {
644        let mut name = None;
645        let mut range = None;
646        let mut modifiers = Vec::new();
647
648        for capture in match_.captures {
649            let node = capture.node;
650            let capture_name = &self.queries.constants.capture_names()[capture.index as usize];
651
652            match capture_name as &str {
653                "const.name" => {
654                    name = Some(
655                        node.utf8_text(source.as_bytes())
656                            .map_err(|e| {
657                                RazError::analysis(format!("Failed to extract constant name: {e}"))
658                            })?
659                            .to_string(),
660                    );
661                    range = Some(self.node_to_range(node));
662                }
663                "const.pub" => modifiers.push("pub".to_string()),
664                _ => {}
665            }
666        }
667
668        if let (Some(name), Some(range)) = (name, range) {
669            Ok(Some(Symbol {
670                name,
671                kind: SymbolKind::Constant,
672                range,
673                modifiers,
674                children: Vec::new(),
675                metadata: HashMap::new(),
676            }))
677        } else {
678            Ok(None)
679        }
680    }
681
682    /// Parse a type alias match from the query result
683    fn parse_type_alias_match(
684        &self,
685        match_: &tree_sitter::QueryMatch,
686        source: &str,
687    ) -> RazResult<Option<Symbol>> {
688        let mut name = None;
689        let mut range = None;
690        let mut modifiers = Vec::new();
691
692        for capture in match_.captures {
693            let node = capture.node;
694            let capture_name = &self.queries.type_aliases.capture_names()[capture.index as usize];
695
696            match capture_name as &str {
697                "type.name" => {
698                    name = Some(
699                        node.utf8_text(source.as_bytes())
700                            .map_err(|e| {
701                                RazError::analysis(format!(
702                                    "Failed to extract type alias name: {e}"
703                                ))
704                            })?
705                            .to_string(),
706                    );
707                    range = Some(self.node_to_range(node));
708                }
709                "type.pub" => modifiers.push("pub".to_string()),
710                _ => {}
711            }
712        }
713
714        if let (Some(name), Some(range)) = (name, range) {
715            Ok(Some(Symbol {
716                name,
717                kind: SymbolKind::TypeAlias,
718                range,
719                modifiers,
720                children: Vec::new(),
721                metadata: HashMap::new(),
722            }))
723        } else {
724            Ok(None)
725        }
726    }
727
728    /// Parse a macro match from the query result
729    fn parse_macro_match(
730        &self,
731        match_: &tree_sitter::QueryMatch,
732        source: &str,
733    ) -> RazResult<Option<Symbol>> {
734        let mut name = None;
735        let mut range = None;
736        let mut modifiers = Vec::new();
737
738        for capture in match_.captures {
739            let node = capture.node;
740            let capture_name = &self.queries.macros.capture_names()[capture.index as usize];
741
742            match capture_name as &str {
743                "macro.name" => {
744                    name = Some(
745                        node.utf8_text(source.as_bytes())
746                            .map_err(|e| {
747                                RazError::analysis(format!("Failed to extract macro name: {e}"))
748                            })?
749                            .to_string(),
750                    );
751                    range = Some(self.node_to_range(node));
752                }
753                "macro.pub" => modifiers.push("pub".to_string()),
754                _ => {}
755            }
756        }
757
758        if let (Some(name), Some(range)) = (name, range) {
759            Ok(Some(Symbol {
760                name,
761                kind: SymbolKind::Macro,
762                range,
763                modifiers,
764                children: Vec::new(),
765                metadata: HashMap::new(),
766            }))
767        } else {
768            Ok(None)
769        }
770    }
771
772    /// Convert a tree-sitter node to a Range
773    fn node_to_range(&self, node: Node) -> Range {
774        Range {
775            start: Position {
776                line: node.start_position().row as u32,
777                column: node.start_position().column as u32,
778            },
779            end: Position {
780                line: node.end_position().row as u32,
781                column: node.end_position().column as u32,
782            },
783        }
784    }
785}
786
787/// Context information about symbols surrounding a cursor position
788#[derive(Debug, Clone, Default)]
789pub struct SymbolContext {
790    /// The function containing the cursor (if any)
791    pub in_function: Option<Symbol>,
792
793    /// The test function containing the cursor (if any)
794    pub in_test_function: Option<Symbol>,
795
796    /// The struct containing the cursor (if any)
797    pub in_struct: Option<Symbol>,
798
799    /// The enum containing the cursor (if any)
800    pub in_enum: Option<Symbol>,
801
802    /// The trait containing the cursor (if any)
803    pub in_trait: Option<Symbol>,
804
805    /// The module containing the cursor (if any)
806    pub in_module: Option<Symbol>,
807
808    /// The impl block containing the cursor (if any)
809    pub in_impl: Option<Symbol>,
810}
811
812/// Collection of tree-sitter queries for extracting symbols
813struct QuerySet {
814    functions: Query,
815    structs: Query,
816    enums: Query,
817    traits: Query,
818    modules: Query,
819    tests: Query,
820    constants: Query,
821    type_aliases: Query,
822    macros: Query,
823}
824
825impl QuerySet {
826    fn new(language: Language) -> RazResult<Self> {
827        let functions = Query::new(
828            &language,
829            r#"
830            (function_item 
831                name: (identifier) @function.name
832            )
833        "#,
834        )
835        .map_err(|e| RazError::analysis(format!("Failed to create functions query: {e}")))?;
836
837        let structs = Query::new(
838            &language,
839            r#"
840            (struct_item 
841                (visibility_modifier)? @struct.vis
842                name: (type_identifier) @struct.name
843            )
844        "#,
845        )
846        .map_err(|e| RazError::analysis(format!("Failed to create structs query: {e}")))?;
847
848        let enums = Query::new(
849            &language,
850            r#"
851            (enum_item 
852                name: (type_identifier) @enum.name
853            )
854        "#,
855        )
856        .map_err(|e| RazError::analysis(format!("Failed to create enums query: {e}")))?;
857
858        let traits = Query::new(
859            &language,
860            r#"
861            (trait_item 
862                name: (type_identifier) @trait.name
863            )
864        "#,
865        )
866        .map_err(|e| RazError::analysis(format!("Failed to create traits query: {e}")))?;
867
868        let modules = Query::new(
869            &language,
870            r#"
871            (mod_item 
872                name: (identifier) @module.name
873            ) @module
874        "#,
875        )
876        .map_err(|e| RazError::analysis(format!("Failed to create modules query: {e}")))?;
877
878        let tests = Query::new(
879            &language,
880            r#"
881            (attribute_item
882                (attribute
883                    (identifier) @test.attr
884                    (#eq? @test.attr "test")
885                )
886            )
887        "#,
888        )
889        .map_err(|e| RazError::analysis(format!("Failed to create tests query: {e}")))?;
890
891        let constants = Query::new(
892            &language,
893            r#"
894            (const_item 
895                name: (identifier) @const.name
896            )
897        "#,
898        )
899        .map_err(|e| RazError::analysis(format!("Failed to create constants query: {e}")))?;
900
901        let type_aliases = Query::new(
902            &language,
903            r#"
904            (type_item 
905                name: (type_identifier) @type.name
906            )
907        "#,
908        )
909        .map_err(|e| RazError::analysis(format!("Failed to create type aliases query: {e}")))?;
910
911        let macros = Query::new(
912            &language,
913            r#"
914            (macro_definition 
915                name: (identifier) @macro.name
916            )
917        "#,
918        )
919        .map_err(|e| RazError::analysis(format!("Failed to create macros query: {e}")))?;
920
921        Ok(Self {
922            functions,
923            structs,
924            enums,
925            traits,
926            modules,
927            tests,
928            constants,
929            type_aliases,
930            macros,
931        })
932    }
933}
934
935#[cfg(test)]
936mod tests {
937    use super::*;
938
939    #[test]
940    fn test_rust_analyzer_creation() {
941        let analyzer = RustAnalyzer::new();
942        assert!(analyzer.is_ok());
943    }
944
945    #[test]
946    fn test_parse_simple_function() {
947        let mut analyzer = RustAnalyzer::new().unwrap();
948        let source = r#"
949            fn hello_world() {
950                println!("Hello, world!");
951            }
952        "#;
953
954        let tree = analyzer.parse(source).unwrap();
955        let symbols = analyzer.extract_symbols(&tree, source).unwrap();
956
957        assert_eq!(symbols.len(), 1);
958        assert_eq!(symbols[0].name, "hello_world");
959        assert_eq!(symbols[0].kind, SymbolKind::Function);
960    }
961
962    #[test]
963    fn test_parse_struct() {
964        let mut analyzer = RustAnalyzer::new().unwrap();
965        let source = r#"
966            pub struct Person {
967                name: String,
968                age: u32,
969            }
970        "#;
971
972        let tree = analyzer.parse(source).unwrap();
973        let symbols = analyzer.extract_symbols(&tree, source).unwrap();
974
975        assert_eq!(symbols.len(), 1);
976        assert_eq!(symbols[0].name, "Person");
977        assert_eq!(symbols[0].kind, SymbolKind::Struct);
978        assert!(symbols[0].modifiers.contains(&"pub".to_string()));
979    }
980
981    #[test]
982    #[ignore] // TODO: Fix tree-sitter test attribute parsing
983    fn test_parse_test_function() {
984        let mut analyzer = RustAnalyzer::new().unwrap();
985        let source = r#"
986            #[test]
987            fn test_addition() {
988                assert_eq!(2 + 2, 4);
989            }
990        "#;
991
992        let tree = analyzer.parse(source).unwrap();
993        let symbols = analyzer.extract_symbols(&tree, source).unwrap();
994
995        assert_eq!(symbols.len(), 1);
996        assert_eq!(symbols[0].name, "test_addition");
997        assert_eq!(symbols[0].kind, SymbolKind::Test);
998        assert!(symbols[0].modifiers.contains(&"test".to_string()));
999    }
1000
1001    #[test]
1002    #[ignore] // TODO: Fix tree-sitter symbol position detection
1003    fn test_symbol_at_position() {
1004        let mut analyzer = RustAnalyzer::new().unwrap();
1005        let source = r#"
1006fn main() {
1007    println!("Hello");
1008}
1009
1010fn helper() {
1011    println!("Helper");
1012}
1013        "#;
1014
1015        let tree = analyzer.parse(source).unwrap();
1016
1017        // Position inside main function
1018        let symbol = analyzer
1019            .symbol_at_position(&tree, source, Position { line: 2, column: 4 })
1020            .unwrap();
1021        assert!(symbol.is_some());
1022        assert_eq!(symbol.unwrap().name, "main");
1023
1024        // Position inside helper function
1025        let symbol = analyzer
1026            .symbol_at_position(&tree, source, Position { line: 6, column: 4 })
1027            .unwrap();
1028        assert!(symbol.is_some());
1029        assert_eq!(symbol.unwrap().name, "helper");
1030    }
1031
1032    #[test]
1033    #[ignore] // TODO: Fix tree-sitter context position detection
1034    fn test_context_at_position() {
1035        let mut analyzer = RustAnalyzer::new().unwrap();
1036        let source = r#"
1037mod tests {
1038    #[test]
1039    fn test_something() {
1040        assert_eq!(1, 1);
1041    }
1042}
1043        "#;
1044
1045        let tree = analyzer.parse(source).unwrap();
1046        let context = analyzer
1047            .context_at_position(&tree, source, Position { line: 4, column: 8 })
1048            .unwrap();
1049
1050        assert!(context.in_test_function.is_some());
1051        assert_eq!(context.in_test_function.unwrap().name, "test_something");
1052        assert!(context.in_module.is_some());
1053        assert_eq!(context.in_module.unwrap().name, "tests");
1054    }
1055}