1use crate::base::{EdgeWeight, Graph, IndexType, Node};
18use crate::error::{GraphError, Result};
19use std::collections::{HashMap, HashSet};
20use std::hash::Hash;
21
22#[derive(Debug, Clone)]
24pub struct LinkScore<N: Node> {
25 pub node_a: N,
27 pub node_b: N,
29 pub score: f64,
31}
32
33#[derive(Debug, Clone)]
35pub struct LinkPredictionEval {
36 pub auc: f64,
38 pub average_precision: f64,
40 pub true_positives: usize,
42 pub false_positives: usize,
44 pub total_positives: usize,
46 pub total_negatives: usize,
48}
49
50#[derive(Debug, Clone)]
52pub struct LinkPredictionConfig {
53 pub max_predictions: usize,
55 pub min_score: f64,
57 pub include_self_loops: bool,
59}
60
61impl Default for LinkPredictionConfig {
62 fn default() -> Self {
63 Self {
64 max_predictions: 100,
65 min_score: 0.0,
66 include_self_loops: false,
67 }
68 }
69}
70
71pub fn common_neighbors_score<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<f64>
82where
83 N: Node + Clone + Hash + Eq + std::fmt::Debug,
84 E: EdgeWeight,
85 Ix: IndexType,
86{
87 validate_nodes(graph, u, v)?;
88
89 let neighbors_u: HashSet<N> = graph.neighbors(u)?.into_iter().collect();
90 let neighbors_v: HashSet<N> = graph.neighbors(v)?.into_iter().collect();
91
92 let common = neighbors_u.intersection(&neighbors_v).count();
93 Ok(common as f64)
94}
95
96pub fn common_neighbors_all<N, E, Ix>(
98 graph: &Graph<N, E, Ix>,
99 config: &LinkPredictionConfig,
100) -> Vec<LinkScore<N>>
101where
102 N: Node + Clone + Hash + Eq + std::fmt::Debug,
103 E: EdgeWeight,
104 Ix: IndexType,
105{
106 compute_all_scores(graph, config, |g, u, v| {
107 common_neighbors_score(g, u, v).unwrap_or(0.0)
108 })
109}
110
111pub fn jaccard_coefficient<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<f64>
121where
122 N: Node + Clone + Hash + Eq + std::fmt::Debug,
123 E: EdgeWeight,
124 Ix: IndexType,
125{
126 validate_nodes(graph, u, v)?;
127
128 let neighbors_u: HashSet<N> = graph.neighbors(u)?.into_iter().collect();
129 let neighbors_v: HashSet<N> = graph.neighbors(v)?.into_iter().collect();
130
131 let intersection = neighbors_u.intersection(&neighbors_v).count();
132 let union = neighbors_u.union(&neighbors_v).count();
133
134 if union == 0 {
135 Ok(0.0)
136 } else {
137 Ok(intersection as f64 / union as f64)
138 }
139}
140
141pub fn jaccard_coefficient_all<N, E, Ix>(
143 graph: &Graph<N, E, Ix>,
144 config: &LinkPredictionConfig,
145) -> Vec<LinkScore<N>>
146where
147 N: Node + Clone + Hash + Eq + std::fmt::Debug,
148 E: EdgeWeight,
149 Ix: IndexType,
150{
151 compute_all_scores(graph, config, |g, u, v| {
152 jaccard_coefficient(g, u, v).unwrap_or(0.0)
153 })
154}
155
156pub fn adamic_adar_index<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<f64>
167where
168 N: Node + Clone + Hash + Eq + std::fmt::Debug,
169 E: EdgeWeight,
170 Ix: IndexType,
171{
172 validate_nodes(graph, u, v)?;
173
174 let neighbors_u: HashSet<N> = graph.neighbors(u)?.into_iter().collect();
175 let neighbors_v: HashSet<N> = graph.neighbors(v)?.into_iter().collect();
176
177 let mut score = 0.0;
178 for common in neighbors_u.intersection(&neighbors_v) {
179 let degree = graph.degree(common);
180 if degree > 1 {
181 score += 1.0 / (degree as f64).ln();
182 }
183 }
184
185 Ok(score)
186}
187
188pub fn adamic_adar_all<N, E, Ix>(
190 graph: &Graph<N, E, Ix>,
191 config: &LinkPredictionConfig,
192) -> Vec<LinkScore<N>>
193where
194 N: Node + Clone + Hash + Eq + std::fmt::Debug,
195 E: EdgeWeight,
196 Ix: IndexType,
197{
198 compute_all_scores(graph, config, |g, u, v| {
199 adamic_adar_index(g, u, v).unwrap_or(0.0)
200 })
201}
202
203pub fn preferential_attachment<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<f64>
214where
215 N: Node + Clone + Hash + Eq + std::fmt::Debug,
216 E: EdgeWeight,
217 Ix: IndexType,
218{
219 validate_nodes(graph, u, v)?;
220
221 let deg_u = graph.degree(u);
222 let deg_v = graph.degree(v);
223
224 Ok((deg_u * deg_v) as f64)
225}
226
227pub fn preferential_attachment_all<N, E, Ix>(
229 graph: &Graph<N, E, Ix>,
230 config: &LinkPredictionConfig,
231) -> Vec<LinkScore<N>>
232where
233 N: Node + Clone + Hash + Eq + std::fmt::Debug,
234 E: EdgeWeight,
235 Ix: IndexType,
236{
237 compute_all_scores(graph, config, |g, u, v| {
238 preferential_attachment(g, u, v).unwrap_or(0.0)
239 })
240}
241
242pub fn resource_allocation_index<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<f64>
253where
254 N: Node + Clone + Hash + Eq + std::fmt::Debug,
255 E: EdgeWeight,
256 Ix: IndexType,
257{
258 validate_nodes(graph, u, v)?;
259
260 let neighbors_u: HashSet<N> = graph.neighbors(u)?.into_iter().collect();
261 let neighbors_v: HashSet<N> = graph.neighbors(v)?.into_iter().collect();
262
263 let mut score = 0.0;
264 for common in neighbors_u.intersection(&neighbors_v) {
265 let degree = graph.degree(common);
266 if degree > 0 {
267 score += 1.0 / degree as f64;
268 }
269 }
270
271 Ok(score)
272}
273
274pub fn resource_allocation_all<N, E, Ix>(
276 graph: &Graph<N, E, Ix>,
277 config: &LinkPredictionConfig,
278) -> Vec<LinkScore<N>>
279where
280 N: Node + Clone + Hash + Eq + std::fmt::Debug,
281 E: EdgeWeight,
282 Ix: IndexType,
283{
284 compute_all_scores(graph, config, |g, u, v| {
285 resource_allocation_index(g, u, v).unwrap_or(0.0)
286 })
287}
288
289pub fn katz_similarity<N, E, Ix>(
306 graph: &Graph<N, E, Ix>,
307 u: &N,
308 v: &N,
309 beta: f64,
310 max_path_length: usize,
311) -> Result<f64>
312where
313 N: Node + Clone + Hash + Eq + std::fmt::Debug,
314 E: EdgeWeight,
315 Ix: IndexType,
316{
317 validate_nodes(graph, u, v)?;
318
319 if beta <= 0.0 || beta >= 1.0 {
320 return Err(GraphError::InvalidGraph(
321 "Beta must be in (0, 1)".to_string(),
322 ));
323 }
324
325 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
326 let n = nodes.len();
327 let node_to_idx: HashMap<N, usize> = nodes
328 .iter()
329 .enumerate()
330 .map(|(i, n)| (n.clone(), i))
331 .collect();
332
333 let u_idx = node_to_idx
334 .get(u)
335 .ok_or_else(|| GraphError::node_not_found(format!("{u:?}")))?;
336 let v_idx = node_to_idx
337 .get(v)
338 .ok_or_else(|| GraphError::node_not_found(format!("{v:?}")))?;
339
340 let mut adj: Vec<Vec<usize>> = vec![vec![]; n];
342 for (i, node) in nodes.iter().enumerate() {
343 if let Ok(neighbors) = graph.neighbors(node) {
344 for neighbor in &neighbors {
345 if let Some(&j) = node_to_idx.get(neighbor) {
346 adj[i].push(j);
347 }
348 }
349 }
350 }
351
352 let mut score = 0.0;
355 let mut current = vec![0.0f64; n];
356 current[*u_idx] = 1.0;
357
358 for l in 1..=max_path_length {
359 let mut next = vec![0.0f64; n];
360 for (i, &count) in current.iter().enumerate() {
361 if count > 0.0 {
362 for &j in &adj[i] {
363 next[j] += count;
364 }
365 }
366 }
367
368 let beta_l = beta.powi(l as i32);
369 score += beta_l * next[*v_idx];
370 current = next;
371 }
372
373 Ok(score)
374}
375
376pub fn katz_similarity_all<N, E, Ix>(
378 graph: &Graph<N, E, Ix>,
379 beta: f64,
380 max_path_length: usize,
381 config: &LinkPredictionConfig,
382) -> Vec<LinkScore<N>>
383where
384 N: Node + Clone + Hash + Eq + std::fmt::Debug,
385 E: EdgeWeight,
386 Ix: IndexType,
387{
388 compute_all_scores(graph, config, |g, u, v| {
389 katz_similarity(g, u, v, beta, max_path_length).unwrap_or(0.0)
390 })
391}
392
393pub fn simrank<N, E, Ix>(
409 graph: &Graph<N, E, Ix>,
410 decay: f64,
411 max_iterations: usize,
412 tolerance: f64,
413) -> Result<HashMap<(N, N), f64>>
414where
415 N: Node + Clone + Hash + Eq + std::fmt::Debug,
416 E: EdgeWeight,
417 Ix: IndexType,
418{
419 if decay <= 0.0 || decay > 1.0 {
420 return Err(GraphError::InvalidGraph(
421 "Decay must be in (0, 1]".to_string(),
422 ));
423 }
424
425 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
426 let n = nodes.len();
427 let node_to_idx: HashMap<N, usize> = nodes
428 .iter()
429 .enumerate()
430 .map(|(i, n)| (n.clone(), i))
431 .collect();
432
433 let mut adj: Vec<Vec<usize>> = vec![vec![]; n];
435 for (i, node) in nodes.iter().enumerate() {
436 if let Ok(neighbors) = graph.neighbors(node) {
437 for neighbor in &neighbors {
438 if let Some(&j) = node_to_idx.get(neighbor) {
439 adj[i].push(j);
440 }
441 }
442 }
443 }
444
445 let mut sim = vec![vec![0.0f64; n]; n];
447 for i in 0..n {
448 sim[i][i] = 1.0;
449 }
450
451 for _ in 0..max_iterations {
453 let mut new_sim = vec![vec![0.0f64; n]; n];
454 let mut max_diff = 0.0f64;
455
456 for i in 0..n {
457 new_sim[i][i] = 1.0;
458 for j in (i + 1)..n {
459 let deg_i = adj[i].len();
460 let deg_j = adj[j].len();
461
462 if deg_i == 0 || deg_j == 0 {
463 new_sim[i][j] = 0.0;
464 new_sim[j][i] = 0.0;
465 continue;
466 }
467
468 let mut sum = 0.0;
469 for &ni in &adj[i] {
470 for &nj in &adj[j] {
471 sum += sim[ni][nj];
472 }
473 }
474
475 let new_val = decay * sum / (deg_i * deg_j) as f64;
476 new_sim[i][j] = new_val;
477 new_sim[j][i] = new_val;
478
479 let diff = (new_val - sim[i][j]).abs();
480 if diff > max_diff {
481 max_diff = diff;
482 }
483 }
484 }
485
486 sim = new_sim;
487
488 if max_diff < tolerance {
489 break;
490 }
491 }
492
493 let mut result = HashMap::new();
495 for i in 0..n {
496 for j in i..n {
497 result.insert((nodes[i].clone(), nodes[j].clone()), sim[i][j]);
498 if i != j {
499 result.insert((nodes[j].clone(), nodes[i].clone()), sim[i][j]);
500 }
501 }
502 }
503
504 Ok(result)
505}
506
507pub fn simrank_score<N, E, Ix>(
509 graph: &Graph<N, E, Ix>,
510 u: &N,
511 v: &N,
512 decay: f64,
513 max_iterations: usize,
514) -> Result<f64>
515where
516 N: Node + Clone + Hash + Eq + std::fmt::Debug,
517 E: EdgeWeight,
518 Ix: IndexType,
519{
520 let all_scores = simrank(graph, decay, max_iterations, 1e-6)?;
521 all_scores
522 .get(&(u.clone(), v.clone()))
523 .copied()
524 .ok_or_else(|| GraphError::node_not_found(format!("{u:?}")))
525}
526
527pub fn evaluate_link_prediction<N>(
541 scores: &[LinkScore<N>],
542 positive_edges: &HashSet<(N, N)>,
543 negative_edges: &HashSet<(N, N)>,
544) -> LinkPredictionEval
545where
546 N: Node + Clone + Hash + Eq + std::fmt::Debug,
547{
548 if positive_edges.is_empty() || negative_edges.is_empty() {
549 return LinkPredictionEval {
550 auc: 0.5,
551 average_precision: 0.0,
552 true_positives: 0,
553 false_positives: 0,
554 total_positives: positive_edges.len(),
555 total_negatives: negative_edges.len(),
556 };
557 }
558
559 let mut scored_labels: Vec<(f64, bool)> = Vec::new();
561
562 for score in scores {
563 let pair = (score.node_a.clone(), score.node_b.clone());
564 let reverse_pair = (score.node_b.clone(), score.node_a.clone());
565
566 let is_positive = positive_edges.contains(&pair) || positive_edges.contains(&reverse_pair);
567 let is_negative = negative_edges.contains(&pair) || negative_edges.contains(&reverse_pair);
568
569 if is_positive || is_negative {
570 scored_labels.push((score.score, is_positive));
571 }
572 }
573
574 scored_labels.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
576
577 let total_positives = scored_labels.iter().filter(|(_, label)| *label).count();
579 let total_negatives = scored_labels.iter().filter(|(_, label)| !*label).count();
580
581 if total_positives == 0 || total_negatives == 0 {
582 return LinkPredictionEval {
583 auc: 0.5,
584 average_precision: 0.0,
585 true_positives: 0,
586 false_positives: 0,
587 total_positives,
588 total_negatives,
589 };
590 }
591
592 let mut auc = 0.0;
593 let mut tp = 0usize;
594 let mut fp = 0usize;
595 let mut prev_fpr = 0.0;
596 let mut prev_tpr = 0.0;
597
598 let mut ap = 0.0;
600 let mut running_tp = 0;
601
602 for (i, &(_, is_positive)) in scored_labels.iter().enumerate() {
603 if is_positive {
604 tp += 1;
605 running_tp += 1;
606 ap += running_tp as f64 / (i + 1) as f64;
607 } else {
608 fp += 1;
609 }
610
611 let tpr = tp as f64 / total_positives as f64;
612 let fpr = fp as f64 / total_negatives as f64;
613
614 auc += (fpr - prev_fpr) * (tpr + prev_tpr) / 2.0;
616 prev_fpr = fpr;
617 prev_tpr = tpr;
618 }
619
620 auc += (1.0 - prev_fpr) * (1.0 + prev_tpr) / 2.0;
622
623 let average_precision = if total_positives > 0 {
624 ap / total_positives as f64
625 } else {
626 0.0
627 };
628
629 LinkPredictionEval {
630 auc,
631 average_precision,
632 true_positives: tp,
633 false_positives: fp,
634 total_positives,
635 total_negatives,
636 }
637}
638
639pub fn compute_auc<N, E, Ix, F>(
642 graph: &Graph<N, E, Ix>,
643 test_edges: &[(N, N)],
644 non_edges: &[(N, N)],
645 score_fn: F,
646) -> f64
647where
648 N: Node + Clone + Hash + Eq + std::fmt::Debug,
649 E: EdgeWeight,
650 Ix: IndexType,
651 F: Fn(&Graph<N, E, Ix>, &N, &N) -> f64,
652{
653 if test_edges.is_empty() || non_edges.is_empty() {
654 return 0.5;
655 }
656
657 let mut n_correct = 0usize;
658 let mut n_tie = 0usize;
659 let mut n_total = 0usize;
660
661 for (pu, pv) in test_edges {
662 let pos_score = score_fn(graph, pu, pv);
663 for (nu, nv) in non_edges {
664 let neg_score = score_fn(graph, nu, nv);
665 n_total += 1;
666 if pos_score > neg_score + 1e-12 {
667 n_correct += 1;
668 } else if (pos_score - neg_score).abs() <= 1e-12 {
669 n_tie += 1;
670 }
671 }
672 }
673
674 if n_total == 0 {
675 return 0.5;
676 }
677
678 (n_correct as f64 + 0.5 * n_tie as f64) / n_total as f64
679}
680
681fn validate_nodes<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<()>
686where
687 N: Node + std::fmt::Debug,
688 E: EdgeWeight,
689 Ix: IndexType,
690{
691 if !graph.has_node(u) {
692 return Err(GraphError::node_not_found(format!("{u:?}")));
693 }
694 if !graph.has_node(v) {
695 return Err(GraphError::node_not_found(format!("{v:?}")));
696 }
697 Ok(())
698}
699
700fn compute_all_scores<N, E, Ix, F>(
701 graph: &Graph<N, E, Ix>,
702 config: &LinkPredictionConfig,
703 score_fn: F,
704) -> Vec<LinkScore<N>>
705where
706 N: Node + Clone + Hash + Eq + std::fmt::Debug,
707 E: EdgeWeight,
708 Ix: IndexType,
709 F: Fn(&Graph<N, E, Ix>, &N, &N) -> f64,
710{
711 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
712 let mut scores = Vec::new();
713
714 for (i, u) in nodes.iter().enumerate() {
715 for v in nodes.iter().skip(i + 1) {
716 if !config.include_self_loops && u == v {
717 continue;
718 }
719 if graph.has_edge(u, v) {
721 continue;
722 }
723
724 let score = score_fn(graph, u, v);
725 if score >= config.min_score {
726 scores.push(LinkScore {
727 node_a: u.clone(),
728 node_b: v.clone(),
729 score,
730 });
731 }
732 }
733 }
734
735 scores.sort_by(|a, b| {
737 b.score
738 .partial_cmp(&a.score)
739 .unwrap_or(std::cmp::Ordering::Equal)
740 });
741
742 scores.truncate(config.max_predictions);
744 scores
745}
746
747#[cfg(test)]
752mod tests {
753 use super::*;
754 use crate::error::Result as GraphResult;
755 use crate::generators::create_graph;
756
757 fn build_test_graph() -> Graph<i32, ()> {
758 let mut g = create_graph::<i32, ()>();
759 let _ = g.add_edge(0, 1, ());
764 let _ = g.add_edge(1, 2, ());
765 let _ = g.add_edge(0, 3, ());
766 let _ = g.add_edge(1, 4, ());
767 let _ = g.add_edge(2, 5, ());
768 let _ = g.add_edge(3, 4, ());
769 let _ = g.add_edge(4, 5, ());
770 g
771 }
772
773 #[test]
774 fn test_common_neighbors() -> GraphResult<()> {
775 let g = build_test_graph();
776
777 let score = common_neighbors_score(&g, &0, &2)?;
779 assert!((score - 1.0).abs() < 1e-6);
780
781 let score = common_neighbors_score(&g, &0, &4)?;
783 assert!((score - 2.0).abs() < 1e-6);
784
785 let score = common_neighbors_score(&g, &0, &5)?;
787 assert!((score - 0.0).abs() < 1e-6);
788 Ok(())
789 }
790
791 #[test]
792 fn test_jaccard_coefficient() -> GraphResult<()> {
793 let g = build_test_graph();
794
795 let score = jaccard_coefficient(&g, &0, &4)?;
797 assert!(score > 0.0 && score <= 1.0);
798
799 let score = jaccard_coefficient(&g, &0, &0)?;
801 assert!((score - 1.0).abs() < 1e-6);
802 Ok(())
803 }
804
805 #[test]
806 fn test_adamic_adar() -> GraphResult<()> {
807 let g = build_test_graph();
808
809 let score = adamic_adar_index(&g, &0, &4)?;
810 assert!(score > 0.0);
811
812 let score = adamic_adar_index(&g, &0, &5)?;
814 assert!((score - 0.0).abs() < 1e-6);
815 Ok(())
816 }
817
818 #[test]
819 fn test_preferential_attachment() -> GraphResult<()> {
820 let g = build_test_graph();
821
822 let score = preferential_attachment(&g, &0, &4)?;
824 assert!((score - 6.0).abs() < 1e-6);
825
826 let score = preferential_attachment(&g, &1, &4)?;
828 assert!((score - 9.0).abs() < 1e-6);
829 Ok(())
830 }
831
832 #[test]
833 fn test_resource_allocation() -> GraphResult<()> {
834 let g = build_test_graph();
835
836 let score = resource_allocation_index(&g, &0, &4)?;
837 assert!(score > 0.0);
838
839 let score = resource_allocation_index(&g, &0, &5)?;
840 assert!((score - 0.0).abs() < 1e-6);
841 Ok(())
842 }
843
844 #[test]
845 fn test_katz_similarity() -> GraphResult<()> {
846 let g = build_test_graph();
847
848 let score = katz_similarity(&g, &0, &2, 0.05, 3)?;
849 assert!(score > 0.0);
850
851 let score_near = katz_similarity(&g, &0, &1, 0.05, 3)?;
853 let score_far = katz_similarity(&g, &0, &5, 0.05, 3)?;
854 assert!(score_near > score_far);
855 Ok(())
856 }
857
858 #[test]
859 fn test_katz_invalid_beta() {
860 let g = build_test_graph();
861 assert!(katz_similarity(&g, &0, &1, 0.0, 3).is_err());
862 assert!(katz_similarity(&g, &0, &1, 1.0, 3).is_err());
863 }
864
865 #[test]
866 fn test_simrank() -> GraphResult<()> {
867 let g = build_test_graph();
868
869 let scores = simrank(&g, 0.8, 10, 1e-4)?;
870
871 let self_score = scores.get(&(0, 0)).copied().unwrap_or(0.0);
873 assert!((self_score - 1.0).abs() < 1e-6);
874
875 for &score in scores.values() {
877 assert!(score >= -1e-6);
878 }
879 Ok(())
880 }
881
882 #[test]
883 fn test_simrank_score() -> GraphResult<()> {
884 let g = build_test_graph();
885 let score = simrank_score(&g, &0, &2, 0.8, 10)?;
886 assert!(score >= 0.0);
887 assert!(score <= 1.0);
888 Ok(())
889 }
890
891 #[test]
892 fn test_evaluate_link_prediction() {
893 let scores = vec![
894 LinkScore {
895 node_a: 0,
896 node_b: 1,
897 score: 0.9,
898 },
899 LinkScore {
900 node_a: 0,
901 node_b: 2,
902 score: 0.8,
903 },
904 LinkScore {
905 node_a: 0,
906 node_b: 3,
907 score: 0.3,
908 },
909 LinkScore {
910 node_a: 1,
911 node_b: 3,
912 score: 0.2,
913 },
914 ];
915
916 let mut positives = HashSet::new();
917 positives.insert((0, 1));
918 positives.insert((0, 2));
919
920 let mut negatives = HashSet::new();
921 negatives.insert((0, 3));
922 negatives.insert((1, 3));
923
924 let eval = evaluate_link_prediction(&scores, &positives, &negatives);
925 assert!(eval.auc >= 0.5); assert!(eval.true_positives > 0);
927 }
928
929 #[test]
930 fn test_compute_auc() -> GraphResult<()> {
931 let g = build_test_graph();
932
933 let test_edges = vec![(0, 4)]; let non_edges = vec![(0, 5)]; let auc = compute_auc(&g, &test_edges, &non_edges, |g, u, v| {
939 common_neighbors_score(g, u, v).unwrap_or(0.0)
940 });
941
942 assert!(auc >= 0.5); Ok(())
944 }
945
946 #[test]
947 fn test_common_neighbors_all() {
948 let g = build_test_graph();
949 let config = LinkPredictionConfig {
950 max_predictions: 10,
951 min_score: 0.0,
952 include_self_loops: false,
953 };
954
955 let scores = common_neighbors_all(&g, &config);
956 for score in &scores {
958 assert!(!g.has_edge(&score.node_a, &score.node_b));
959 }
960 for window in scores.windows(2) {
962 assert!(window[0].score >= window[1].score);
963 }
964 }
965
966 #[test]
967 fn test_invalid_nodes() {
968 let g = build_test_graph();
969 assert!(common_neighbors_score(&g, &0, &99).is_err());
970 assert!(jaccard_coefficient(&g, &99, &0).is_err());
971 assert!(adamic_adar_index(&g, &0, &99).is_err());
972 }
973
974 #[test]
975 fn test_empty_graph_link_prediction() -> GraphResult<()> {
976 let mut g = create_graph::<i32, ()>();
977 let _ = g.add_node(0);
978
979 let config = LinkPredictionConfig::default();
980 let scores = common_neighbors_all(&g, &config);
981 assert!(scores.is_empty());
982 Ok(())
983 }
984
985 #[test]
986 fn test_all_methods_consistency() -> GraphResult<()> {
987 let g = build_test_graph();
988
989 let cn = common_neighbors_score(&g, &0, &4)?;
991 let jc = jaccard_coefficient(&g, &0, &4)?;
992 let aa = adamic_adar_index(&g, &0, &4)?;
993 let pa = preferential_attachment(&g, &0, &4)?;
994 let ra = resource_allocation_index(&g, &0, &4)?;
995 let kz = katz_similarity(&g, &0, &4, 0.05, 3)?;
996
997 assert!(cn >= 0.0);
998 assert!(jc >= 0.0);
999 assert!(aa >= 0.0);
1000 assert!(pa >= 0.0);
1001 assert!(ra >= 0.0);
1002 assert!(kz >= 0.0);
1003
1004 assert!(cn > 0.0);
1006 assert!(jc > 0.0);
1007 assert!(aa > 0.0);
1008 assert!(ra > 0.0);
1009 Ok(())
1010 }
1011}