use crate::base::{EdgeWeight, Graph, IndexType, Node};
use crate::error::{GraphError, Result};
use std::collections::{HashMap, HashSet};
use std::hash::Hash;
#[derive(Debug, Clone)]
pub struct LinkScore<N: Node> {
pub node_a: N,
pub node_b: N,
pub score: f64,
}
#[derive(Debug, Clone)]
pub struct LinkPredictionEval {
pub auc: f64,
pub average_precision: f64,
pub true_positives: usize,
pub false_positives: usize,
pub total_positives: usize,
pub total_negatives: usize,
}
#[derive(Debug, Clone)]
pub struct LinkPredictionConfig {
pub max_predictions: usize,
pub min_score: f64,
pub include_self_loops: bool,
}
impl Default for LinkPredictionConfig {
fn default() -> Self {
Self {
max_predictions: 100,
min_score: 0.0,
include_self_loops: false,
}
}
}
pub fn common_neighbors_score<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<f64>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
validate_nodes(graph, u, v)?;
let neighbors_u: HashSet<N> = graph.neighbors(u)?.into_iter().collect();
let neighbors_v: HashSet<N> = graph.neighbors(v)?.into_iter().collect();
let common = neighbors_u.intersection(&neighbors_v).count();
Ok(common as f64)
}
pub fn common_neighbors_all<N, E, Ix>(
graph: &Graph<N, E, Ix>,
config: &LinkPredictionConfig,
) -> Vec<LinkScore<N>>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
compute_all_scores(graph, config, |g, u, v| {
common_neighbors_score(g, u, v).unwrap_or(0.0)
})
}
pub fn jaccard_coefficient<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<f64>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
validate_nodes(graph, u, v)?;
let neighbors_u: HashSet<N> = graph.neighbors(u)?.into_iter().collect();
let neighbors_v: HashSet<N> = graph.neighbors(v)?.into_iter().collect();
let intersection = neighbors_u.intersection(&neighbors_v).count();
let union = neighbors_u.union(&neighbors_v).count();
if union == 0 {
Ok(0.0)
} else {
Ok(intersection as f64 / union as f64)
}
}
pub fn jaccard_coefficient_all<N, E, Ix>(
graph: &Graph<N, E, Ix>,
config: &LinkPredictionConfig,
) -> Vec<LinkScore<N>>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
compute_all_scores(graph, config, |g, u, v| {
jaccard_coefficient(g, u, v).unwrap_or(0.0)
})
}
pub fn adamic_adar_index<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<f64>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
validate_nodes(graph, u, v)?;
let neighbors_u: HashSet<N> = graph.neighbors(u)?.into_iter().collect();
let neighbors_v: HashSet<N> = graph.neighbors(v)?.into_iter().collect();
let mut score = 0.0;
for common in neighbors_u.intersection(&neighbors_v) {
let degree = graph.degree(common);
if degree > 1 {
score += 1.0 / (degree as f64).ln();
}
}
Ok(score)
}
pub fn adamic_adar_all<N, E, Ix>(
graph: &Graph<N, E, Ix>,
config: &LinkPredictionConfig,
) -> Vec<LinkScore<N>>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
compute_all_scores(graph, config, |g, u, v| {
adamic_adar_index(g, u, v).unwrap_or(0.0)
})
}
pub fn preferential_attachment<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<f64>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
validate_nodes(graph, u, v)?;
let deg_u = graph.degree(u);
let deg_v = graph.degree(v);
Ok((deg_u * deg_v) as f64)
}
pub fn preferential_attachment_all<N, E, Ix>(
graph: &Graph<N, E, Ix>,
config: &LinkPredictionConfig,
) -> Vec<LinkScore<N>>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
compute_all_scores(graph, config, |g, u, v| {
preferential_attachment(g, u, v).unwrap_or(0.0)
})
}
pub fn resource_allocation_index<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<f64>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
validate_nodes(graph, u, v)?;
let neighbors_u: HashSet<N> = graph.neighbors(u)?.into_iter().collect();
let neighbors_v: HashSet<N> = graph.neighbors(v)?.into_iter().collect();
let mut score = 0.0;
for common in neighbors_u.intersection(&neighbors_v) {
let degree = graph.degree(common);
if degree > 0 {
score += 1.0 / degree as f64;
}
}
Ok(score)
}
pub fn resource_allocation_all<N, E, Ix>(
graph: &Graph<N, E, Ix>,
config: &LinkPredictionConfig,
) -> Vec<LinkScore<N>>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
compute_all_scores(graph, config, |g, u, v| {
resource_allocation_index(g, u, v).unwrap_or(0.0)
})
}
pub fn katz_similarity<N, E, Ix>(
graph: &Graph<N, E, Ix>,
u: &N,
v: &N,
beta: f64,
max_path_length: usize,
) -> Result<f64>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
validate_nodes(graph, u, v)?;
if beta <= 0.0 || beta >= 1.0 {
return Err(GraphError::InvalidGraph(
"Beta must be in (0, 1)".to_string(),
));
}
let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
let n = nodes.len();
let node_to_idx: HashMap<N, usize> = nodes
.iter()
.enumerate()
.map(|(i, n)| (n.clone(), i))
.collect();
let u_idx = node_to_idx
.get(u)
.ok_or_else(|| GraphError::node_not_found(format!("{u:?}")))?;
let v_idx = node_to_idx
.get(v)
.ok_or_else(|| GraphError::node_not_found(format!("{v:?}")))?;
let mut adj: Vec<Vec<usize>> = vec![vec![]; n];
for (i, node) in nodes.iter().enumerate() {
if let Ok(neighbors) = graph.neighbors(node) {
for neighbor in &neighbors {
if let Some(&j) = node_to_idx.get(neighbor) {
adj[i].push(j);
}
}
}
}
let mut score = 0.0;
let mut current = vec![0.0f64; n];
current[*u_idx] = 1.0;
for l in 1..=max_path_length {
let mut next = vec![0.0f64; n];
for (i, &count) in current.iter().enumerate() {
if count > 0.0 {
for &j in &adj[i] {
next[j] += count;
}
}
}
let beta_l = beta.powi(l as i32);
score += beta_l * next[*v_idx];
current = next;
}
Ok(score)
}
pub fn katz_similarity_all<N, E, Ix>(
graph: &Graph<N, E, Ix>,
beta: f64,
max_path_length: usize,
config: &LinkPredictionConfig,
) -> Vec<LinkScore<N>>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
compute_all_scores(graph, config, |g, u, v| {
katz_similarity(g, u, v, beta, max_path_length).unwrap_or(0.0)
})
}
pub fn simrank<N, E, Ix>(
graph: &Graph<N, E, Ix>,
decay: f64,
max_iterations: usize,
tolerance: f64,
) -> Result<HashMap<(N, N), f64>>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
if decay <= 0.0 || decay > 1.0 {
return Err(GraphError::InvalidGraph(
"Decay must be in (0, 1]".to_string(),
));
}
let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
let n = nodes.len();
let node_to_idx: HashMap<N, usize> = nodes
.iter()
.enumerate()
.map(|(i, n)| (n.clone(), i))
.collect();
let mut adj: Vec<Vec<usize>> = vec![vec![]; n];
for (i, node) in nodes.iter().enumerate() {
if let Ok(neighbors) = graph.neighbors(node) {
for neighbor in &neighbors {
if let Some(&j) = node_to_idx.get(neighbor) {
adj[i].push(j);
}
}
}
}
let mut sim = vec![vec![0.0f64; n]; n];
for i in 0..n {
sim[i][i] = 1.0;
}
for _ in 0..max_iterations {
let mut new_sim = vec![vec![0.0f64; n]; n];
let mut max_diff = 0.0f64;
for i in 0..n {
new_sim[i][i] = 1.0;
for j in (i + 1)..n {
let deg_i = adj[i].len();
let deg_j = adj[j].len();
if deg_i == 0 || deg_j == 0 {
new_sim[i][j] = 0.0;
new_sim[j][i] = 0.0;
continue;
}
let mut sum = 0.0;
for &ni in &adj[i] {
for &nj in &adj[j] {
sum += sim[ni][nj];
}
}
let new_val = decay * sum / (deg_i * deg_j) as f64;
new_sim[i][j] = new_val;
new_sim[j][i] = new_val;
let diff = (new_val - sim[i][j]).abs();
if diff > max_diff {
max_diff = diff;
}
}
}
sim = new_sim;
if max_diff < tolerance {
break;
}
}
let mut result = HashMap::new();
for i in 0..n {
for j in i..n {
result.insert((nodes[i].clone(), nodes[j].clone()), sim[i][j]);
if i != j {
result.insert((nodes[j].clone(), nodes[i].clone()), sim[i][j]);
}
}
}
Ok(result)
}
pub fn simrank_score<N, E, Ix>(
graph: &Graph<N, E, Ix>,
u: &N,
v: &N,
decay: f64,
max_iterations: usize,
) -> Result<f64>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
let all_scores = simrank(graph, decay, max_iterations, 1e-6)?;
all_scores
.get(&(u.clone(), v.clone()))
.copied()
.ok_or_else(|| GraphError::node_not_found(format!("{u:?}")))
}
pub fn evaluate_link_prediction<N>(
scores: &[LinkScore<N>],
positive_edges: &HashSet<(N, N)>,
negative_edges: &HashSet<(N, N)>,
) -> LinkPredictionEval
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
{
if positive_edges.is_empty() || negative_edges.is_empty() {
return LinkPredictionEval {
auc: 0.5,
average_precision: 0.0,
true_positives: 0,
false_positives: 0,
total_positives: positive_edges.len(),
total_negatives: negative_edges.len(),
};
}
let mut scored_labels: Vec<(f64, bool)> = Vec::new();
for score in scores {
let pair = (score.node_a.clone(), score.node_b.clone());
let reverse_pair = (score.node_b.clone(), score.node_a.clone());
let is_positive = positive_edges.contains(&pair) || positive_edges.contains(&reverse_pair);
let is_negative = negative_edges.contains(&pair) || negative_edges.contains(&reverse_pair);
if is_positive || is_negative {
scored_labels.push((score.score, is_positive));
}
}
scored_labels.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let total_positives = scored_labels.iter().filter(|(_, label)| *label).count();
let total_negatives = scored_labels.iter().filter(|(_, label)| !*label).count();
if total_positives == 0 || total_negatives == 0 {
return LinkPredictionEval {
auc: 0.5,
average_precision: 0.0,
true_positives: 0,
false_positives: 0,
total_positives,
total_negatives,
};
}
let mut auc = 0.0;
let mut tp = 0usize;
let mut fp = 0usize;
let mut prev_fpr = 0.0;
let mut prev_tpr = 0.0;
let mut ap = 0.0;
let mut running_tp = 0;
for (i, &(_, is_positive)) in scored_labels.iter().enumerate() {
if is_positive {
tp += 1;
running_tp += 1;
ap += running_tp as f64 / (i + 1) as f64;
} else {
fp += 1;
}
let tpr = tp as f64 / total_positives as f64;
let fpr = fp as f64 / total_negatives as f64;
auc += (fpr - prev_fpr) * (tpr + prev_tpr) / 2.0;
prev_fpr = fpr;
prev_tpr = tpr;
}
auc += (1.0 - prev_fpr) * (1.0 + prev_tpr) / 2.0;
let average_precision = if total_positives > 0 {
ap / total_positives as f64
} else {
0.0
};
LinkPredictionEval {
auc,
average_precision,
true_positives: tp,
false_positives: fp,
total_positives,
total_negatives,
}
}
pub fn compute_auc<N, E, Ix, F>(
graph: &Graph<N, E, Ix>,
test_edges: &[(N, N)],
non_edges: &[(N, N)],
score_fn: F,
) -> f64
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
F: Fn(&Graph<N, E, Ix>, &N, &N) -> f64,
{
if test_edges.is_empty() || non_edges.is_empty() {
return 0.5;
}
let mut n_correct = 0usize;
let mut n_tie = 0usize;
let mut n_total = 0usize;
for (pu, pv) in test_edges {
let pos_score = score_fn(graph, pu, pv);
for (nu, nv) in non_edges {
let neg_score = score_fn(graph, nu, nv);
n_total += 1;
if pos_score > neg_score + 1e-12 {
n_correct += 1;
} else if (pos_score - neg_score).abs() <= 1e-12 {
n_tie += 1;
}
}
}
if n_total == 0 {
return 0.5;
}
(n_correct as f64 + 0.5 * n_tie as f64) / n_total as f64
}
fn validate_nodes<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<()>
where
N: Node + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
if !graph.has_node(u) {
return Err(GraphError::node_not_found(format!("{u:?}")));
}
if !graph.has_node(v) {
return Err(GraphError::node_not_found(format!("{v:?}")));
}
Ok(())
}
fn compute_all_scores<N, E, Ix, F>(
graph: &Graph<N, E, Ix>,
config: &LinkPredictionConfig,
score_fn: F,
) -> Vec<LinkScore<N>>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
F: Fn(&Graph<N, E, Ix>, &N, &N) -> f64,
{
let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
let mut scores = Vec::new();
for (i, u) in nodes.iter().enumerate() {
for v in nodes.iter().skip(i + 1) {
if !config.include_self_loops && u == v {
continue;
}
if graph.has_edge(u, v) {
continue;
}
let score = score_fn(graph, u, v);
if score >= config.min_score {
scores.push(LinkScore {
node_a: u.clone(),
node_b: v.clone(),
score,
});
}
}
}
scores.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scores.truncate(config.max_predictions);
scores
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result as GraphResult;
use crate::generators::create_graph;
fn build_test_graph() -> Graph<i32, ()> {
let mut g = create_graph::<i32, ()>();
let _ = g.add_edge(0, 1, ());
let _ = g.add_edge(1, 2, ());
let _ = g.add_edge(0, 3, ());
let _ = g.add_edge(1, 4, ());
let _ = g.add_edge(2, 5, ());
let _ = g.add_edge(3, 4, ());
let _ = g.add_edge(4, 5, ());
g
}
#[test]
fn test_common_neighbors() -> GraphResult<()> {
let g = build_test_graph();
let score = common_neighbors_score(&g, &0, &2)?;
assert!((score - 1.0).abs() < 1e-6);
let score = common_neighbors_score(&g, &0, &4)?;
assert!((score - 2.0).abs() < 1e-6);
let score = common_neighbors_score(&g, &0, &5)?;
assert!((score - 0.0).abs() < 1e-6);
Ok(())
}
#[test]
fn test_jaccard_coefficient() -> GraphResult<()> {
let g = build_test_graph();
let score = jaccard_coefficient(&g, &0, &4)?;
assert!(score > 0.0 && score <= 1.0);
let score = jaccard_coefficient(&g, &0, &0)?;
assert!((score - 1.0).abs() < 1e-6);
Ok(())
}
#[test]
fn test_adamic_adar() -> GraphResult<()> {
let g = build_test_graph();
let score = adamic_adar_index(&g, &0, &4)?;
assert!(score > 0.0);
let score = adamic_adar_index(&g, &0, &5)?;
assert!((score - 0.0).abs() < 1e-6);
Ok(())
}
#[test]
fn test_preferential_attachment() -> GraphResult<()> {
let g = build_test_graph();
let score = preferential_attachment(&g, &0, &4)?;
assert!((score - 6.0).abs() < 1e-6);
let score = preferential_attachment(&g, &1, &4)?;
assert!((score - 9.0).abs() < 1e-6);
Ok(())
}
#[test]
fn test_resource_allocation() -> GraphResult<()> {
let g = build_test_graph();
let score = resource_allocation_index(&g, &0, &4)?;
assert!(score > 0.0);
let score = resource_allocation_index(&g, &0, &5)?;
assert!((score - 0.0).abs() < 1e-6);
Ok(())
}
#[test]
fn test_katz_similarity() -> GraphResult<()> {
let g = build_test_graph();
let score = katz_similarity(&g, &0, &2, 0.05, 3)?;
assert!(score > 0.0);
let score_near = katz_similarity(&g, &0, &1, 0.05, 3)?;
let score_far = katz_similarity(&g, &0, &5, 0.05, 3)?;
assert!(score_near > score_far);
Ok(())
}
#[test]
fn test_katz_invalid_beta() {
let g = build_test_graph();
assert!(katz_similarity(&g, &0, &1, 0.0, 3).is_err());
assert!(katz_similarity(&g, &0, &1, 1.0, 3).is_err());
}
#[test]
fn test_simrank() -> GraphResult<()> {
let g = build_test_graph();
let scores = simrank(&g, 0.8, 10, 1e-4)?;
let self_score = scores.get(&(0, 0)).copied().unwrap_or(0.0);
assert!((self_score - 1.0).abs() < 1e-6);
for &score in scores.values() {
assert!(score >= -1e-6);
}
Ok(())
}
#[test]
fn test_simrank_score() -> GraphResult<()> {
let g = build_test_graph();
let score = simrank_score(&g, &0, &2, 0.8, 10)?;
assert!(score >= 0.0);
assert!(score <= 1.0);
Ok(())
}
#[test]
fn test_evaluate_link_prediction() {
let scores = vec![
LinkScore {
node_a: 0,
node_b: 1,
score: 0.9,
},
LinkScore {
node_a: 0,
node_b: 2,
score: 0.8,
},
LinkScore {
node_a: 0,
node_b: 3,
score: 0.3,
},
LinkScore {
node_a: 1,
node_b: 3,
score: 0.2,
},
];
let mut positives = HashSet::new();
positives.insert((0, 1));
positives.insert((0, 2));
let mut negatives = HashSet::new();
negatives.insert((0, 3));
negatives.insert((1, 3));
let eval = evaluate_link_prediction(&scores, &positives, &negatives);
assert!(eval.auc >= 0.5); assert!(eval.true_positives > 0);
}
#[test]
fn test_compute_auc() -> GraphResult<()> {
let g = build_test_graph();
let test_edges = vec![(0, 4)]; let non_edges = vec![(0, 5)];
let auc = compute_auc(&g, &test_edges, &non_edges, |g, u, v| {
common_neighbors_score(g, u, v).unwrap_or(0.0)
});
assert!(auc >= 0.5); Ok(())
}
#[test]
fn test_common_neighbors_all() {
let g = build_test_graph();
let config = LinkPredictionConfig {
max_predictions: 10,
min_score: 0.0,
include_self_loops: false,
};
let scores = common_neighbors_all(&g, &config);
for score in &scores {
assert!(!g.has_edge(&score.node_a, &score.node_b));
}
for window in scores.windows(2) {
assert!(window[0].score >= window[1].score);
}
}
#[test]
fn test_invalid_nodes() {
let g = build_test_graph();
assert!(common_neighbors_score(&g, &0, &99).is_err());
assert!(jaccard_coefficient(&g, &99, &0).is_err());
assert!(adamic_adar_index(&g, &0, &99).is_err());
}
#[test]
fn test_empty_graph_link_prediction() -> GraphResult<()> {
let mut g = create_graph::<i32, ()>();
let _ = g.add_node(0);
let config = LinkPredictionConfig::default();
let scores = common_neighbors_all(&g, &config);
assert!(scores.is_empty());
Ok(())
}
#[test]
fn test_all_methods_consistency() -> GraphResult<()> {
let g = build_test_graph();
let cn = common_neighbors_score(&g, &0, &4)?;
let jc = jaccard_coefficient(&g, &0, &4)?;
let aa = adamic_adar_index(&g, &0, &4)?;
let pa = preferential_attachment(&g, &0, &4)?;
let ra = resource_allocation_index(&g, &0, &4)?;
let kz = katz_similarity(&g, &0, &4, 0.05, 3)?;
assert!(cn >= 0.0);
assert!(jc >= 0.0);
assert!(aa >= 0.0);
assert!(pa >= 0.0);
assert!(ra >= 0.0);
assert!(kz >= 0.0);
assert!(cn > 0.0);
assert!(jc > 0.0);
assert!(aa > 0.0);
assert!(ra > 0.0);
Ok(())
}
}