1use crate::args::Cli;
6use crate::commands::graph::loader::{GraphLoadConfig, load_unified_graph_for_cli};
7use crate::index_discovery::find_nearest_index;
8use crate::output::OutputStreams;
9use anyhow::{Context, Result, anyhow};
10use serde::Serialize;
11use sqry_core::graph::unified::node::NodeId;
12use sqry_core::graph::unified::{
13 EdgeFilter, TraversalConfig, TraversalDirection, TraversalLimits, traverse,
14};
15use std::collections::HashSet;
16
17#[derive(Debug, Serialize)]
19struct SubgraphOutput {
20 seeds: Vec<String>,
22 nodes: Vec<SubgraphNode>,
24 edges: Vec<SubgraphEdge>,
26 stats: SubgraphStats,
28}
29
30#[derive(Debug, Clone, Serialize)]
31struct SubgraphNode {
32 id: String,
33 name: String,
34 qualified_name: String,
35 kind: String,
36 file: String,
37 line: u32,
38 language: String,
39 is_seed: bool,
41 depth: usize,
43}
44
45#[derive(Debug, Clone, Serialize)]
46struct SubgraphEdge {
47 source: String,
48 target: String,
49 kind: String,
50}
51
52#[derive(Debug, Serialize)]
53struct SubgraphStats {
54 node_count: usize,
55 edge_count: usize,
56 max_depth_reached: usize,
57}
58
59fn find_seed_nodes(
61 graph: &sqry_core::graph::unified::concurrent::CodeGraph,
62 symbols: &[String],
63) -> Vec<NodeId> {
64 let strings = graph.strings();
65 let mut seed_nodes: Vec<NodeId> = Vec::new();
66
67 for symbol in symbols {
68 let found = graph.nodes().iter().find(|(_, entry)| {
69 if let Some(qn_id) = entry.qualified_name
71 && let Some(qn) = strings.resolve(qn_id)
72 && (qn.as_ref() == symbol.as_str() || qn.contains(symbol.as_str()))
73 {
74 return true;
75 }
76 if let Some(name) = strings.resolve(entry.name)
78 && name.as_ref() == symbol.as_str()
79 {
80 return true;
81 }
82 false
83 });
84
85 if let Some((node_id, _)) = found {
86 seed_nodes.push(node_id);
87 }
88 }
89
90 seed_nodes
91}
92
93struct SubgraphBfsResult {
95 visited: HashSet<NodeId>,
96 node_depths: std::collections::HashMap<NodeId, usize>,
97 collected_edges: Vec<(NodeId, NodeId, String)>,
98 max_depth_reached: usize,
99}
100
101#[allow(clippy::similar_names)]
107fn collect_subgraph_bfs(
108 graph: &sqry_core::graph::unified::concurrent::CodeGraph,
109 seed_nodes: &[NodeId],
110 max_depth: usize,
111 max_nodes: usize,
112 include_callers: bool,
113 include_callees: bool,
114 include_imports: bool,
115) -> SubgraphBfsResult {
116 let snapshot = graph.snapshot();
117
118 let direction = match (include_callers, include_callees) {
119 (true, true) => TraversalDirection::Both,
120 (true, false) => TraversalDirection::Incoming,
121 #[allow(clippy::match_same_arms)] (false, true) => TraversalDirection::Outgoing,
123 (false, false) => TraversalDirection::Outgoing,
125 };
126
127 let edge_filter = if include_imports {
128 EdgeFilter::calls_and_imports()
129 } else {
130 EdgeFilter::calls_only()
131 };
132
133 let config = TraversalConfig {
134 direction,
135 edge_filter,
136 limits: TraversalLimits {
137 max_depth: u32::try_from(max_depth).unwrap_or(u32::MAX),
138 max_nodes: Some(max_nodes),
139 max_edges: None,
140 max_paths: None,
141 },
142 };
143
144 let result = traverse(&snapshot, seed_nodes, &config, None);
145
146 let mut visited: HashSet<NodeId> = HashSet::new();
147 let mut node_depths: std::collections::HashMap<NodeId, usize> =
148 std::collections::HashMap::new();
149 let mut collected_edges: Vec<(NodeId, NodeId, String)> = Vec::new();
150 let mut max_depth_reached: usize = 0;
151
152 for (idx, mat_node) in result.nodes.iter().enumerate() {
154 visited.insert(mat_node.node_id);
155
156 let depth = if seed_nodes.contains(&mat_node.node_id) {
158 0
159 } else {
160 result
161 .edges
162 .iter()
163 .filter(|e| e.source_idx == idx || e.target_idx == idx)
164 .map(|e| e.depth as usize)
165 .min()
166 .unwrap_or(0)
167 };
168
169 node_depths.insert(mat_node.node_id, depth);
170 max_depth_reached = max_depth_reached.max(depth);
171 }
172
173 for edge in &result.edges {
175 let source_id = result.nodes[edge.source_idx].node_id;
176 let target_id = result.nodes[edge.target_idx].node_id;
177 let kind_str = format!("{:?}", edge.raw_kind);
178 collected_edges.push((source_id, target_id, kind_str));
179 }
180
181 SubgraphBfsResult {
182 visited,
183 node_depths,
184 collected_edges,
185 max_depth_reached,
186 }
187}
188
189fn extension_to_display_language(ext: &str) -> &str {
191 match ext {
192 "rs" => "Rust",
193 "py" => "Python",
194 "js" => "JavaScript",
195 "ts" => "TypeScript",
196 "go" => "Go",
197 "java" => "Java",
198 "c" | "h" => "C",
199 "cpp" | "hpp" | "cc" => "C++",
200 "rb" => "Ruby",
201 "swift" => "Swift",
202 "kt" => "Kotlin",
203 _ => ext,
204 }
205}
206
207fn build_subgraph_nodes(
209 graph: &sqry_core::graph::unified::concurrent::CodeGraph,
210 bfs: &SubgraphBfsResult,
211 seed_nodes: &[NodeId],
212) -> Vec<SubgraphNode> {
213 let strings = graph.strings();
214 let files = graph.files();
215 let seed_set: HashSet<_> = seed_nodes.iter().collect();
216
217 let mut nodes: Vec<SubgraphNode> = bfs
218 .visited
219 .iter()
220 .filter_map(|&node_id| {
221 let entry = graph.nodes().get(node_id)?;
222 let name = strings
223 .resolve(entry.name)
224 .map(|s| s.to_string())
225 .unwrap_or_default();
226 let qualified_name = entry
227 .qualified_name
228 .and_then(|id| strings.resolve(id))
229 .map_or_else(|| name.clone(), |s| s.to_string());
230
231 let file_path = files
232 .resolve(entry.file)
233 .map(|p| p.display().to_string())
234 .unwrap_or_default();
235
236 let language = files.resolve(entry.file).map_or_else(
238 || "Unknown".to_string(),
239 |p| {
240 p.extension()
241 .and_then(|ext| ext.to_str())
242 .map_or("Unknown", extension_to_display_language)
243 .to_string()
244 },
245 );
246
247 Some(SubgraphNode {
248 id: qualified_name.clone(),
249 name,
250 qualified_name,
251 kind: format!("{:?}", entry.kind),
252 file: file_path,
253 line: entry.start_line,
254 language,
255 is_seed: seed_set.contains(&node_id),
256 depth: *bfs.node_depths.get(&node_id).unwrap_or(&0),
257 })
258 })
259 .collect();
260
261 nodes.sort_by(|a, b| a.qualified_name.cmp(&b.qualified_name));
263 nodes
264}
265
266fn build_subgraph_edges(
268 graph: &sqry_core::graph::unified::concurrent::CodeGraph,
269 bfs: &SubgraphBfsResult,
270) -> Vec<SubgraphEdge> {
271 let strings = graph.strings();
272
273 let node_names: std::collections::HashMap<NodeId, String> = bfs
275 .visited
276 .iter()
277 .filter_map(|&node_id| {
278 let entry = graph.nodes().get(node_id)?;
279 let name = strings
280 .resolve(entry.name)
281 .map(|s| s.to_string())
282 .unwrap_or_default();
283 let qn = entry
284 .qualified_name
285 .and_then(|id| strings.resolve(id))
286 .map_or_else(|| name, |s| s.to_string());
287 Some((node_id, qn))
288 })
289 .collect();
290
291 let mut edges: Vec<SubgraphEdge> = bfs
292 .collected_edges
293 .iter()
294 .filter(|(src, tgt, _)| bfs.visited.contains(src) && bfs.visited.contains(tgt))
295 .filter_map(|(src, tgt, kind)| {
296 let src_name = node_names.get(src)?.clone();
297 let tgt_name = node_names.get(tgt)?.clone();
298 Some(SubgraphEdge {
299 source: src_name,
300 target: tgt_name,
301 kind: kind.clone(),
302 })
303 })
304 .collect();
305
306 edges.sort_by(|a, b| (&a.source, &a.target, &a.kind).cmp(&(&b.source, &b.target, &b.kind)));
308 edges.dedup_by(|a, b| a.source == b.source && a.target == b.target && a.kind == b.kind);
309 edges
310}
311
312#[allow(clippy::similar_names)]
317pub fn run_subgraph(
319 cli: &Cli,
320 symbols: &[String],
321 path: Option<&str>,
322 max_depth: usize,
323 max_nodes: usize,
324 include_callers: bool,
325 include_callees: bool,
326 include_imports: bool,
327) -> Result<()> {
328 let mut streams = OutputStreams::new();
329
330 if symbols.is_empty() {
331 return Err(anyhow!("At least one seed symbol is required"));
332 }
333
334 let search_path = path.map_or_else(
336 || std::env::current_dir().unwrap_or_default(),
337 std::path::PathBuf::from,
338 );
339
340 let index_location = find_nearest_index(&search_path);
341 let Some(ref loc) = index_location else {
342 streams
343 .write_diagnostic("No .sqry-index found. Run 'sqry index' first to build the index.")?;
344 return Ok(());
345 };
346
347 let config = GraphLoadConfig::default();
349 let graph = load_unified_graph_for_cli(&loc.index_root, &config, cli)
350 .context("Failed to load graph. Run 'sqry index' to build the graph.")?;
351
352 let seed_nodes = find_seed_nodes(&graph, symbols);
354 if seed_nodes.is_empty() {
355 streams.write_diagnostic("No seed symbols found in the graph.")?;
356 return Ok(());
357 }
358
359 let bfs = collect_subgraph_bfs(
361 &graph,
362 &seed_nodes,
363 max_depth,
364 max_nodes,
365 include_callers,
366 include_callees,
367 include_imports,
368 );
369
370 let nodes = build_subgraph_nodes(&graph, &bfs, &seed_nodes);
372 let edges = build_subgraph_edges(&graph, &bfs);
373
374 let stats = SubgraphStats {
375 node_count: nodes.len(),
376 edge_count: edges.len(),
377 max_depth_reached: bfs.max_depth_reached,
378 };
379
380 let output = SubgraphOutput {
381 seeds: symbols.to_vec(),
382 nodes,
383 edges,
384 stats,
385 };
386
387 if cli.json {
389 let json = serde_json::to_string_pretty(&output).context("Failed to serialize to JSON")?;
390 streams.write_result(&json)?;
391 } else {
392 let text = format_subgraph_text(&output);
393 streams.write_result(&text)?;
394 }
395
396 Ok(())
397}
398
399fn format_subgraph_text(output: &SubgraphOutput) -> String {
400 let mut lines = Vec::new();
401
402 lines.push(format!(
403 "Subgraph around {} seed(s): {}",
404 output.seeds.len(),
405 output.seeds.join(", ")
406 ));
407 lines.push(format!(
408 "Stats: {} nodes, {} edges, max depth {}",
409 output.stats.node_count, output.stats.edge_count, output.stats.max_depth_reached
410 ));
411 lines.push(String::new());
412
413 lines.push("Nodes:".to_string());
414 for node in &output.nodes {
415 let seed_marker = if node.is_seed { " [SEED]" } else { "" };
416 lines.push(format!(
417 " {} [{}] depth={}{} ",
418 node.qualified_name, node.kind, node.depth, seed_marker
419 ));
420 lines.push(format!(" {}:{}", node.file, node.line));
421 }
422
423 if !output.edges.is_empty() {
424 lines.push(String::new());
425 lines.push("Edges:".to_string());
426 for edge in &output.edges {
427 lines.push(format!(
428 " {} --[{}]--> {}",
429 edge.source, edge.kind, edge.target
430 ));
431 }
432 }
433
434 lines.join("\n")
435}