codeprism_core/linkers/
symbol_resolver.rs

1//! Symbol resolver for creating cross-file edges
2//!
3//! This module resolves imports, function calls, and other references across files
4//! to create a complete dependency graph after initial parsing.
5
6use crate::ast::{Edge, EdgeKind, Node, NodeId, NodeKind};
7use crate::error::Result;
8use crate::graph::GraphStore;
9use std::collections::HashMap;
10use std::path::Path;
11use std::sync::Arc;
12
13/// Symbol resolver for cross-file linking
14pub struct SymbolResolver {
15    graph: Arc<GraphStore>,
16    /// Index of importable symbols by module path
17    module_symbols: HashMap<String, Vec<NodeId>>,
18    /// Index of symbols by qualified name (module.symbol)
19    qualified_symbols: HashMap<String, NodeId>,
20    /// Import resolution cache
21    #[allow(dead_code)]
22    import_cache: HashMap<String, String>,
23}
24
25impl SymbolResolver {
26    /// Create a new symbol resolver
27    pub fn new(graph: Arc<GraphStore>) -> Self {
28        Self {
29            graph,
30            module_symbols: HashMap::new(),
31            qualified_symbols: HashMap::new(),
32            import_cache: HashMap::new(),
33        }
34    }
35
36    /// Resolve all cross-file relationships
37    pub fn resolve_all(&mut self) -> Result<Vec<Edge>> {
38        let mut new_edges = Vec::new();
39
40        // Build symbol indices
41        self.build_symbol_indices()?;
42
43        // Resolve imports
44        new_edges.extend(self.resolve_imports()?);
45
46        // Resolve function calls
47        new_edges.extend(self.resolve_function_calls()?);
48
49        // Resolve class instantiations
50        new_edges.extend(self.resolve_class_instantiations()?);
51
52        // Resolve inheritance relationships
53        new_edges.extend(self.resolve_inheritance()?);
54
55        Ok(new_edges)
56    }
57
58    /// Build indices of available symbols for resolution
59    fn build_symbol_indices(&mut self) -> Result<()> {
60        // Get all nodes and organize by module
61        for (file_path, node_ids) in self.graph.iter_file_index() {
62            let module_name = self.file_path_to_module_name(&file_path);
63
64            for node_id in node_ids {
65                if let Some(node) = self.graph.get_node(&node_id) {
66                    match node.kind {
67                        NodeKind::Class | NodeKind::Function | NodeKind::Variable => {
68                            // Add to module symbols
69                            self.module_symbols
70                                .entry(module_name.clone())
71                                .or_default()
72                                .push(node_id);
73
74                            // Add to qualified symbols
75                            let qualified_name = format!("{}.{}", module_name, node.name);
76                            self.qualified_symbols.insert(qualified_name, node_id);
77                        }
78                        _ => {}
79                    }
80                }
81            }
82        }
83
84        Ok(())
85    }
86
87    /// Resolve import statements to create edges to imported symbols
88    fn resolve_imports(&mut self) -> Result<Vec<Edge>> {
89        let mut edges = Vec::new();
90
91        // Find all import nodes
92        let import_nodes = self.graph.get_nodes_by_kind(NodeKind::Import);
93
94        for import_node in import_nodes {
95            edges.extend(self.resolve_single_import(&import_node)?);
96        }
97
98        Ok(edges)
99    }
100
101    /// Resolve a single import node
102    fn resolve_single_import(&mut self, import_node: &Node) -> Result<Vec<Edge>> {
103        let mut edges = Vec::new();
104
105        // Parse import statement
106        let import_parts = self.parse_import_statement(&import_node.name);
107
108        for (module_path, symbol_name) in import_parts {
109            // Find the target symbol
110            if let Some(target_id) = self.find_symbol_in_module(&module_path, &symbol_name) {
111                // Create import edge
112                edges.push(Edge::new(import_node.id, target_id, EdgeKind::Imports));
113            }
114        }
115
116        Ok(edges)
117    }
118
119    /// Resolve function calls to actual function definitions
120    fn resolve_function_calls(&mut self) -> Result<Vec<Edge>> {
121        let mut edges = Vec::new();
122
123        // Find all call nodes
124        let call_nodes = self.graph.get_nodes_by_kind(NodeKind::Call);
125
126        for call_node in call_nodes {
127            if let Some(target_id) = self.resolve_call_target(&call_node)? {
128                edges.push(Edge::new(call_node.id, target_id, EdgeKind::Calls));
129            }
130        }
131
132        Ok(edges)
133    }
134
135    /// Resolve class instantiations (calls to __init__)
136    fn resolve_class_instantiations(&mut self) -> Result<Vec<Edge>> {
137        let mut edges = Vec::new();
138
139        // Find call nodes that might be class instantiations
140        let call_nodes = self.graph.get_nodes_by_kind(NodeKind::Call);
141
142        for call_node in call_nodes {
143            // Check if this is a class name (first letter uppercase)
144            if call_node
145                .name
146                .chars()
147                .next()
148                .is_some_and(|c| c.is_uppercase())
149            {
150                if let Some(class_id) = self.find_class_by_name(&call_node.name) {
151                    // Find the __init__ method of this class
152                    if let Some(init_id) = self.find_method_in_class(class_id, "__init__") {
153                        edges.push(Edge::new(call_node.id, init_id, EdgeKind::Calls));
154                    }
155                }
156            }
157        }
158
159        Ok(edges)
160    }
161
162    /// Parse import statement to extract module and symbol names
163    fn parse_import_statement(&self, import_name: &str) -> Vec<(String, String)> {
164        let mut results = Vec::new();
165
166        // Handle different import patterns
167        if import_name.contains('.') {
168            // Module.symbol or complex import
169            let parts: Vec<&str> = import_name.split('.').collect();
170            if parts.len() >= 2 {
171                let module = parts[..parts.len() - 1].join(".");
172                let symbol = parts.last().unwrap().to_string();
173                results.push((module, symbol));
174            }
175        } else {
176            // Simple module import - all exportable symbols
177            if let Some(symbols) = self.module_symbols.get(import_name) {
178                for symbol_id in symbols {
179                    if let Some(node) = self.graph.get_node(symbol_id) {
180                        results.push((import_name.to_string(), node.name.clone()));
181                    }
182                }
183            }
184        }
185
186        results
187    }
188
189    /// Find a symbol in a specific module
190    fn find_symbol_in_module(&self, module_path: &str, symbol_name: &str) -> Option<NodeId> {
191        // Try qualified name first
192        let qualified_name = format!("{}.{}", module_path, symbol_name);
193        if let Some(node_id) = self.qualified_symbols.get(&qualified_name) {
194            return Some(*node_id);
195        }
196
197        // Try by module and name
198        if let Some(symbol_ids) = self.module_symbols.get(module_path) {
199            for symbol_id in symbol_ids {
200                if let Some(node) = self.graph.get_node(symbol_id) {
201                    if node.name == symbol_name {
202                        return Some(*symbol_id);
203                    }
204                }
205            }
206        }
207
208        None
209    }
210
211    /// Resolve the target of a function call
212    fn resolve_call_target(&self, call_node: &Node) -> Result<Option<NodeId>> {
213        // Get the file where this call is made
214        let calling_file = &call_node.file;
215
216        // First check for local functions in the same file
217        let file_nodes = self.graph.get_nodes_in_file(calling_file);
218        for node in &file_nodes {
219            if matches!(node.kind, NodeKind::Function | NodeKind::Method)
220                && node.name == call_node.name
221            {
222                return Ok(Some(node.id));
223            }
224        }
225
226        // Check imported functions
227        // Find import nodes in the same file
228        for node in &file_nodes {
229            if node.kind == NodeKind::Import {
230                let import_parts = self.parse_import_statement(&node.name);
231                for (module_path, symbol_name) in import_parts {
232                    if symbol_name == call_node.name {
233                        if let Some(target_id) =
234                            self.find_symbol_in_module(&module_path, &symbol_name)
235                        {
236                            return Ok(Some(target_id));
237                        }
238                    }
239                }
240            }
241        }
242
243        Ok(None)
244    }
245
246    /// Find a class by name (could be local or imported)
247    fn find_class_by_name(&self, class_name: &str) -> Option<NodeId> {
248        // Search all class nodes for matching name
249        let class_nodes = self.graph.get_nodes_by_kind(NodeKind::Class);
250        for node in class_nodes {
251            if node.name == class_name {
252                return Some(node.id);
253            }
254        }
255        None
256    }
257
258    /// Find a method within a specific class
259    fn find_method_in_class(&self, class_id: NodeId, method_name: &str) -> Option<NodeId> {
260        // Get the class node to find its file
261        if let Some(class_node) = self.graph.get_node(&class_id) {
262            let file_nodes = self.graph.get_nodes_in_file(&class_node.file);
263
264            for node in file_nodes {
265                if node.kind == NodeKind::Method && node.name == method_name {
266                    // Check if this method is within the class span
267                    if node.span.start_line >= class_node.span.start_line
268                        && node.span.end_line <= class_node.span.end_line
269                    {
270                        return Some(node.id);
271                    }
272                }
273            }
274        }
275        None
276    }
277
278    /// Convert file path to module name
279    fn file_path_to_module_name(&self, file_path: &Path) -> String {
280        // Convert file path to Python module name
281        if let Some(stem) = file_path.file_stem().and_then(|s| s.to_str()) {
282            if stem == "__init__" {
283                // For __init__.py, use parent directory name
284                if let Some(parent) = file_path.parent() {
285                    if let Some(parent_name) = parent.file_name().and_then(|s| s.to_str()) {
286                        return parent_name.to_string();
287                    }
288                }
289            }
290
291            // Convert path separators to dots for module name
292            let path_str = file_path.to_string_lossy();
293            let module_path = path_str
294                .replace(['/', '\\'], ".")
295                .replace(".py", "")
296                .replace(".__init__", "");
297
298            return module_path;
299        }
300
301        "unknown".to_string()
302    }
303
304    /// Resolve inheritance relationships (class extends parent class)
305    fn resolve_inheritance(&mut self) -> Result<Vec<Edge>> {
306        let mut edges = Vec::new();
307
308        // Find all class nodes
309        let class_nodes = self.graph.get_nodes_by_kind(NodeKind::Class);
310
311        for class_node in class_nodes {
312            // Get all outgoing Call edges from this class (inheritance is represented as Call)
313            let outgoing_edges = self.graph.get_outgoing_edges(&class_node.id);
314
315            for edge in outgoing_edges {
316                if edge.kind == EdgeKind::Calls {
317                    // Check if the target node represents a base class reference
318                    if let Some(call_node) = self.graph.get_node(&edge.target) {
319                        if call_node.kind == NodeKind::Call {
320                            // Try to resolve this call to an actual class
321                            if let Some(target_class_id) =
322                                self.resolve_base_class_name(&call_node.name, &class_node.file)
323                            {
324                                // Create inheritance edge: child class -> parent class
325                                edges.push(Edge::new(
326                                    class_node.id,
327                                    target_class_id,
328                                    EdgeKind::Calls,
329                                ));
330                            }
331                        }
332                    }
333                }
334            }
335        }
336
337        Ok(edges)
338    }
339
340    /// Resolve a base class name to its actual class node
341    fn resolve_base_class_name(
342        &self,
343        class_name: &str,
344        calling_file: &std::path::PathBuf,
345    ) -> Option<NodeId> {
346        // First check for local classes in the same file
347        let file_nodes = self.graph.get_nodes_in_file(calling_file);
348        for node in &file_nodes {
349            if node.kind == NodeKind::Class && node.name == class_name {
350                return Some(node.id);
351            }
352        }
353
354        // Check imported classes
355        // Find import nodes in the same file and see if they import this class
356        for node in &file_nodes {
357            if node.kind == NodeKind::Import {
358                let import_parts = self.parse_import_statement(&node.name);
359                for (module_path, symbol_name) in import_parts {
360                    if symbol_name == class_name {
361                        if let Some(target_id) =
362                            self.find_symbol_in_module(&module_path, &symbol_name)
363                        {
364                            // Verify it's actually a class
365                            if let Some(target_node) = self.graph.get_node(&target_id) {
366                                if target_node.kind == NodeKind::Class {
367                                    return Some(target_id);
368                                }
369                            }
370                        }
371                    }
372                }
373            }
374        }
375
376        // Fallback: search all classes by name
377        let all_class_nodes = self.graph.get_nodes_by_kind(NodeKind::Class);
378        for node in all_class_nodes {
379            if node.name == class_name {
380                return Some(node.id);
381            }
382        }
383
384        None
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use std::path::PathBuf;
392
393    #[test]
394    fn test_module_name_conversion() {
395        let resolver = SymbolResolver::new(Arc::new(GraphStore::new()));
396
397        let path1 = PathBuf::from("src/rustic_ai/core/guild/agent.py");
398        assert_eq!(
399            resolver.file_path_to_module_name(&path1),
400            "src.rustic_ai.core.guild.agent"
401        );
402
403        let path2 = PathBuf::from("src/utils/__init__.py");
404        assert_eq!(resolver.file_path_to_module_name(&path2), "utils");
405    }
406
407    #[test]
408    fn test_import_parsing() {
409        let resolver = SymbolResolver::new(Arc::new(GraphStore::new()));
410
411        let parts = resolver.parse_import_statement("rustic_ai.core.guild.Agent");
412        assert_eq!(parts.len(), 1);
413        assert_eq!(
414            parts[0],
415            ("rustic_ai.core.guild".to_string(), "Agent".to_string())
416        );
417    }
418}