Skip to main content

codegraph/
graph.rs

1use crate::db::Database;
2use crate::types::{Edge, EdgeKind, GraphPath, Node, NodeEdge};
3use anyhow::Result;
4use serde::Serialize;
5use std::collections::{BTreeSet, HashMap, HashSet, VecDeque};
6
7#[derive(Debug, Clone, Default, Serialize)]
8pub struct Subgraph {
9    pub nodes: HashMap<String, Node>,
10    pub edges: Vec<Edge>,
11    pub roots: Vec<String>,
12}
13
14pub struct GraphTraverser<'a> {
15    db: &'a Database,
16}
17
18impl<'a> GraphTraverser<'a> {
19    pub fn new(db: &'a Database) -> Self {
20        Self { db }
21    }
22
23    pub fn get_callers(&self, node_id: &str, max_depth: usize) -> Result<Vec<NodeEdge>> {
24        self.walk_edges(
25            node_id,
26            max_depth,
27            Direction::Incoming,
28            &[EdgeKind::Calls, EdgeKind::References, EdgeKind::Imports],
29        )
30    }
31
32    pub fn get_callees(&self, node_id: &str, max_depth: usize) -> Result<Vec<NodeEdge>> {
33        self.walk_edges(
34            node_id,
35            max_depth,
36            Direction::Outgoing,
37            &[EdgeKind::Calls, EdgeKind::References, EdgeKind::Imports],
38        )
39    }
40
41    pub fn get_impact_radius(&self, node_id: &str, max_depth: usize) -> Result<Subgraph> {
42        let Some(root) = self.db.get_node(node_id)? else {
43            return Ok(Subgraph::default());
44        };
45        let mut out = Subgraph::default();
46        out.roots.push(node_id.to_string());
47        out.nodes.insert(root.id.clone(), root);
48
49        let mut visited = HashSet::new();
50        let mut seen_edges = BTreeSet::new();
51        let mut queue = VecDeque::new();
52        queue.push_back((node_id.to_string(), 0usize));
53
54        while let Some((current, depth)) = queue.pop_front() {
55            if depth > max_depth || !visited.insert(current.clone()) {
56                continue;
57            }
58
59            if let Some(node) = self.db.get_node(&current)? {
60                if matches!(
61                    node.kind,
62                    crate::types::NodeKind::Class
63                        | crate::types::NodeKind::Interface
64                        | crate::types::NodeKind::Struct
65                        | crate::types::NodeKind::Trait
66                        | crate::types::NodeKind::Protocol
67                        | crate::types::NodeKind::Module
68                        | crate::types::NodeKind::Enum
69                ) {
70                    for edge in self
71                        .db
72                        .get_outgoing_edges(&current, Some(&[EdgeKind::Contains]))?
73                    {
74                        if let Some(child) = self.db.get_node(&edge.target)? {
75                            out.nodes.insert(child.id.clone(), child.clone());
76                            push_unique_edge(&mut out.edges, &mut seen_edges, edge.clone());
77                            queue.push_back((child.id, depth));
78                        }
79                    }
80                }
81            }
82
83            if depth == max_depth {
84                continue;
85            }
86
87            for edge in self.db.get_incoming_edges(&current, None)? {
88                if let Some(source) = self.db.get_node(&edge.source)? {
89                    out.nodes.insert(source.id.clone(), source.clone());
90                    push_unique_edge(&mut out.edges, &mut seen_edges, edge);
91                    queue.push_back((source.id, depth + 1));
92                }
93            }
94        }
95
96        out.edges.sort_by(edge_sort_key);
97        Ok(out)
98    }
99
100    pub fn find_paths(
101        &self,
102        from_node_id: &str,
103        to_node_id: &str,
104        max_depth: usize,
105        max_paths: usize,
106    ) -> Result<Vec<GraphPath>> {
107        if max_depth == 0 || max_paths == 0 {
108            return Ok(Vec::new());
109        }
110        let Some(root) = self.db.get_node(from_node_id)? else {
111            return Ok(Vec::new());
112        };
113        if self.db.get_node(to_node_id)?.is_none() {
114            return Ok(Vec::new());
115        }
116
117        let path_kinds = [
118            EdgeKind::Calls,
119            EdgeKind::References,
120            EdgeKind::Imports,
121            EdgeKind::Extends,
122            EdgeKind::Implements,
123        ];
124        let mut out = Vec::new();
125        let mut queue = VecDeque::new();
126        queue.push_back(PathState {
127            node_id: from_node_id.to_string(),
128            nodes: vec![root],
129            edges: Vec::new(),
130            visited: BTreeSet::from([from_node_id.to_string()]),
131        });
132
133        while let Some(state) = queue.pop_front() {
134            if state.edges.len() >= max_depth {
135                continue;
136            }
137
138            let mut outgoing = self
139                .db
140                .get_outgoing_edges(&state.node_id, Some(&path_kinds))?;
141            outgoing.sort_by(edge_sort_key);
142            for edge in outgoing {
143                if state.visited.contains(&edge.target) {
144                    continue;
145                }
146                let Some(next_node) = self.db.get_node(&edge.target)? else {
147                    continue;
148                };
149                let mut nodes = state.nodes.clone();
150                nodes.push(next_node.clone());
151                let mut edges = state.edges.clone();
152                edges.push(edge.clone());
153                if edge.target == to_node_id {
154                    out.push(GraphPath { nodes, edges });
155                    if out.len() >= max_paths {
156                        return Ok(out);
157                    }
158                    continue;
159                }
160                let mut visited = state.visited.clone();
161                visited.insert(edge.target.clone());
162                queue.push_back(PathState {
163                    node_id: edge.target,
164                    nodes,
165                    edges,
166                    visited,
167                });
168            }
169        }
170
171        Ok(out)
172    }
173
174    fn walk_edges(
175        &self,
176        node_id: &str,
177        max_depth: usize,
178        direction: Direction,
179        kinds: &[EdgeKind],
180    ) -> Result<Vec<NodeEdge>> {
181        let mut out = Vec::new();
182        let mut visited = HashSet::new();
183        let mut emitted = BTreeSet::new();
184        let mut queue = VecDeque::new();
185        queue.push_back((node_id.to_string(), 0usize));
186
187        while let Some((current, depth)) = queue.pop_front() {
188            if depth >= max_depth || !visited.insert(current.clone()) {
189                continue;
190            }
191
192            let mut edges = match direction {
193                Direction::Incoming => self.db.get_incoming_edges(&current, Some(kinds))?,
194                Direction::Outgoing => self.db.get_outgoing_edges(&current, Some(kinds))?,
195            };
196            edges.sort_by(edge_sort_key);
197
198            for edge in edges {
199                let next = match direction {
200                    Direction::Incoming => &edge.source,
201                    Direction::Outgoing => &edge.target,
202                };
203                if visited.contains(next) || emitted.contains(next) {
204                    continue;
205                }
206                if let Some(node) = self.db.get_node(next)? {
207                    queue.push_back((node.id.clone(), depth + 1));
208                    emitted.insert(node.id.clone());
209                    out.push(NodeEdge {
210                        node,
211                        edge,
212                        depth: depth + 1,
213                    });
214                }
215            }
216        }
217
218        out.sort_by(node_edge_sort);
219        Ok(out)
220    }
221}
222
223struct PathState {
224    node_id: String,
225    nodes: Vec<Node>,
226    edges: Vec<Edge>,
227    visited: BTreeSet<String>,
228}
229
230enum Direction {
231    Incoming,
232    Outgoing,
233}
234
235fn push_unique_edge(edges: &mut Vec<Edge>, seen: &mut BTreeSet<String>, edge: Edge) {
236    let key = edge_key(&edge);
237    if seen.insert(key) {
238        edges.push(edge);
239    }
240}
241
242fn edge_key(edge: &Edge) -> String {
243    format!(
244        "{}\0{}\0{}\0{:?}\0{:?}",
245        edge.source,
246        edge.target,
247        edge.kind.as_str(),
248        edge.line,
249        edge.col
250    )
251}
252
253fn edge_sort_key(a: &Edge, b: &Edge) -> std::cmp::Ordering {
254    edge_key(a).cmp(&edge_key(b))
255}
256
257fn node_edge_sort(a: &NodeEdge, b: &NodeEdge) -> std::cmp::Ordering {
258    a.depth
259        .cmp(&b.depth)
260        .then_with(|| a.node.file_path.cmp(&b.node.file_path))
261        .then_with(|| a.node.start_line.cmp(&b.node.start_line))
262        .then_with(|| a.node.kind.as_str().cmp(b.node.kind.as_str()))
263        .then_with(|| a.node.name.cmp(&b.node.name))
264}