1use crate::args::Cli;
6use crate::commands::graph::loader::{GraphLoadConfig, load_unified_graph_for_cli};
7use crate::index_discovery::find_nearest_index;
8use crate::output::OutputStreams;
9use anyhow::{Context, Result, anyhow};
10use serde::Serialize;
11use sqry_core::graph::unified::node::NodeId;
12use sqry_core::graph::unified::{
13 EdgeFilter, TraversalConfig, TraversalDirection, TraversalLimits, traverse,
14};
15use std::collections::HashSet;
16
17#[derive(Debug, Serialize)]
19struct SubgraphOutput {
20 seeds: Vec<String>,
22 nodes: Vec<SubgraphNode>,
24 edges: Vec<SubgraphEdge>,
26 stats: SubgraphStats,
28}
29
30#[derive(Debug, Clone, Serialize)]
31struct SubgraphNode {
32 id: String,
33 name: String,
34 qualified_name: String,
35 kind: String,
36 file: String,
37 line: u32,
38 language: String,
39 is_seed: bool,
41 depth: usize,
43}
44
45#[derive(Debug, Clone, Serialize)]
46struct SubgraphEdge {
47 source: String,
48 target: String,
49 kind: String,
50}
51
52#[derive(Debug, Serialize)]
53struct SubgraphStats {
54 node_count: usize,
55 edge_count: usize,
56 max_depth_reached: usize,
57}
58
59fn find_seed_nodes(
61 graph: &sqry_core::graph::unified::concurrent::CodeGraph,
62 symbols: &[String],
63) -> Vec<NodeId> {
64 let strings = graph.strings();
65 let mut seed_nodes: Vec<NodeId> = Vec::new();
66
67 for symbol in symbols {
68 let found = graph.nodes().iter().find(|(_, entry)| {
69 if entry.is_unified_loser() {
72 return false;
73 }
74 if let Some(qn_id) = entry.qualified_name
76 && let Some(qn) = strings.resolve(qn_id)
77 && (qn.as_ref() == symbol.as_str() || qn.contains(symbol.as_str()))
78 {
79 return true;
80 }
81 if let Some(name) = strings.resolve(entry.name)
83 && name.as_ref() == symbol.as_str()
84 {
85 return true;
86 }
87 false
88 });
89
90 if let Some((node_id, _)) = found {
91 seed_nodes.push(node_id);
92 }
93 }
94
95 seed_nodes
96}
97
98struct SubgraphBfsResult {
100 visited: HashSet<NodeId>,
101 node_depths: std::collections::HashMap<NodeId, usize>,
102 collected_edges: Vec<(NodeId, NodeId, String)>,
103 max_depth_reached: usize,
104}
105
106#[allow(clippy::similar_names)]
125fn collect_subgraph_bfs(
126 graph: &sqry_core::graph::unified::concurrent::CodeGraph,
127 seed_nodes: &[NodeId],
128 max_depth: usize,
129 max_nodes: usize,
130 include_callers: bool,
131 include_callees: bool,
132 include_imports: bool,
133) -> SubgraphBfsResult {
134 let snapshot = graph.snapshot();
135
136 let direction = match (include_callers, include_callees) {
137 (true, true) => TraversalDirection::Both,
138 (true, false) => TraversalDirection::Incoming,
139 #[allow(clippy::match_same_arms)] (false, true) => TraversalDirection::Outgoing,
141 (false, false) => TraversalDirection::Outgoing,
143 };
144
145 let edge_filter = if include_imports {
146 EdgeFilter::calls_and_imports()
147 } else {
148 EdgeFilter::calls_only()
149 };
150
151 let config = TraversalConfig {
152 direction,
153 edge_filter,
154 limits: TraversalLimits {
155 max_depth: u32::try_from(max_depth).unwrap_or(u32::MAX),
156 max_nodes: Some(max_nodes),
157 max_edges: None,
158 max_paths: None,
159 },
160 };
161
162 let result = traverse(&snapshot, seed_nodes, &config, None);
163
164 let mut visited: HashSet<NodeId> = HashSet::new();
165 let mut node_depths: std::collections::HashMap<NodeId, usize> =
166 std::collections::HashMap::new();
167 let mut collected_edges: Vec<(NodeId, NodeId, String)> = Vec::new();
168 let mut max_depth_reached: usize = 0;
169
170 for (idx, mat_node) in result.nodes.iter().enumerate() {
172 visited.insert(mat_node.node_id);
173
174 let depth = if seed_nodes.contains(&mat_node.node_id) {
176 0
177 } else {
178 result
179 .edges
180 .iter()
181 .filter(|e| e.source_idx == idx || e.target_idx == idx)
182 .map(|e| e.depth as usize)
183 .min()
184 .unwrap_or(0)
185 };
186
187 node_depths.insert(mat_node.node_id, depth);
188 max_depth_reached = max_depth_reached.max(depth);
189 }
190
191 for edge in &result.edges {
193 let source_id = result.nodes[edge.source_idx].node_id;
194 let target_id = result.nodes[edge.target_idx].node_id;
195 let kind_str = format!("{:?}", edge.raw_kind);
196 collected_edges.push((source_id, target_id, kind_str));
197 }
198
199 SubgraphBfsResult {
200 visited,
201 node_depths,
202 collected_edges,
203 max_depth_reached,
204 }
205}
206
207fn extension_to_display_language(ext: &str) -> &str {
209 match ext {
210 "rs" => "Rust",
211 "py" => "Python",
212 "js" => "JavaScript",
213 "ts" => "TypeScript",
214 "go" => "Go",
215 "java" => "Java",
216 "c" | "h" => "C",
217 "cpp" | "hpp" | "cc" => "C++",
218 "rb" => "Ruby",
219 "swift" => "Swift",
220 "kt" => "Kotlin",
221 _ => ext,
222 }
223}
224
225fn build_subgraph_nodes(
227 graph: &sqry_core::graph::unified::concurrent::CodeGraph,
228 bfs: &SubgraphBfsResult,
229 seed_nodes: &[NodeId],
230) -> Vec<SubgraphNode> {
231 let strings = graph.strings();
232 let files = graph.files();
233 let seed_set: HashSet<_> = seed_nodes.iter().collect();
234
235 let mut nodes: Vec<SubgraphNode> = bfs
236 .visited
237 .iter()
238 .filter_map(|&node_id| {
239 let entry = graph.nodes().get(node_id)?;
240 let name = strings
241 .resolve(entry.name)
242 .map(|s| s.to_string())
243 .unwrap_or_default();
244 let qualified_name = entry
245 .qualified_name
246 .and_then(|id| strings.resolve(id))
247 .map_or_else(|| name.clone(), |s| s.to_string());
248
249 let file_path = files
250 .resolve(entry.file)
251 .map(|p| p.display().to_string())
252 .unwrap_or_default();
253
254 let language = files.resolve(entry.file).map_or_else(
256 || "Unknown".to_string(),
257 |p| {
258 p.extension()
259 .and_then(|ext| ext.to_str())
260 .map_or("Unknown", extension_to_display_language)
261 .to_string()
262 },
263 );
264
265 Some(SubgraphNode {
266 id: qualified_name.clone(),
267 name,
268 qualified_name,
269 kind: format!("{:?}", entry.kind),
270 file: file_path,
271 line: entry.start_line,
272 language,
273 is_seed: seed_set.contains(&node_id),
274 depth: *bfs.node_depths.get(&node_id).unwrap_or(&0),
275 })
276 })
277 .collect();
278
279 nodes.sort_by(|a, b| a.qualified_name.cmp(&b.qualified_name));
281 nodes
282}
283
284fn build_subgraph_edges(
286 graph: &sqry_core::graph::unified::concurrent::CodeGraph,
287 bfs: &SubgraphBfsResult,
288) -> Vec<SubgraphEdge> {
289 let strings = graph.strings();
290
291 let node_names: std::collections::HashMap<NodeId, String> = bfs
293 .visited
294 .iter()
295 .filter_map(|&node_id| {
296 let entry = graph.nodes().get(node_id)?;
297 let name = strings
298 .resolve(entry.name)
299 .map(|s| s.to_string())
300 .unwrap_or_default();
301 let qn = entry
302 .qualified_name
303 .and_then(|id| strings.resolve(id))
304 .map_or_else(|| name, |s| s.to_string());
305 Some((node_id, qn))
306 })
307 .collect();
308
309 let mut edges: Vec<SubgraphEdge> = bfs
310 .collected_edges
311 .iter()
312 .filter(|(src, tgt, _)| bfs.visited.contains(src) && bfs.visited.contains(tgt))
313 .filter_map(|(src, tgt, kind)| {
314 let src_name = node_names.get(src)?.clone();
315 let tgt_name = node_names.get(tgt)?.clone();
316 Some(SubgraphEdge {
317 source: src_name,
318 target: tgt_name,
319 kind: kind.clone(),
320 })
321 })
322 .collect();
323
324 edges.sort_by(|a, b| (&a.source, &a.target, &a.kind).cmp(&(&b.source, &b.target, &b.kind)));
326 edges.dedup_by(|a, b| a.source == b.source && a.target == b.target && a.kind == b.kind);
327 edges
328}
329
330#[allow(clippy::similar_names)]
335pub fn run_subgraph(
337 cli: &Cli,
338 symbols: &[String],
339 path: Option<&str>,
340 max_depth: usize,
341 max_nodes: usize,
342 include_callers: bool,
343 include_callees: bool,
344 include_imports: bool,
345) -> Result<()> {
346 let mut streams = OutputStreams::new();
347
348 if symbols.is_empty() {
349 return Err(anyhow!("At least one seed symbol is required"));
350 }
351
352 let search_path = path.map_or_else(
354 || std::env::current_dir().unwrap_or_default(),
355 std::path::PathBuf::from,
356 );
357
358 let index_location = find_nearest_index(&search_path);
359 let Some(ref loc) = index_location else {
360 streams
361 .write_diagnostic("No .sqry-index found. Run 'sqry index' first to build the index.")?;
362 return Ok(());
363 };
364
365 let config = GraphLoadConfig::default();
367 let graph = load_unified_graph_for_cli(&loc.index_root, &config, cli)
368 .context("Failed to load graph. Run 'sqry index' to build the graph.")?;
369
370 let seed_nodes = find_seed_nodes(&graph, symbols);
372 if seed_nodes.is_empty() {
373 streams.write_diagnostic("No seed symbols found in the graph.")?;
374 return Ok(());
375 }
376
377 let bfs = collect_subgraph_bfs(
379 &graph,
380 &seed_nodes,
381 max_depth,
382 max_nodes,
383 include_callers,
384 include_callees,
385 include_imports,
386 );
387
388 let nodes = build_subgraph_nodes(&graph, &bfs, &seed_nodes);
390 let edges = build_subgraph_edges(&graph, &bfs);
391
392 let stats = SubgraphStats {
393 node_count: nodes.len(),
394 edge_count: edges.len(),
395 max_depth_reached: bfs.max_depth_reached,
396 };
397
398 let output = SubgraphOutput {
399 seeds: symbols.to_vec(),
400 nodes,
401 edges,
402 stats,
403 };
404
405 if cli.json {
407 let json = serde_json::to_string_pretty(&output).context("Failed to serialize to JSON")?;
408 streams.write_result(&json)?;
409 } else {
410 let text = format_subgraph_text(&output);
411 streams.write_result(&text)?;
412 }
413
414 Ok(())
415}
416
417fn format_subgraph_text(output: &SubgraphOutput) -> String {
418 let mut lines = Vec::new();
419
420 lines.push(format!(
421 "Subgraph around {} seed(s): {}",
422 output.seeds.len(),
423 output.seeds.join(", ")
424 ));
425 lines.push(format!(
426 "Stats: {} nodes, {} edges, max depth {}",
427 output.stats.node_count, output.stats.edge_count, output.stats.max_depth_reached
428 ));
429 lines.push(String::new());
430
431 lines.push("Nodes:".to_string());
432 for node in &output.nodes {
433 let seed_marker = if node.is_seed { " [SEED]" } else { "" };
434 lines.push(format!(
435 " {} [{}] depth={}{} ",
436 node.qualified_name, node.kind, node.depth, seed_marker
437 ));
438 lines.push(format!(" {}:{}", node.file, node.line));
439 }
440
441 if !output.edges.is_empty() {
442 lines.push(String::new());
443 lines.push("Edges:".to_string());
444 for edge in &output.edges {
445 lines.push(format!(
446 " {} --[{}]--> {}",
447 edge.source, edge.kind, edge.target
448 ));
449 }
450 }
451
452 lines.join("\n")
453}