Skip to main content

aster/map/
symbol_reference_analyzer.rs

1//! 符号引用分析器
2//!
3//! 分析函数调用、变量读写等符号级引用关系
4
5use std::collections::HashMap;
6use std::fs;
7use std::path::{Path, PathBuf};
8
9use super::types::{ClassNode, FunctionNode, LocationInfo, ModuleNode};
10use super::types_enhanced::{SymbolCall, SymbolEntry, SymbolKind};
11
12/// 符号信息
13#[derive(Debug, Clone)]
14struct SymbolInfo {
15    id: String,
16    name: String,
17    kind: SymbolKind,
18    module_id: String,
19    location: LocationInfo,
20    signature: Option<String>,
21    parent: Option<String>,
22}
23
24/// 调用信息
25#[derive(Debug, Clone)]
26struct CallInfo {
27    caller_symbol: String,
28    callee_symbol: String,
29    callee_name: String,
30    call_type: CallType,
31    location: LocationInfo,
32}
33
34/// 调用类型
35#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
36#[serde(rename_all = "lowercase")]
37pub enum CallType {
38    Direct,
39    Method,
40    Constructor,
41}
42
43/// 符号引用分析结果
44#[derive(Debug, Clone)]
45pub struct SymbolReferenceResult {
46    pub symbols: HashMap<String, SymbolEntry>,
47    pub calls: Vec<SymbolCall>,
48}
49
50/// 符号引用分析器
51pub struct SymbolReferenceAnalyzer {
52    root_path: PathBuf,
53    /// 符号索引
54    symbol_index: HashMap<String, SymbolInfo>,
55    /// 名称到符号的映射
56    name_to_symbols: HashMap<String, Vec<String>>,
57}
58
59impl SymbolReferenceAnalyzer {
60    /// 创建新的分析器
61    pub fn new(root_path: impl AsRef<Path>) -> Self {
62        Self {
63            root_path: root_path.as_ref().to_path_buf(),
64            symbol_index: HashMap::new(),
65            name_to_symbols: HashMap::new(),
66        }
67    }
68
69    /// 分析模块列表,提取符号引用关系
70    pub fn analyze(&mut self, modules: &[ModuleNode]) -> SymbolReferenceResult {
71        // 1. 构建符号索引
72        self.build_symbol_index(modules);
73
74        // 2. 分析调用关系
75        let calls = self.analyze_call_relations(modules);
76
77        // 3. 转换为输出格式
78        let symbols = self.convert_to_symbol_entries();
79
80        SymbolReferenceResult { symbols, calls }
81    }
82
83    /// 构建符号索引
84    fn build_symbol_index(&mut self, modules: &[ModuleNode]) {
85        self.symbol_index.clear();
86        self.name_to_symbols.clear();
87
88        for module in modules {
89            // 函数
90            for func in &module.functions {
91                let info = SymbolInfo {
92                    id: func.id.clone(),
93                    name: func.name.clone(),
94                    kind: SymbolKind::Function,
95                    module_id: module.id.clone(),
96                    location: func.location.clone(),
97                    signature: Some(func.signature.clone()),
98                    parent: None,
99                };
100                self.add_symbol(info);
101            }
102
103            // 类
104            for cls in &module.classes {
105                let class_info = SymbolInfo {
106                    id: cls.id.clone(),
107                    name: cls.name.clone(),
108                    kind: SymbolKind::Class,
109                    module_id: module.id.clone(),
110                    location: cls.location.clone(),
111                    signature: None,
112                    parent: None,
113                };
114                self.add_symbol(class_info);
115
116                // 方法
117                for method in &cls.methods {
118                    let method_info = SymbolInfo {
119                        id: method.id.clone(),
120                        name: method.name.clone(),
121                        kind: SymbolKind::Method,
122                        module_id: module.id.clone(),
123                        location: method.location.clone(),
124                        signature: Some(method.signature.clone()),
125                        parent: Some(cls.id.clone()),
126                    };
127                    self.add_symbol(method_info);
128                }
129
130                // 属性
131                for prop in &cls.properties {
132                    let prop_info = SymbolInfo {
133                        id: prop.id.clone(),
134                        name: prop.name.clone(),
135                        kind: SymbolKind::Property,
136                        module_id: module.id.clone(),
137                        location: prop.location.clone(),
138                        signature: None,
139                        parent: Some(cls.id.clone()),
140                    };
141                    self.add_symbol(prop_info);
142                }
143            }
144
145            // 接口
146            for iface in &module.interfaces {
147                let info = SymbolInfo {
148                    id: iface.id.clone(),
149                    name: iface.name.clone(),
150                    kind: SymbolKind::Interface,
151                    module_id: module.id.clone(),
152                    location: iface.location.clone(),
153                    signature: None,
154                    parent: None,
155                };
156                self.add_symbol(info);
157            }
158
159            // 类型
160            for type_node in &module.types {
161                let info = SymbolInfo {
162                    id: type_node.id.clone(),
163                    name: type_node.name.clone(),
164                    kind: SymbolKind::Type,
165                    module_id: module.id.clone(),
166                    location: type_node.location.clone(),
167                    signature: None,
168                    parent: None,
169                };
170                self.add_symbol(info);
171            }
172
173            // 枚举
174            for enum_node in &module.enums {
175                let info = SymbolInfo {
176                    id: enum_node.id.clone(),
177                    name: enum_node.name.clone(),
178                    kind: SymbolKind::Enum,
179                    module_id: module.id.clone(),
180                    location: enum_node.location.clone(),
181                    signature: None,
182                    parent: None,
183                };
184                self.add_symbol(info);
185            }
186
187            // 变量
188            for var in &module.variables {
189                let kind = if var.kind == super::types::VariableKind::Const {
190                    SymbolKind::Constant
191                } else {
192                    SymbolKind::Variable
193                };
194                let info = SymbolInfo {
195                    id: var.id.clone(),
196                    name: var.name.clone(),
197                    kind,
198                    module_id: module.id.clone(),
199                    location: var.location.clone(),
200                    signature: None,
201                    parent: None,
202                };
203                self.add_symbol(info);
204            }
205
206            // 导出的符号
207            for exp in &module.exports {
208                if exp.name.starts_with('*') {
209                    continue;
210                }
211
212                let existing_id = format!("{}::{}", module.id, exp.name);
213                if self.symbol_index.contains_key(&existing_id) {
214                    continue;
215                }
216
217                let starts_with_uppercase =
218                    exp.name.chars().next().is_some_and(|c| c.is_uppercase());
219                let looks_like_type = starts_with_uppercase && !exp.name.contains('_');
220
221                let info = SymbolInfo {
222                    id: existing_id,
223                    name: exp.name.clone(),
224                    kind: if looks_like_type {
225                        SymbolKind::Type
226                    } else {
227                        SymbolKind::Variable
228                    },
229                    module_id: module.id.clone(),
230                    location: exp.location.clone(),
231                    signature: None,
232                    parent: None,
233                };
234                self.add_symbol(info);
235            }
236        }
237    }
238
239    /// 添加符号到索引
240    fn add_symbol(&mut self, info: SymbolInfo) {
241        let name = info.name.clone();
242        let id = info.id.clone();
243
244        self.symbol_index.insert(id.clone(), info);
245
246        self.name_to_symbols.entry(name).or_default().push(id);
247    }
248
249    /// 分析调用关系
250    fn analyze_call_relations(&self, modules: &[ModuleNode]) -> Vec<SymbolCall> {
251        let mut call_map: HashMap<String, SymbolCall> = HashMap::new();
252
253        for module in modules {
254            // 读取文件内容
255            let file_path = self.root_path.join(&module.id);
256            let content = match fs::read_to_string(&file_path) {
257                Ok(c) => c,
258                Err(_) => continue,
259            };
260
261            let lines: Vec<&str> = content.lines().collect();
262
263            // 分析函数内的调用
264            for func in &module.functions {
265                let func_calls = self.analyze_calls_in_function(func, module, &lines, None);
266                self.merge_calls_into_map(func_calls, &mut call_map);
267            }
268
269            // 分析类方法内的调用
270            for cls in &module.classes {
271                for method in &cls.methods {
272                    let method_calls = self.analyze_calls_in_method(method, module, &lines, cls);
273                    self.merge_calls_into_map(method_calls, &mut call_map);
274                }
275            }
276        }
277
278        call_map.into_values().collect()
279    }
280
281    /// 分析方法内的调用
282    fn analyze_calls_in_method(
283        &self,
284        method: &super::types::MethodNode,
285        module: &ModuleNode,
286        lines: &[&str],
287        parent_class: &ClassNode,
288    ) -> Vec<CallInfo> {
289        // 创建一个临时的 FunctionNode 风格的数据来复用逻辑
290        let func_like = FunctionNode {
291            id: method.id.clone(),
292            name: method.name.clone(),
293            signature: method.signature.clone(),
294            parameters: method.parameters.clone(),
295            return_type: method.return_type.clone(),
296            location: method.location.clone(),
297            is_async: method.is_async,
298            is_exported: false,
299            is_generator: false,
300            documentation: method.documentation.clone(),
301            calls: vec![],
302            called_by: vec![],
303        };
304        self.analyze_calls_in_function(&func_like, module, lines, Some(parent_class))
305    }
306
307    /// 分析函数/方法内的调用
308    fn analyze_calls_in_function(
309        &self,
310        func: &FunctionNode,
311        module: &ModuleNode,
312        lines: &[&str],
313        parent_class: Option<&ClassNode>,
314    ) -> Vec<CallInfo> {
315        use once_cell::sync::Lazy;
316
317        static RE_FUNC_CALL: Lazy<regex::Regex> =
318            Lazy::new(|| regex::Regex::new(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\s*\(").unwrap());
319        static RE_METHOD_CALL: Lazy<regex::Regex> = Lazy::new(|| {
320            regex::Regex::new(
321                r"(?:([a-zA-Z_][a-zA-Z0-9_]*)|self|this)\.([a-zA-Z_][a-zA-Z0-9_]*)\s*\(",
322            )
323            .unwrap()
324        });
325        static RE_CONSTRUCTOR: Lazy<regex::Regex> =
326            Lazy::new(|| regex::Regex::new(r"(?:new\s+|::new\s*\()([A-Z][a-zA-Z0-9_]*)").unwrap());
327
328        let mut calls = Vec::new();
329        let caller_symbol = func.id.clone();
330
331        // 获取函数体的行范围
332        let start_line = (func.location.start_line as usize).saturating_sub(1);
333        let end_line = (func.location.end_line as usize).min(lines.len());
334
335        // 忽略的关键字
336        let ignored: std::collections::HashSet<&str> = [
337            "if",
338            "else",
339            "for",
340            "while",
341            "switch",
342            "case",
343            "catch",
344            "try",
345            "return",
346            "throw",
347            "typeof",
348            "instanceof",
349            "delete",
350            "void",
351            "function",
352            "class",
353            "const",
354            "let",
355            "var",
356            "import",
357            "export",
358            "async",
359            "await",
360            "yield",
361            "super",
362            "this",
363            "fn",
364            "pub",
365            "mod",
366            "use",
367            "impl",
368            "struct",
369            "enum",
370            "trait",
371            "match",
372            "loop",
373        ]
374        .into_iter()
375        .collect();
376
377        for (i, line) in lines.iter().enumerate().take(end_line).skip(start_line) {
378            let line_num = (i + 1) as u32;
379
380            // 跳过注释行
381            let trimmed = line.trim();
382            if trimmed.starts_with("//") || trimmed.starts_with('*') || trimmed.starts_with("/*") {
383                continue;
384            }
385
386            // 模式 1: 普通函数调用 functionName(
387            for cap in RE_FUNC_CALL.captures_iter(line) {
388                if let Some(func_name) = cap.get(1) {
389                    let name = func_name.as_str();
390                    if ignored.contains(name) {
391                        continue;
392                    }
393
394                    let targets = self.find_target_symbols(name, module);
395                    for target_id in targets {
396                        calls.push(CallInfo {
397                            caller_symbol: caller_symbol.clone(),
398                            callee_symbol: target_id,
399                            callee_name: name.to_string(),
400                            call_type: CallType::Direct,
401                            location: LocationInfo {
402                                file: module.id.clone(),
403                                start_line: line_num,
404                                start_column: func_name.start() as u32,
405                                end_line: line_num,
406                                end_column: func_name.end() as u32,
407                            },
408                        });
409                    }
410                }
411            }
412
413            // 模式 2: 方法调用 obj.methodName( 或 self.methodName(
414            for cap in RE_METHOD_CALL.captures_iter(line) {
415                let obj_name = cap.get(1).map(|m| m.as_str());
416                if let Some(method_name) = cap.get(2) {
417                    let name = method_name.as_str();
418                    if ignored.contains(name) {
419                        continue;
420                    }
421
422                    // self.method() 或 this.method() 调用
423                    if obj_name.is_none() {
424                        if let Some(cls) = parent_class {
425                            let target_id = format!("{}::{}::{}", module.id, cls.name, name);
426                            if self.symbol_index.contains_key(&target_id) {
427                                calls.push(CallInfo {
428                                    caller_symbol: caller_symbol.clone(),
429                                    callee_symbol: target_id,
430                                    callee_name: name.to_string(),
431                                    call_type: CallType::Method,
432                                    location: LocationInfo {
433                                        file: module.id.clone(),
434                                        start_line: line_num,
435                                        start_column: method_name.start() as u32,
436                                        end_line: line_num,
437                                        end_column: method_name.end() as u32,
438                                    },
439                                });
440                            }
441                        }
442                    } else {
443                        // obj.method() 调用
444                        let targets = self.find_method_targets(name);
445                        for target_id in targets {
446                            calls.push(CallInfo {
447                                caller_symbol: caller_symbol.clone(),
448                                callee_symbol: target_id,
449                                callee_name: name.to_string(),
450                                call_type: CallType::Method,
451                                location: LocationInfo {
452                                    file: module.id.clone(),
453                                    start_line: line_num,
454                                    start_column: method_name.start() as u32,
455                                    end_line: line_num,
456                                    end_column: method_name.end() as u32,
457                                },
458                            });
459                        }
460                    }
461                }
462            }
463
464            // 模式 3: 构造函数调用 new ClassName( 或 ClassName::new(
465            for cap in RE_CONSTRUCTOR.captures_iter(line) {
466                if let Some(class_name) = cap.get(1) {
467                    let name = class_name.as_str();
468                    let targets = self.find_target_symbols(name, module);
469                    for target_id in targets {
470                        if let Some(symbol) = self.symbol_index.get(&target_id) {
471                            if symbol.kind == SymbolKind::Class {
472                                calls.push(CallInfo {
473                                    caller_symbol: caller_symbol.clone(),
474                                    callee_symbol: target_id,
475                                    callee_name: name.to_string(),
476                                    call_type: CallType::Constructor,
477                                    location: LocationInfo {
478                                        file: module.id.clone(),
479                                        start_line: line_num,
480                                        start_column: class_name.start() as u32,
481                                        end_line: line_num,
482                                        end_column: class_name.end() as u32,
483                                    },
484                                });
485                            }
486                        }
487                    }
488                }
489            }
490        }
491
492        calls
493    }
494
495    /// 查找目标符号
496    fn find_target_symbols(&self, name: &str, current_module: &ModuleNode) -> Vec<String> {
497        let candidates = match self.name_to_symbols.get(name) {
498            Some(c) => c.clone(),
499            None => return vec![],
500        };
501
502        // 获取当前模块导入的符号
503        let imported_symbols: std::collections::HashSet<String> = current_module
504            .imports
505            .iter()
506            .flat_map(|imp| imp.symbols.iter().cloned())
507            .collect();
508
509        let mut same_module = Vec::new();
510        let mut imported = Vec::new();
511        let mut others = Vec::new();
512
513        for candidate_id in candidates {
514            if let Some(symbol) = self.symbol_index.get(&candidate_id) {
515                if symbol.module_id == current_module.id {
516                    same_module.push(candidate_id);
517                } else if imported_symbols.contains(name) {
518                    imported.push(candidate_id);
519                } else {
520                    others.push(candidate_id);
521                }
522            }
523        }
524
525        // 返回最可能的目标
526        if !same_module.is_empty() {
527            return same_module;
528        }
529        if !imported.is_empty() {
530            return imported;
531        }
532        others.into_iter().take(1).collect()
533    }
534
535    /// 查找方法目标
536    fn find_method_targets(&self, method_name: &str) -> Vec<String> {
537        self.symbol_index
538            .iter()
539            .filter(|(_, symbol)| symbol.kind == SymbolKind::Method && symbol.name == method_name)
540            .map(|(id, _)| id.clone())
541            .collect()
542    }
543
544    /// 合并调用到 Map(去重)
545    fn merge_calls_into_map(&self, calls: Vec<CallInfo>, map: &mut HashMap<String, SymbolCall>) {
546        for call in calls {
547            let key = format!("{}::{}", call.caller_symbol, call.callee_symbol);
548
549            if let Some(existing) = map.get_mut(&key) {
550                existing.locations.push(call.location);
551            } else {
552                map.insert(
553                    key,
554                    SymbolCall {
555                        caller: call.caller_symbol,
556                        callee: call.callee_symbol,
557                        call_type: match call.call_type {
558                            CallType::Direct => "direct".to_string(),
559                            CallType::Method => "method".to_string(),
560                            CallType::Constructor => "constructor".to_string(),
561                        },
562                        locations: vec![call.location],
563                    },
564                );
565            }
566        }
567    }
568
569    /// 转换为 SymbolEntry 格式
570    fn convert_to_symbol_entries(&self) -> HashMap<String, SymbolEntry> {
571        let mut entries = HashMap::new();
572
573        for (id, info) in &self.symbol_index {
574            let mut entry = SymbolEntry {
575                id: info.id.clone(),
576                name: info.name.clone(),
577                kind: info.kind,
578                module_id: info.module_id.clone(),
579                location: info.location.clone(),
580                signature: info.signature.clone(),
581                semantic: None,
582                parent: info.parent.clone(),
583                children: None,
584            };
585
586            // 收集子符号
587            if info.kind == SymbolKind::Class {
588                let children: Vec<String> = self
589                    .symbol_index
590                    .iter()
591                    .filter(|(_, child)| child.parent.as_ref() == Some(id))
592                    .map(|(child_id, _)| child_id.clone())
593                    .collect();
594
595                if !children.is_empty() {
596                    entry.children = Some(children);
597                }
598            }
599
600            entries.insert(id.clone(), entry);
601        }
602
603        entries
604    }
605}
606
607// ============================================================================
608// 便捷函数
609// ============================================================================
610
611/// 分析符号引用
612pub fn analyze_symbol_references(
613    root_path: impl AsRef<Path>,
614    modules: &[ModuleNode],
615) -> SymbolReferenceResult {
616    let mut analyzer = SymbolReferenceAnalyzer::new(root_path);
617    analyzer.analyze(modules)
618}