Skip to main content

sqry_cli/commands/
subgraph.rs

1//! Subgraph command implementation
2//!
3//! Provides CLI interface for extracting a focused subgraph around symbols.
4
5use crate::args::Cli;
6use crate::commands::graph::loader::{GraphLoadConfig, load_unified_graph};
7use crate::index_discovery::find_nearest_index;
8use crate::output::OutputStreams;
9use anyhow::{Context, Result, anyhow};
10use serde::Serialize;
11use sqry_core::graph::unified::edge::EdgeKind;
12use sqry_core::graph::unified::node::NodeId;
13use std::collections::{HashSet, VecDeque};
14
15/// Subgraph output
16#[derive(Debug, Serialize)]
17struct SubgraphOutput {
18    /// Seed symbols
19    seeds: Vec<String>,
20    /// Nodes in the subgraph
21    nodes: Vec<SubgraphNode>,
22    /// Edges in the subgraph
23    edges: Vec<SubgraphEdge>,
24    /// Statistics
25    stats: SubgraphStats,
26}
27
28#[derive(Debug, Clone, Serialize)]
29struct SubgraphNode {
30    id: String,
31    name: String,
32    qualified_name: String,
33    kind: String,
34    file: String,
35    line: u32,
36    language: String,
37    /// Whether this is a seed node
38    is_seed: bool,
39    /// Depth from nearest seed
40    depth: usize,
41}
42
43#[derive(Debug, Clone, Serialize)]
44struct SubgraphEdge {
45    source: String,
46    target: String,
47    kind: String,
48}
49
50#[derive(Debug, Serialize)]
51struct SubgraphStats {
52    node_count: usize,
53    edge_count: usize,
54    max_depth_reached: usize,
55}
56
57/// Find seed nodes in the graph matching the given symbol names.
58fn find_seed_nodes(
59    graph: &sqry_core::graph::unified::concurrent::CodeGraph,
60    symbols: &[String],
61) -> Vec<NodeId> {
62    let strings = graph.strings();
63    let mut seed_nodes: Vec<NodeId> = Vec::new();
64
65    for symbol in symbols {
66        let found = graph.nodes().iter().find(|(_, entry)| {
67            // Check qualified name
68            if let Some(qn_id) = entry.qualified_name
69                && let Some(qn) = strings.resolve(qn_id)
70                && (qn.as_ref() == symbol.as_str() || qn.contains(symbol.as_str()))
71            {
72                return true;
73            }
74            // Check simple name
75            if let Some(name) = strings.resolve(entry.name)
76                && name.as_ref() == symbol.as_str()
77            {
78                return true;
79            }
80            false
81        });
82
83        if let Some((node_id, _)) = found {
84            seed_nodes.push(node_id);
85        }
86    }
87
88    seed_nodes
89}
90
91/// Result of BFS subgraph collection.
92struct SubgraphBfsResult {
93    visited: HashSet<NodeId>,
94    node_depths: std::collections::HashMap<NodeId, usize>,
95    collected_edges: Vec<(NodeId, NodeId, String)>,
96    max_depth_reached: usize,
97}
98
99/// Process outgoing edges (callees) from a node during BFS traversal.
100#[allow(clippy::too_many_arguments)]
101fn process_callee_edges(
102    graph: &sqry_core::graph::unified::concurrent::CodeGraph,
103    node_id: NodeId,
104    include_imports: bool,
105    collected_edges: &mut Vec<(NodeId, NodeId, String)>,
106    visited: &mut HashSet<NodeId>,
107    node_depths: &mut std::collections::HashMap<NodeId, usize>,
108    queue: &mut VecDeque<(NodeId, usize)>,
109    depth: usize,
110    max_nodes: usize,
111) {
112    for edge_ref in graph.edges().edges_from(node_id) {
113        let is_call = matches!(edge_ref.kind, EdgeKind::Calls { .. });
114        let is_import = matches!(edge_ref.kind, EdgeKind::Imports { .. });
115
116        if is_call || (include_imports && is_import) {
117            let kind_str = format!("{:?}", edge_ref.kind);
118            collected_edges.push((node_id, edge_ref.target, kind_str));
119
120            if !visited.contains(&edge_ref.target) && visited.len() < max_nodes {
121                visited.insert(edge_ref.target);
122                node_depths.insert(edge_ref.target, depth + 1);
123                queue.push_back((edge_ref.target, depth + 1));
124            }
125        }
126    }
127}
128
129/// Process incoming edges (callers) to a node during BFS traversal.
130fn process_caller_edges(
131    graph: &sqry_core::graph::unified::concurrent::CodeGraph,
132    node_id: NodeId,
133    collected_edges: &mut Vec<(NodeId, NodeId, String)>,
134    visited: &mut HashSet<NodeId>,
135    node_depths: &mut std::collections::HashMap<NodeId, usize>,
136    queue: &mut VecDeque<(NodeId, usize)>,
137    depth: usize,
138    max_nodes: usize,
139) {
140    for edge_ref in graph.edges().edges_to(node_id) {
141        if matches!(edge_ref.kind, EdgeKind::Calls { .. }) {
142            let kind_str = format!("{:?}", edge_ref.kind);
143            collected_edges.push((edge_ref.source, node_id, kind_str));
144
145            if !visited.contains(&edge_ref.source) && visited.len() < max_nodes {
146                visited.insert(edge_ref.source);
147                node_depths.insert(edge_ref.source, depth + 1);
148                queue.push_back((edge_ref.source, depth + 1));
149            }
150        }
151    }
152}
153
154/// Collect a subgraph via BFS from seed nodes, following callers and/or callees.
155#[allow(clippy::similar_names)]
156fn collect_subgraph_bfs(
157    graph: &sqry_core::graph::unified::concurrent::CodeGraph,
158    seed_nodes: &[NodeId],
159    max_depth: usize,
160    max_nodes: usize,
161    include_callers: bool,
162    include_callees: bool,
163    include_imports: bool,
164) -> SubgraphBfsResult {
165    let mut visited: HashSet<NodeId> = HashSet::new();
166    let mut node_depths: std::collections::HashMap<NodeId, usize> =
167        std::collections::HashMap::new();
168    let mut queue: VecDeque<(NodeId, usize)> = VecDeque::new();
169    let mut collected_edges: Vec<(NodeId, NodeId, String)> = Vec::new();
170
171    // Initialize with seeds
172    for &seed in seed_nodes {
173        visited.insert(seed);
174        node_depths.insert(seed, 0);
175        queue.push_back((seed, 0));
176    }
177
178    let mut max_depth_reached = 0;
179
180    while let Some((node_id, depth)) = queue.pop_front() {
181        if visited.len() >= max_nodes {
182            break;
183        }
184        if depth >= max_depth {
185            continue;
186        }
187
188        max_depth_reached = max_depth_reached.max(depth);
189
190        // Process outgoing edges (callees)
191        if include_callees {
192            process_callee_edges(
193                graph,
194                node_id,
195                include_imports,
196                &mut collected_edges,
197                &mut visited,
198                &mut node_depths,
199                &mut queue,
200                depth,
201                max_nodes,
202            );
203        }
204
205        // Process incoming edges (callers)
206        if include_callers {
207            process_caller_edges(
208                graph,
209                node_id,
210                &mut collected_edges,
211                &mut visited,
212                &mut node_depths,
213                &mut queue,
214                depth,
215                max_nodes,
216            );
217        }
218    }
219
220    SubgraphBfsResult {
221        visited,
222        node_depths,
223        collected_edges,
224        max_depth_reached,
225    }
226}
227
228/// Infer language display name from a file extension.
229fn extension_to_display_language(ext: &str) -> &str {
230    match ext {
231        "rs" => "Rust",
232        "py" => "Python",
233        "js" => "JavaScript",
234        "ts" => "TypeScript",
235        "go" => "Go",
236        "java" => "Java",
237        "c" | "h" => "C",
238        "cpp" | "hpp" | "cc" => "C++",
239        "rb" => "Ruby",
240        "swift" => "Swift",
241        "kt" => "Kotlin",
242        _ => ext,
243    }
244}
245
246/// Build `SubgraphNode` list from visited node IDs in BFS results.
247fn build_subgraph_nodes(
248    graph: &sqry_core::graph::unified::concurrent::CodeGraph,
249    bfs: &SubgraphBfsResult,
250    seed_nodes: &[NodeId],
251) -> Vec<SubgraphNode> {
252    let strings = graph.strings();
253    let files = graph.files();
254    let seed_set: HashSet<_> = seed_nodes.iter().collect();
255
256    let mut nodes: Vec<SubgraphNode> = bfs
257        .visited
258        .iter()
259        .filter_map(|&node_id| {
260            let entry = graph.nodes().get(node_id)?;
261            let name = strings
262                .resolve(entry.name)
263                .map(|s| s.to_string())
264                .unwrap_or_default();
265            let qualified_name = entry
266                .qualified_name
267                .and_then(|id| strings.resolve(id))
268                .map_or_else(|| name.clone(), |s| s.to_string());
269
270            let file_path = files
271                .resolve(entry.file)
272                .map(|p| p.display().to_string())
273                .unwrap_or_default();
274
275            // Infer language from file extension
276            let language = files.resolve(entry.file).map_or_else(
277                || "Unknown".to_string(),
278                |p| {
279                    p.extension()
280                        .and_then(|ext| ext.to_str())
281                        .map_or("Unknown", extension_to_display_language)
282                        .to_string()
283                },
284            );
285
286            Some(SubgraphNode {
287                id: qualified_name.clone(),
288                name,
289                qualified_name,
290                kind: format!("{:?}", entry.kind),
291                file: file_path,
292                line: entry.start_line,
293                language,
294                is_seed: seed_set.contains(&node_id),
295                depth: *bfs.node_depths.get(&node_id).unwrap_or(&0),
296            })
297        })
298        .collect();
299
300    // Sort for determinism
301    nodes.sort_by(|a, b| a.qualified_name.cmp(&b.qualified_name));
302    nodes
303}
304
305/// Build `SubgraphEdge` list from collected edges, resolving node IDs to names.
306fn build_subgraph_edges(
307    graph: &sqry_core::graph::unified::concurrent::CodeGraph,
308    bfs: &SubgraphBfsResult,
309) -> Vec<SubgraphEdge> {
310    let strings = graph.strings();
311
312    // Create a map from NodeId to qualified_name for edge resolution
313    let node_names: std::collections::HashMap<NodeId, String> = bfs
314        .visited
315        .iter()
316        .filter_map(|&node_id| {
317            let entry = graph.nodes().get(node_id)?;
318            let name = strings
319                .resolve(entry.name)
320                .map(|s| s.to_string())
321                .unwrap_or_default();
322            let qn = entry
323                .qualified_name
324                .and_then(|id| strings.resolve(id))
325                .map_or_else(|| name, |s| s.to_string());
326            Some((node_id, qn))
327        })
328        .collect();
329
330    let mut edges: Vec<SubgraphEdge> = bfs
331        .collected_edges
332        .iter()
333        .filter(|(src, tgt, _)| bfs.visited.contains(src) && bfs.visited.contains(tgt))
334        .filter_map(|(src, tgt, kind)| {
335            let src_name = node_names.get(src)?.clone();
336            let tgt_name = node_names.get(tgt)?.clone();
337            Some(SubgraphEdge {
338                source: src_name,
339                target: tgt_name,
340                kind: kind.clone(),
341            })
342        })
343        .collect();
344
345    // Deduplicate edges
346    edges.sort_by(|a, b| (&a.source, &a.target, &a.kind).cmp(&(&b.source, &b.target, &b.kind)));
347    edges.dedup_by(|a, b| a.source == b.source && a.target == b.target && a.kind == b.kind);
348    edges
349}
350
351/// Run the subgraph command.
352///
353/// # Errors
354/// Returns an error if the graph cannot be loaded or symbols cannot be found.
355#[allow(clippy::similar_names)]
356// Callers/callees naming mirrors CLI flag semantics.
357pub fn run_subgraph(
358    cli: &Cli,
359    symbols: &[String],
360    path: Option<&str>,
361    max_depth: usize,
362    max_nodes: usize,
363    include_callers: bool,
364    include_callees: bool,
365    include_imports: bool,
366) -> Result<()> {
367    let mut streams = OutputStreams::new();
368
369    if symbols.is_empty() {
370        return Err(anyhow!("At least one seed symbol is required"));
371    }
372
373    // Find index
374    let search_path = path.map_or_else(
375        || std::env::current_dir().unwrap_or_default(),
376        std::path::PathBuf::from,
377    );
378
379    let index_location = find_nearest_index(&search_path);
380    let Some(ref loc) = index_location else {
381        streams
382            .write_diagnostic("No .sqry-index found. Run 'sqry index' first to build the index.")?;
383        return Ok(());
384    };
385
386    // Load graph
387    let config = GraphLoadConfig::default();
388    let graph = load_unified_graph(&loc.index_root, &config)
389        .context("Failed to load graph. Run 'sqry index' to build the graph.")?;
390
391    // Find seed nodes
392    let seed_nodes = find_seed_nodes(&graph, symbols);
393    if seed_nodes.is_empty() {
394        streams.write_diagnostic("No seed symbols found in the graph.")?;
395        return Ok(());
396    }
397
398    // BFS to collect subgraph
399    let bfs = collect_subgraph_bfs(
400        &graph,
401        &seed_nodes,
402        max_depth,
403        max_nodes,
404        include_callers,
405        include_callees,
406        include_imports,
407    );
408
409    // Build output
410    let nodes = build_subgraph_nodes(&graph, &bfs, &seed_nodes);
411    let edges = build_subgraph_edges(&graph, &bfs);
412
413    let stats = SubgraphStats {
414        node_count: nodes.len(),
415        edge_count: edges.len(),
416        max_depth_reached: bfs.max_depth_reached,
417    };
418
419    let output = SubgraphOutput {
420        seeds: symbols.to_vec(),
421        nodes,
422        edges,
423        stats,
424    };
425
426    // Output
427    if cli.json {
428        let json = serde_json::to_string_pretty(&output).context("Failed to serialize to JSON")?;
429        streams.write_result(&json)?;
430    } else {
431        let text = format_subgraph_text(&output);
432        streams.write_result(&text)?;
433    }
434
435    Ok(())
436}
437
438fn format_subgraph_text(output: &SubgraphOutput) -> String {
439    let mut lines = Vec::new();
440
441    lines.push(format!(
442        "Subgraph around {} seed(s): {}",
443        output.seeds.len(),
444        output.seeds.join(", ")
445    ));
446    lines.push(format!(
447        "Stats: {} nodes, {} edges, max depth {}",
448        output.stats.node_count, output.stats.edge_count, output.stats.max_depth_reached
449    ));
450    lines.push(String::new());
451
452    lines.push("Nodes:".to_string());
453    for node in &output.nodes {
454        let seed_marker = if node.is_seed { " [SEED]" } else { "" };
455        lines.push(format!(
456            "  {} [{}] depth={}{} ",
457            node.qualified_name, node.kind, node.depth, seed_marker
458        ));
459        lines.push(format!("    {}:{}", node.file, node.line));
460    }
461
462    if !output.edges.is_empty() {
463        lines.push(String::new());
464        lines.push("Edges:".to_string());
465        for edge in &output.edges {
466            lines.push(format!(
467                "  {} --[{}]--> {}",
468                edge.source, edge.kind, edge.target
469            ));
470        }
471    }
472
473    lines.join("\n")
474}