codeprism_core/linkers/
symbol_resolver.rs1use crate::ast::{Edge, EdgeKind, Node, NodeId, NodeKind};
7use crate::error::Result;
8use crate::graph::GraphStore;
9use std::collections::HashMap;
10use std::path::Path;
11use std::sync::Arc;
12
13pub struct SymbolResolver {
15 graph: Arc<GraphStore>,
16 module_symbols: HashMap<String, Vec<NodeId>>,
18 qualified_symbols: HashMap<String, NodeId>,
20 #[allow(dead_code)]
22 import_cache: HashMap<String, String>,
23}
24
25impl SymbolResolver {
26 pub fn new(graph: Arc<GraphStore>) -> Self {
28 Self {
29 graph,
30 module_symbols: HashMap::new(),
31 qualified_symbols: HashMap::new(),
32 import_cache: HashMap::new(),
33 }
34 }
35
36 pub fn resolve_all(&mut self) -> Result<Vec<Edge>> {
38 let mut new_edges = Vec::new();
39
40 self.build_symbol_indices()?;
42
43 new_edges.extend(self.resolve_imports()?);
45
46 new_edges.extend(self.resolve_function_calls()?);
48
49 new_edges.extend(self.resolve_class_instantiations()?);
51
52 new_edges.extend(self.resolve_inheritance()?);
54
55 Ok(new_edges)
56 }
57
58 fn build_symbol_indices(&mut self) -> Result<()> {
60 for (file_path, node_ids) in self.graph.iter_file_index() {
62 let module_name = self.file_path_to_module_name(&file_path);
63
64 for node_id in node_ids {
65 if let Some(node) = self.graph.get_node(&node_id) {
66 match node.kind {
67 NodeKind::Class | NodeKind::Function | NodeKind::Variable => {
68 self.module_symbols
70 .entry(module_name.clone())
71 .or_default()
72 .push(node_id);
73
74 let qualified_name = format!("{}.{}", module_name, node.name);
76 self.qualified_symbols.insert(qualified_name, node_id);
77 }
78 _ => {}
79 }
80 }
81 }
82 }
83
84 Ok(())
85 }
86
87 fn resolve_imports(&mut self) -> Result<Vec<Edge>> {
89 let mut edges = Vec::new();
90
91 let import_nodes = self.graph.get_nodes_by_kind(NodeKind::Import);
93
94 for import_node in import_nodes {
95 edges.extend(self.resolve_single_import(&import_node)?);
96 }
97
98 Ok(edges)
99 }
100
101 fn resolve_single_import(&mut self, import_node: &Node) -> Result<Vec<Edge>> {
103 let mut edges = Vec::new();
104
105 let import_parts = self.parse_import_statement(&import_node.name);
107
108 for (module_path, symbol_name) in import_parts {
109 if let Some(target_id) = self.find_symbol_in_module(&module_path, &symbol_name) {
111 edges.push(Edge::new(import_node.id, target_id, EdgeKind::Imports));
113 }
114 }
115
116 Ok(edges)
117 }
118
119 fn resolve_function_calls(&mut self) -> Result<Vec<Edge>> {
121 let mut edges = Vec::new();
122
123 let call_nodes = self.graph.get_nodes_by_kind(NodeKind::Call);
125
126 for call_node in call_nodes {
127 if let Some(target_id) = self.resolve_call_target(&call_node)? {
128 edges.push(Edge::new(call_node.id, target_id, EdgeKind::Calls));
129 }
130 }
131
132 Ok(edges)
133 }
134
135 fn resolve_class_instantiations(&mut self) -> Result<Vec<Edge>> {
137 let mut edges = Vec::new();
138
139 let call_nodes = self.graph.get_nodes_by_kind(NodeKind::Call);
141
142 for call_node in call_nodes {
143 if call_node
145 .name
146 .chars()
147 .next()
148 .is_some_and(|c| c.is_uppercase())
149 {
150 if let Some(class_id) = self.find_class_by_name(&call_node.name) {
151 if let Some(init_id) = self.find_method_in_class(class_id, "__init__") {
153 edges.push(Edge::new(call_node.id, init_id, EdgeKind::Calls));
154 }
155 }
156 }
157 }
158
159 Ok(edges)
160 }
161
162 fn parse_import_statement(&self, import_name: &str) -> Vec<(String, String)> {
164 let mut results = Vec::new();
165
166 if import_name.contains('.') {
168 let parts: Vec<&str> = import_name.split('.').collect();
170 if parts.len() >= 2 {
171 let module = parts[..parts.len() - 1].join(".");
172 let symbol = parts.last().unwrap().to_string();
173 results.push((module, symbol));
174 }
175 } else {
176 if let Some(symbols) = self.module_symbols.get(import_name) {
178 for symbol_id in symbols {
179 if let Some(node) = self.graph.get_node(symbol_id) {
180 results.push((import_name.to_string(), node.name.clone()));
181 }
182 }
183 }
184 }
185
186 results
187 }
188
189 fn find_symbol_in_module(&self, module_path: &str, symbol_name: &str) -> Option<NodeId> {
191 let qualified_name = format!("{}.{}", module_path, symbol_name);
193 if let Some(node_id) = self.qualified_symbols.get(&qualified_name) {
194 return Some(*node_id);
195 }
196
197 if let Some(symbol_ids) = self.module_symbols.get(module_path) {
199 for symbol_id in symbol_ids {
200 if let Some(node) = self.graph.get_node(symbol_id) {
201 if node.name == symbol_name {
202 return Some(*symbol_id);
203 }
204 }
205 }
206 }
207
208 None
209 }
210
211 fn resolve_call_target(&self, call_node: &Node) -> Result<Option<NodeId>> {
213 let calling_file = &call_node.file;
215
216 let file_nodes = self.graph.get_nodes_in_file(calling_file);
218 for node in &file_nodes {
219 if matches!(node.kind, NodeKind::Function | NodeKind::Method)
220 && node.name == call_node.name
221 {
222 return Ok(Some(node.id));
223 }
224 }
225
226 for node in &file_nodes {
229 if node.kind == NodeKind::Import {
230 let import_parts = self.parse_import_statement(&node.name);
231 for (module_path, symbol_name) in import_parts {
232 if symbol_name == call_node.name {
233 if let Some(target_id) =
234 self.find_symbol_in_module(&module_path, &symbol_name)
235 {
236 return Ok(Some(target_id));
237 }
238 }
239 }
240 }
241 }
242
243 Ok(None)
244 }
245
246 fn find_class_by_name(&self, class_name: &str) -> Option<NodeId> {
248 let class_nodes = self.graph.get_nodes_by_kind(NodeKind::Class);
250 for node in class_nodes {
251 if node.name == class_name {
252 return Some(node.id);
253 }
254 }
255 None
256 }
257
258 fn find_method_in_class(&self, class_id: NodeId, method_name: &str) -> Option<NodeId> {
260 if let Some(class_node) = self.graph.get_node(&class_id) {
262 let file_nodes = self.graph.get_nodes_in_file(&class_node.file);
263
264 for node in file_nodes {
265 if node.kind == NodeKind::Method && node.name == method_name {
266 if node.span.start_line >= class_node.span.start_line
268 && node.span.end_line <= class_node.span.end_line
269 {
270 return Some(node.id);
271 }
272 }
273 }
274 }
275 None
276 }
277
278 fn file_path_to_module_name(&self, file_path: &Path) -> String {
280 if let Some(stem) = file_path.file_stem().and_then(|s| s.to_str()) {
282 if stem == "__init__" {
283 if let Some(parent) = file_path.parent() {
285 if let Some(parent_name) = parent.file_name().and_then(|s| s.to_str()) {
286 return parent_name.to_string();
287 }
288 }
289 }
290
291 let path_str = file_path.to_string_lossy();
293 let module_path = path_str
294 .replace(['/', '\\'], ".")
295 .replace(".py", "")
296 .replace(".__init__", "");
297
298 return module_path;
299 }
300
301 "unknown".to_string()
302 }
303
304 fn resolve_inheritance(&mut self) -> Result<Vec<Edge>> {
306 let mut edges = Vec::new();
307
308 let class_nodes = self.graph.get_nodes_by_kind(NodeKind::Class);
310
311 for class_node in class_nodes {
312 let outgoing_edges = self.graph.get_outgoing_edges(&class_node.id);
314
315 for edge in outgoing_edges {
316 if edge.kind == EdgeKind::Calls {
317 if let Some(call_node) = self.graph.get_node(&edge.target) {
319 if call_node.kind == NodeKind::Call {
320 if let Some(target_class_id) =
322 self.resolve_base_class_name(&call_node.name, &class_node.file)
323 {
324 edges.push(Edge::new(
326 class_node.id,
327 target_class_id,
328 EdgeKind::Calls,
329 ));
330 }
331 }
332 }
333 }
334 }
335 }
336
337 Ok(edges)
338 }
339
340 fn resolve_base_class_name(
342 &self,
343 class_name: &str,
344 calling_file: &std::path::PathBuf,
345 ) -> Option<NodeId> {
346 let file_nodes = self.graph.get_nodes_in_file(calling_file);
348 for node in &file_nodes {
349 if node.kind == NodeKind::Class && node.name == class_name {
350 return Some(node.id);
351 }
352 }
353
354 for node in &file_nodes {
357 if node.kind == NodeKind::Import {
358 let import_parts = self.parse_import_statement(&node.name);
359 for (module_path, symbol_name) in import_parts {
360 if symbol_name == class_name {
361 if let Some(target_id) =
362 self.find_symbol_in_module(&module_path, &symbol_name)
363 {
364 if let Some(target_node) = self.graph.get_node(&target_id) {
366 if target_node.kind == NodeKind::Class {
367 return Some(target_id);
368 }
369 }
370 }
371 }
372 }
373 }
374 }
375
376 let all_class_nodes = self.graph.get_nodes_by_kind(NodeKind::Class);
378 for node in all_class_nodes {
379 if node.name == class_name {
380 return Some(node.id);
381 }
382 }
383
384 None
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use std::path::PathBuf;
392
393 #[test]
394 fn test_module_name_conversion() {
395 let resolver = SymbolResolver::new(Arc::new(GraphStore::new()));
396
397 let path1 = PathBuf::from("src/rustic_ai/core/guild/agent.py");
398 assert_eq!(
399 resolver.file_path_to_module_name(&path1),
400 "src.rustic_ai.core.guild.agent"
401 );
402
403 let path2 = PathBuf::from("src/utils/__init__.py");
404 assert_eq!(resolver.file_path_to_module_name(&path2), "utils");
405 }
406
407 #[test]
408 fn test_import_parsing() {
409 let resolver = SymbolResolver::new(Arc::new(GraphStore::new()));
410
411 let parts = resolver.parse_import_statement("rustic_ai.core.guild.Agent");
412 assert_eq!(parts.len(), 1);
413 assert_eq!(
414 parts[0],
415 ("rustic_ai.core.guild".to_string(), "Agent".to_string())
416 );
417 }
418}