1use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
4
5use crate::graph::traversal::TraversalDirection;
6use crate::graph::MemoryGraph;
7use crate::types::{AmemResult, Edge, EdgeType, EventType};
8
9#[derive(Debug, Clone)]
11pub enum CentralityAlgorithm {
12 PageRank { damping: f32 },
14 Degree,
16 Betweenness,
18}
19
20pub struct CentralityParams {
22 pub algorithm: CentralityAlgorithm,
23 pub max_iterations: u32,
24 pub tolerance: f32,
25 pub top_k: usize,
26 pub event_types: Vec<EventType>,
27 pub edge_types: Vec<EdgeType>,
28}
29
30pub struct CentralityResult {
32 pub scores: Vec<(u64, f32)>,
34 pub algorithm: CentralityAlgorithm,
35 pub iterations: u32,
36 pub converged: bool,
37}
38
39pub struct ShortestPathParams {
41 pub source_id: u64,
42 pub target_id: u64,
43 pub edge_types: Vec<EdgeType>,
44 pub direction: TraversalDirection,
45 pub max_depth: u32,
46 pub weighted: bool,
47}
48
49pub struct PathResult {
51 pub path: Vec<u64>,
53 pub edges: Vec<Edge>,
55 pub cost: f32,
57 pub found: bool,
58}
59
60impl super::query::QueryEngine {
61 pub fn centrality(
63 &self,
64 graph: &MemoryGraph,
65 params: CentralityParams,
66 ) -> AmemResult<CentralityResult> {
67 let type_filter: HashSet<EventType> = params.event_types.iter().copied().collect();
68 let edge_filter: HashSet<EdgeType> = params.edge_types.iter().copied().collect();
69
70 let node_ids: Vec<u64> = graph
72 .nodes()
73 .iter()
74 .filter(|n| type_filter.is_empty() || type_filter.contains(&n.event_type))
75 .map(|n| n.id)
76 .collect();
77
78 let node_set: HashSet<u64> = node_ids.iter().copied().collect();
79
80 let edges: Vec<&Edge> = graph
82 .edges()
83 .iter()
84 .filter(|e| {
85 node_set.contains(&e.source_id)
86 && node_set.contains(&e.target_id)
87 && (edge_filter.is_empty() || edge_filter.contains(&e.edge_type))
88 })
89 .collect();
90
91 match params.algorithm {
92 CentralityAlgorithm::PageRank { damping } => self.pagerank(
93 &node_ids,
94 &edges,
95 damping,
96 params.max_iterations,
97 params.tolerance,
98 params.top_k,
99 ),
100 CentralityAlgorithm::Degree => self.degree_centrality(&node_ids, &edges, params.top_k),
101 CentralityAlgorithm::Betweenness => {
102 self.betweenness_centrality(&node_ids, &edges, params.top_k)
103 }
104 }
105 }
106
107 fn pagerank(
108 &self,
109 node_ids: &[u64],
110 edges: &[&Edge],
111 damping: f32,
112 max_iterations: u32,
113 tolerance: f32,
114 top_k: usize,
115 ) -> AmemResult<CentralityResult> {
116 let n = node_ids.len();
117 if n == 0 {
118 return Ok(CentralityResult {
119 scores: Vec::new(),
120 algorithm: CentralityAlgorithm::PageRank { damping },
121 iterations: 0,
122 converged: true,
123 });
124 }
125
126 let id_to_idx: HashMap<u64, usize> = node_ids
127 .iter()
128 .enumerate()
129 .map(|(i, &id)| (id, i))
130 .collect();
131
132 let mut outgoing: Vec<Vec<usize>> = vec![Vec::new(); n];
134 let mut incoming: Vec<Vec<usize>> = vec![Vec::new(); n];
135
136 for edge in edges {
137 if let (Some(&src_idx), Some(&tgt_idx)) = (
138 id_to_idx.get(&edge.source_id),
139 id_to_idx.get(&edge.target_id),
140 ) {
141 outgoing[src_idx].push(tgt_idx);
142 incoming[tgt_idx].push(src_idx);
143 }
144 }
145
146 let mut pr = vec![1.0 / n as f32; n];
147 let mut iterations = 0;
148 let mut converged = false;
149
150 for _ in 0..max_iterations {
151 iterations += 1;
152 let mut new_pr = vec![(1.0 - damping) / n as f32; n];
153
154 let dangling_sum: f32 = (0..n)
156 .filter(|&i| outgoing[i].is_empty())
157 .map(|i| pr[i])
158 .sum();
159
160 for i in 0..n {
161 new_pr[i] += damping * dangling_sum / n as f32;
162 for &j in &incoming[i] {
163 let out_degree = outgoing[j].len() as f32;
164 if out_degree > 0.0 {
165 new_pr[i] += damping * pr[j] / out_degree;
166 }
167 }
168 }
169
170 let max_diff = (0..n)
172 .map(|i| (new_pr[i] - pr[i]).abs())
173 .fold(0.0f32, f32::max);
174
175 pr = new_pr;
176
177 if max_diff < tolerance {
178 converged = true;
179 break;
180 }
181 }
182
183 let mut scores: Vec<(u64, f32)> = node_ids
184 .iter()
185 .zip(pr.iter())
186 .map(|(&id, &s)| (id, s))
187 .collect();
188 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
189 scores.truncate(top_k);
190
191 Ok(CentralityResult {
192 scores,
193 algorithm: CentralityAlgorithm::PageRank { damping },
194 iterations,
195 converged,
196 })
197 }
198
199 fn degree_centrality(
200 &self,
201 node_ids: &[u64],
202 edges: &[&Edge],
203 top_k: usize,
204 ) -> AmemResult<CentralityResult> {
205 let n = node_ids.len();
206 let mut degrees: HashMap<u64, u32> = HashMap::new();
207 for &id in node_ids {
208 degrees.insert(id, 0);
209 }
210
211 for edge in edges {
212 *degrees.entry(edge.source_id).or_insert(0) += 1;
213 *degrees.entry(edge.target_id).or_insert(0) += 1;
214 }
215
216 let max_possible = if n > 1 { 2 * (n - 1) } else { 1 };
217
218 let mut scores: Vec<(u64, f32)> = degrees
219 .into_iter()
220 .map(|(id, deg)| (id, deg as f32 / max_possible as f32))
221 .collect();
222 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
223 scores.truncate(top_k);
224
225 Ok(CentralityResult {
226 scores,
227 algorithm: CentralityAlgorithm::Degree,
228 iterations: 0,
229 converged: true,
230 })
231 }
232
233 fn betweenness_centrality(
234 &self,
235 node_ids: &[u64],
236 edges: &[&Edge],
237 top_k: usize,
238 ) -> AmemResult<CentralityResult> {
239 let n = node_ids.len();
240 if n == 0 {
241 return Ok(CentralityResult {
242 scores: Vec::new(),
243 algorithm: CentralityAlgorithm::Betweenness,
244 iterations: 0,
245 converged: true,
246 });
247 }
248
249 let id_to_idx: HashMap<u64, usize> = node_ids
250 .iter()
251 .enumerate()
252 .map(|(i, &id)| (id, i))
253 .collect();
254
255 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
257 for edge in edges {
258 if let (Some(&src), Some(&tgt)) = (
259 id_to_idx.get(&edge.source_id),
260 id_to_idx.get(&edge.target_id),
261 ) {
262 adj[src].push(tgt);
263 adj[tgt].push(src);
264 }
265 }
266
267 let mut betweenness = vec![0.0f32; n];
268
269 let sources: Vec<usize> = if n > 10_000 {
271 (0..1000.min(n)).collect()
272 } else {
273 (0..n).collect()
274 };
275
276 for &s in &sources {
278 let mut stack: Vec<usize> = Vec::new();
279 let mut pred: Vec<Vec<usize>> = vec![Vec::new(); n];
280 let mut sigma = vec![0.0f64; n];
281 sigma[s] = 1.0;
282 let mut dist: Vec<i64> = vec![-1; n];
283 dist[s] = 0;
284 let mut queue = VecDeque::new();
285 queue.push_back(s);
286
287 while let Some(v) = queue.pop_front() {
288 stack.push(v);
289 for &w in &adj[v] {
290 if dist[w] < 0 {
291 queue.push_back(w);
292 dist[w] = dist[v] + 1;
293 }
294 if dist[w] == dist[v] + 1 {
295 sigma[w] += sigma[v];
296 pred[w].push(v);
297 }
298 }
299 }
300
301 let mut delta = vec![0.0f64; n];
302 while let Some(w) = stack.pop() {
303 for &v in &pred[w] {
304 delta[v] += (sigma[v] / sigma[w]) * (1.0 + delta[w]);
305 }
306 if w != s {
307 betweenness[w] += delta[w] as f32;
308 }
309 }
310 }
311
312 let norm = if n > 2 {
314 ((n - 1) * (n - 2)) as f32
315 } else {
316 1.0
317 };
318
319 let mut scores: Vec<(u64, f32)> = node_ids
320 .iter()
321 .enumerate()
322 .map(|(i, &id)| (id, betweenness[i] / norm))
323 .collect();
324 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
325 scores.truncate(top_k);
326
327 Ok(CentralityResult {
328 scores,
329 algorithm: CentralityAlgorithm::Betweenness,
330 iterations: 0,
331 converged: true,
332 })
333 }
334
335 pub fn shortest_path(
337 &self,
338 graph: &MemoryGraph,
339 params: ShortestPathParams,
340 ) -> AmemResult<PathResult> {
341 if params.source_id == params.target_id {
343 return Ok(PathResult {
344 path: vec![params.source_id],
345 edges: Vec::new(),
346 cost: 0.0,
347 found: true,
348 });
349 }
350
351 if graph.get_node(params.source_id).is_none() {
353 return Err(crate::types::AmemError::NodeNotFound(params.source_id));
354 }
355 if graph.get_node(params.target_id).is_none() {
356 return Err(crate::types::AmemError::NodeNotFound(params.target_id));
357 }
358
359 let edge_filter: HashSet<EdgeType> = params.edge_types.iter().copied().collect();
360
361 if params.weighted {
362 self.dijkstra_path(graph, ¶ms, &edge_filter)
363 } else {
364 self.bidirectional_bfs(graph, ¶ms, &edge_filter)
365 }
366 }
367
368 fn bidirectional_bfs(
369 &self,
370 graph: &MemoryGraph,
371 params: &ShortestPathParams,
372 edge_filter: &HashSet<EdgeType>,
373 ) -> AmemResult<PathResult> {
374 let mut forward_visited: HashMap<u64, u64> = HashMap::new(); let mut backward_visited: HashMap<u64, u64> = HashMap::new();
376 let mut forward_queue: VecDeque<(u64, u32)> = VecDeque::new();
377 let mut backward_queue: VecDeque<(u64, u32)> = VecDeque::new();
378
379 forward_visited.insert(params.source_id, params.source_id);
380 backward_visited.insert(params.target_id, params.target_id);
381 forward_queue.push_back((params.source_id, 0));
382 backward_queue.push_back((params.target_id, 0));
383
384 let half_depth = params.max_depth / 2 + 1;
385 let mut meeting_node: Option<u64> = None;
386
387 let get_neighbors = |node_id: u64, forward: bool| -> Vec<u64> {
389 let mut neighbors = Vec::new();
390 match params.direction {
391 TraversalDirection::Forward | TraversalDirection::Both => {
392 if forward {
393 for edge in graph.edges_from(node_id) {
394 if edge_filter.is_empty() || edge_filter.contains(&edge.edge_type) {
395 neighbors.push(edge.target_id);
396 }
397 }
398 }
399 }
400 TraversalDirection::Backward => {}
401 }
402 match params.direction {
403 TraversalDirection::Backward | TraversalDirection::Both => {
404 if forward {
405 for edge in graph.edges_to(node_id) {
406 if edge_filter.is_empty() || edge_filter.contains(&edge.edge_type) {
407 neighbors.push(edge.source_id);
408 }
409 }
410 }
411 }
412 TraversalDirection::Forward => {}
413 }
414 if !forward {
416 let mut rev_neighbors = Vec::new();
417 match params.direction {
418 TraversalDirection::Forward | TraversalDirection::Both => {
419 for edge in graph.edges_to(node_id) {
420 if edge_filter.is_empty() || edge_filter.contains(&edge.edge_type) {
421 rev_neighbors.push(edge.source_id);
422 }
423 }
424 }
425 TraversalDirection::Backward => {}
426 }
427 match params.direction {
428 TraversalDirection::Backward | TraversalDirection::Both => {
429 for edge in graph.edges_from(node_id) {
430 if edge_filter.is_empty() || edge_filter.contains(&edge.edge_type) {
431 rev_neighbors.push(edge.target_id);
432 }
433 }
434 }
435 TraversalDirection::Forward => {}
436 }
437 return rev_neighbors;
438 }
439 neighbors
440 };
441
442 'outer: while !forward_queue.is_empty() || !backward_queue.is_empty() {
443 if let Some((node, depth)) = forward_queue.pop_front() {
445 if depth < half_depth {
446 for neighbor in get_neighbors(node, true) {
447 forward_visited.entry(neighbor).or_insert_with(|| {
448 forward_queue.push_back((neighbor, depth + 1));
449 node
450 });
451 if backward_visited.contains_key(&neighbor) {
452 forward_visited.entry(neighbor).or_insert(node);
453 meeting_node = Some(neighbor);
454 break 'outer;
455 }
456 }
457 }
458 }
459
460 if let Some((node, depth)) = backward_queue.pop_front() {
462 if depth < half_depth {
463 for neighbor in get_neighbors(node, false) {
464 backward_visited.entry(neighbor).or_insert_with(|| {
465 backward_queue.push_back((neighbor, depth + 1));
466 node
467 });
468 if forward_visited.contains_key(&neighbor) {
469 backward_visited.entry(neighbor).or_insert(node);
470 meeting_node = Some(neighbor);
471 break 'outer;
472 }
473 }
474 }
475 }
476 }
477
478 match meeting_node {
479 Some(mid) => {
480 let mut forward_path = Vec::new();
482 let mut current = mid;
483 while current != params.source_id {
484 forward_path.push(current);
485 current = forward_visited[¤t];
486 }
487 forward_path.push(params.source_id);
488 forward_path.reverse();
489
490 let mut backward_path = Vec::new();
491 current = mid;
492 while current != params.target_id {
493 current = backward_visited[¤t];
494 backward_path.push(current);
495 }
496
497 let mut path = forward_path;
498 path.extend(backward_path);
499
500 let cost = (path.len() - 1) as f32;
501
502 let mut edges = Vec::new();
504 for i in 0..path.len() - 1 {
505 for edge in graph.edges_from(path[i]) {
506 if edge.target_id == path[i + 1] {
507 edges.push(*edge);
508 break;
509 }
510 }
511 if edges.len() < i + 1 {
512 for edge in graph.edges_from(path[i + 1]) {
514 if edge.target_id == path[i] {
515 edges.push(*edge);
516 break;
517 }
518 }
519 }
520 }
521
522 Ok(PathResult {
523 path,
524 edges,
525 cost,
526 found: true,
527 })
528 }
529 None => Ok(PathResult {
530 path: Vec::new(),
531 edges: Vec::new(),
532 cost: 0.0,
533 found: false,
534 }),
535 }
536 }
537
538 fn dijkstra_path(
539 &self,
540 graph: &MemoryGraph,
541 params: &ShortestPathParams,
542 edge_filter: &HashSet<EdgeType>,
543 ) -> AmemResult<PathResult> {
544 #[derive(PartialEq)]
545 struct State {
546 cost: f32,
547 node: u64,
548 }
549 impl Eq for State {}
550 impl PartialOrd for State {
551 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
552 Some(self.cmp(other))
553 }
554 }
555 impl Ord for State {
556 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
557 other
558 .cost
559 .partial_cmp(&self.cost)
560 .unwrap_or(std::cmp::Ordering::Equal)
561 }
562 }
563
564 let mut dist: HashMap<u64, f32> = HashMap::new();
565 let mut prev: HashMap<u64, u64> = HashMap::new();
566 let mut heap = BinaryHeap::new();
567
568 dist.insert(params.source_id, 0.0);
569 heap.push(State {
570 cost: 0.0,
571 node: params.source_id,
572 });
573
574 while let Some(State { cost, node }) = heap.pop() {
575 if node == params.target_id {
576 let mut path = Vec::new();
578 let mut current = params.target_id;
579 while current != params.source_id {
580 path.push(current);
581 current = prev[¤t];
582 }
583 path.push(params.source_id);
584 path.reverse();
585
586 let mut edges = Vec::new();
588 for i in 0..path.len() - 1 {
589 for edge in graph.edges_from(path[i]) {
590 if edge.target_id == path[i + 1] {
591 edges.push(*edge);
592 break;
593 }
594 }
595 }
596
597 return Ok(PathResult {
598 path,
599 edges,
600 cost,
601 found: true,
602 });
603 }
604
605 if cost > *dist.get(&node).unwrap_or(&f32::INFINITY) {
606 continue;
607 }
608
609 for edge in graph.edges_from(node) {
611 if !edge_filter.is_empty() && !edge_filter.contains(&edge.edge_type) {
612 continue;
613 }
614 let edge_cost = 1.0 - edge.weight; let next_cost = cost + edge_cost;
616
617 if next_cost < *dist.get(&edge.target_id).unwrap_or(&f32::INFINITY) {
618 dist.insert(edge.target_id, next_cost);
619 prev.insert(edge.target_id, node);
620 heap.push(State {
621 cost: next_cost,
622 node: edge.target_id,
623 });
624 }
625 }
626
627 if matches!(
629 params.direction,
630 TraversalDirection::Backward | TraversalDirection::Both
631 ) {
632 for edge in graph.edges_to(node) {
633 if !edge_filter.is_empty() && !edge_filter.contains(&edge.edge_type) {
634 continue;
635 }
636 let edge_cost = 1.0 - edge.weight;
637 let next_cost = cost + edge_cost;
638
639 if next_cost < *dist.get(&edge.source_id).unwrap_or(&f32::INFINITY) {
640 dist.insert(edge.source_id, next_cost);
641 prev.insert(edge.source_id, node);
642 heap.push(State {
643 cost: next_cost,
644 node: edge.source_id,
645 });
646 }
647 }
648 }
649 }
650
651 Ok(PathResult {
652 path: Vec::new(),
653 edges: Vec::new(),
654 cost: 0.0,
655 found: false,
656 })
657 }
658}