use rand::prelude::*;
use serde::{Deserialize, Serialize};
use crate::graph::SparseGraph;
use crate::traits::ImportanceScorer;
use crate::types::EdgeImportance;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EffectiveResistanceEstimator {
pub max_walk_length: usize,
pub num_walks: usize,
}
impl Default for EffectiveResistanceEstimator {
fn default() -> Self {
Self {
max_walk_length: 100,
num_walks: 10,
}
}
}
impl EffectiveResistanceEstimator {
pub fn new(max_walk_length: usize, num_walks: usize) -> Self {
Self {
max_walk_length,
num_walks,
}
}
pub fn estimate(&self, graph: &SparseGraph, u: usize, v: usize) -> f64 {
if u == v {
return 0.0;
}
let total_w = graph.total_weight();
if total_w <= 0.0 {
return f64::MAX;
}
if graph.degree(u) == 0 || graph.degree(v) == 0 {
return f64::MAX;
}
let mut rng = rand::thread_rng();
let mut total_steps = 0u64;
for _ in 0..self.num_walks {
total_steps += self.walk_to_target(graph, u, v, &mut rng) as u64;
total_steps += self.walk_to_target(graph, v, u, &mut rng) as u64;
}
let avg_commute = total_steps as f64 / self.num_walks as f64;
avg_commute / (2.0 * total_w)
}
fn walk_to_target<R: Rng>(
&self,
graph: &SparseGraph,
start: usize,
target: usize,
rng: &mut R,
) -> usize {
let mut current = start;
for step in 1..=self.max_walk_length {
current = self.random_neighbor(graph, current, rng);
if current == target {
return step;
}
}
self.max_walk_length
}
fn random_neighbor<R: Rng>(&self, graph: &SparseGraph, u: usize, rng: &mut R) -> usize {
let w_deg = graph.weighted_degree(u);
if w_deg <= 0.0 {
return u; }
let threshold = rng.gen::<f64>() * w_deg;
let mut cumulative = 0.0;
for (v, w) in graph.neighbors(u) {
cumulative += w;
if cumulative >= threshold {
return v;
}
}
u
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LocalImportanceScorer {
pub estimator: EffectiveResistanceEstimator,
}
impl LocalImportanceScorer {
pub fn new(walk_length: usize, num_walks: usize) -> Self {
Self {
estimator: EffectiveResistanceEstimator::new(walk_length, num_walks),
}
}
pub fn importance_score(&self, graph: &SparseGraph, u: usize, v: usize, weight: f64) -> f64 {
let r_eff = self.estimator.estimate(graph, u, v);
weight * r_eff
}
}
impl ImportanceScorer for LocalImportanceScorer {
fn score(
&self,
graph: &SparseGraph,
u: usize,
v: usize,
weight: f64,
) -> EdgeImportance {
let r_eff = self.estimator.estimate(graph, u, v);
EdgeImportance::new(u, v, weight, r_eff)
}
fn score_all(&self, graph: &SparseGraph) -> Vec<EdgeImportance> {
let edges: Vec<(usize, usize, f64)> = graph.edges().collect();
if edges.len() > 100 {
use rayon::prelude::*;
edges
.par_iter()
.map(|&(u, v, w)| {
let r_eff = self.estimator.estimate(graph, u, v);
EdgeImportance::new(u, v, w, r_eff)
})
.collect()
} else {
edges
.iter()
.map(|&(u, v, w)| self.score(graph, u, v, w))
.collect()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_self_loop_resistance() {
let g = SparseGraph::from_edges(&[(0, 1, 1.0)]);
let est = EffectiveResistanceEstimator::new(50, 5);
let r = est.estimate(&g, 0, 0);
assert!((r - 0.0).abs() < 1e-12);
}
#[test]
fn test_resistance_positive() {
let g = SparseGraph::from_edges(&[
(0, 1, 1.0),
(1, 2, 1.0),
(2, 3, 1.0),
(3, 0, 1.0),
]);
let est = EffectiveResistanceEstimator::new(200, 20);
let r = est.estimate(&g, 0, 2);
assert!(r > 0.0);
}
#[test]
fn test_scorer_all() {
let g = SparseGraph::from_edges(&[(0, 1, 1.0), (1, 2, 2.0)]);
let scorer = LocalImportanceScorer::new(50, 5);
let scores = scorer.score_all(&g);
assert_eq!(scores.len(), 2);
for s in &scores {
assert!(s.score > 0.0);
}
}
}