agcodex_ast/
semantic_index.rs

1//! Semantic indexing for symbols, relationships, and call graphs
2
3use crate::error::AstError;
4use crate::error::AstResult;
5use crate::types::AstNodeKind;
6use crate::types::ParsedAst;
7use crate::types::SourceLocation;
8use crate::types::Visibility;
9use dashmap::DashMap;
10// use std::collections::{HashMap, HashSet}; // unused
11use std::path::Path;
12use std::path::PathBuf;
13use tree_sitter::Node;
14use tree_sitter::TreeCursor;
15
16/// Symbol in the codebase
17#[derive(Debug, Clone)]
18pub struct Symbol {
19    pub name: String,
20    pub kind: SymbolKind,
21    pub location: SourceLocation,
22    pub visibility: Visibility,
23    pub signature: String,
24    pub documentation: Option<String>,
25    pub references: Vec<SourceLocation>,
26    pub definitions: Vec<SourceLocation>,
27    pub call_sites: Vec<SourceLocation>,
28}
29
30/// Symbol kind classification
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub enum SymbolKind {
33    Function,
34    Method,
35    Class,
36    Struct,
37    Enum,
38    Interface,
39    Trait,
40    Module,
41    Variable,
42    Constant,
43    Type,
44    Property,
45    Field,
46    Parameter,
47}
48
49impl SymbolKind {
50    /// Convert from AST node kind
51    pub const fn from_ast_kind(kind: AstNodeKind) -> Self {
52        match kind {
53            AstNodeKind::Function => Self::Function,
54            AstNodeKind::Class => Self::Class,
55            AstNodeKind::Struct => Self::Struct,
56            AstNodeKind::Enum => Self::Enum,
57            AstNodeKind::Interface => Self::Interface,
58            AstNodeKind::Trait => Self::Trait,
59            AstNodeKind::Module => Self::Module,
60            AstNodeKind::Variable => Self::Variable,
61            AstNodeKind::Constant => Self::Constant,
62            AstNodeKind::Type => Self::Type,
63            _ => Self::Variable,
64        }
65    }
66}
67
68/// Semantic index for code intelligence
69#[derive(Debug)]
70pub struct SemanticIndex {
71    /// Symbol table: symbol_id -> Symbol
72    symbols: DashMap<String, Symbol>,
73    /// File index: file_path -> symbol_ids
74    file_symbols: DashMap<PathBuf, Vec<String>>,
75    /// Call graph: caller_id -> [callee_ids]
76    call_graph: DashMap<String, Vec<String>>,
77    /// Inheritance graph: parent_id -> [child_ids]
78    inheritance_graph: DashMap<String, Vec<String>>,
79    /// Import graph: file_path -> [imported_files]
80    import_graph: DashMap<PathBuf, Vec<PathBuf>>,
81}
82
83impl SemanticIndex {
84    /// Create a new semantic index
85    pub fn new() -> Self {
86        Self {
87            symbols: DashMap::new(),
88            file_symbols: DashMap::new(),
89            call_graph: DashMap::new(),
90            inheritance_graph: DashMap::new(),
91            import_graph: DashMap::new(),
92        }
93    }
94
95    /// Index an AST for semantic information
96    pub fn index_ast(&mut self, path: &Path, ast: &ParsedAst) -> AstResult<()> {
97        let source = ast.source.as_bytes();
98        let mut cursor = ast.tree.root_node().walk();
99        let mut symbols = Vec::new();
100
101        // Extract symbols from AST
102        self.extract_symbols(&mut cursor, source, path, &mut symbols)?;
103
104        // Store file symbols
105        let symbol_ids: Vec<String> = symbols.iter().map(|s| self.get_symbol_id(s)).collect();
106        self.file_symbols.insert(path.to_path_buf(), symbol_ids);
107
108        // Insert symbols into index
109        for symbol in symbols {
110            let id = self.get_symbol_id(&symbol);
111            self.symbols.insert(id, symbol);
112        }
113
114        // Build relationships
115        self.build_relationships(path, ast)?;
116
117        Ok(())
118    }
119
120    /// Extract symbols from AST nodes
121    fn extract_symbols(
122        &self,
123        cursor: &mut TreeCursor,
124        source: &[u8],
125        file_path: &Path,
126        symbols: &mut Vec<Symbol>,
127    ) -> AstResult<()> {
128        let node = cursor.node();
129        let node_type = node.kind();
130        let _ast_kind = AstNodeKind::from_node_type(node_type);
131
132        // Check if this is a symbol definition
133        if self.is_symbol_definition(&node) {
134            let symbol = self.create_symbol(&node, source, file_path)?;
135            symbols.push(symbol);
136        }
137
138        // Recurse into children
139        if cursor.goto_first_child() {
140            loop {
141                self.extract_symbols(cursor, source, file_path, symbols)?;
142                if !cursor.goto_next_sibling() {
143                    break;
144                }
145            }
146            cursor.goto_parent();
147        }
148
149        Ok(())
150    }
151
152    /// Check if a node represents a symbol definition
153    fn is_symbol_definition(&self, node: &Node) -> bool {
154        let node_type = node.kind();
155        matches!(
156            node_type,
157            "function_declaration"
158                | "function_definition"
159                | "function_item"
160                | "method_declaration"
161                | "method_definition"
162                | "class_declaration"
163                | "class_definition"
164                | "struct_item"
165                | "struct_declaration"
166                | "enum_item"
167                | "enum_declaration"
168                | "interface_declaration"
169                | "protocol_declaration"
170                | "trait_item"
171                | "trait_declaration"
172                | "module"
173                | "module_declaration"
174                | "variable_declaration"
175                | "const_item"
176                | "let_declaration"
177                | "type_alias"
178                | "typedef"
179        )
180    }
181
182    /// Create a symbol from an AST node
183    fn create_symbol(&self, node: &Node, source: &[u8], file_path: &Path) -> AstResult<Symbol> {
184        let name = self.extract_symbol_name(node, source)?;
185        let kind = SymbolKind::from_ast_kind(AstNodeKind::from_node_type(node.kind()));
186
187        let location = SourceLocation::new(
188            file_path.display().to_string(),
189            node.start_position().row + 1,
190            node.start_position().column + 1,
191            node.end_position().row + 1,
192            node.end_position().column + 1,
193            (node.start_byte(), node.end_byte()),
194        );
195
196        let signature = self.extract_signature(node, source)?;
197        let visibility = self.extract_visibility(node, source);
198        let documentation = self.extract_documentation(node, source);
199
200        Ok(Symbol {
201            name,
202            kind,
203            location: location.clone(),
204            visibility,
205            signature,
206            documentation,
207            references: Vec::new(),
208            definitions: vec![location],
209            call_sites: Vec::new(),
210        })
211    }
212
213    /// Extract symbol name from node
214    fn extract_symbol_name(&self, node: &Node, source: &[u8]) -> AstResult<String> {
215        // Look for identifier child node
216        for i in 0..node.child_count() {
217            if let Some(child) = node.child(i)
218                && (child.kind() == "identifier" || child.kind() == "name")
219            {
220                let name = std::str::from_utf8(&source[child.byte_range()])
221                    .map_err(|e| AstError::ParserError(e.to_string()))?;
222                return Ok(name.to_string());
223            }
224        }
225
226        // Fallback: use first non-keyword text
227        let text = std::str::from_utf8(&source[node.byte_range()])
228            .map_err(|e| AstError::ParserError(e.to_string()))?;
229
230        // Extract name from text (simple heuristic)
231        let words: Vec<&str> = text.split_whitespace().collect();
232        for word in words {
233            if !Self::is_keyword(word) && word.chars().any(|c| c.is_alphabetic()) {
234                return Ok(word.to_string());
235            }
236        }
237
238        Ok("anonymous".to_string())
239    }
240
241    /// Check if a word is a keyword
242    fn is_keyword(word: &str) -> bool {
243        matches!(
244            word,
245            "fn" | "function"
246                | "def"
247                | "class"
248                | "struct"
249                | "enum"
250                | "interface"
251                | "trait"
252                | "impl"
253                | "module"
254                | "namespace"
255                | "const"
256                | "let"
257                | "var"
258                | "type"
259                | "public"
260                | "private"
261                | "protected"
262                | "static"
263                | "async"
264                | "export"
265                | "import"
266        )
267    }
268
269    /// Extract signature from node
270    fn extract_signature(&self, node: &Node, source: &[u8]) -> AstResult<String> {
271        // Get text up to body/block
272        let mut sig_end = node.end_byte();
273
274        for i in 0..node.child_count() {
275            if let Some(child) = node.child(i)
276                && (child.kind() == "block"
277                    || child.kind() == "compound_statement"
278                    || child.kind() == "function_body")
279            {
280                sig_end = child.start_byte();
281                break;
282            }
283        }
284
285        let signature = std::str::from_utf8(&source[node.start_byte()..sig_end])
286            .map_err(|e| AstError::ParserError(e.to_string()))?;
287
288        Ok(signature.trim().to_string())
289    }
290
291    /// Extract visibility from node
292    fn extract_visibility(&self, node: &Node, source: &[u8]) -> Visibility {
293        let text = std::str::from_utf8(&source[node.byte_range()]).unwrap_or("");
294        Visibility::from_text(text)
295    }
296
297    /// Extract documentation from node
298    fn extract_documentation(&self, node: &Node, source: &[u8]) -> Option<String> {
299        // Look for preceding comment node
300        if let Some(prev) = node.prev_sibling()
301            && prev.kind().contains("comment")
302        {
303            let doc = std::str::from_utf8(&source[prev.byte_range()]).ok()?;
304            return Some(doc.to_string());
305        }
306        None
307    }
308
309    /// Build relationships between symbols
310    const fn build_relationships(&mut self, _path: &Path, _ast: &ParsedAst) -> AstResult<()> {
311        // TODO: Implement call graph building
312        // TODO: Implement inheritance graph building
313        // TODO: Implement import graph building
314        Ok(())
315    }
316
317    /// Generate symbol ID
318    fn get_symbol_id(&self, symbol: &Symbol) -> String {
319        format!(
320            "{}:{}:{}",
321            symbol.location.file_path, symbol.kind as u8, symbol.name
322        )
323    }
324
325    /// Search for symbols by query
326    pub fn search(&self, query: &str) -> Vec<Symbol> {
327        let query_lower = query.to_lowercase();
328        let mut results = Vec::new();
329
330        for entry in self.symbols.iter() {
331            let symbol = entry.value();
332            if symbol.name.to_lowercase().contains(&query_lower) {
333                results.push(symbol.clone());
334            }
335        }
336
337        results
338    }
339
340    /// Get call graph for a function
341    pub fn get_call_graph(&self, path: &Path, function_name: &str) -> Vec<Symbol> {
342        // Find the function symbol
343        let symbol_id = format!(
344            "{}:{}:{}",
345            path.display(),
346            SymbolKind::Function as u8,
347            function_name
348        );
349
350        // Get callees
351        if let Some(callees) = self.call_graph.get(&symbol_id) {
352            let mut results = Vec::new();
353            for callee_id in callees.value() {
354                if let Some(symbol) = self.symbols.get(callee_id) {
355                    results.push(symbol.clone());
356                }
357            }
358            return results;
359        }
360
361        Vec::new()
362    }
363
364    /// Get symbols defined in a file
365    pub fn get_file_symbols(&self, path: &Path) -> Vec<Symbol> {
366        if let Some(symbol_ids) = self.file_symbols.get(&path.to_path_buf()) {
367            let mut results = Vec::new();
368            for id in symbol_ids.value() {
369                if let Some(symbol) = self.symbols.get(id) {
370                    results.push(symbol.clone());
371                }
372            }
373            return results;
374        }
375        Vec::new()
376    }
377
378    /// Clear the index
379    pub fn clear(&mut self) {
380        self.symbols.clear();
381        self.file_symbols.clear();
382        self.call_graph.clear();
383        self.inheritance_graph.clear();
384        self.import_graph.clear();
385    }
386}
387
388impl Default for SemanticIndex {
389    fn default() -> Self {
390        Self::new()
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::language_registry::Language;
398    use crate::language_registry::LanguageRegistry;
399
400    #[test]
401    fn test_semantic_indexing() {
402        let mut index = SemanticIndex::new();
403        let registry = LanguageRegistry::new();
404
405        let code = r#"
406pub fn calculate(x: i32, y: i32) -> i32 {
407    add(x, y)
408}
409
410fn add(a: i32, b: i32) -> i32 {
411    a + b
412}
413
414pub struct Calculator {
415    value: i32,
416}
417
418impl Calculator {
419    pub fn new() -> Self {
420        Self { value: 0 }
421    }
422}
423"#;
424
425        let ast = registry.parse(&Language::Rust, code).unwrap();
426        let path = Path::new("test.rs");
427
428        index.index_ast(path, &ast).unwrap();
429
430        // Search for symbols
431        let results = index.search("calc");
432        assert!(!results.is_empty());
433        assert!(results.iter().any(|s| s.name.contains("calculate")));
434
435        // Get file symbols
436        let file_symbols = index.get_file_symbols(path);
437        assert!(file_symbols.len() >= 3); // calculate, add, Calculator
438    }
439}