1use std::collections::{HashMap, HashSet, BinaryHeap, VecDeque};
7use std::cmp::Reverse;
8use kotoba_core::prelude::*;
9use crate::graph::{Graph, EdgeData, VertexData};
10
11#[derive(Debug, Clone, Default)]
13pub struct ShortestPathResult {
14 pub distances: HashMap<VertexId, u64>,
15 pub previous: HashMap<VertexId, VertexId>,
16}
17
18#[derive(Debug, Clone)]
20pub struct CentralityResult {
21 pub scores: HashMap<VertexId, f64>,
23 pub algorithm: CentralityAlgorithm,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq)]
29pub enum CentralityAlgorithm {
30 Degree,
32 Betweenness,
34 Closeness,
36 Eigenvector,
38 PageRank,
40}
41
42#[derive(Debug, Clone)]
44pub struct PatternMatchResult {
45 pub mappings: Vec<SubgraphMapping>,
47 pub count: usize,
49}
50
51#[derive(Debug, Clone)]
53pub struct SubgraphMapping {
54 pub vertex_map: HashMap<VertexId, VertexId>,
56 pub edge_map: HashMap<EdgeId, EdgeId>,
58}
59
60#[derive(Debug)]
62pub struct GraphAlgorithms;
63
64impl GraphAlgorithms {
65 pub fn shortest_path_dijkstra(
67 graph: &Graph,
68 source: VertexId,
69 weight_fn: impl Fn(&EdgeData) -> u64,
70 ) -> Result<ShortestPathResult> {
71 let mut distances: HashMap<VertexId, u64> = HashMap::new();
72 let mut predecessors: HashMap<VertexId, VertexId> = HashMap::new();
73 let mut pq: BinaryHeap<Reverse<(u64, VertexId)>> = BinaryHeap::new();
74 let mut visited: HashSet<VertexId> = HashSet::new();
75
76 for &vertex_id in graph.vertices.keys() {
78 distances.insert(vertex_id, u64::MAX);
79 }
80 distances.insert(source, 0);
81 pq.push(Reverse((0, source)));
82
83 while let Some(Reverse((dist, u))) = pq.pop() {
84 if visited.contains(&u) {
85 continue;
86 }
87 visited.insert(u);
88
89 if let Some(neighbors) = graph.adj_out.get(&u) {
91 for &v in neighbors {
92 let edge_weight = graph.edges.values()
94 .find(|e| e.src == u && e.dst == v)
95 .map(&weight_fn)
96 .unwrap_or(1); let new_dist = dist + edge_weight;
99
100 if new_dist < *distances.get(&v).unwrap_or(&u64::MAX) {
101 distances.insert(v, new_dist);
102 predecessors.insert(v, u);
103 pq.push(Reverse((new_dist, v)));
104 }
105 }
106 }
107 }
108
109 Ok(ShortestPathResult {
110 distances,
111 previous: predecessors,
112 })
113 }
114
115 pub fn shortest_path_bellman_ford(
117 graph: &Graph,
118 source: VertexId,
119 weight_fn: impl Fn(&EdgeData) -> u64,
120 ) -> Result<ShortestPathResult> {
121 let mut distances: HashMap<VertexId, u64> = HashMap::new();
122 let mut predecessors: HashMap<VertexId, VertexId> = HashMap::new();
123
124 for &vertex_id in graph.vertices.keys() {
126 distances.insert(vertex_id, u64::MAX);
127 }
128 distances.insert(source, 0);
129
130 let vertex_count = graph.vertices.len();
131
132 for _ in 0..vertex_count - 1 {
134 for edge in graph.edges.values() {
135 let u = edge.src;
136 let v = edge.dst;
137 let weight = weight_fn(edge);
138
139 if let (Some(&dist_u), Some(&dist_v)) = (distances.get(&u), distances.get(&v)) {
140 if dist_u + weight < dist_v {
141 distances.insert(v, dist_u + weight);
142 predecessors.insert(v, u);
143 }
144 }
145 }
146 }
147
148 for edge in graph.edges.values() {
150 let u = edge.src;
151 let v = edge.dst;
152 let weight = weight_fn(edge);
153
154 if let (Some(&dist_u), Some(&dist_v)) = (distances.get(&u), distances.get(&v)) {
155 if dist_u + weight < dist_v {
156 return Err(kotoba_errors::KotobaError::Execution("Negative cycle detected".to_string()));
157 }
158 }
159 }
160
161 Ok(ShortestPathResult {
162 distances,
163 previous: predecessors,
164 })
165 }
166
167 pub fn all_pairs_shortest_paths(
169 graph: &Graph,
170 weight_fn: impl Fn(&EdgeData) -> u64,
171 ) -> Result<HashMap<(VertexId, VertexId), u64>> {
172 let vertices: Vec<VertexId> = graph.vertices.keys().cloned().collect();
173 let _n = vertices.len();
174
175 let mut dist: HashMap<(VertexId, VertexId), u64> = HashMap::new();
177
178 for &u in &vertices {
180 for &v in &vertices {
181 if u == v {
182 dist.insert((u, v), 0);
183 } else {
184 dist.insert((u, v), u64::MAX);
185 }
186 }
187 }
188
189 for edge in graph.edges.values() {
191 let weight = weight_fn(edge);
192 dist.insert((edge.src, edge.dst), weight);
193 }
194
195 for &k in &vertices {
197 for &i in &vertices {
198 for &j in &vertices {
199 let ik_dist = *dist.get(&(i, k)).unwrap_or(&u64::MAX);
200 let kj_dist = *dist.get(&(k, j)).unwrap_or(&u64::MAX);
201 let ij_dist = *dist.get(&(i, j)).unwrap_or(&u64::MAX);
202
203 if ik_dist + kj_dist < ij_dist {
204 dist.insert((i, j), ik_dist + kj_dist);
205 }
206 }
207 }
208 }
209
210 Ok(dist)
211 }
212
213 pub fn shortest_path_astar(
215 graph: &Graph,
216 source: VertexId,
217 target: VertexId,
218 weight_fn: impl Fn(&EdgeData) -> u64,
219 heuristic_fn: impl Fn(VertexId, VertexId) -> u64,
220 ) -> Result<Option<Vec<VertexId>>> {
221 let mut g_score: HashMap<VertexId, u64> = HashMap::new();
222 let mut f_score: HashMap<VertexId, u64> = HashMap::new();
223 let mut came_from: HashMap<VertexId, VertexId> = HashMap::new();
224 let mut open_set: BinaryHeap<Reverse<(u64, VertexId)>> = BinaryHeap::new();
225 let mut open_set_hash: HashSet<VertexId> = HashSet::new();
226 let mut closed_set: HashSet<VertexId> = HashSet::new();
227
228 for &vertex_id in graph.vertices.keys() {
230 g_score.insert(vertex_id, u64::MAX);
231 f_score.insert(vertex_id, u64::MAX);
232 }
233
234 g_score.insert(source, 0);
235 f_score.insert(source, heuristic_fn(source, target));
236 open_set.push(Reverse((f_score[&source], source)));
237 open_set_hash.insert(source);
238
239 while let Some(Reverse((_, current))) = open_set.pop() {
240 open_set_hash.remove(¤t);
241
242 if current == target {
243 return Ok(Some(Self::reconstruct_path(&came_from, current)));
245 }
246
247 if closed_set.contains(¤t) {
248 continue;
249 }
250 closed_set.insert(current);
251
252 if let Some(neighbors) = graph.adj_out.get(¤t) {
254 for &neighbor in neighbors {
255 if closed_set.contains(&neighbor) {
256 continue;
257 }
258
259 let edge_weight = graph.edges.values()
261 .find(|e| e.src == current && e.dst == neighbor)
262 .map(&weight_fn)
263 .unwrap_or(1);
264
265 let tentative_g_score = g_score[¤t] + edge_weight;
266
267 if tentative_g_score < *g_score.get(&neighbor).unwrap_or(&u64::MAX) {
268 came_from.insert(neighbor, current);
269 g_score.insert(neighbor, tentative_g_score);
270 f_score.insert(neighbor, tentative_g_score + heuristic_fn(neighbor, target));
271
272 if !open_set_hash.contains(&neighbor) {
273 open_set.push(Reverse((f_score[&neighbor], neighbor)));
274 open_set_hash.insert(neighbor);
275 }
276 }
277 }
278 }
279 }
280
281 Ok(None)
283 }
284
285 fn reconstruct_path(came_from: &HashMap<VertexId, VertexId>, current: VertexId) -> Vec<VertexId> {
287 let mut path = vec![current];
288 let mut current = current;
289
290 while let Some(&prev) = came_from.get(¤t) {
291 path.push(prev);
292 current = prev;
293 }
294
295 path.reverse();
296 path
297 }
298
299 pub fn degree_centrality(graph: &Graph, normalized: bool) -> CentralityResult {
301 let mut scores: HashMap<VertexId, f64> = HashMap::new();
302 let max_degree = if normalized { graph.vertices.len().saturating_sub(1) as f64 } else { 1.0 };
303
304 for (&vertex_id, _) in &graph.vertices {
305 let out_degree = graph.adj_out.get(&vertex_id).map(|s| s.len()).unwrap_or(0) as f64;
306 let in_degree = graph.adj_in.get(&vertex_id).map(|s| s.len()).unwrap_or(0) as f64;
307 let total_degree = out_degree + in_degree;
308
309 scores.insert(vertex_id, if normalized && max_degree > 0.0 {
310 total_degree / max_degree
311 } else {
312 total_degree
313 });
314 }
315
316 CentralityResult {
317 scores,
318 algorithm: CentralityAlgorithm::Degree,
319 }
320 }
321
322 pub fn betweenness_centrality(graph: &Graph, normalized: bool) -> CentralityResult {
324 let mut scores: HashMap<VertexId, f64> = HashMap::new();
325 let vertices: Vec<VertexId> = graph.vertices.keys().cloned().collect();
326
327 for &v in &vertices {
329 scores.insert(v, 0.0);
330 }
331
332 for &s in &vertices {
333 let mut stack: Vec<VertexId> = Vec::new();
335 let mut predecessors: HashMap<VertexId, Vec<VertexId>> = HashMap::new();
336 let mut sigma: HashMap<VertexId, f64> = HashMap::new();
337 let mut dist: HashMap<VertexId, i32> = HashMap::new();
338 let mut queue: VecDeque<VertexId> = VecDeque::new();
339
340 for &v in &vertices {
342 predecessors.insert(v, Vec::new());
343 sigma.insert(v, 0.0);
344 dist.insert(v, -1);
345 }
346
347 sigma.insert(s, 1.0);
348 dist.insert(s, 0);
349 queue.push_back(s);
350
351 while let Some(v) = queue.pop_front() {
353 stack.push(v);
354
355 if let Some(neighbors) = graph.adj_out.get(&v) {
356 for &w in neighbors {
357 if *dist.get(&w).unwrap_or(&-1) < 0 {
358 queue.push_back(w);
359 dist.insert(w, dist[&v] + 1);
360 }
361
362 if dist[&w] == dist[&v] + 1 {
363 sigma.insert(w, sigma[&w] + sigma[&v]);
364 predecessors.get_mut(&w).unwrap().push(v);
365 }
366 }
367 }
368 }
369
370 let mut delta: HashMap<VertexId, f64> = HashMap::new();
372 for &v in &vertices {
373 delta.insert(v, 0.0);
374 }
375
376 while let Some(w) = stack.pop() {
377 for &v in &predecessors[&w] {
378 let coeff = (sigma[&v] / sigma[&w]) * (1.0 + delta[&w]);
379 delta.insert(v, delta[&v] + coeff);
380 }
381
382 if w != s {
383 scores.insert(w, scores[&w] + delta[&w]);
384 }
385 }
386 }
387
388 if normalized {
390 let n = vertices.len() as f64;
391 if n > 2.0 {
392 let normalization_factor = 1.0 / ((n - 1.0) * (n - 2.0));
393 for score in scores.values_mut() {
394 *score *= normalization_factor;
395 }
396 }
397 }
398
399 CentralityResult {
400 scores,
401 algorithm: CentralityAlgorithm::Betweenness,
402 }
403 }
404
405 pub fn closeness_centrality(graph: &Graph, normalized: bool) -> CentralityResult {
407 let mut scores: HashMap<VertexId, f64> = HashMap::new();
408 let vertices: Vec<VertexId> = graph.vertices.keys().cloned().collect();
409
410 for &source in &vertices {
411 let result = Self::shortest_path_dijkstra(graph, source, |_| 1).unwrap_or_default();
413
414 let mut total_distance = 0.0;
415 let mut reachable_count = 0;
416
417 for &target in &vertices {
418 if let Some(&dist) = result.distances.get(&target) {
419 if dist < u64::MAX {
420 total_distance += dist as f64;
421 reachable_count += 1;
422 }
423 }
424 }
425
426 if reachable_count > 1 {
427 let closeness = if normalized {
428 (reachable_count - 1) as f64 / total_distance
429 } else {
430 1.0 / total_distance
431 };
432 scores.insert(source, closeness);
433 } else {
434 scores.insert(source, 0.0);
435 }
436 }
437
438 CentralityResult {
439 scores,
440 algorithm: CentralityAlgorithm::Closeness,
441 }
442 }
443
444 pub fn pagerank(graph: &Graph, damping_factor: f64, max_iterations: usize, tolerance: f64) -> CentralityResult {
446 let vertices: Vec<VertexId> = graph.vertices.keys().cloned().collect();
447 let n = vertices.len() as f64;
448
449 if n == 0.0 {
450 return CentralityResult {
451 scores: HashMap::new(),
452 algorithm: CentralityAlgorithm::PageRank,
453 };
454 }
455
456 let mut scores: HashMap<VertexId, f64> = vertices.iter()
458 .map(|&v| (v, 1.0 / n))
459 .collect();
460
461 let mut new_scores: HashMap<VertexId, f64> = HashMap::new();
462
463 for _ in 0..max_iterations {
464 let mut converged = true;
465
466 for &v in &vertices {
468 let mut incoming_score = 0.0;
469
470 for (&u, _) in &graph.vertices {
472 if let Some(out_neighbors) = graph.adj_out.get(&u) {
473 if out_neighbors.contains(&v) {
474 let out_degree = out_neighbors.len() as f64;
475 if out_degree > 0.0 {
476 incoming_score += scores[&u] / out_degree;
477 }
478 }
479 }
480 }
481
482 let new_score = (1.0 - damping_factor) / n + damping_factor * incoming_score;
483 new_scores.insert(v, new_score);
484
485 if (new_score - scores[&v]).abs() > tolerance {
487 converged = false;
488 }
489 }
490
491 scores.clone_from(&new_scores);
493
494 if converged {
495 break;
496 }
497 }
498
499 CentralityResult {
500 scores,
501 algorithm: CentralityAlgorithm::PageRank,
502 }
503 }
504
505 pub fn subgraph_isomorphism(pattern: &Graph, target: &Graph) -> PatternMatchResult {
507 let mut mappings = Vec::new();
508
509 if pattern.vertices.is_empty() {
510 return PatternMatchResult {
511 mappings,
512 count: 0,
513 };
514 }
515
516 let pattern_vertices: Vec<VertexId> = pattern.vertices.keys().cloned().collect();
517 let target_vertices: Vec<VertexId> = target.vertices.keys().cloned().collect();
518
519 Self::find_subgraph_matches(
521 pattern,
522 target,
523 &pattern_vertices,
524 &target_vertices,
525 0,
526 &mut HashMap::new(),
527 &mut HashMap::new(),
528 &mut mappings,
529 );
530
531 PatternMatchResult {
532 mappings: mappings.clone(),
533 count: mappings.len(),
534 }
535 }
536
537 fn find_subgraph_matches(
539 pattern: &Graph,
540 target: &Graph,
541 pattern_vertices: &[VertexId],
542 target_vertices: &[VertexId],
543 depth: usize,
544 vertex_map: &mut HashMap<VertexId, VertexId>,
545 edge_map: &mut HashMap<EdgeId, EdgeId>,
546 mappings: &mut Vec<SubgraphMapping>,
547 ) {
548 if depth == pattern_vertices.len() {
549 mappings.push(SubgraphMapping {
551 vertex_map: vertex_map.clone(),
552 edge_map: edge_map.clone(),
553 });
554 return;
555 }
556
557 let pattern_vertex = pattern_vertices[depth];
558
559 for &target_vertex in target_vertices {
560 let pattern_vertex_data = &pattern.vertices[&pattern_vertex];
562 let target_vertex_data = &target.vertices[&target_vertex];
563
564 if !Self::vertices_match(pattern_vertex_data, target_vertex_data) {
565 continue;
566 }
567
568 if Self::is_valid_mapping(pattern, target, pattern_vertex, target_vertex, vertex_map) {
570 vertex_map.insert(pattern_vertex, target_vertex);
572
573 let mut local_edge_map = edge_map.clone();
575 Self::map_edges(pattern, target, pattern_vertex, target_vertex, &mut local_edge_map);
576
577 Self::find_subgraph_matches(
579 pattern,
580 target,
581 pattern_vertices,
582 target_vertices,
583 depth + 1,
584 vertex_map,
585 &mut local_edge_map,
586 mappings,
587 );
588
589 vertex_map.remove(&pattern_vertex);
591 }
592 }
593 }
594
595 fn vertices_match(pattern_vertex: &VertexData, target_vertex: &VertexData) -> bool {
597 pattern_vertex.labels.iter().any(|label| target_vertex.labels.contains(label))
599 }
600
601 fn is_valid_mapping(
603 pattern: &Graph,
604 target: &Graph,
605 pattern_vertex: VertexId,
606 target_vertex: VertexId,
607 vertex_map: &HashMap<VertexId, VertexId>,
608 ) -> bool {
609 for (&pv, &tv) in vertex_map {
611 if let Some(pattern_neighbors) = pattern.adj_out.get(&pv) {
613 if pattern_neighbors.contains(&pattern_vertex) {
614 if let Some(target_neighbors) = target.adj_out.get(&tv) {
616 if !target_neighbors.contains(&target_vertex) {
617 return false;
618 }
619 } else {
620 return false;
621 }
622 }
623 }
624
625 if let Some(pattern_neighbors) = pattern.adj_in.get(&pv) {
626 if pattern_neighbors.contains(&pattern_vertex) {
627 if let Some(target_neighbors) = target.adj_in.get(&tv) {
628 if !target_neighbors.contains(&target_vertex) {
629 return false;
630 }
631 } else {
632 return false;
633 }
634 }
635 }
636 }
637
638 true
639 }
640
641 fn map_edges(
643 _pattern: &Graph,
644 _target: &Graph,
645 _pattern_vertex: VertexId,
646 _target_vertex: VertexId,
647 _edge_map: &mut HashMap<EdgeId, EdgeId>,
648 ) {
649 }
652}
653
654#[cfg(test)]
655mod tests {
656 use super::*;
657 use kotoba_core::types::*;
658
659 fn create_test_graph() -> Graph {
661 let mut graph = Graph::empty();
662
663 let v1 = graph.add_vertex(VertexData {
665 id: VertexId::new("v1").unwrap(),
666 labels: vec!["Person".to_string()],
667 props: HashMap::new(),
668 });
669
670 let v2 = graph.add_vertex(VertexData {
671 id: VertexId::new("v2").unwrap(),
672 labels: vec!["Person".to_string()],
673 props: HashMap::new(),
674 });
675
676 let v3 = graph.add_vertex(VertexData {
677 id: VertexId::new("v3").unwrap(),
678 labels: vec!["Person".to_string()],
679 props: HashMap::new(),
680 });
681
682 graph.add_edge(EdgeData {
684 id: EdgeId::new("e1").unwrap(),
685 src: v1,
686 dst: v2,
687 label: "FOLLOWS".to_string(),
688 props: HashMap::new(),
689 });
690
691 graph.add_edge(EdgeData {
692 id: EdgeId::new("e2").unwrap(),
693 src: v2,
694 dst: v3,
695 label: "FOLLOWS".to_string(),
696 props: HashMap::new(),
697 });
698
699 graph
700 }
701
702 #[test]
703 fn test_dijkstra_shortest_path() {
704 let graph = create_test_graph();
705 let source = VertexId::new("v1").unwrap();
706
707 let result = GraphAlgorithms::shortest_path_dijkstra(&graph, source, |_| 1).unwrap();
708
709 assert_eq!(result.distances[&source], 0);
711
712 let v2 = VertexId::new("v2").unwrap();
714 let v3 = VertexId::new("v3").unwrap();
715
716 assert!(result.distances[&v2] > 0);
717 assert!(result.distances[&v3] > result.distances[&v2]);
718 }
719
720 #[test]
721 fn test_degree_centrality() {
722 let graph = create_test_graph();
723
724 let result = GraphAlgorithms::degree_centrality(&graph, false);
725
726 assert_eq!(result.algorithm, CentralityAlgorithm::Degree);
727 assert!(!result.scores.is_empty());
728
729 for &score in result.scores.values() {
731 assert!(score >= 0.0);
732 }
733 }
734
735 #[test]
736 fn test_betweenness_centrality() {
737 let graph = create_test_graph();
738
739 let result = GraphAlgorithms::betweenness_centrality(&graph, false);
740
741 assert_eq!(result.algorithm, CentralityAlgorithm::Betweenness);
742 assert!(!result.scores.is_empty());
743
744 for &score in result.scores.values() {
746 assert!(score >= 0.0);
747 }
748 }
749
750 #[test]
751 fn test_pagerank() {
752 let graph = create_test_graph();
753
754 let result = GraphAlgorithms::pagerank(&graph, 0.85, 10, 1e-6);
755
756 assert_eq!(result.algorithm, CentralityAlgorithm::PageRank);
757 assert!(!result.scores.is_empty());
758
759 for &score in result.scores.values() {
761 assert!(score >= 0.0);
762 }
763 }
764
765 #[test]
766 fn test_subgraph_isomorphism() {
767 let pattern = create_test_graph();
768 let target = create_test_graph();
769
770 let result = GraphAlgorithms::subgraph_isomorphism(&pattern, &target);
771
772 assert!(result.count >= 0);
773 }
776}