Skip to main content

sqry_cli/commands/
subgraph.rs

1//! Subgraph command implementation
2//!
3//! Provides CLI interface for extracting a focused subgraph around symbols.
4
5use crate::args::Cli;
6use crate::commands::graph::loader::{GraphLoadConfig, load_unified_graph_for_cli};
7use crate::index_discovery::find_nearest_index;
8use crate::output::OutputStreams;
9use anyhow::{Context, Result, anyhow};
10use serde::Serialize;
11use sqry_core::graph::unified::node::NodeId;
12use sqry_core::graph::unified::{
13    EdgeFilter, TraversalConfig, TraversalDirection, TraversalLimits, traverse,
14};
15use std::collections::HashSet;
16
17/// Subgraph output
18#[derive(Debug, Serialize)]
19struct SubgraphOutput {
20    /// Seed symbols
21    seeds: Vec<String>,
22    /// Nodes in the subgraph
23    nodes: Vec<SubgraphNode>,
24    /// Edges in the subgraph
25    edges: Vec<SubgraphEdge>,
26    /// Statistics
27    stats: SubgraphStats,
28}
29
30#[derive(Debug, Clone, Serialize)]
31struct SubgraphNode {
32    id: String,
33    name: String,
34    qualified_name: String,
35    kind: String,
36    file: String,
37    line: u32,
38    language: String,
39    /// Whether this is a seed node
40    is_seed: bool,
41    /// Depth from nearest seed
42    depth: usize,
43}
44
45#[derive(Debug, Clone, Serialize)]
46struct SubgraphEdge {
47    source: String,
48    target: String,
49    kind: String,
50}
51
52#[derive(Debug, Serialize)]
53struct SubgraphStats {
54    node_count: usize,
55    edge_count: usize,
56    max_depth_reached: usize,
57}
58
59/// Find seed nodes in the graph matching the given symbol names.
60fn find_seed_nodes(
61    graph: &sqry_core::graph::unified::concurrent::CodeGraph,
62    symbols: &[String],
63) -> Vec<NodeId> {
64    let strings = graph.strings();
65    let mut seed_nodes: Vec<NodeId> = Vec::new();
66
67    for symbol in symbols {
68        let found = graph.nodes().iter().find(|(_, entry)| {
69            // Check qualified name
70            if let Some(qn_id) = entry.qualified_name
71                && let Some(qn) = strings.resolve(qn_id)
72                && (qn.as_ref() == symbol.as_str() || qn.contains(symbol.as_str()))
73            {
74                return true;
75            }
76            // Check simple name
77            if let Some(name) = strings.resolve(entry.name)
78                && name.as_ref() == symbol.as_str()
79            {
80                return true;
81            }
82            false
83        });
84
85        if let Some((node_id, _)) = found {
86            seed_nodes.push(node_id);
87        }
88    }
89
90    seed_nodes
91}
92
93/// Result of BFS subgraph collection.
94struct SubgraphBfsResult {
95    visited: HashSet<NodeId>,
96    node_depths: std::collections::HashMap<NodeId, usize>,
97    collected_edges: Vec<(NodeId, NodeId, String)>,
98    max_depth_reached: usize,
99}
100
101/// Collect a subgraph via BFS from seed nodes, following callers and/or callees.
102///
103/// Uses the traversal kernel with configurable direction and edge filter.
104/// Converts the kernel's `TraversalResult` into the `SubgraphBfsResult`
105/// expected by downstream code.
106#[allow(clippy::similar_names)]
107fn collect_subgraph_bfs(
108    graph: &sqry_core::graph::unified::concurrent::CodeGraph,
109    seed_nodes: &[NodeId],
110    max_depth: usize,
111    max_nodes: usize,
112    include_callers: bool,
113    include_callees: bool,
114    include_imports: bool,
115) -> SubgraphBfsResult {
116    let snapshot = graph.snapshot();
117
118    let direction = match (include_callers, include_callees) {
119        (true, true) => TraversalDirection::Both,
120        (true, false) => TraversalDirection::Incoming,
121        #[allow(clippy::match_same_arms)] // Subgraph mode arms intentionally separate
122        (false, true) => TraversalDirection::Outgoing,
123        // If neither is selected, default to outgoing (no edges will match anyway)
124        (false, false) => TraversalDirection::Outgoing,
125    };
126
127    let edge_filter = if include_imports {
128        EdgeFilter::calls_and_imports()
129    } else {
130        EdgeFilter::calls_only()
131    };
132
133    let config = TraversalConfig {
134        direction,
135        edge_filter,
136        limits: TraversalLimits {
137            max_depth: u32::try_from(max_depth).unwrap_or(u32::MAX),
138            max_nodes: Some(max_nodes),
139            max_edges: None,
140            max_paths: None,
141        },
142    };
143
144    let result = traverse(&snapshot, seed_nodes, &config, None);
145
146    let mut visited: HashSet<NodeId> = HashSet::new();
147    let mut node_depths: std::collections::HashMap<NodeId, usize> =
148        std::collections::HashMap::new();
149    let mut collected_edges: Vec<(NodeId, NodeId, String)> = Vec::new();
150    let mut max_depth_reached: usize = 0;
151
152    // Populate visited set and node depths from traversal result
153    for (idx, mat_node) in result.nodes.iter().enumerate() {
154        visited.insert(mat_node.node_id);
155
156        // Seed nodes have depth 0; others get depth from the first edge
157        let depth = if seed_nodes.contains(&mat_node.node_id) {
158            0
159        } else {
160            result
161                .edges
162                .iter()
163                .filter(|e| e.source_idx == idx || e.target_idx == idx)
164                .map(|e| e.depth as usize)
165                .min()
166                .unwrap_or(0)
167        };
168
169        node_depths.insert(mat_node.node_id, depth);
170        max_depth_reached = max_depth_reached.max(depth);
171    }
172
173    // Convert edges to (source_id, target_id, kind_str) tuples
174    for edge in &result.edges {
175        let source_id = result.nodes[edge.source_idx].node_id;
176        let target_id = result.nodes[edge.target_idx].node_id;
177        let kind_str = format!("{:?}", edge.raw_kind);
178        collected_edges.push((source_id, target_id, kind_str));
179    }
180
181    SubgraphBfsResult {
182        visited,
183        node_depths,
184        collected_edges,
185        max_depth_reached,
186    }
187}
188
189/// Infer language display name from a file extension.
190fn extension_to_display_language(ext: &str) -> &str {
191    match ext {
192        "rs" => "Rust",
193        "py" => "Python",
194        "js" => "JavaScript",
195        "ts" => "TypeScript",
196        "go" => "Go",
197        "java" => "Java",
198        "c" | "h" => "C",
199        "cpp" | "hpp" | "cc" => "C++",
200        "rb" => "Ruby",
201        "swift" => "Swift",
202        "kt" => "Kotlin",
203        _ => ext,
204    }
205}
206
207/// Build `SubgraphNode` list from visited node IDs in BFS results.
208fn build_subgraph_nodes(
209    graph: &sqry_core::graph::unified::concurrent::CodeGraph,
210    bfs: &SubgraphBfsResult,
211    seed_nodes: &[NodeId],
212) -> Vec<SubgraphNode> {
213    let strings = graph.strings();
214    let files = graph.files();
215    let seed_set: HashSet<_> = seed_nodes.iter().collect();
216
217    let mut nodes: Vec<SubgraphNode> = bfs
218        .visited
219        .iter()
220        .filter_map(|&node_id| {
221            let entry = graph.nodes().get(node_id)?;
222            let name = strings
223                .resolve(entry.name)
224                .map(|s| s.to_string())
225                .unwrap_or_default();
226            let qualified_name = entry
227                .qualified_name
228                .and_then(|id| strings.resolve(id))
229                .map_or_else(|| name.clone(), |s| s.to_string());
230
231            let file_path = files
232                .resolve(entry.file)
233                .map(|p| p.display().to_string())
234                .unwrap_or_default();
235
236            // Infer language from file extension
237            let language = files.resolve(entry.file).map_or_else(
238                || "Unknown".to_string(),
239                |p| {
240                    p.extension()
241                        .and_then(|ext| ext.to_str())
242                        .map_or("Unknown", extension_to_display_language)
243                        .to_string()
244                },
245            );
246
247            Some(SubgraphNode {
248                id: qualified_name.clone(),
249                name,
250                qualified_name,
251                kind: format!("{:?}", entry.kind),
252                file: file_path,
253                line: entry.start_line,
254                language,
255                is_seed: seed_set.contains(&node_id),
256                depth: *bfs.node_depths.get(&node_id).unwrap_or(&0),
257            })
258        })
259        .collect();
260
261    // Sort for determinism
262    nodes.sort_by(|a, b| a.qualified_name.cmp(&b.qualified_name));
263    nodes
264}
265
266/// Build `SubgraphEdge` list from collected edges, resolving node IDs to names.
267fn build_subgraph_edges(
268    graph: &sqry_core::graph::unified::concurrent::CodeGraph,
269    bfs: &SubgraphBfsResult,
270) -> Vec<SubgraphEdge> {
271    let strings = graph.strings();
272
273    // Create a map from NodeId to qualified_name for edge resolution
274    let node_names: std::collections::HashMap<NodeId, String> = bfs
275        .visited
276        .iter()
277        .filter_map(|&node_id| {
278            let entry = graph.nodes().get(node_id)?;
279            let name = strings
280                .resolve(entry.name)
281                .map(|s| s.to_string())
282                .unwrap_or_default();
283            let qn = entry
284                .qualified_name
285                .and_then(|id| strings.resolve(id))
286                .map_or_else(|| name, |s| s.to_string());
287            Some((node_id, qn))
288        })
289        .collect();
290
291    let mut edges: Vec<SubgraphEdge> = bfs
292        .collected_edges
293        .iter()
294        .filter(|(src, tgt, _)| bfs.visited.contains(src) && bfs.visited.contains(tgt))
295        .filter_map(|(src, tgt, kind)| {
296            let src_name = node_names.get(src)?.clone();
297            let tgt_name = node_names.get(tgt)?.clone();
298            Some(SubgraphEdge {
299                source: src_name,
300                target: tgt_name,
301                kind: kind.clone(),
302            })
303        })
304        .collect();
305
306    // Deduplicate edges
307    edges.sort_by(|a, b| (&a.source, &a.target, &a.kind).cmp(&(&b.source, &b.target, &b.kind)));
308    edges.dedup_by(|a, b| a.source == b.source && a.target == b.target && a.kind == b.kind);
309    edges
310}
311
312/// Run the subgraph command.
313///
314/// # Errors
315/// Returns an error if the graph cannot be loaded or symbols cannot be found.
316#[allow(clippy::similar_names)]
317// Callers/callees naming mirrors CLI flag semantics.
318pub fn run_subgraph(
319    cli: &Cli,
320    symbols: &[String],
321    path: Option<&str>,
322    max_depth: usize,
323    max_nodes: usize,
324    include_callers: bool,
325    include_callees: bool,
326    include_imports: bool,
327) -> Result<()> {
328    let mut streams = OutputStreams::new();
329
330    if symbols.is_empty() {
331        return Err(anyhow!("At least one seed symbol is required"));
332    }
333
334    // Find index
335    let search_path = path.map_or_else(
336        || std::env::current_dir().unwrap_or_default(),
337        std::path::PathBuf::from,
338    );
339
340    let index_location = find_nearest_index(&search_path);
341    let Some(ref loc) = index_location else {
342        streams
343            .write_diagnostic("No .sqry-index found. Run 'sqry index' first to build the index.")?;
344        return Ok(());
345    };
346
347    // Load graph
348    let config = GraphLoadConfig::default();
349    let graph = load_unified_graph_for_cli(&loc.index_root, &config, cli)
350        .context("Failed to load graph. Run 'sqry index' to build the graph.")?;
351
352    // Find seed nodes
353    let seed_nodes = find_seed_nodes(&graph, symbols);
354    if seed_nodes.is_empty() {
355        streams.write_diagnostic("No seed symbols found in the graph.")?;
356        return Ok(());
357    }
358
359    // BFS to collect subgraph
360    let bfs = collect_subgraph_bfs(
361        &graph,
362        &seed_nodes,
363        max_depth,
364        max_nodes,
365        include_callers,
366        include_callees,
367        include_imports,
368    );
369
370    // Build output
371    let nodes = build_subgraph_nodes(&graph, &bfs, &seed_nodes);
372    let edges = build_subgraph_edges(&graph, &bfs);
373
374    let stats = SubgraphStats {
375        node_count: nodes.len(),
376        edge_count: edges.len(),
377        max_depth_reached: bfs.max_depth_reached,
378    };
379
380    let output = SubgraphOutput {
381        seeds: symbols.to_vec(),
382        nodes,
383        edges,
384        stats,
385    };
386
387    // Output
388    if cli.json {
389        let json = serde_json::to_string_pretty(&output).context("Failed to serialize to JSON")?;
390        streams.write_result(&json)?;
391    } else {
392        let text = format_subgraph_text(&output);
393        streams.write_result(&text)?;
394    }
395
396    Ok(())
397}
398
399fn format_subgraph_text(output: &SubgraphOutput) -> String {
400    let mut lines = Vec::new();
401
402    lines.push(format!(
403        "Subgraph around {} seed(s): {}",
404        output.seeds.len(),
405        output.seeds.join(", ")
406    ));
407    lines.push(format!(
408        "Stats: {} nodes, {} edges, max depth {}",
409        output.stats.node_count, output.stats.edge_count, output.stats.max_depth_reached
410    ));
411    lines.push(String::new());
412
413    lines.push("Nodes:".to_string());
414    for node in &output.nodes {
415        let seed_marker = if node.is_seed { " [SEED]" } else { "" };
416        lines.push(format!(
417            "  {} [{}] depth={}{} ",
418            node.qualified_name, node.kind, node.depth, seed_marker
419        ));
420        lines.push(format!("    {}:{}", node.file, node.line));
421    }
422
423    if !output.edges.is_empty() {
424        lines.push(String::new());
425        lines.push("Edges:".to_string());
426        for edge in &output.edges {
427            lines.push(format!(
428                "  {} --[{}]--> {}",
429                edge.source, edge.kind, edge.target
430            ));
431        }
432    }
433
434    lines.join("\n")
435}