lean_ctx/tools/
ctx_graph_diagram.rs1use std::collections::HashMap;
2
3use crate::core::call_graph::CallGraph;
4use crate::core::graph_index;
5
6const DEFAULT_MAX_NODES: usize = 30;
7const DEFAULT_DEPTH: usize = 2;
8
9pub fn handle(
10 file: Option<&str>,
11 depth: Option<usize>,
12 kind: Option<&str>,
13 project_root: &str,
14) -> String {
15 let max_depth = depth.unwrap_or(DEFAULT_DEPTH);
16 let graph_kind = kind.unwrap_or("deps");
17
18 match graph_kind {
19 "calls" => render_call_graph(file, max_depth, project_root),
20 _ => render_dep_graph(file, max_depth, project_root),
21 }
22}
23
24fn render_dep_graph(file: Option<&str>, depth: usize, project_root: &str) -> String {
25 let index = graph_index::load_or_build(project_root);
26
27 if index.edges.is_empty() {
28 return "No dependency edges found in project index.".to_string();
29 }
30
31 let edges: Vec<_> = if let Some(focus) = file {
32 let reachable = bfs_reachable_files(focus, &index.edges, depth);
33 index
34 .edges
35 .iter()
36 .filter(|e| reachable.contains(e.from.as_str()) || reachable.contains(e.to.as_str()))
37 .collect()
38 } else {
39 index.edges.iter().collect()
40 };
41
42 if edges.is_empty() {
43 return format!(
44 "No dependency edges found{}",
45 file.map(|f| format!(" for '{f}'")).unwrap_or_default()
46 );
47 }
48
49 let top_edges = select_top_edges(&edges, DEFAULT_MAX_NODES);
50
51 let mut mermaid = String::from("```mermaid\nflowchart TD\n");
52 for edge in &top_edges {
53 let from_id = sanitize_node_id(&edge.from);
54 let to_id = sanitize_node_id(&edge.to);
55 let from_label = shorten_path(&edge.from);
56 let to_label = shorten_path(&edge.to);
57 mermaid.push_str(&format!(
58 " {from_id}[\"{from_label}\"] -->|{}| {to_id}[\"{to_label}\"]\n",
59 edge.kind
60 ));
61 }
62 mermaid.push_str("```");
63
64 let total = index.edges.len();
65 let shown = top_edges.len();
66 if shown < total {
67 format!("{mermaid}\n\n({shown}/{total} edges shown, top by connectivity)")
68 } else {
69 mermaid
70 }
71}
72
73fn bfs_reachable_files(
74 start: &str,
75 edges: &[graph_index::IndexEdge],
76 max_depth: usize,
77) -> std::collections::HashSet<String> {
78 let mut visited = std::collections::HashSet::new();
79 let mut queue: std::collections::VecDeque<(String, usize)> = std::collections::VecDeque::new();
80
81 for edge in edges {
82 if edge.from.contains(start) || edge.to.contains(start) {
83 if edge.from.contains(start) {
84 visited.insert(edge.from.clone());
85 queue.push_back((edge.from.clone(), 0));
86 }
87 if edge.to.contains(start) {
88 visited.insert(edge.to.clone());
89 queue.push_back((edge.to.clone(), 0));
90 }
91 }
92 }
93
94 while let Some((node, d)) = queue.pop_front() {
95 if d >= max_depth {
96 continue;
97 }
98 for edge in edges {
99 let neighbor = if edge.from == node {
100 &edge.to
101 } else if edge.to == node {
102 &edge.from
103 } else {
104 continue;
105 };
106 if visited.insert(neighbor.clone()) {
107 queue.push_back((neighbor.clone(), d + 1));
108 }
109 }
110 }
111
112 visited
113}
114
115fn render_call_graph(file: Option<&str>, _depth: usize, project_root: &str) -> String {
116 let index = graph_index::load_or_build(project_root);
117 let call_graph = CallGraph::load_or_build(project_root, &index);
118 let _ = call_graph.save();
119
120 if call_graph.edges.is_empty() {
121 return "No call edges found. Run ctx_callgraph first to build the call graph.".to_string();
122 }
123
124 let edges: Vec<_> = if let Some(focus) = file {
125 call_graph
126 .edges
127 .iter()
128 .filter(|e| {
129 e.caller_file.contains(focus)
130 || e.caller_symbol.contains(focus)
131 || e.callee_name.contains(focus)
132 })
133 .collect()
134 } else {
135 call_graph.edges.iter().collect()
136 };
137
138 if edges.is_empty() {
139 return format!(
140 "No call edges found{}",
141 file.map(|f| format!(" matching '{f}'")).unwrap_or_default()
142 );
143 }
144
145 let top_nodes = select_top_call_nodes(&edges, DEFAULT_MAX_NODES);
146
147 let mut mermaid = String::from("```mermaid\nflowchart LR\n");
148 let mut seen = std::collections::HashSet::new();
149
150 for edge in &edges {
151 if !top_nodes.contains(&edge.caller_symbol.as_str())
152 && !top_nodes.contains(&edge.callee_name.as_str())
153 {
154 continue;
155 }
156 let key = format!("{}→{}", edge.caller_symbol, edge.callee_name);
157 if !seen.insert(key) {
158 continue;
159 }
160 let from_id = sanitize_node_id(&edge.caller_symbol);
161 let to_id = sanitize_node_id(&edge.callee_name);
162 mermaid.push_str(&format!(" {from_id} --> {to_id}\n"));
163 }
164 mermaid.push_str("```");
165
166 let total = call_graph.edges.len();
167 let shown = seen.len();
168 if shown < total {
169 format!("{mermaid}\n\n({shown}/{total} call edges shown, top by connectivity)")
170 } else {
171 mermaid
172 }
173}
174
175fn select_top_edges<'a>(
176 edges: &'a [&'a graph_index::IndexEdge],
177 max_nodes: usize,
178) -> Vec<&'a graph_index::IndexEdge> {
179 let mut node_counts: HashMap<&str, usize> = HashMap::new();
180 for edge in edges {
181 *node_counts.entry(&edge.from).or_insert(0) += 1;
182 *node_counts.entry(&edge.to).or_insert(0) += 1;
183 }
184
185 let mut nodes_sorted: Vec<_> = node_counts.into_iter().collect();
186 nodes_sorted.sort_by_key(|x| std::cmp::Reverse(x.1));
187 let top: std::collections::HashSet<&str> = nodes_sorted
188 .iter()
189 .take(max_nodes)
190 .map(|(n, _)| *n)
191 .collect();
192
193 edges
194 .iter()
195 .filter(|e| top.contains(e.from.as_str()) || top.contains(e.to.as_str()))
196 .copied()
197 .collect()
198}
199
200fn select_top_call_nodes<'a>(
201 edges: &[&'a crate::core::call_graph::CallEdge],
202 max_nodes: usize,
203) -> std::collections::HashSet<&'a str> {
204 let mut counts: HashMap<&str, usize> = HashMap::new();
205 for edge in edges {
206 *counts.entry(&edge.caller_symbol).or_insert(0) += 1;
207 *counts.entry(&edge.callee_name).or_insert(0) += 1;
208 }
209
210 let mut sorted: Vec<_> = counts.into_iter().collect();
211 sorted.sort_by_key(|x| std::cmp::Reverse(x.1));
212 sorted.into_iter().take(max_nodes).map(|(n, _)| n).collect()
213}
214
215fn sanitize_node_id(name: &str) -> String {
216 name.chars()
217 .map(|c| {
218 if c.is_alphanumeric() || c == '_' {
219 c
220 } else {
221 '_'
222 }
223 })
224 .collect()
225}
226
227fn shorten_path(path: &str) -> String {
228 let parts: Vec<&str> = path.split('/').collect();
229 if parts.len() <= 2 {
230 return path.to_string();
231 }
232 let last_two = &parts[parts.len() - 2..];
233 format!("…/{}", last_two.join("/"))
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn sanitize_node_id_removes_special_chars() {
242 assert_eq!(sanitize_node_id("src/main.rs"), "src_main_rs");
243 assert_eq!(sanitize_node_id("foo::bar"), "foo__bar");
244 }
245
246 #[test]
247 fn shorten_path_keeps_short_paths() {
248 assert_eq!(shorten_path("main.rs"), "main.rs");
249 assert_eq!(shorten_path("src/main.rs"), "src/main.rs");
250 }
251
252 #[test]
253 fn shorten_path_truncates_long_paths() {
254 assert_eq!(shorten_path("a/b/c/main.rs"), "…/c/main.rs");
255 }
256
257 #[test]
258 fn render_dep_graph_empty_index() {
259 let result = render_dep_graph(None, 2, "/nonexistent/path");
260 assert!(result.contains("No dependency edges") || result.contains("flowchart"));
261 }
262
263 #[test]
264 fn render_call_graph_empty() {
265 let result = render_call_graph(None, 2, "/nonexistent/path");
266 assert!(result.contains("No call edges") || result.contains("flowchart"));
267 }
268}