Skip to main content

sqry_cli/commands/
subgraph.rs

1//! Subgraph command implementation
2//!
3//! Provides CLI interface for extracting a focused subgraph around symbols.
4
5use crate::args::Cli;
6use crate::commands::graph::loader::{GraphLoadConfig, load_unified_graph_for_cli};
7use crate::index_discovery::find_nearest_index;
8use crate::output::OutputStreams;
9use anyhow::{Context, Result, anyhow};
10use serde::Serialize;
11use sqry_core::graph::unified::node::NodeId;
12use sqry_core::graph::unified::{
13    EdgeFilter, TraversalConfig, TraversalDirection, TraversalLimits, traverse,
14};
15use std::collections::HashSet;
16
17/// Subgraph output
18#[derive(Debug, Serialize)]
19struct SubgraphOutput {
20    /// Seed symbols
21    seeds: Vec<String>,
22    /// Nodes in the subgraph
23    nodes: Vec<SubgraphNode>,
24    /// Edges in the subgraph
25    edges: Vec<SubgraphEdge>,
26    /// Statistics
27    stats: SubgraphStats,
28}
29
30#[derive(Debug, Clone, Serialize)]
31struct SubgraphNode {
32    id: String,
33    name: String,
34    qualified_name: String,
35    kind: String,
36    file: String,
37    line: u32,
38    language: String,
39    /// Whether this is a seed node
40    is_seed: bool,
41    /// Depth from nearest seed
42    depth: usize,
43}
44
45#[derive(Debug, Clone, Serialize)]
46struct SubgraphEdge {
47    source: String,
48    target: String,
49    kind: String,
50}
51
52#[derive(Debug, Serialize)]
53struct SubgraphStats {
54    node_count: usize,
55    edge_count: usize,
56    max_depth_reached: usize,
57}
58
59/// Find seed nodes in the graph matching the given symbol names.
60fn find_seed_nodes(
61    graph: &sqry_core::graph::unified::concurrent::CodeGraph,
62    symbols: &[String],
63) -> Vec<NodeId> {
64    let strings = graph.strings();
65    let mut seed_nodes: Vec<NodeId> = Vec::new();
66
67    for symbol in symbols {
68        let found = graph.nodes().iter().find(|(_, entry)| {
69            // Gate 0d iter-2 fix: skip unified losers from CLI
70            // `subgraph` seed lookup. See `NodeEntry::is_unified_loser`.
71            if entry.is_unified_loser() {
72                return false;
73            }
74            // Check qualified name
75            if let Some(qn_id) = entry.qualified_name
76                && let Some(qn) = strings.resolve(qn_id)
77                && (qn.as_ref() == symbol.as_str() || qn.contains(symbol.as_str()))
78            {
79                return true;
80            }
81            // Check simple name
82            if let Some(name) = strings.resolve(entry.name)
83                && name.as_ref() == symbol.as_str()
84            {
85                return true;
86            }
87            false
88        });
89
90        if let Some((node_id, _)) = found {
91            seed_nodes.push(node_id);
92        }
93    }
94
95    seed_nodes
96}
97
98/// Result of BFS subgraph collection.
99struct SubgraphBfsResult {
100    visited: HashSet<NodeId>,
101    node_depths: std::collections::HashMap<NodeId, usize>,
102    collected_edges: Vec<(NodeId, NodeId, String)>,
103    max_depth_reached: usize,
104}
105
106/// Collect a subgraph via BFS from seed nodes, following callers and/or callees.
107///
108/// Uses the traversal kernel with configurable direction and edge filter.
109/// Converts the kernel's `TraversalResult` into the `SubgraphBfsResult`
110/// expected by downstream code.
111///
112/// # Frontier invariant (DB19)
113///
114/// `seed_nodes` is a `&[NodeId]` — seeds are resolved once at the
115/// handler boundary. The [`traverse`] kernel walks the NodeId-keyed
116/// graph and never re-resolves names at depth ≥ 1. This is the same
117/// invariant that DB17/DB18 locked for `trace_path`, `call-chain-depth`,
118/// and `dependency-tree`: a same-simple-name node at depth 1 never
119/// broadens the frontier because the kernel operates on
120/// generational-indexed `NodeId`s, not names.
121///
122/// Regressions that reintroduced DB15-class same-name broadening would
123/// surface in the CLI `migration_golden_cli_test` golden tests.
124#[allow(clippy::similar_names)]
125fn collect_subgraph_bfs(
126    graph: &sqry_core::graph::unified::concurrent::CodeGraph,
127    seed_nodes: &[NodeId],
128    max_depth: usize,
129    max_nodes: usize,
130    include_callers: bool,
131    include_callees: bool,
132    include_imports: bool,
133) -> SubgraphBfsResult {
134    let snapshot = graph.snapshot();
135
136    let direction = match (include_callers, include_callees) {
137        (true, true) => TraversalDirection::Both,
138        (true, false) => TraversalDirection::Incoming,
139        #[allow(clippy::match_same_arms)] // Subgraph mode arms intentionally separate
140        (false, true) => TraversalDirection::Outgoing,
141        // If neither is selected, default to outgoing (no edges will match anyway)
142        (false, false) => TraversalDirection::Outgoing,
143    };
144
145    let edge_filter = if include_imports {
146        EdgeFilter::calls_and_imports()
147    } else {
148        EdgeFilter::calls_only()
149    };
150
151    let config = TraversalConfig {
152        direction,
153        edge_filter,
154        limits: TraversalLimits {
155            max_depth: u32::try_from(max_depth).unwrap_or(u32::MAX),
156            max_nodes: Some(max_nodes),
157            max_edges: None,
158            max_paths: None,
159        },
160    };
161
162    let result = traverse(&snapshot, seed_nodes, &config, None);
163
164    let mut visited: HashSet<NodeId> = HashSet::new();
165    let mut node_depths: std::collections::HashMap<NodeId, usize> =
166        std::collections::HashMap::new();
167    let mut collected_edges: Vec<(NodeId, NodeId, String)> = Vec::new();
168    let mut max_depth_reached: usize = 0;
169
170    // Populate visited set and node depths from traversal result
171    for (idx, mat_node) in result.nodes.iter().enumerate() {
172        visited.insert(mat_node.node_id);
173
174        // Seed nodes have depth 0; others get depth from the first edge
175        let depth = if seed_nodes.contains(&mat_node.node_id) {
176            0
177        } else {
178            result
179                .edges
180                .iter()
181                .filter(|e| e.source_idx == idx || e.target_idx == idx)
182                .map(|e| e.depth as usize)
183                .min()
184                .unwrap_or(0)
185        };
186
187        node_depths.insert(mat_node.node_id, depth);
188        max_depth_reached = max_depth_reached.max(depth);
189    }
190
191    // Convert edges to (source_id, target_id, kind_str) tuples
192    for edge in &result.edges {
193        let source_id = result.nodes[edge.source_idx].node_id;
194        let target_id = result.nodes[edge.target_idx].node_id;
195        let kind_str = format!("{:?}", edge.raw_kind);
196        collected_edges.push((source_id, target_id, kind_str));
197    }
198
199    SubgraphBfsResult {
200        visited,
201        node_depths,
202        collected_edges,
203        max_depth_reached,
204    }
205}
206
207/// Infer language display name from a file extension.
208fn extension_to_display_language(ext: &str) -> &str {
209    match ext {
210        "rs" => "Rust",
211        "py" => "Python",
212        "js" => "JavaScript",
213        "ts" => "TypeScript",
214        "go" => "Go",
215        "java" => "Java",
216        "c" | "h" => "C",
217        "cpp" | "hpp" | "cc" => "C++",
218        "rb" => "Ruby",
219        "swift" => "Swift",
220        "kt" => "Kotlin",
221        _ => ext,
222    }
223}
224
225/// Build `SubgraphNode` list from visited node IDs in BFS results.
226fn build_subgraph_nodes(
227    graph: &sqry_core::graph::unified::concurrent::CodeGraph,
228    bfs: &SubgraphBfsResult,
229    seed_nodes: &[NodeId],
230) -> Vec<SubgraphNode> {
231    let strings = graph.strings();
232    let files = graph.files();
233    let seed_set: HashSet<_> = seed_nodes.iter().collect();
234
235    let mut nodes: Vec<SubgraphNode> = bfs
236        .visited
237        .iter()
238        .filter_map(|&node_id| {
239            let entry = graph.nodes().get(node_id)?;
240            let name = strings
241                .resolve(entry.name)
242                .map(|s| s.to_string())
243                .unwrap_or_default();
244            let qualified_name = entry
245                .qualified_name
246                .and_then(|id| strings.resolve(id))
247                .map_or_else(|| name.clone(), |s| s.to_string());
248
249            let file_path = files
250                .resolve(entry.file)
251                .map(|p| p.display().to_string())
252                .unwrap_or_default();
253
254            // Infer language from file extension
255            let language = files.resolve(entry.file).map_or_else(
256                || "Unknown".to_string(),
257                |p| {
258                    p.extension()
259                        .and_then(|ext| ext.to_str())
260                        .map_or("Unknown", extension_to_display_language)
261                        .to_string()
262                },
263            );
264
265            Some(SubgraphNode {
266                id: qualified_name.clone(),
267                name,
268                qualified_name,
269                kind: format!("{:?}", entry.kind),
270                file: file_path,
271                line: entry.start_line,
272                language,
273                is_seed: seed_set.contains(&node_id),
274                depth: *bfs.node_depths.get(&node_id).unwrap_or(&0),
275            })
276        })
277        .collect();
278
279    // Sort for determinism
280    nodes.sort_by(|a, b| a.qualified_name.cmp(&b.qualified_name));
281    nodes
282}
283
284/// Build `SubgraphEdge` list from collected edges, resolving node IDs to names.
285fn build_subgraph_edges(
286    graph: &sqry_core::graph::unified::concurrent::CodeGraph,
287    bfs: &SubgraphBfsResult,
288) -> Vec<SubgraphEdge> {
289    let strings = graph.strings();
290
291    // Create a map from NodeId to qualified_name for edge resolution
292    let node_names: std::collections::HashMap<NodeId, String> = bfs
293        .visited
294        .iter()
295        .filter_map(|&node_id| {
296            let entry = graph.nodes().get(node_id)?;
297            let name = strings
298                .resolve(entry.name)
299                .map(|s| s.to_string())
300                .unwrap_or_default();
301            let qn = entry
302                .qualified_name
303                .and_then(|id| strings.resolve(id))
304                .map_or_else(|| name, |s| s.to_string());
305            Some((node_id, qn))
306        })
307        .collect();
308
309    let mut edges: Vec<SubgraphEdge> = bfs
310        .collected_edges
311        .iter()
312        .filter(|(src, tgt, _)| bfs.visited.contains(src) && bfs.visited.contains(tgt))
313        .filter_map(|(src, tgt, kind)| {
314            let src_name = node_names.get(src)?.clone();
315            let tgt_name = node_names.get(tgt)?.clone();
316            Some(SubgraphEdge {
317                source: src_name,
318                target: tgt_name,
319                kind: kind.clone(),
320            })
321        })
322        .collect();
323
324    // Deduplicate edges
325    edges.sort_by(|a, b| (&a.source, &a.target, &a.kind).cmp(&(&b.source, &b.target, &b.kind)));
326    edges.dedup_by(|a, b| a.source == b.source && a.target == b.target && a.kind == b.kind);
327    edges
328}
329
330/// Run the subgraph command.
331///
332/// # Errors
333/// Returns an error if the graph cannot be loaded or symbols cannot be found.
334#[allow(clippy::similar_names)]
335// Callers/callees naming mirrors CLI flag semantics.
336pub fn run_subgraph(
337    cli: &Cli,
338    symbols: &[String],
339    path: Option<&str>,
340    max_depth: usize,
341    max_nodes: usize,
342    include_callers: bool,
343    include_callees: bool,
344    include_imports: bool,
345) -> Result<()> {
346    let mut streams = OutputStreams::new();
347
348    if symbols.is_empty() {
349        return Err(anyhow!("At least one seed symbol is required"));
350    }
351
352    // Find index
353    let search_path = path.map_or_else(
354        || std::env::current_dir().unwrap_or_default(),
355        std::path::PathBuf::from,
356    );
357
358    let index_location = find_nearest_index(&search_path);
359    let Some(ref loc) = index_location else {
360        streams
361            .write_diagnostic("No .sqry-index found. Run 'sqry index' first to build the index.")?;
362        return Ok(());
363    };
364
365    // Load graph
366    let config = GraphLoadConfig::default();
367    let graph = load_unified_graph_for_cli(&loc.index_root, &config, cli)
368        .context("Failed to load graph. Run 'sqry index' to build the graph.")?;
369
370    // Find seed nodes
371    let seed_nodes = find_seed_nodes(&graph, symbols);
372    if seed_nodes.is_empty() {
373        streams.write_diagnostic("No seed symbols found in the graph.")?;
374        return Ok(());
375    }
376
377    // BFS to collect subgraph
378    let bfs = collect_subgraph_bfs(
379        &graph,
380        &seed_nodes,
381        max_depth,
382        max_nodes,
383        include_callers,
384        include_callees,
385        include_imports,
386    );
387
388    // Build output
389    let nodes = build_subgraph_nodes(&graph, &bfs, &seed_nodes);
390    let edges = build_subgraph_edges(&graph, &bfs);
391
392    let stats = SubgraphStats {
393        node_count: nodes.len(),
394        edge_count: edges.len(),
395        max_depth_reached: bfs.max_depth_reached,
396    };
397
398    let output = SubgraphOutput {
399        seeds: symbols.to_vec(),
400        nodes,
401        edges,
402        stats,
403    };
404
405    // Output
406    if cli.json {
407        let json = serde_json::to_string_pretty(&output).context("Failed to serialize to JSON")?;
408        streams.write_result(&json)?;
409    } else {
410        let text = format_subgraph_text(&output);
411        streams.write_result(&text)?;
412    }
413
414    Ok(())
415}
416
417fn format_subgraph_text(output: &SubgraphOutput) -> String {
418    let mut lines = Vec::new();
419
420    lines.push(format!(
421        "Subgraph around {} seed(s): {}",
422        output.seeds.len(),
423        output.seeds.join(", ")
424    ));
425    lines.push(format!(
426        "Stats: {} nodes, {} edges, max depth {}",
427        output.stats.node_count, output.stats.edge_count, output.stats.max_depth_reached
428    ));
429    lines.push(String::new());
430
431    lines.push("Nodes:".to_string());
432    for node in &output.nodes {
433        let seed_marker = if node.is_seed { " [SEED]" } else { "" };
434        lines.push(format!(
435            "  {} [{}] depth={}{} ",
436            node.qualified_name, node.kind, node.depth, seed_marker
437        ));
438        lines.push(format!("    {}:{}", node.file, node.line));
439    }
440
441    if !output.edges.is_empty() {
442        lines.push(String::new());
443        lines.push("Edges:".to_string());
444        for edge in &output.edges {
445            lines.push(format!(
446                "  {} --[{}]--> {}",
447                edge.source, edge.kind, edge.target
448            ));
449        }
450    }
451
452    lines.join("\n")
453}