Skip to main content

lean_ctx/tools/
ctx_graph_diagram.rs

1use std::collections::HashMap;
2
3use crate::core::call_graph::CallGraph;
4use crate::core::graph_index;
5
6const DEFAULT_MAX_NODES: usize = 30;
7const DEFAULT_DEPTH: usize = 2;
8
9pub fn handle(
10    file: Option<&str>,
11    depth: Option<usize>,
12    kind: Option<&str>,
13    project_root: &str,
14) -> String {
15    let max_depth = depth.unwrap_or(DEFAULT_DEPTH);
16    let graph_kind = kind.unwrap_or("deps");
17
18    match graph_kind {
19        "calls" => render_call_graph(file, max_depth, project_root),
20        _ => render_dep_graph(file, max_depth, project_root),
21    }
22}
23
24fn render_dep_graph(file: Option<&str>, depth: usize, project_root: &str) -> String {
25    let index = graph_index::load_or_build(project_root);
26
27    if index.edges.is_empty() {
28        return "No dependency edges found in project index.".to_string();
29    }
30
31    let edges: Vec<_> = if let Some(focus) = file {
32        let reachable = bfs_reachable_files(focus, &index.edges, depth);
33        index
34            .edges
35            .iter()
36            .filter(|e| reachable.contains(e.from.as_str()) || reachable.contains(e.to.as_str()))
37            .collect()
38    } else {
39        index.edges.iter().collect()
40    };
41
42    if edges.is_empty() {
43        return format!(
44            "No dependency edges found{}",
45            file.map(|f| format!(" for '{f}'")).unwrap_or_default()
46        );
47    }
48
49    let top_edges = select_top_edges(&edges, DEFAULT_MAX_NODES);
50
51    let mut mermaid = String::from("```mermaid\nflowchart TD\n");
52    for edge in &top_edges {
53        let from_id = sanitize_node_id(&edge.from);
54        let to_id = sanitize_node_id(&edge.to);
55        let from_label = shorten_path(&edge.from);
56        let to_label = shorten_path(&edge.to);
57        mermaid.push_str(&format!(
58            "    {from_id}[\"{from_label}\"] -->|{}| {to_id}[\"{to_label}\"]\n",
59            edge.kind
60        ));
61    }
62    mermaid.push_str("```");
63
64    let total = index.edges.len();
65    let shown = top_edges.len();
66    if shown < total {
67        format!("{mermaid}\n\n({shown}/{total} edges shown, top by connectivity)")
68    } else {
69        mermaid
70    }
71}
72
73fn bfs_reachable_files(
74    start: &str,
75    edges: &[graph_index::IndexEdge],
76    max_depth: usize,
77) -> std::collections::HashSet<String> {
78    let mut visited = std::collections::HashSet::new();
79    let mut queue: std::collections::VecDeque<(String, usize)> = std::collections::VecDeque::new();
80
81    for edge in edges {
82        if edge.from.contains(start) || edge.to.contains(start) {
83            if edge.from.contains(start) {
84                visited.insert(edge.from.clone());
85                queue.push_back((edge.from.clone(), 0));
86            }
87            if edge.to.contains(start) {
88                visited.insert(edge.to.clone());
89                queue.push_back((edge.to.clone(), 0));
90            }
91        }
92    }
93
94    while let Some((node, d)) = queue.pop_front() {
95        if d >= max_depth {
96            continue;
97        }
98        for edge in edges {
99            let neighbor = if edge.from == node {
100                &edge.to
101            } else if edge.to == node {
102                &edge.from
103            } else {
104                continue;
105            };
106            if visited.insert(neighbor.clone()) {
107                queue.push_back((neighbor.clone(), d + 1));
108            }
109        }
110    }
111
112    visited
113}
114
115fn render_call_graph(file: Option<&str>, _depth: usize, project_root: &str) -> String {
116    let index = graph_index::load_or_build(project_root);
117    let call_graph = CallGraph::load_or_build(project_root, &index);
118    let _ = call_graph.save();
119
120    if call_graph.edges.is_empty() {
121        return "No call edges found. Run ctx_callgraph first to build the call graph.".to_string();
122    }
123
124    let edges: Vec<_> = if let Some(focus) = file {
125        call_graph
126            .edges
127            .iter()
128            .filter(|e| {
129                e.caller_file.contains(focus)
130                    || e.caller_symbol.contains(focus)
131                    || e.callee_name.contains(focus)
132            })
133            .collect()
134    } else {
135        call_graph.edges.iter().collect()
136    };
137
138    if edges.is_empty() {
139        return format!(
140            "No call edges found{}",
141            file.map(|f| format!(" matching '{f}'")).unwrap_or_default()
142        );
143    }
144
145    let top_nodes = select_top_call_nodes(&edges, DEFAULT_MAX_NODES);
146
147    let mut mermaid = String::from("```mermaid\nflowchart LR\n");
148    let mut seen = std::collections::HashSet::new();
149
150    for edge in &edges {
151        if !top_nodes.contains(&edge.caller_symbol.as_str())
152            && !top_nodes.contains(&edge.callee_name.as_str())
153        {
154            continue;
155        }
156        let key = format!("{}→{}", edge.caller_symbol, edge.callee_name);
157        if !seen.insert(key) {
158            continue;
159        }
160        let from_id = sanitize_node_id(&edge.caller_symbol);
161        let to_id = sanitize_node_id(&edge.callee_name);
162        mermaid.push_str(&format!("    {from_id} --> {to_id}\n"));
163    }
164    mermaid.push_str("```");
165
166    let total = call_graph.edges.len();
167    let shown = seen.len();
168    if shown < total {
169        format!("{mermaid}\n\n({shown}/{total} call edges shown, top by connectivity)")
170    } else {
171        mermaid
172    }
173}
174
175fn select_top_edges<'a>(
176    edges: &'a [&'a graph_index::IndexEdge],
177    max_nodes: usize,
178) -> Vec<&'a graph_index::IndexEdge> {
179    let mut node_counts: HashMap<&str, usize> = HashMap::new();
180    for edge in edges {
181        *node_counts.entry(&edge.from).or_insert(0) += 1;
182        *node_counts.entry(&edge.to).or_insert(0) += 1;
183    }
184
185    let mut nodes_sorted: Vec<_> = node_counts.into_iter().collect();
186    nodes_sorted.sort_by_key(|x| std::cmp::Reverse(x.1));
187    let top: std::collections::HashSet<&str> = nodes_sorted
188        .iter()
189        .take(max_nodes)
190        .map(|(n, _)| *n)
191        .collect();
192
193    edges
194        .iter()
195        .filter(|e| top.contains(e.from.as_str()) || top.contains(e.to.as_str()))
196        .copied()
197        .collect()
198}
199
200fn select_top_call_nodes<'a>(
201    edges: &[&'a crate::core::call_graph::CallEdge],
202    max_nodes: usize,
203) -> std::collections::HashSet<&'a str> {
204    let mut counts: HashMap<&str, usize> = HashMap::new();
205    for edge in edges {
206        *counts.entry(&edge.caller_symbol).or_insert(0) += 1;
207        *counts.entry(&edge.callee_name).or_insert(0) += 1;
208    }
209
210    let mut sorted: Vec<_> = counts.into_iter().collect();
211    sorted.sort_by_key(|x| std::cmp::Reverse(x.1));
212    sorted.into_iter().take(max_nodes).map(|(n, _)| n).collect()
213}
214
215fn sanitize_node_id(name: &str) -> String {
216    name.chars()
217        .map(|c| {
218            if c.is_alphanumeric() || c == '_' {
219                c
220            } else {
221                '_'
222            }
223        })
224        .collect()
225}
226
227fn shorten_path(path: &str) -> String {
228    let parts: Vec<&str> = path.split('/').collect();
229    if parts.len() <= 2 {
230        return path.to_string();
231    }
232    let last_two = &parts[parts.len() - 2..];
233    format!("…/{}", last_two.join("/"))
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn sanitize_node_id_removes_special_chars() {
242        assert_eq!(sanitize_node_id("src/main.rs"), "src_main_rs");
243        assert_eq!(sanitize_node_id("foo::bar"), "foo__bar");
244    }
245
246    #[test]
247    fn shorten_path_keeps_short_paths() {
248        assert_eq!(shorten_path("main.rs"), "main.rs");
249        assert_eq!(shorten_path("src/main.rs"), "src/main.rs");
250    }
251
252    #[test]
253    fn shorten_path_truncates_long_paths() {
254        assert_eq!(shorten_path("a/b/c/main.rs"), "…/c/main.rs");
255    }
256
257    #[test]
258    fn render_dep_graph_empty_index() {
259        let result = render_dep_graph(None, 2, "/nonexistent/path");
260        assert!(result.contains("No dependency edges") || result.contains("flowchart"));
261    }
262
263    #[test]
264    fn render_call_graph_empty() {
265        let result = render_call_graph(None, 2, "/nonexistent/path");
266        assert!(result.contains("No call edges") || result.contains("flowchart"));
267    }
268}