1use crate::db::Database;
2use crate::types::{Edge, EdgeKind, Node, NodeEdge};
3use anyhow::Result;
4use serde::Serialize;
5use std::collections::{HashMap, HashSet, VecDeque};
6
7#[derive(Debug, Clone, Default, Serialize)]
8pub struct Subgraph {
9 pub nodes: HashMap<String, Node>,
10 pub edges: Vec<Edge>,
11 pub roots: Vec<String>,
12}
13
14pub struct GraphTraverser<'a> {
15 db: &'a Database,
16}
17
18impl<'a> GraphTraverser<'a> {
19 pub fn new(db: &'a Database) -> Self {
20 Self { db }
21 }
22
23 pub fn get_callers(&self, node_id: &str, max_depth: usize) -> Result<Vec<NodeEdge>> {
24 self.walk_edges(
25 node_id,
26 max_depth,
27 Direction::Incoming,
28 &[EdgeKind::Calls, EdgeKind::References, EdgeKind::Imports],
29 )
30 }
31
32 pub fn get_callees(&self, node_id: &str, max_depth: usize) -> Result<Vec<NodeEdge>> {
33 self.walk_edges(
34 node_id,
35 max_depth,
36 Direction::Outgoing,
37 &[EdgeKind::Calls, EdgeKind::References, EdgeKind::Imports],
38 )
39 }
40
41 pub fn get_impact_radius(&self, node_id: &str, max_depth: usize) -> Result<Subgraph> {
42 let Some(root) = self.db.get_node(node_id)? else {
43 return Ok(Subgraph::default());
44 };
45 let mut out = Subgraph::default();
46 out.roots.push(node_id.to_string());
47 out.nodes.insert(root.id.clone(), root);
48
49 let mut visited = HashSet::new();
50 let mut queue = VecDeque::new();
51 queue.push_back((node_id.to_string(), 0usize));
52
53 while let Some((current, depth)) = queue.pop_front() {
54 if depth > max_depth || !visited.insert(current.clone()) {
55 continue;
56 }
57
58 if let Some(node) = self.db.get_node(¤t)? {
59 if matches!(
60 node.kind,
61 crate::types::NodeKind::Class
62 | crate::types::NodeKind::Interface
63 | crate::types::NodeKind::Struct
64 | crate::types::NodeKind::Trait
65 | crate::types::NodeKind::Protocol
66 | crate::types::NodeKind::Module
67 | crate::types::NodeKind::Enum
68 ) {
69 for edge in self
70 .db
71 .get_outgoing_edges(¤t, Some(&[EdgeKind::Contains]))?
72 {
73 if let Some(child) = self.db.get_node(&edge.target)? {
74 out.nodes.insert(child.id.clone(), child.clone());
75 out.edges.push(edge.clone());
76 queue.push_back((child.id, depth));
77 }
78 }
79 }
80 }
81
82 if depth == max_depth {
83 continue;
84 }
85
86 for edge in self.db.get_incoming_edges(¤t, None)? {
87 if let Some(source) = self.db.get_node(&edge.source)? {
88 out.nodes.insert(source.id.clone(), source.clone());
89 out.edges.push(edge);
90 queue.push_back((source.id, depth + 1));
91 }
92 }
93 }
94
95 Ok(out)
96 }
97
98 fn walk_edges(
99 &self,
100 node_id: &str,
101 max_depth: usize,
102 direction: Direction,
103 kinds: &[EdgeKind],
104 ) -> Result<Vec<NodeEdge>> {
105 let mut out = Vec::new();
106 let mut visited = HashSet::new();
107 let mut queue = VecDeque::new();
108 queue.push_back((node_id.to_string(), 0usize));
109
110 while let Some((current, depth)) = queue.pop_front() {
111 if depth >= max_depth || !visited.insert(current.clone()) {
112 continue;
113 }
114
115 let edges = match direction {
116 Direction::Incoming => self.db.get_incoming_edges(¤t, Some(kinds))?,
117 Direction::Outgoing => self.db.get_outgoing_edges(¤t, Some(kinds))?,
118 };
119
120 for edge in edges {
121 let next = match direction {
122 Direction::Incoming => &edge.source,
123 Direction::Outgoing => &edge.target,
124 };
125 if visited.contains(next) {
126 continue;
127 }
128 if let Some(node) = self.db.get_node(next)? {
129 queue.push_back((node.id.clone(), depth + 1));
130 out.push(NodeEdge { node, edge });
131 }
132 }
133 }
134
135 Ok(out)
136 }
137}
138
139enum Direction {
140 Incoming,
141 Outgoing,
142}