Skip to main content

codegraph/
graph.rs

1use crate::db::Database;
2use crate::types::{Edge, EdgeKind, Node, NodeEdge};
3use anyhow::Result;
4use serde::Serialize;
5use std::collections::{HashMap, HashSet, VecDeque};
6
7#[derive(Debug, Clone, Default, Serialize)]
8pub struct Subgraph {
9    pub nodes: HashMap<String, Node>,
10    pub edges: Vec<Edge>,
11    pub roots: Vec<String>,
12}
13
14pub struct GraphTraverser<'a> {
15    db: &'a Database,
16}
17
18impl<'a> GraphTraverser<'a> {
19    pub fn new(db: &'a Database) -> Self {
20        Self { db }
21    }
22
23    pub fn get_callers(&self, node_id: &str, max_depth: usize) -> Result<Vec<NodeEdge>> {
24        self.walk_edges(
25            node_id,
26            max_depth,
27            Direction::Incoming,
28            &[EdgeKind::Calls, EdgeKind::References, EdgeKind::Imports],
29        )
30    }
31
32    pub fn get_callees(&self, node_id: &str, max_depth: usize) -> Result<Vec<NodeEdge>> {
33        self.walk_edges(
34            node_id,
35            max_depth,
36            Direction::Outgoing,
37            &[EdgeKind::Calls, EdgeKind::References, EdgeKind::Imports],
38        )
39    }
40
41    pub fn get_impact_radius(&self, node_id: &str, max_depth: usize) -> Result<Subgraph> {
42        let Some(root) = self.db.get_node(node_id)? else {
43            return Ok(Subgraph::default());
44        };
45        let mut out = Subgraph::default();
46        out.roots.push(node_id.to_string());
47        out.nodes.insert(root.id.clone(), root);
48
49        let mut visited = HashSet::new();
50        let mut queue = VecDeque::new();
51        queue.push_back((node_id.to_string(), 0usize));
52
53        while let Some((current, depth)) = queue.pop_front() {
54            if depth > max_depth || !visited.insert(current.clone()) {
55                continue;
56            }
57
58            if let Some(node) = self.db.get_node(&current)? {
59                if matches!(
60                    node.kind,
61                    crate::types::NodeKind::Class
62                        | crate::types::NodeKind::Interface
63                        | crate::types::NodeKind::Struct
64                        | crate::types::NodeKind::Trait
65                        | crate::types::NodeKind::Protocol
66                        | crate::types::NodeKind::Module
67                        | crate::types::NodeKind::Enum
68                ) {
69                    for edge in self
70                        .db
71                        .get_outgoing_edges(&current, Some(&[EdgeKind::Contains]))?
72                    {
73                        if let Some(child) = self.db.get_node(&edge.target)? {
74                            out.nodes.insert(child.id.clone(), child.clone());
75                            out.edges.push(edge.clone());
76                            queue.push_back((child.id, depth));
77                        }
78                    }
79                }
80            }
81
82            if depth == max_depth {
83                continue;
84            }
85
86            for edge in self.db.get_incoming_edges(&current, None)? {
87                if let Some(source) = self.db.get_node(&edge.source)? {
88                    out.nodes.insert(source.id.clone(), source.clone());
89                    out.edges.push(edge);
90                    queue.push_back((source.id, depth + 1));
91                }
92            }
93        }
94
95        Ok(out)
96    }
97
98    fn walk_edges(
99        &self,
100        node_id: &str,
101        max_depth: usize,
102        direction: Direction,
103        kinds: &[EdgeKind],
104    ) -> Result<Vec<NodeEdge>> {
105        let mut out = Vec::new();
106        let mut visited = HashSet::new();
107        let mut queue = VecDeque::new();
108        queue.push_back((node_id.to_string(), 0usize));
109
110        while let Some((current, depth)) = queue.pop_front() {
111            if depth >= max_depth || !visited.insert(current.clone()) {
112                continue;
113            }
114
115            let edges = match direction {
116                Direction::Incoming => self.db.get_incoming_edges(&current, Some(kinds))?,
117                Direction::Outgoing => self.db.get_outgoing_edges(&current, Some(kinds))?,
118            };
119
120            for edge in edges {
121                let next = match direction {
122                    Direction::Incoming => &edge.source,
123                    Direction::Outgoing => &edge.target,
124                };
125                if visited.contains(next) {
126                    continue;
127                }
128                if let Some(node) = self.db.get_node(next)? {
129                    queue.push_back((node.id.clone(), depth + 1));
130                    out.push(NodeEdge { node, edge });
131                }
132            }
133        }
134
135        Ok(out)
136    }
137}
138
139enum Direction {
140    Incoming,
141    Outgoing,
142}