context_creator/core/semantic/
graph_traverser.rs1use crate::core::semantic::dependency_types::{DependencyEdgeType, DependencyNode as RichNode};
7use anyhow::{anyhow, Result};
8use petgraph::algo::toposort;
9use petgraph::graph::{DiGraph, NodeIndex};
10use petgraph::visit::Dfs;
11use std::collections::{HashSet, VecDeque};
12
13#[derive(Debug, Clone)]
15pub struct TraversalOptions {
16 pub max_depth: usize,
18 pub include_types: bool,
20 pub include_functions: bool,
22}
23
24impl Default for TraversalOptions {
25 fn default() -> Self {
26 Self {
27 max_depth: 5,
28 include_types: true,
29 include_functions: true,
30 }
31 }
32}
33
34pub struct GraphTraverser {
36 }
38
39impl GraphTraverser {
40 pub fn new() -> Self {
42 Self {}
43 }
44
45 pub fn traverse_bfs(
47 &self,
48 graph: &DiGraph<RichNode, DependencyEdgeType>,
49 start: NodeIndex,
50 options: &TraversalOptions,
51 ) -> Vec<NodeIndex> {
52 let mut visited = Vec::new();
53 let mut seen = HashSet::new();
54 let mut queue = VecDeque::new();
55
56 queue.push_back((start, 0));
58 seen.insert(start);
59
60 while let Some((node, depth)) = queue.pop_front() {
61 if depth > options.max_depth {
63 continue;
64 }
65
66 visited.push(node);
67
68 for neighbor in graph.neighbors(node) {
70 if !seen.contains(&neighbor) {
71 seen.insert(neighbor);
72 queue.push_back((neighbor, depth + 1));
73 }
74 }
75 }
76
77 visited
78 }
79
80 pub fn traverse_dfs(
82 &self,
83 graph: &DiGraph<RichNode, DependencyEdgeType>,
84 start: NodeIndex,
85 options: &TraversalOptions,
86 ) -> Vec<NodeIndex> {
87 let mut visited = Vec::new();
88 let mut dfs = Dfs::new(graph, start);
89 let mut depths = HashMap::new();
90 depths.insert(start, 0);
91
92 while let Some(node) = dfs.next(graph) {
93 let current_depth = *depths.get(&node).unwrap_or(&0);
94
95 if current_depth <= options.max_depth {
97 visited.push(node);
98
99 for neighbor in graph.neighbors(node) {
101 depths.entry(neighbor).or_insert(current_depth + 1);
102 }
103 }
104 }
105
106 visited
107 }
108
109 pub fn topological_sort(
111 &self,
112 graph: &DiGraph<RichNode, DependencyEdgeType>,
113 ) -> Result<Vec<NodeIndex>> {
114 match toposort(graph, None) {
115 Ok(order) => Ok(order),
116 Err(_) => Err(anyhow!(
117 "Graph contains a cycle, topological sort not possible"
118 )),
119 }
120 }
121
122 pub fn find_reachable_nodes(
124 &self,
125 graph: &DiGraph<RichNode, DependencyEdgeType>,
126 start: NodeIndex,
127 ) -> HashSet<NodeIndex> {
128 let mut reachable = HashSet::new();
129 let mut dfs = Dfs::new(graph, start);
130
131 while let Some(node) = dfs.next(graph) {
132 reachable.insert(node);
133 }
134
135 reachable
136 }
137
138 pub fn get_nodes_at_depth(
140 &self,
141 graph: &DiGraph<RichNode, DependencyEdgeType>,
142 start: NodeIndex,
143 target_depth: usize,
144 ) -> Vec<NodeIndex> {
145 let mut nodes_at_depth = Vec::new();
146 let mut queue = VecDeque::new();
147 let mut seen = HashSet::new();
148
149 queue.push_back((start, 0));
150 seen.insert(start);
151
152 while let Some((node, depth)) = queue.pop_front() {
153 if depth == target_depth {
154 nodes_at_depth.push(node);
155 } else if depth < target_depth {
156 for neighbor in graph.neighbors(node) {
158 if !seen.contains(&neighbor) {
159 seen.insert(neighbor);
160 queue.push_back((neighbor, depth + 1));
161 }
162 }
163 }
164 }
165
166 nodes_at_depth
167 }
168
169 pub fn find_shortest_path(
171 &self,
172 graph: &DiGraph<RichNode, DependencyEdgeType>,
173 start: NodeIndex,
174 end: NodeIndex,
175 ) -> Option<Vec<NodeIndex>> {
176 use petgraph::algo::dijkstra;
177
178 let predecessors = dijkstra(graph, start, Some(end), |_| 1);
179
180 if !predecessors.contains_key(&end) {
181 return None;
182 }
183
184 let mut path = vec![end];
186 let mut current = end;
187
188 while current != start {
191 if let Some(neighbor) = graph
192 .neighbors_directed(current, petgraph::Direction::Incoming)
193 .next()
194 {
195 path.push(neighbor);
196 current = neighbor;
197 } else {
198 break;
199 }
200 }
201
202 path.reverse();
203 Some(path)
204 }
205}
206
207impl Default for GraphTraverser {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213use std::collections::HashMap;
215
216#[cfg(test)]
217#[path = "graph_traverser_tests.rs"]
218mod tests;