1mod algorithms;
7mod traversal;
8
9#[cfg(test)]
10use codemem_core::NodeKind;
11use codemem_core::{CodememError, Edge, GraphBackend, GraphNode};
12use petgraph::graph::{DiGraph, NodeIndex};
13use petgraph::Direction;
14use std::collections::{HashMap, HashSet, VecDeque};
15
16pub struct GraphEngine {
18 pub(crate) graph: DiGraph<String, f64>,
19 pub(crate) id_to_index: HashMap<String, NodeIndex>,
21 pub(crate) nodes: HashMap<String, GraphNode>,
23 pub(crate) edges: HashMap<String, Edge>,
25 pub(crate) cached_pagerank: HashMap<String, f64>,
27 pub(crate) cached_betweenness: HashMap<String, f64>,
29}
30
31impl GraphEngine {
32 pub fn new() -> Self {
34 Self {
35 graph: DiGraph::new(),
36 id_to_index: HashMap::new(),
37 nodes: HashMap::new(),
38 edges: HashMap::new(),
39 cached_pagerank: HashMap::new(),
40 cached_betweenness: HashMap::new(),
41 }
42 }
43
44 pub fn from_storage(storage: &dyn codemem_core::StorageBackend) -> Result<Self, CodememError> {
46 let mut engine = Self::new();
47
48 let nodes = storage.all_graph_nodes()?;
50 for node in nodes {
51 engine.add_node(node)?;
52 }
53
54 let edges = storage.all_graph_edges()?;
56 for edge in edges {
57 engine.add_edge(edge)?;
58 }
59
60 engine.compute_centrality();
62
63 Ok(engine)
64 }
65
66 pub fn node_count(&self) -> usize {
68 self.nodes.len()
69 }
70
71 pub fn edge_count(&self) -> usize {
73 self.edges.len()
74 }
75
76 pub fn expand(
78 &self,
79 start_ids: &[String],
80 max_hops: usize,
81 ) -> Result<Vec<GraphNode>, CodememError> {
82 let mut visited = std::collections::HashSet::new();
83 let mut result = Vec::new();
84
85 for start_id in start_ids {
86 let nodes = self.bfs(start_id, max_hops)?;
87 for node in nodes {
88 if visited.insert(node.id.clone()) {
89 result.push(node);
90 }
91 }
92 }
93
94 Ok(result)
95 }
96
97 pub fn neighbors(&self, node_id: &str) -> Result<Vec<GraphNode>, CodememError> {
99 let idx = self
100 .id_to_index
101 .get(node_id)
102 .ok_or_else(|| CodememError::NotFound(format!("Node {node_id}")))?;
103
104 let mut result = Vec::new();
105 for neighbor_idx in self.graph.neighbors(*idx) {
106 if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
107 if let Some(node) = self.nodes.get(neighbor_id) {
108 result.push(node.clone());
109 }
110 }
111 }
112
113 Ok(result)
114 }
115
116 pub fn connected_components(&self) -> Vec<Vec<String>> {
122 let mut visited: HashSet<NodeIndex> = HashSet::new();
123 let mut components: Vec<Vec<String>> = Vec::new();
124
125 for &start_idx in self.id_to_index.values() {
126 if visited.contains(&start_idx) {
127 continue;
128 }
129
130 let mut component: Vec<String> = Vec::new();
132 let mut queue: VecDeque<NodeIndex> = VecDeque::new();
133 queue.push_back(start_idx);
134 visited.insert(start_idx);
135
136 while let Some(current) = queue.pop_front() {
137 if let Some(node_id) = self.graph.node_weight(current) {
138 component.push(node_id.clone());
139 }
140
141 for neighbor in self.graph.neighbors_directed(current, Direction::Outgoing) {
143 if visited.insert(neighbor) {
144 queue.push_back(neighbor);
145 }
146 }
147
148 for neighbor in self.graph.neighbors_directed(current, Direction::Incoming) {
150 if visited.insert(neighbor) {
151 queue.push_back(neighbor);
152 }
153 }
154 }
155
156 component.sort();
157 components.push(component);
158 }
159
160 components.sort();
161 components
162 }
163
164 pub fn compute_centrality(&mut self) {
170 let n = self.nodes.len();
171 if n <= 1 {
172 for node in self.nodes.values_mut() {
173 node.centrality = 0.0;
174 }
175 return;
176 }
177
178 let denominator = (n - 1) as f64;
179
180 let centrality_map: HashMap<String, f64> = self
182 .id_to_index
183 .iter()
184 .map(|(id, &idx)| {
185 let in_deg = self
186 .graph
187 .neighbors_directed(idx, Direction::Incoming)
188 .count();
189 let out_deg = self
190 .graph
191 .neighbors_directed(idx, Direction::Outgoing)
192 .count();
193 let centrality = (in_deg + out_deg) as f64 / denominator;
194 (id.clone(), centrality)
195 })
196 .collect();
197
198 for (id, centrality) in ¢rality_map {
200 if let Some(node) = self.nodes.get_mut(id) {
201 node.centrality = *centrality;
202 }
203 }
204 }
205
206 pub fn get_all_nodes(&self) -> Vec<GraphNode> {
208 self.nodes.values().cloned().collect()
209 }
210
211 pub fn recompute_centrality(&mut self) {
216 self.cached_pagerank = self.pagerank(0.85, 100, 1e-6);
217 self.cached_betweenness = self.betweenness_centrality();
218 }
219
220 pub fn get_pagerank(&self, node_id: &str) -> f64 {
222 self.cached_pagerank.get(node_id).copied().unwrap_or(0.0)
223 }
224
225 pub fn get_betweenness(&self, node_id: &str) -> f64 {
227 self.cached_betweenness.get(node_id).copied().unwrap_or(0.0)
228 }
229
230 pub fn graph_strength_for_memory(&self, memory_id: &str) -> f64 {
237 let idx = match self.id_to_index.get(memory_id) {
238 Some(idx) => *idx,
239 None => return 0.0,
240 };
241
242 let mut max_pagerank = 0.0_f64;
243 let mut max_betweenness = 0.0_f64;
244 let mut code_neighbor_count = 0_usize;
245 let mut total_edge_weight = 0.0_f64;
246
247 for direction in &[Direction::Outgoing, Direction::Incoming] {
249 for neighbor_idx in self.graph.neighbors_directed(idx, *direction) {
250 if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
251 if neighbor_id.starts_with("sym:")
253 || neighbor_id.starts_with("file:")
254 || neighbor_id.starts_with("chunk:")
255 || neighbor_id.starts_with("pkg:")
256 {
257 code_neighbor_count += 1;
258 let pr = self
259 .cached_pagerank
260 .get(neighbor_id)
261 .copied()
262 .unwrap_or(0.0);
263 let bt = self
264 .cached_betweenness
265 .get(neighbor_id)
266 .copied()
267 .unwrap_or(0.0);
268 max_pagerank = max_pagerank.max(pr);
269 max_betweenness = max_betweenness.max(bt);
270
271 for edge in self.edges.values() {
273 if (edge.src == memory_id && edge.dst == *neighbor_id)
274 || (edge.dst == memory_id && edge.src == *neighbor_id)
275 {
276 total_edge_weight += edge.weight;
277 break;
278 }
279 }
280 }
281 }
282 }
283 }
284
285 if code_neighbor_count == 0 {
286 return 0.0;
287 }
288
289 let connectivity_bonus = (code_neighbor_count as f64 / 5.0).min(1.0);
290 let edge_weight_bonus = (total_edge_weight / code_neighbor_count as f64).min(1.0);
291
292 (0.4 * max_pagerank
293 + 0.3 * max_betweenness
294 + 0.2 * connectivity_bonus
295 + 0.1 * edge_weight_bonus)
296 .min(1.0)
297 }
298
299 pub fn max_degree(&self) -> f64 {
302 if self.nodes.len() <= 1 {
303 return 1.0;
304 }
305 self.id_to_index
306 .values()
307 .map(|&idx| {
308 let in_deg = self
309 .graph
310 .neighbors_directed(idx, Direction::Incoming)
311 .count();
312 let out_deg = self
313 .graph
314 .neighbors_directed(idx, Direction::Outgoing)
315 .count();
316 (in_deg + out_deg) as f64
317 })
318 .fold(1.0f64, f64::max)
319 }
320}
321
322#[cfg(test)]
323#[path = "tests/lib_tests.rs"]
324mod tests;