use super::modularity::modularity;
use super::types::CommunityResult;
use crate::base::{EdgeWeight, Graph, IndexType, Node};
use std::collections::HashMap;
use std::hash::Hash;
#[derive(Debug, Clone)]
pub struct InfomapResult<N: Node> {
pub node_communities: HashMap<N, usize>,
pub code_length: f64,
pub modularity: f64,
}
#[allow(dead_code)]
pub fn infomap_communities<N, E, Ix>(
graph: &Graph<N, E, Ix>,
max_iterations: usize,
tolerance: f64,
) -> InfomapResult<N>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight + Into<f64> + Copy,
Ix: IndexType,
{
let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
let n = nodes.len();
if n == 0 {
return InfomapResult {
node_communities: HashMap::new(),
code_length: 0.0,
modularity: 0.0,
};
}
let (transition_matrix, node_weights) = build_transition_matrix(graph, &nodes);
let stationary_probs = compute_stationary_distribution(&transition_matrix, &node_weights);
let mut communities: HashMap<N, usize> = nodes
.iter()
.enumerate()
.map(|(i, node)| (node.clone(), i))
.collect();
let mut current_code_length = calculate_map_equation(
graph,
&communities,
&transition_matrix,
&stationary_probs,
&nodes,
);
let mut best_communities = communities.clone();
let mut best_code_length = current_code_length;
let mut rng = scirs2_core::random::rng();
for _iteration in 0..max_iterations {
let mut improved = false;
for node in &nodes {
let current_community = communities[node];
let mut best_community = current_community;
let mut best_local_code_length = current_code_length;
let mut neighboring_communities = std::collections::HashSet::new();
if let Ok(neighbors) = graph.neighbors(node) {
for neighbor in neighbors {
if let Some(&comm) = communities.get(&neighbor) {
neighboring_communities.insert(comm);
}
}
}
for &candidate_community in &neighboring_communities {
if candidate_community != current_community {
communities.insert(node.clone(), candidate_community);
let new_code_length = calculate_map_equation(
graph,
&communities,
&transition_matrix,
&stationary_probs,
&nodes,
);
if new_code_length < best_local_code_length {
best_local_code_length = new_code_length;
best_community = candidate_community;
}
}
}
if best_community != current_community {
communities.insert(node.clone(), best_community);
current_code_length = best_local_code_length;
improved = true;
if current_code_length < best_code_length {
best_code_length = current_code_length;
best_communities = communities.clone();
}
} else {
communities.insert(node.clone(), current_community);
}
}
if !improved || (best_code_length - current_code_length).abs() < tolerance {
break;
}
}
let mut community_map: HashMap<usize, usize> = HashMap::new();
let mut next_id = 0;
for &comm in best_communities.values() {
if let std::collections::hash_map::Entry::Vacant(e) = community_map.entry(comm) {
e.insert(next_id);
next_id += 1;
}
}
for (_, comm) in best_communities.iter_mut() {
*comm = community_map[comm];
}
let final_modularity = modularity(graph, &best_communities);
InfomapResult {
node_communities: best_communities,
code_length: best_code_length,
modularity: final_modularity,
}
}
#[allow(dead_code)]
fn build_transition_matrix<N, E, Ix>(
graph: &Graph<N, E, Ix>,
nodes: &[N],
) -> (Vec<Vec<f64>>, Vec<f64>)
where
N: Node + std::fmt::Debug,
E: EdgeWeight + Into<f64> + Copy,
Ix: IndexType,
{
let n = nodes.len();
let mut transition_matrix = vec![vec![0.0; n]; n];
let mut node_weights = vec![0.0; n];
let node_to_idx: HashMap<&N, usize> = nodes.iter().enumerate().map(|(i, n)| (n, i)).collect();
for (i, node) in nodes.iter().enumerate() {
let mut total_weight = 0.0;
if let Ok(neighbors) = graph.neighbors(node) {
for neighbor in neighbors {
if let Ok(weight) = graph.edge_weight(node, &neighbor) {
total_weight += weight.into();
}
}
}
node_weights[i] = total_weight;
if total_weight > 0.0 {
if let Ok(neighbors) = graph.neighbors(node) {
for neighbor in neighbors {
if let Some(&j) = node_to_idx.get(&neighbor) {
if let Ok(weight) = graph.edge_weight(node, &neighbor) {
transition_matrix[i][j] = weight.into() / total_weight;
}
}
}
}
} else {
for j in 0..n {
transition_matrix[i][j] = 1.0 / n as f64;
}
}
}
(transition_matrix, node_weights)
}
#[allow(dead_code)]
fn compute_stationary_distribution(
transition_matrix: &[Vec<f64>],
node_weights: &[f64],
) -> Vec<f64> {
let n = transition_matrix.len();
if n == 0 {
return vec![];
}
let total_weight: f64 = node_weights.iter().sum();
let mut pi = if total_weight > 0.0 {
node_weights.iter().map(|&w| w / total_weight).collect()
} else {
vec![1.0 / n as f64; n]
};
for _ in 0..1000 {
let mut new_pi = vec![0.0; n];
for (i, new_pi_item) in new_pi.iter_mut().enumerate().take(n) {
for j in 0..n {
*new_pi_item += pi[j] * transition_matrix[j][i];
}
}
let sum: f64 = new_pi.iter().sum();
if sum > 0.0 {
for p in new_pi.iter_mut() {
*p /= sum;
}
}
let diff: f64 = pi
.iter()
.zip(&new_pi)
.map(|(old, new)| (old - new).abs())
.sum();
pi = new_pi;
if diff < 1e-10 {
break;
}
}
pi
}
#[allow(dead_code)]
fn calculate_map_equation<N, E, Ix>(
graph: &Graph<N, E, Ix>,
communities: &HashMap<N, usize>,
transition_matrix: &[Vec<f64>],
stationary_probs: &[f64],
nodes: &[N],
) -> f64
where
N: Node + std::fmt::Debug,
E: EdgeWeight + Into<f64> + Copy,
Ix: IndexType,
{
let n = nodes.len();
if n == 0 {
return 0.0;
}
let node_to_idx: HashMap<&N, usize> = nodes.iter().enumerate().map(|(i, n)| (n, i)).collect();
let mut community_exit_prob: HashMap<usize, f64> = HashMap::new();
let mut community_flow: HashMap<usize, f64> = HashMap::new();
for &comm in communities.values() {
community_exit_prob.insert(comm, 0.0);
community_flow.insert(comm, 0.0);
}
for (node, &comm) in communities {
if let Some(&i) = node_to_idx.get(node) {
let pi_i = stationary_probs[i];
*community_flow.get_mut(&comm).expect("Operation failed") += pi_i;
if let Ok(neighbors) = graph.neighbors(node) {
for neighbor in neighbors {
if let Some(&neighbor_comm) = communities.get(&neighbor) {
if neighbor_comm != comm {
if let Some(&j) = node_to_idx.get(&neighbor) {
*community_exit_prob
.get_mut(&comm)
.expect("Operation failed") += pi_i * transition_matrix[i][j];
}
}
}
}
}
}
}
let mut code_length = 0.0;
let total_exit_flow: f64 = community_exit_prob.values().sum();
if total_exit_flow > 0.0 {
for &q_alpha in community_exit_prob.values() {
if q_alpha > 0.0 {
code_length -= q_alpha * (q_alpha / total_exit_flow).ln();
}
}
}
for (&comm, &q_alpha) in &community_exit_prob {
let p_alpha = community_flow[&comm];
let total_alpha = q_alpha + p_alpha;
if total_alpha > 0.0 {
let mut h_alpha = 0.0;
if q_alpha > 0.0 {
h_alpha -= (q_alpha / total_alpha) * (q_alpha / total_alpha).ln();
}
for (node, &node_comm) in communities {
if node_comm == comm {
if let Some(&i) = node_to_idx.get(node) {
let pi_i = stationary_probs[i];
if pi_i > 0.0 {
let prob_in_module = pi_i / total_alpha;
h_alpha -= prob_in_module * prob_in_module.ln();
}
}
}
}
code_length += total_alpha * h_alpha;
}
}
code_length
}