Skip to main content

code_analyze_mcp/
graph.rs

1use crate::types::SemanticAnalysis;
2use std::collections::{HashMap, HashSet, VecDeque};
3use std::path::{Path, PathBuf};
4use thiserror::Error;
5use tracing::instrument;
6
7/// Type info for a function: (path, line, parameters, return_type)
8type FunctionTypeInfo = (PathBuf, usize, Vec<String>, Option<String>);
9
10#[derive(Debug, Error)]
11pub enum GraphError {
12    #[error("Symbol not found: {0}")]
13    SymbolNotFound(String),
14}
15
16/// Strip scope prefixes from a callee name.
17/// Handles patterns: 'self.method' -> 'method', 'Type::method' -> 'method', 'module::function' -> 'function'.
18/// If no prefix is found, returns the original name.
19fn strip_scope_prefix(name: &str) -> &str {
20    if let Some(pos) = name.rfind("::") {
21        &name[pos + 2..]
22    } else if let Some(pos) = name.rfind('.') {
23        &name[pos + 1..]
24    } else {
25        name
26    }
27}
28
29#[derive(Debug, Clone)]
30pub struct CallChain {
31    pub chain: Vec<(String, PathBuf, usize)>,
32}
33
34#[derive(Debug, Clone)]
35pub struct CallGraph {
36    pub callers: HashMap<String, Vec<(PathBuf, usize, String)>>,
37    pub callees: HashMap<String, Vec<(PathBuf, usize, String)>>,
38    pub definitions: HashMap<String, Vec<(PathBuf, usize)>>,
39    // Internal: maps function name to type info for type-aware disambiguation
40    function_types: HashMap<String, Vec<FunctionTypeInfo>>,
41}
42
43impl CallGraph {
44    pub fn new() -> Self {
45        Self {
46            callers: HashMap::new(),
47            callees: HashMap::new(),
48            definitions: HashMap::new(),
49            function_types: HashMap::new(),
50        }
51    }
52
53    /// Count parameters in a parameter string.
54    /// Handles: "(x: i32, y: String)" -> 2, "(&self, x: i32)" -> 2, "()" -> 0, "(&self)" -> 1
55    fn count_parameters(params_str: &str) -> usize {
56        if params_str.is_empty() || params_str == "()" {
57            return 0;
58        }
59        // Remove outer parens and trim
60        let inner = params_str
61            .trim_start_matches('(')
62            .trim_end_matches(')')
63            .trim();
64        if inner.is_empty() {
65            return 0;
66        }
67        // Count commas + 1 to get parameter count
68        inner.split(',').count()
69    }
70
71    /// Match a callee by parameter count and return type.
72    /// Returns the index of the best match in the candidates list, or None if no good match.
73    /// Strategy: prefer candidates with matching param count, then by return type match.
74    fn match_by_type(
75        &self,
76        candidates: &[FunctionTypeInfo],
77        expected_param_count: Option<usize>,
78        expected_return_type: Option<&str>,
79    ) -> Option<usize> {
80        if candidates.is_empty() {
81            return None;
82        }
83
84        // If we have no type info to match against, return None (fallback to line proximity)
85        if expected_param_count.is_none() && expected_return_type.is_none() {
86            return None;
87        }
88
89        let mut best_idx = 0;
90        let mut best_score = 0;
91
92        for (idx, (_path, _line, params, ret_type)) in candidates.iter().enumerate() {
93            let mut score = 0;
94
95            // Score parameter count match
96            if let Some(expected_count) = expected_param_count
97                && !params.is_empty()
98            {
99                let actual_count = Self::count_parameters(&params[0]);
100                if actual_count == expected_count {
101                    score += 2;
102                }
103            }
104
105            // Score return type match
106            if let Some(expected_ret) = expected_return_type
107                && let Some(actual_ret) = ret_type
108                && actual_ret == expected_ret
109            {
110                score += 1;
111            }
112
113            // Prefer candidates with more type info
114            if !params.is_empty() {
115                score += 1;
116            }
117            if ret_type.is_some() {
118                score += 1;
119            }
120
121            if score > best_score {
122                best_score = score;
123                best_idx = idx;
124            }
125        }
126
127        // Only return a match if we found a meaningful score
128        (best_score > 0).then_some(best_idx)
129    }
130
131    /// Resolve a callee name using four strategies:
132    /// 1. Try the raw callee name first in definitions
133    /// 2. If not found, try the stripped name (via strip_scope_prefix)
134    /// 3. If multiple definitions exist, prefer same-file candidates
135    /// 4. Among same-file candidates, use type info as tiebreaker, then line proximity
136    /// 5. If no same-file candidates, use any definition (first one)
137    ///
138    /// Returns the resolved callee name (which may be the stripped version).
139    fn resolve_callee(
140        &self,
141        callee: &str,
142        call_file: &Path,
143        call_line: usize,
144        arg_count: Option<usize>,
145        definitions: &HashMap<String, Vec<(PathBuf, usize)>>,
146        function_types: &HashMap<String, Vec<FunctionTypeInfo>>,
147    ) -> String {
148        // Try raw callee name first
149        if let Some(defs) = definitions.get(callee) {
150            return self.pick_best_definition(
151                defs,
152                call_file,
153                call_line,
154                arg_count,
155                callee,
156                function_types,
157            );
158        }
159
160        // Try stripped name
161        let stripped = strip_scope_prefix(callee);
162        if stripped != callee
163            && let Some(defs) = definitions.get(stripped)
164        {
165            return self.pick_best_definition(
166                defs,
167                call_file,
168                call_line,
169                arg_count,
170                stripped,
171                function_types,
172            );
173        }
174
175        // No definition found; return the original callee
176        callee.to_string()
177    }
178
179    /// Pick the best definition from a list based on same-file preference, type matching, and line proximity.
180    fn pick_best_definition(
181        &self,
182        defs: &[(PathBuf, usize)],
183        call_file: &Path,
184        call_line: usize,
185        arg_count: Option<usize>,
186        resolved_name: &str,
187        function_types: &HashMap<String, Vec<FunctionTypeInfo>>,
188    ) -> String {
189        // Filter to same-file candidates
190        let same_file_defs: Vec<_> = defs.iter().filter(|(path, _)| path == call_file).collect();
191
192        if !same_file_defs.is_empty() {
193            // Try type-aware disambiguation if we have type info
194            if let Some(type_info) = function_types.get(resolved_name) {
195                let same_file_types: Vec<_> = type_info
196                    .iter()
197                    .filter(|(path, _, _, _)| path == call_file)
198                    .cloned()
199                    .collect();
200
201                if !same_file_types.is_empty() && same_file_types.len() > 1 {
202                    // Group candidates by line proximity (within 5 lines)
203                    let mut proximity_groups: Vec<Vec<usize>> = vec![];
204                    for (idx, (_, def_line, _, _)) in same_file_types.iter().enumerate() {
205                        let mut placed = false;
206                        for group in &mut proximity_groups {
207                            if let Some((_, first_line, _, _)) = same_file_types.get(group[0])
208                                && first_line.abs_diff(*def_line) <= 5
209                            {
210                                group.push(idx);
211                                placed = true;
212                                break;
213                            }
214                        }
215                        if !placed {
216                            proximity_groups.push(vec![idx]);
217                        }
218                    }
219
220                    // Find the closest proximity group
221                    let closest_group = proximity_groups.iter().min_by_key(|group| {
222                        group
223                            .iter()
224                            .map(|idx| {
225                                if let Some((_, def_line, _, _)) = same_file_types.get(*idx) {
226                                    def_line.abs_diff(call_line)
227                                } else {
228                                    usize::MAX
229                                }
230                            })
231                            .min()
232                            .unwrap_or(usize::MAX)
233                    });
234
235                    if let Some(group) = closest_group {
236                        // Within the closest group, try type matching
237                        if group.len() > 1 {
238                            // Collect candidates for type matching
239                            let candidates: Vec<_> = group
240                                .iter()
241                                .filter_map(|idx| same_file_types.get(*idx).cloned())
242                                .collect();
243                            // Try to match by type using argument count from call site
244                            if let Some(_best_idx) =
245                                self.match_by_type(&candidates, arg_count, None)
246                            {
247                                return resolved_name.to_string();
248                            }
249                        }
250                    }
251                }
252            }
253
254            // Fallback to line proximity
255            let _best = same_file_defs
256                .iter()
257                .min_by_key(|(_, def_line)| (*def_line).abs_diff(call_line));
258            return resolved_name.to_string();
259        }
260
261        // No same-file candidates; use any definition (first one)
262        resolved_name.to_string()
263    }
264
265    #[instrument(skip_all)]
266    pub fn build_from_results(
267        results: Vec<(PathBuf, SemanticAnalysis)>,
268    ) -> Result<Self, GraphError> {
269        let mut graph = CallGraph::new();
270
271        // Build definitions and function_types maps first
272        for (path, analysis) in &results {
273            for func in &analysis.functions {
274                graph
275                    .definitions
276                    .entry(func.name.clone())
277                    .or_default()
278                    .push((path.clone(), func.line));
279                graph
280                    .function_types
281                    .entry(func.name.clone())
282                    .or_default()
283                    .push((
284                        path.clone(),
285                        func.line,
286                        func.parameters.clone(),
287                        func.return_type.clone(),
288                    ));
289            }
290            for class in &analysis.classes {
291                graph
292                    .definitions
293                    .entry(class.name.clone())
294                    .or_default()
295                    .push((path.clone(), class.line));
296                graph
297                    .function_types
298                    .entry(class.name.clone())
299                    .or_default()
300                    .push((path.clone(), class.line, vec![], None));
301            }
302        }
303
304        // Process calls with resolved callee names
305        for (path, analysis) in &results {
306            for call in &analysis.calls {
307                let resolved_callee = graph.resolve_callee(
308                    &call.callee,
309                    path,
310                    call.line,
311                    call.arg_count,
312                    &graph.definitions,
313                    &graph.function_types,
314                );
315
316                graph.callees.entry(call.caller.clone()).or_default().push((
317                    path.clone(),
318                    call.line,
319                    resolved_callee.clone(),
320                ));
321                graph.callers.entry(resolved_callee).or_default().push((
322                    path.clone(),
323                    call.line,
324                    call.caller.clone(),
325                ));
326            }
327            for reference in &analysis.references {
328                graph
329                    .callers
330                    .entry(reference.symbol.clone())
331                    .or_default()
332                    .push((path.clone(), reference.line, "<reference>".to_string()));
333            }
334        }
335
336        let total_edges = graph.callees.values().map(|v| v.len()).sum::<usize>()
337            + graph.callers.values().map(|v| v.len()).sum::<usize>();
338        let file_count = results.len();
339
340        tracing::debug!(
341            definitions = graph.definitions.len(),
342            edges = total_edges,
343            files = file_count,
344            "graph built"
345        );
346
347        Ok(graph)
348    }
349
350    fn find_chains_bfs(
351        &self,
352        symbol: &str,
353        follow_depth: u32,
354        is_incoming: bool,
355    ) -> Result<Vec<CallChain>, GraphError> {
356        let graph_map = if is_incoming {
357            &self.callers
358        } else {
359            &self.callees
360        };
361
362        if !self.definitions.contains_key(symbol) && !graph_map.contains_key(symbol) {
363            return Err(GraphError::SymbolNotFound(symbol.to_string()));
364        }
365
366        let mut chains = Vec::new();
367        let mut visited = HashSet::new();
368        let mut queue = VecDeque::new();
369        queue.push_back((symbol.to_string(), 0));
370        visited.insert(symbol.to_string());
371
372        while let Some((current, depth)) = queue.pop_front() {
373            if depth > follow_depth {
374                continue;
375            }
376
377            if let Some(neighbors) = graph_map.get(&current) {
378                for (path, line, neighbor) in neighbors {
379                    let mut chain = vec![(current.clone(), path.clone(), *line)];
380                    let mut chain_node = neighbor.clone();
381                    let mut chain_depth = depth;
382
383                    while chain_depth < follow_depth {
384                        if let Some(next_neighbors) = graph_map.get(&chain_node) {
385                            if let Some((p, l, n)) = next_neighbors.first() {
386                                if is_incoming {
387                                    chain.insert(0, (chain_node.clone(), p.clone(), *l));
388                                } else {
389                                    chain.push((chain_node.clone(), p.clone(), *l));
390                                }
391                                chain_node = n.clone();
392                                chain_depth += 1;
393                            } else {
394                                break;
395                            }
396                        } else {
397                            break;
398                        }
399                    }
400
401                    if is_incoming {
402                        chain.insert(0, (neighbor.clone(), path.clone(), *line));
403                    } else {
404                        chain.push((neighbor.clone(), path.clone(), *line));
405                    }
406                    chains.push(CallChain { chain });
407
408                    if !visited.contains(neighbor) && depth < follow_depth {
409                        visited.insert(neighbor.clone());
410                        queue.push_back((neighbor.clone(), depth + 1));
411                    }
412                }
413            }
414        }
415
416        Ok(chains)
417    }
418
419    #[instrument(skip(self))]
420    pub fn find_incoming_chains(
421        &self,
422        symbol: &str,
423        follow_depth: u32,
424    ) -> Result<Vec<CallChain>, GraphError> {
425        self.find_chains_bfs(symbol, follow_depth, true)
426    }
427
428    #[instrument(skip(self))]
429    pub fn find_outgoing_chains(
430        &self,
431        symbol: &str,
432        follow_depth: u32,
433    ) -> Result<Vec<CallChain>, GraphError> {
434        self.find_chains_bfs(symbol, follow_depth, false)
435    }
436}
437
438impl Default for CallGraph {
439    fn default() -> Self {
440        Self::new()
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use crate::types::{CallInfo, FunctionInfo};
448
449    fn make_analysis(
450        funcs: Vec<(&str, usize)>,
451        calls: Vec<(&str, &str, usize)>,
452    ) -> SemanticAnalysis {
453        SemanticAnalysis {
454            functions: funcs
455                .into_iter()
456                .map(|(n, l)| FunctionInfo {
457                    name: n.to_string(),
458                    line: l,
459                    end_line: l + 5,
460                    parameters: vec![],
461                    return_type: None,
462                })
463                .collect(),
464            classes: vec![],
465            imports: vec![],
466            references: vec![],
467            call_frequency: Default::default(),
468            calls: calls
469                .into_iter()
470                .map(|(c, e, l)| CallInfo {
471                    caller: c.to_string(),
472                    callee: e.to_string(),
473                    line: l,
474                    column: 0,
475                    arg_count: None,
476                })
477                .collect(),
478            assignments: vec![],
479            field_accesses: vec![],
480        }
481    }
482
483    fn make_typed_analysis(
484        funcs: Vec<(&str, usize, Vec<String>, Option<&str>)>,
485        calls: Vec<(&str, &str, usize, Option<usize>)>,
486    ) -> SemanticAnalysis {
487        SemanticAnalysis {
488            functions: funcs
489                .into_iter()
490                .map(|(n, l, params, ret_type)| FunctionInfo {
491                    name: n.to_string(),
492                    line: l,
493                    end_line: l + 5,
494                    parameters: params,
495                    return_type: ret_type.map(|s| s.to_string()),
496                })
497                .collect(),
498            classes: vec![],
499            imports: vec![],
500            references: vec![],
501            call_frequency: Default::default(),
502            calls: calls
503                .into_iter()
504                .map(|(c, e, l, arg_count)| CallInfo {
505                    caller: c.to_string(),
506                    callee: e.to_string(),
507                    line: l,
508                    column: 0,
509                    arg_count,
510                })
511                .collect(),
512            assignments: vec![],
513            field_accesses: vec![],
514        }
515    }
516
517    #[test]
518    fn test_graph_construction() {
519        let analysis = make_analysis(
520            vec![("main", 1), ("foo", 10), ("bar", 20)],
521            vec![("main", "foo", 2), ("foo", "bar", 15)],
522        );
523        let graph = CallGraph::build_from_results(vec![(PathBuf::from("test.rs"), analysis)])
524            .expect("Failed to build graph");
525        assert!(graph.definitions.contains_key("main"));
526        assert!(graph.definitions.contains_key("foo"));
527        assert_eq!(graph.callees["main"][0].2, "foo");
528        assert_eq!(graph.callers["foo"][0].2, "main");
529    }
530
531    #[test]
532    fn test_find_incoming_chains_depth_zero() {
533        let analysis = make_analysis(vec![("main", 1), ("foo", 10)], vec![("main", "foo", 2)]);
534        let graph = CallGraph::build_from_results(vec![(PathBuf::from("test.rs"), analysis)])
535            .expect("Failed to build graph");
536        assert!(
537            !graph
538                .find_incoming_chains("foo", 0)
539                .expect("Failed to find chains")
540                .is_empty()
541        );
542    }
543
544    #[test]
545    fn test_find_outgoing_chains_depth_zero() {
546        let analysis = make_analysis(vec![("main", 1), ("foo", 10)], vec![("main", "foo", 2)]);
547        let graph = CallGraph::build_from_results(vec![(PathBuf::from("test.rs"), analysis)])
548            .expect("Failed to build graph");
549        assert!(
550            !graph
551                .find_outgoing_chains("main", 0)
552                .expect("Failed to find chains")
553                .is_empty()
554        );
555    }
556
557    #[test]
558    fn test_symbol_not_found() {
559        assert!(
560            CallGraph::new()
561                .find_incoming_chains("nonexistent", 0)
562                .is_err()
563        );
564    }
565
566    #[test]
567    fn test_same_file_preference() {
568        // Two files each define "helper". File a.rs has a call from "main" to "helper".
569        // Assert that the graph's callees for "main" point to "helper" and the callers
570        // for "helper" include an entry from a.rs (not b.rs).
571        let analysis_a = make_analysis(
572            vec![("main", 1), ("helper", 10)],
573            vec![("main", "helper", 5)],
574        );
575        let analysis_b = make_analysis(vec![("helper", 20)], vec![]);
576
577        let graph = CallGraph::build_from_results(vec![
578            (PathBuf::from("a.rs"), analysis_a),
579            (PathBuf::from("b.rs"), analysis_b),
580        ])
581        .expect("Failed to build graph");
582
583        // Check that main calls helper
584        assert!(graph.callees.contains_key("main"));
585        let main_callees = &graph.callees["main"];
586        assert_eq!(main_callees.len(), 1);
587        assert_eq!(main_callees[0].2, "helper");
588
589        // Check that the call is from a.rs (same file as main)
590        assert_eq!(main_callees[0].0, PathBuf::from("a.rs"));
591
592        // Check that helper has a caller from a.rs
593        assert!(graph.callers.contains_key("helper"));
594        let helper_callers = &graph.callers["helper"];
595        assert!(
596            helper_callers
597                .iter()
598                .any(|(path, _, _)| path == &PathBuf::from("a.rs"))
599        );
600    }
601
602    #[test]
603    fn test_line_proximity() {
604        // One file with "process" defined at line 10 and line 50, and a call at line 12.
605        // Assert resolution picks the definition at line 10 (closest).
606        let analysis = make_analysis(
607            vec![("process", 10), ("process", 50)],
608            vec![("main", "process", 12)],
609        );
610
611        let graph = CallGraph::build_from_results(vec![(PathBuf::from("test.rs"), analysis)])
612            .expect("Failed to build graph");
613
614        // Check that main calls process
615        assert!(graph.callees.contains_key("main"));
616        let main_callees = &graph.callees["main"];
617        assert_eq!(main_callees.len(), 1);
618        assert_eq!(main_callees[0].2, "process");
619
620        // Check that process has a caller from main at line 12
621        assert!(graph.callers.contains_key("process"));
622        let process_callers = &graph.callers["process"];
623        assert!(
624            process_callers
625                .iter()
626                .any(|(_, line, caller)| *line == 12 && caller == "main")
627        );
628    }
629
630    #[test]
631    fn test_scope_prefix_stripping() {
632        // One file defines "method" at line 10. Calls use "self.method", "Type::method".
633        // Assert these resolve to "method" in the graph.
634        let analysis = make_analysis(
635            vec![("method", 10)],
636            vec![
637                ("caller1", "self.method", 5),
638                ("caller2", "Type::method", 15),
639                ("caller3", "module::method", 25),
640            ],
641        );
642
643        let graph = CallGraph::build_from_results(vec![(PathBuf::from("test.rs"), analysis)])
644            .expect("Failed to build graph");
645
646        // Check that all three callers have "method" as their callee
647        assert_eq!(graph.callees["caller1"][0].2, "method");
648        assert_eq!(graph.callees["caller2"][0].2, "method");
649        assert_eq!(graph.callees["caller3"][0].2, "method");
650
651        // Check that method has three callers
652        assert!(graph.callers.contains_key("method"));
653        let method_callers = &graph.callers["method"];
654        assert_eq!(method_callers.len(), 3);
655        assert!(
656            method_callers
657                .iter()
658                .any(|(_, _, caller)| caller == "caller1")
659        );
660        assert!(
661            method_callers
662                .iter()
663                .any(|(_, _, caller)| caller == "caller2")
664        );
665        assert!(
666            method_callers
667                .iter()
668                .any(|(_, _, caller)| caller == "caller3")
669        );
670    }
671
672    #[test]
673    fn test_no_same_file_fallback() {
674        // File a.rs calls "helper" but "helper" is only defined in b.rs.
675        // Assert the call still resolves (graph has the edge).
676        let analysis_a = make_analysis(vec![("main", 1)], vec![("main", "helper", 5)]);
677        let analysis_b = make_analysis(vec![("helper", 10)], vec![]);
678
679        let graph = CallGraph::build_from_results(vec![
680            (PathBuf::from("a.rs"), analysis_a),
681            (PathBuf::from("b.rs"), analysis_b),
682        ])
683        .expect("Failed to build graph");
684
685        // Check that main calls helper
686        assert!(graph.callees.contains_key("main"));
687        let main_callees = &graph.callees["main"];
688        assert_eq!(main_callees.len(), 1);
689        assert_eq!(main_callees[0].2, "helper");
690
691        // Check that helper has a caller from a.rs
692        assert!(graph.callers.contains_key("helper"));
693        let helper_callers = &graph.callers["helper"];
694        assert!(
695            helper_callers
696                .iter()
697                .any(|(path, _, caller)| { path == &PathBuf::from("a.rs") && caller == "main" })
698        );
699    }
700
701    #[test]
702    fn test_type_disambiguation_by_params() {
703        // Two functions named 'process' in the same file with different parameter counts.
704        // process(x: i32) at line 10, process(x: i32, y: String) at line 12.
705        // Call from main at line 11 is equidistant from both (1 line away).
706        // Type matching should prefer the 2-param version since arg_count=2.
707        let analysis = make_typed_analysis(
708            vec![
709                ("process", 10, vec!["(x: i32)".to_string()], Some("i32")),
710                (
711                    "process",
712                    12,
713                    vec!["(x: i32, y: String)".to_string()],
714                    Some("String"),
715                ),
716                ("main", 1, vec![], None),
717            ],
718            vec![("main", "process", 11, Some(2))],
719        );
720
721        let graph = CallGraph::build_from_results(vec![(PathBuf::from("test.rs"), analysis)])
722            .expect("Failed to build graph");
723
724        // Check that main calls process
725        assert!(graph.callees.contains_key("main"));
726        let main_callees = &graph.callees["main"];
727        assert_eq!(main_callees.len(), 1);
728        assert_eq!(main_callees[0].2, "process");
729
730        // Check that process has a caller from main at line 11
731        assert!(graph.callers.contains_key("process"));
732        let process_callers = &graph.callers["process"];
733        assert!(
734            process_callers
735                .iter()
736                .any(|(_, line, caller)| *line == 11 && caller == "main")
737        );
738    }
739
740    #[test]
741    fn test_type_disambiguation_fallback() {
742        // Two functions named 'process' with no type info (empty parameters, None return_type).
743        // Call from main at line 12 should resolve using line proximity (no regression).
744        // arg_count=None means type matching won't fire, fallback to line proximity.
745        let analysis = make_analysis(
746            vec![("process", 10), ("process", 50), ("main", 1)],
747            vec![("main", "process", 12)],
748        );
749
750        let graph = CallGraph::build_from_results(vec![(PathBuf::from("test.rs"), analysis)])
751            .expect("Failed to build graph");
752
753        // Check that main calls process
754        assert!(graph.callees.contains_key("main"));
755        let main_callees = &graph.callees["main"];
756        assert_eq!(main_callees.len(), 1);
757        assert_eq!(main_callees[0].2, "process");
758
759        // Check that process has a caller from main
760        assert!(graph.callers.contains_key("process"));
761        let process_callers = &graph.callers["process"];
762        assert!(
763            process_callers
764                .iter()
765                .any(|(_, line, caller)| *line == 12 && caller == "main")
766        );
767    }
768}