1use crate::db::Database;
2use crate::types::{Edge, EdgeKind, GraphPath, Node, NodeEdge};
3use anyhow::Result;
4use serde::Serialize;
5use std::collections::{BTreeSet, HashMap, HashSet, VecDeque};
6
7#[derive(Debug, Clone, Default, Serialize)]
8pub struct Subgraph {
9 pub nodes: HashMap<String, Node>,
10 pub edges: Vec<Edge>,
11 pub roots: Vec<String>,
12}
13
14pub struct GraphTraverser<'a> {
15 db: &'a Database,
16}
17
18impl<'a> GraphTraverser<'a> {
19 pub fn new(db: &'a Database) -> Self {
20 Self { db }
21 }
22
23 pub fn get_callers(&self, node_id: &str, max_depth: usize) -> Result<Vec<NodeEdge>> {
24 self.walk_edges(
25 node_id,
26 max_depth,
27 Direction::Incoming,
28 &[EdgeKind::Calls, EdgeKind::References, EdgeKind::Imports],
29 )
30 }
31
32 pub fn get_callees(&self, node_id: &str, max_depth: usize) -> Result<Vec<NodeEdge>> {
33 self.walk_edges(
34 node_id,
35 max_depth,
36 Direction::Outgoing,
37 &[EdgeKind::Calls, EdgeKind::References, EdgeKind::Imports],
38 )
39 }
40
41 pub fn get_impact_radius(&self, node_id: &str, max_depth: usize) -> Result<Subgraph> {
42 let Some(root) = self.db.get_node(node_id)? else {
43 return Ok(Subgraph::default());
44 };
45 let mut out = Subgraph::default();
46 out.roots.push(node_id.to_string());
47 out.nodes.insert(root.id.clone(), root);
48
49 let mut visited = HashSet::new();
50 let mut seen_edges = BTreeSet::new();
51 let mut queue = VecDeque::new();
52 queue.push_back((node_id.to_string(), 0usize));
53
54 while let Some((current, depth)) = queue.pop_front() {
55 if depth > max_depth || !visited.insert(current.clone()) {
56 continue;
57 }
58
59 if let Some(node) = self.db.get_node(¤t)? {
60 if matches!(
61 node.kind,
62 crate::types::NodeKind::Class
63 | crate::types::NodeKind::Interface
64 | crate::types::NodeKind::Struct
65 | crate::types::NodeKind::Trait
66 | crate::types::NodeKind::Protocol
67 | crate::types::NodeKind::Module
68 | crate::types::NodeKind::Enum
69 ) {
70 for edge in self
71 .db
72 .get_outgoing_edges(¤t, Some(&[EdgeKind::Contains]))?
73 {
74 if let Some(child) = self.db.get_node(&edge.target)? {
75 out.nodes.insert(child.id.clone(), child.clone());
76 push_unique_edge(&mut out.edges, &mut seen_edges, edge.clone());
77 queue.push_back((child.id, depth));
78 }
79 }
80 }
81 }
82
83 if depth == max_depth {
84 continue;
85 }
86
87 for edge in self.db.get_incoming_edges(¤t, None)? {
88 if let Some(source) = self.db.get_node(&edge.source)? {
89 out.nodes.insert(source.id.clone(), source.clone());
90 push_unique_edge(&mut out.edges, &mut seen_edges, edge);
91 queue.push_back((source.id, depth + 1));
92 }
93 }
94 }
95
96 out.edges.sort_by(edge_sort_key);
97 Ok(out)
98 }
99
100 pub fn find_paths(
101 &self,
102 from_node_id: &str,
103 to_node_id: &str,
104 max_depth: usize,
105 max_paths: usize,
106 ) -> Result<Vec<GraphPath>> {
107 if max_depth == 0 || max_paths == 0 {
108 return Ok(Vec::new());
109 }
110 let Some(root) = self.db.get_node(from_node_id)? else {
111 return Ok(Vec::new());
112 };
113 if self.db.get_node(to_node_id)?.is_none() {
114 return Ok(Vec::new());
115 }
116
117 let path_kinds = [
118 EdgeKind::Calls,
119 EdgeKind::References,
120 EdgeKind::Imports,
121 EdgeKind::Extends,
122 EdgeKind::Implements,
123 ];
124 let mut out = Vec::new();
125 let mut queue = VecDeque::new();
126 queue.push_back(PathState {
127 node_id: from_node_id.to_string(),
128 nodes: vec![root],
129 edges: Vec::new(),
130 visited: BTreeSet::from([from_node_id.to_string()]),
131 });
132
133 while let Some(state) = queue.pop_front() {
134 if state.edges.len() >= max_depth {
135 continue;
136 }
137
138 let mut outgoing = self
139 .db
140 .get_outgoing_edges(&state.node_id, Some(&path_kinds))?;
141 outgoing.sort_by(edge_sort_key);
142 for edge in outgoing {
143 if state.visited.contains(&edge.target) {
144 continue;
145 }
146 let Some(next_node) = self.db.get_node(&edge.target)? else {
147 continue;
148 };
149 let mut nodes = state.nodes.clone();
150 nodes.push(next_node.clone());
151 let mut edges = state.edges.clone();
152 edges.push(edge.clone());
153 if edge.target == to_node_id {
154 out.push(GraphPath { nodes, edges });
155 if out.len() >= max_paths {
156 return Ok(out);
157 }
158 continue;
159 }
160 let mut visited = state.visited.clone();
161 visited.insert(edge.target.clone());
162 queue.push_back(PathState {
163 node_id: edge.target,
164 nodes,
165 edges,
166 visited,
167 });
168 }
169 }
170
171 Ok(out)
172 }
173
174 fn walk_edges(
175 &self,
176 node_id: &str,
177 max_depth: usize,
178 direction: Direction,
179 kinds: &[EdgeKind],
180 ) -> Result<Vec<NodeEdge>> {
181 let mut out = Vec::new();
182 let mut visited = HashSet::new();
183 let mut emitted = BTreeSet::new();
184 let mut queue = VecDeque::new();
185 queue.push_back((node_id.to_string(), 0usize));
186
187 while let Some((current, depth)) = queue.pop_front() {
188 if depth >= max_depth || !visited.insert(current.clone()) {
189 continue;
190 }
191
192 let mut edges = match direction {
193 Direction::Incoming => self.db.get_incoming_edges(¤t, Some(kinds))?,
194 Direction::Outgoing => self.db.get_outgoing_edges(¤t, Some(kinds))?,
195 };
196 edges.sort_by(edge_sort_key);
197
198 for edge in edges {
199 let next = match direction {
200 Direction::Incoming => &edge.source,
201 Direction::Outgoing => &edge.target,
202 };
203 if visited.contains(next) || emitted.contains(next) {
204 continue;
205 }
206 if let Some(node) = self.db.get_node(next)? {
207 queue.push_back((node.id.clone(), depth + 1));
208 emitted.insert(node.id.clone());
209 out.push(NodeEdge {
210 node,
211 edge,
212 depth: depth + 1,
213 });
214 }
215 }
216 }
217
218 out.sort_by(node_edge_sort);
219 Ok(out)
220 }
221}
222
223struct PathState {
224 node_id: String,
225 nodes: Vec<Node>,
226 edges: Vec<Edge>,
227 visited: BTreeSet<String>,
228}
229
230enum Direction {
231 Incoming,
232 Outgoing,
233}
234
235fn push_unique_edge(edges: &mut Vec<Edge>, seen: &mut BTreeSet<String>, edge: Edge) {
236 let key = edge_key(&edge);
237 if seen.insert(key) {
238 edges.push(edge);
239 }
240}
241
242fn edge_key(edge: &Edge) -> String {
243 format!(
244 "{}\0{}\0{}\0{:?}\0{:?}",
245 edge.source,
246 edge.target,
247 edge.kind.as_str(),
248 edge.line,
249 edge.col
250 )
251}
252
253fn edge_sort_key(a: &Edge, b: &Edge) -> std::cmp::Ordering {
254 edge_key(a).cmp(&edge_key(b))
255}
256
257fn node_edge_sort(a: &NodeEdge, b: &NodeEdge) -> std::cmp::Ordering {
258 a.depth
259 .cmp(&b.depth)
260 .then_with(|| a.node.file_path.cmp(&b.node.file_path))
261 .then_with(|| a.node.start_line.cmp(&b.node.start_line))
262 .then_with(|| a.node.kind.as_str().cmp(b.node.kind.as_str()))
263 .then_with(|| a.node.name.cmp(&b.node.name))
264}