1use std::collections::{BinaryHeap, HashMap, HashSet};
25
26#[derive(Debug, Clone)]
32pub struct KgNode {
33 pub id: String,
35 pub label: String,
37 pub node_type: String,
39 pub importance: f64,
41}
42
43#[derive(Debug, Clone)]
45pub struct KgEdge {
46 pub from_id: String,
48 pub to_id: String,
50 pub relation: String,
52 pub weight: f64,
54 pub confidence: f64,
56}
57
58#[derive(Debug, Clone)]
60pub struct KgPath {
61 pub nodes: Vec<String>,
63 pub edges: Vec<String>,
65 pub total_weight: f64,
67 pub hop_count: usize,
69}
70
71impl KgPath {
72 pub fn endpoint_pair(&self) -> (&str, &str) {
74 let start = self.nodes.first().map(String::as_str).unwrap_or("");
75 let end = self.nodes.last().map(String::as_str).unwrap_or("");
76 (start, end)
77 }
78}
79
80#[derive(Debug, Clone)]
86pub struct PathRankingConfig {
87 pub weight_factor: f64,
89 pub hop_penalty: f64,
91 pub confidence_factor: f64,
93 pub importance_bonus: f64,
95}
96
97impl Default for PathRankingConfig {
98 fn default() -> Self {
99 Self {
100 weight_factor: 1.0,
101 hop_penalty: 0.9,
102 confidence_factor: 1.0,
103 importance_bonus: 0.1,
104 }
105 }
106}
107
108pub struct KnowledgeGraph {
114 nodes: HashMap<String, KgNode>,
115 adj: HashMap<String, Vec<usize>>,
117 edges: Vec<KgEdge>,
118}
119
120impl KnowledgeGraph {
121 pub fn new() -> Self {
123 Self {
124 nodes: HashMap::new(),
125 adj: HashMap::new(),
126 edges: Vec::new(),
127 }
128 }
129
130 pub fn add_node(&mut self, node: KgNode) {
132 self.adj.entry(node.id.clone()).or_default();
133 self.nodes.insert(node.id.clone(), node);
134 }
135
136 pub fn add_edge(&mut self, edge: KgEdge) {
139 for id in [&edge.from_id, &edge.to_id] {
141 if !self.nodes.contains_key(id.as_str()) {
142 let n = KgNode {
143 id: id.to_string(),
144 label: id.to_string(),
145 node_type: "Unknown".to_string(),
146 importance: 0.0,
147 };
148 self.nodes.insert(id.to_string(), n);
149 self.adj.entry(id.to_string()).or_default();
150 }
151 }
152
153 let idx = self.edges.len();
154 self.adj.entry(edge.from_id.clone()).or_default().push(idx);
155 self.edges.push(edge);
156 }
157
158 pub fn neighbors<'a>(&'a self, node_id: &str) -> Vec<(&'a KgEdge, &'a KgNode)> {
160 let Some(edge_indices) = self.adj.get(node_id) else {
161 return Vec::new();
162 };
163 edge_indices
164 .iter()
165 .filter_map(|&idx| {
166 let edge = &self.edges[idx];
167 let node = self.nodes.get(&edge.to_id)?;
168 Some((edge, node))
169 })
170 .collect()
171 }
172
173 pub fn node_count(&self) -> usize {
175 self.nodes.len()
176 }
177
178 pub fn edge_count(&self) -> usize {
180 self.edges.len()
181 }
182
183 pub fn get_node(&self, id: &str) -> Option<&KgNode> {
185 self.nodes.get(id)
186 }
187}
188
189impl Default for KnowledgeGraph {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195pub struct PathRanker;
201
202impl PathRanker {
203 pub fn find_paths(
212 graph: &KnowledgeGraph,
213 start: &str,
214 end: &str,
215 max_hops: usize,
216 ) -> Vec<KgPath> {
217 if max_hops == 0 {
218 return Vec::new();
219 }
220
221 let mut results: Vec<KgPath> = Vec::new();
222
223 type StackItem = (String, Vec<String>, Vec<String>, f64, HashSet<String>);
225
226 let mut stack: Vec<StackItem> = Vec::new();
227 let mut initial_visited = HashSet::new();
228 initial_visited.insert(start.to_string());
229 stack.push((
230 start.to_string(),
231 vec![start.to_string()],
232 Vec::new(),
233 0.0,
234 initial_visited,
235 ));
236
237 while let Some((current, nodes, edges_so_far, weight, visited)) = stack.pop() {
238 if current == end && !edges_so_far.is_empty() {
239 results.push(KgPath {
240 hop_count: edges_so_far.len(),
241 nodes: nodes.clone(),
242 edges: edges_so_far.clone(),
243 total_weight: weight,
244 });
245 }
247
248 if edges_so_far.len() >= max_hops {
249 continue;
250 }
251
252 for (edge, neighbor) in graph.neighbors(¤t) {
253 if visited.contains(&neighbor.id) {
254 continue; }
256
257 let mut new_nodes = nodes.clone();
258 new_nodes.push(neighbor.id.clone());
259
260 let mut new_edges = edges_so_far.clone();
261 new_edges.push(edge.relation.clone());
262
263 let mut new_visited = visited.clone();
264 new_visited.insert(neighbor.id.clone());
265
266 stack.push((
267 neighbor.id.clone(),
268 new_nodes,
269 new_edges,
270 weight + edge.weight,
271 new_visited,
272 ));
273 }
274 }
275
276 results
277 }
278
279 pub fn score_path(graph: &KnowledgeGraph, path: &KgPath, config: &PathRankingConfig) -> f64 {
290 let weight_score = path.total_weight * config.weight_factor;
292
293 let hop_multiplier = config.hop_penalty.powi(path.hop_count as i32);
295
296 let confidence_product: f64 = Self::edge_confidences_product(graph, path);
298 let confidence_score = confidence_product * config.confidence_factor;
299
300 let importance_sum: f64 = path
302 .nodes
303 .iter()
304 .filter_map(|id| graph.get_node(id))
305 .map(|n| n.importance)
306 .sum();
307 let importance_score = importance_sum * config.importance_bonus;
308
309 (weight_score * hop_multiplier * confidence_score) + importance_score
310 }
311
312 pub fn rank_paths(
314 graph: &KnowledgeGraph,
315 paths: Vec<KgPath>,
316 config: &PathRankingConfig,
317 ) -> Vec<(KgPath, f64)> {
318 let mut scored: Vec<(KgPath, f64)> = paths
319 .into_iter()
320 .map(|p| {
321 let s = Self::score_path(graph, &p, config);
322 (p, s)
323 })
324 .collect();
325
326 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
327 scored
328 }
329
330 pub fn shortest_path_dijkstra(
339 graph: &KnowledgeGraph,
340 start: &str,
341 end: &str,
342 ) -> Option<KgPath> {
343 if !graph.nodes.contains_key(start) || !graph.nodes.contains_key(end) {
344 return None;
345 }
346 if start == end {
347 return Some(KgPath {
348 nodes: vec![start.to_string()],
349 edges: Vec::new(),
350 total_weight: 0.0,
351 hop_count: 0,
352 });
353 }
354
355 #[derive(PartialEq)]
358 struct HeapItem {
359 neg_dist: f64,
360 node: String,
361 nodes: Vec<String>,
362 edges: Vec<String>,
363 acc_weight: f64,
364 }
365
366 impl Eq for HeapItem {}
367
368 impl PartialOrd for HeapItem {
369 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
370 Some(self.cmp(other))
371 }
372 }
373
374 impl Ord for HeapItem {
375 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
376 self.neg_dist
377 .partial_cmp(&other.neg_dist)
378 .unwrap_or(std::cmp::Ordering::Equal)
379 }
380 }
381
382 let mut dist_map: HashMap<String, f64> = HashMap::new();
383 dist_map.insert(start.to_string(), 0.0);
384
385 let mut heap = BinaryHeap::new();
386 heap.push(HeapItem {
387 neg_dist: 0.0,
388 node: start.to_string(),
389 nodes: vec![start.to_string()],
390 edges: Vec::new(),
391 acc_weight: 0.0,
392 });
393
394 while let Some(item) = heap.pop() {
395 let current_dist = -item.neg_dist;
396
397 if item.node == end {
398 return Some(KgPath {
399 hop_count: item.edges.len(),
400 nodes: item.nodes,
401 edges: item.edges,
402 total_weight: item.acc_weight,
403 });
404 }
405
406 if let Some(&best) = dist_map.get(&item.node) {
408 if current_dist > best + 1e-12 {
409 continue;
410 }
411 }
412
413 for (edge, neighbor) in graph.neighbors(&item.node) {
414 if item.nodes.contains(&neighbor.id) {
416 continue;
417 }
418
419 let step_dist = if edge.weight > 0.0 {
420 1.0 / edge.weight
421 } else {
422 f64::INFINITY
423 };
424 let new_dist = current_dist + step_dist;
425
426 let best = dist_map.entry(neighbor.id.clone()).or_insert(f64::INFINITY);
427 if new_dist < *best - 1e-12 {
428 *best = new_dist;
429
430 let mut new_nodes = item.nodes.clone();
431 new_nodes.push(neighbor.id.clone());
432 let mut new_edges = item.edges.clone();
433 new_edges.push(edge.relation.clone());
434
435 heap.push(HeapItem {
436 neg_dist: -new_dist,
437 node: neighbor.id.clone(),
438 nodes: new_nodes,
439 edges: new_edges,
440 acc_weight: item.acc_weight + edge.weight,
441 });
442 }
443 }
444 }
445
446 None }
448
449 pub fn most_relevant_paths(
455 graph: &KnowledgeGraph,
456 start: &str,
457 end: &str,
458 max_hops: usize,
459 top_k: usize,
460 config: &PathRankingConfig,
461 ) -> Vec<(KgPath, f64)> {
462 let paths = Self::find_paths(graph, start, end, max_hops);
463 let mut ranked = Self::rank_paths(graph, paths, config);
464 ranked.truncate(top_k);
465 ranked
466 }
467
468 fn edge_confidences_product(graph: &KnowledgeGraph, path: &KgPath) -> f64 {
477 if path.edges.is_empty() {
478 return 1.0;
479 }
480 let mut product = 1.0;
481 for (i, relation) in path.edges.iter().enumerate() {
482 let from = &path.nodes[i];
483 let confidence = graph
484 .neighbors(from)
485 .into_iter()
486 .find(|(e, _)| &e.relation == relation)
487 .map(|(e, _)| e.confidence)
488 .unwrap_or(1.0);
489 product *= confidence;
490 }
491 product
492 }
493}
494
495#[cfg(test)]
500mod tests {
501 use super::*;
502
503 fn triangle_graph() -> KnowledgeGraph {
505 let mut g = KnowledgeGraph::new();
506 g.add_node(KgNode {
507 id: "A".into(),
508 label: "Alpha".into(),
509 node_type: "Entity".into(),
510 importance: 1.0,
511 });
512 g.add_node(KgNode {
513 id: "B".into(),
514 label: "Beta".into(),
515 node_type: "Entity".into(),
516 importance: 0.5,
517 });
518 g.add_node(KgNode {
519 id: "C".into(),
520 label: "Gamma".into(),
521 node_type: "Entity".into(),
522 importance: 0.8,
523 });
524 g.add_edge(KgEdge {
525 from_id: "A".into(),
526 to_id: "B".into(),
527 relation: "knows".into(),
528 weight: 1.0,
529 confidence: 0.9,
530 });
531 g.add_edge(KgEdge {
532 from_id: "B".into(),
533 to_id: "C".into(),
534 relation: "related".into(),
535 weight: 2.0,
536 confidence: 0.8,
537 });
538 g.add_edge(KgEdge {
539 from_id: "A".into(),
540 to_id: "C".into(),
541 relation: "direct".into(),
542 weight: 0.5,
543 confidence: 0.95,
544 });
545 g
546 }
547
548 #[test]
551 fn test_graph_node_count() {
552 let g = triangle_graph();
553 assert_eq!(g.node_count(), 3);
554 }
555
556 #[test]
557 fn test_graph_edge_count() {
558 let g = triangle_graph();
559 assert_eq!(g.edge_count(), 3);
560 }
561
562 #[test]
563 fn test_graph_neighbors() {
564 let g = triangle_graph();
565 let nb = g.neighbors("A");
566 assert_eq!(nb.len(), 2);
567 }
568
569 #[test]
570 fn test_graph_add_node_idempotent() {
571 let mut g = KnowledgeGraph::new();
572 g.add_node(KgNode {
573 id: "X".into(),
574 label: "X".into(),
575 node_type: "T".into(),
576 importance: 1.0,
577 });
578 g.add_node(KgNode {
579 id: "X".into(),
580 label: "X2".into(),
581 node_type: "T2".into(),
582 importance: 2.0,
583 });
584 assert_eq!(g.node_count(), 1);
585 }
586
587 #[test]
588 fn test_auto_create_nodes_on_add_edge() {
589 let mut g = KnowledgeGraph::new();
590 g.add_edge(KgEdge {
591 from_id: "P".into(),
592 to_id: "Q".into(),
593 relation: "r".into(),
594 weight: 1.0,
595 confidence: 1.0,
596 });
597 assert_eq!(g.node_count(), 2);
598 }
599
600 #[test]
601 fn test_get_node() {
602 let g = triangle_graph();
603 let n = g.get_node("B");
604 assert!(n.is_some());
605 assert_eq!(n.expect("should succeed").label, "Beta");
606 }
607
608 #[test]
609 fn test_get_node_missing() {
610 let g = triangle_graph();
611 assert!(g.get_node("Z").is_none());
612 }
613
614 #[test]
617 fn test_find_paths_direct() {
618 let g = triangle_graph();
619 let paths = PathRanker::find_paths(&g, "A", "C", 1);
620 assert_eq!(paths.len(), 1);
621 assert_eq!(paths[0].edges[0], "direct");
622 }
623
624 #[test]
625 fn test_find_paths_two_hops() {
626 let g = triangle_graph();
627 let paths = PathRanker::find_paths(&g, "A", "C", 2);
628 assert_eq!(paths.len(), 2);
630 }
631
632 #[test]
633 fn test_find_paths_no_path() {
634 let g = triangle_graph();
635 let paths = PathRanker::find_paths(&g, "C", "A", 5);
636 assert!(paths.is_empty());
637 }
638
639 #[test]
640 fn test_find_paths_zero_max_hops() {
641 let g = triangle_graph();
642 let paths = PathRanker::find_paths(&g, "A", "C", 0);
643 assert!(paths.is_empty());
644 }
645
646 #[test]
647 fn test_find_paths_cycle_avoidance() {
648 let mut g = KnowledgeGraph::new();
649 g.add_edge(KgEdge {
651 from_id: "A".into(),
652 to_id: "B".into(),
653 relation: "ab".into(),
654 weight: 1.0,
655 confidence: 1.0,
656 });
657 g.add_edge(KgEdge {
658 from_id: "B".into(),
659 to_id: "A".into(),
660 relation: "ba".into(),
661 weight: 1.0,
662 confidence: 1.0,
663 });
664 let paths = PathRanker::find_paths(&g, "A", "C", 10);
666 assert!(paths.is_empty());
667 }
668
669 #[test]
670 fn test_find_paths_self_loop_ignored() {
671 let mut g = KnowledgeGraph::new();
672 g.add_node(KgNode {
673 id: "A".into(),
674 label: "A".into(),
675 node_type: "E".into(),
676 importance: 1.0,
677 });
678 g.add_node(KgNode {
679 id: "B".into(),
680 label: "B".into(),
681 node_type: "E".into(),
682 importance: 1.0,
683 });
684 g.add_edge(KgEdge {
685 from_id: "A".into(),
686 to_id: "A".into(),
687 relation: "self".into(),
688 weight: 1.0,
689 confidence: 1.0,
690 });
691 g.add_edge(KgEdge {
692 from_id: "A".into(),
693 to_id: "B".into(),
694 relation: "ab".into(),
695 weight: 1.0,
696 confidence: 1.0,
697 });
698 let paths = PathRanker::find_paths(&g, "A", "B", 2);
699 assert_eq!(paths.len(), 1);
700 }
701
702 #[test]
703 fn test_find_paths_hop_count_correct() {
704 let g = triangle_graph();
705 let paths = PathRanker::find_paths(&g, "A", "C", 2);
706 let hops: Vec<usize> = paths.iter().map(|p| p.hop_count).collect();
707 assert!(hops.contains(&1));
708 assert!(hops.contains(&2));
709 }
710
711 #[test]
712 fn test_endpoint_pair() {
713 let path = KgPath {
714 nodes: vec!["A".into(), "B".into(), "C".into()],
715 edges: vec!["r1".into(), "r2".into()],
716 total_weight: 3.0,
717 hop_count: 2,
718 };
719 let (s, e) = path.endpoint_pair();
720 assert_eq!(s, "A");
721 assert_eq!(e, "C");
722 }
723
724 #[test]
727 fn test_score_path_direct_higher_with_low_hop_penalty() {
728 let mut g = KnowledgeGraph::new();
732 g.add_node(KgNode {
733 id: "A".into(),
734 label: "A".into(),
735 node_type: "E".into(),
736 importance: 0.0,
737 });
738 g.add_node(KgNode {
739 id: "B".into(),
740 label: "B".into(),
741 node_type: "E".into(),
742 importance: 0.0,
743 });
744 g.add_node(KgNode {
745 id: "C".into(),
746 label: "C".into(),
747 node_type: "E".into(),
748 importance: 0.0,
749 });
750 g.add_edge(KgEdge {
751 from_id: "A".into(),
752 to_id: "B".into(),
753 relation: "ab".into(),
754 weight: 1.0,
755 confidence: 0.9,
756 });
757 g.add_edge(KgEdge {
758 from_id: "B".into(),
759 to_id: "C".into(),
760 relation: "bc".into(),
761 weight: 1.0,
762 confidence: 0.9,
763 });
764 g.add_edge(KgEdge {
765 from_id: "A".into(),
766 to_id: "C".into(),
767 relation: "direct".into(),
768 weight: 10.0,
769 confidence: 0.99,
770 });
771
772 let config = PathRankingConfig {
773 weight_factor: 1.0,
774 hop_penalty: 0.5, confidence_factor: 1.0,
776 importance_bonus: 0.0,
777 };
778 let paths = PathRanker::find_paths(&g, "A", "C", 2);
779 assert_eq!(paths.len(), 2, "expected 2 paths");
780 let scores: Vec<f64> = paths
781 .iter()
782 .map(|p| PathRanker::score_path(&g, p, &config))
783 .collect();
784 let (direct, two_hop) = if paths[0].hop_count == 1 {
787 (scores[0], scores[1])
788 } else {
789 (scores[1], scores[0])
790 };
791 assert!(direct > two_hop, "direct={direct}, two_hop={two_hop}");
792 }
793
794 #[test]
795 fn test_rank_paths_sorted_descending() {
796 let g = triangle_graph();
797 let paths = PathRanker::find_paths(&g, "A", "C", 2);
798 let config = PathRankingConfig::default();
799 let ranked = PathRanker::rank_paths(&g, paths, &config);
800 assert!(ranked.len() <= 2);
801 if ranked.len() == 2 {
802 assert!(ranked[0].1 >= ranked[1].1);
803 }
804 }
805
806 #[test]
807 fn test_rank_paths_empty() {
808 let g = triangle_graph();
809 let ranked = PathRanker::rank_paths(&g, vec![], &PathRankingConfig::default());
810 assert!(ranked.is_empty());
811 }
812
813 #[test]
816 fn test_dijkstra_direct_path() {
817 let g = triangle_graph();
818 let path = PathRanker::shortest_path_dijkstra(&g, "A", "B");
819 assert!(path.is_some());
820 let p = path.expect("should succeed");
821 assert_eq!(p.nodes, vec!["A", "B"]);
822 }
823
824 #[test]
825 fn test_dijkstra_same_node() {
826 let g = triangle_graph();
827 let path = PathRanker::shortest_path_dijkstra(&g, "A", "A");
828 assert!(path.is_some());
829 let p = path.expect("should succeed");
830 assert_eq!(p.nodes, vec!["A"]);
831 assert_eq!(p.hop_count, 0);
832 }
833
834 #[test]
835 fn test_dijkstra_no_path() {
836 let g = triangle_graph();
837 let path = PathRanker::shortest_path_dijkstra(&g, "C", "A");
838 assert!(path.is_none());
839 }
840
841 #[test]
842 fn test_dijkstra_missing_node() {
843 let g = triangle_graph();
844 let path = PathRanker::shortest_path_dijkstra(&g, "A", "Z");
845 assert!(path.is_none());
846 }
847
848 #[test]
849 fn test_dijkstra_prefers_high_weight_edge() {
850 let mut g = KnowledgeGraph::new();
854 g.add_edge(KgEdge {
855 from_id: "A".into(),
856 to_id: "B".into(),
857 relation: "r1".into(),
858 weight: 1.0,
859 confidence: 1.0,
860 });
861 g.add_edge(KgEdge {
862 from_id: "B".into(),
863 to_id: "C".into(),
864 relation: "r2".into(),
865 weight: 10.0,
866 confidence: 1.0,
867 });
868 g.add_edge(KgEdge {
869 from_id: "A".into(),
870 to_id: "C".into(),
871 relation: "direct".into(),
872 weight: 2.0,
873 confidence: 1.0,
874 });
875 let path = PathRanker::shortest_path_dijkstra(&g, "A", "C").expect("should succeed");
876 assert_eq!(path.hop_count, 1);
880 assert_eq!(path.edges[0], "direct");
881 }
882
883 #[test]
886 fn test_most_relevant_paths_top_k() {
887 let g = triangle_graph();
888 let config = PathRankingConfig::default();
889 let results = PathRanker::most_relevant_paths(&g, "A", "C", 3, 1, &config);
890 assert_eq!(results.len(), 1);
891 }
892
893 #[test]
894 fn test_most_relevant_paths_no_path() {
895 let g = triangle_graph();
896 let config = PathRankingConfig::default();
897 let results = PathRanker::most_relevant_paths(&g, "C", "A", 5, 10, &config);
898 assert!(results.is_empty());
899 }
900
901 #[test]
902 fn test_most_relevant_paths_all_returned_when_k_large() {
903 let g = triangle_graph();
904 let config = PathRankingConfig::default();
905 let results = PathRanker::most_relevant_paths(&g, "A", "C", 2, 100, &config);
906 assert_eq!(results.len(), 2);
907 }
908
909 #[test]
912 fn test_disconnected_graph() {
913 let mut g = KnowledgeGraph::new();
914 g.add_node(KgNode {
915 id: "X".into(),
916 label: "X".into(),
917 node_type: "E".into(),
918 importance: 1.0,
919 });
920 g.add_node(KgNode {
921 id: "Y".into(),
922 label: "Y".into(),
923 node_type: "E".into(),
924 importance: 1.0,
925 });
926 let paths = PathRanker::find_paths(&g, "X", "Y", 5);
928 assert!(paths.is_empty());
929 let shortest = PathRanker::shortest_path_dijkstra(&g, "X", "Y");
930 assert!(shortest.is_none());
931 }
932
933 #[test]
936 fn test_score_increases_with_importance_bonus() {
937 let g = triangle_graph();
938 let config_low = PathRankingConfig {
939 importance_bonus: 0.0,
940 ..Default::default()
941 };
942 let config_high = PathRankingConfig {
943 importance_bonus: 10.0,
944 ..Default::default()
945 };
946 let paths = PathRanker::find_paths(&g, "A", "C", 1);
947 assert!(!paths.is_empty());
948 let s_low = PathRanker::score_path(&g, &paths[0], &config_low);
949 let s_high = PathRanker::score_path(&g, &paths[0], &config_high);
950 assert!(s_high > s_low);
951 }
952
953 #[test]
954 fn test_score_path_single_node_path() {
955 let g = triangle_graph();
956 let single = KgPath {
957 nodes: vec!["A".into()],
958 edges: Vec::new(),
959 total_weight: 0.0,
960 hop_count: 0,
961 };
962 let config = PathRankingConfig::default();
963 let s = PathRanker::score_path(&g, &single, &config);
964 assert!((s - 0.1).abs() < 1e-9, "score={s}");
966 }
967}