use crate::error::{SparseError, SparseResult};
use crate::sparray::SparseArray;
use scirs2_core::numeric::{Float, SparseElement};
use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
pub fn louvain_communities<T, S>(
graph: &S,
resolution: T,
max_iter: usize,
) -> SparseResult<(usize, Vec<usize>)>
where
T: Float + SparseElement + Debug + Copy + std::iter::Sum + 'static,
S: SparseArray<T>,
{
let n = graph.shape().0;
if graph.shape().0 != graph.shape().1 {
return Err(SparseError::ValueError(
"Graph matrix must be square".to_string(),
));
}
let mut communities = (0..n).collect::<Vec<_>>();
let mut degrees = vec![T::sparse_zero(); n];
let mut sum_all_weights = T::sparse_zero();
for i in 0..n {
for j in 0..n {
let weight = graph.get(i, j);
if !scirs2_core::SparseElement::is_zero(&weight) {
degrees[i] = degrees[i] + weight;
sum_all_weights = sum_all_weights + weight;
}
}
}
let two = T::from(2.0)
.ok_or_else(|| SparseError::ComputationError("Cannot convert 2.0".to_string()))?;
let m = sum_all_weights / two;
if scirs2_core::SparseElement::is_zero(&m) {
return Ok((n, communities));
}
let mut improvement = true;
let mut iteration = 0;
while improvement && iteration < max_iter {
improvement = false;
iteration += 1;
for node in 0..n {
let current_community = communities[node];
let mut neighbor_community_set = HashSet::new();
for neighbor in 0..n {
let weight = graph.get(node, neighbor);
if !scirs2_core::SparseElement::is_zero(&weight) && neighbor != node {
neighbor_community_set.insert(communities[neighbor]);
}
}
neighbor_community_set.insert(current_community);
let mut neighbor_communities: Vec<usize> = neighbor_community_set.into_iter().collect();
neighbor_communities.sort();
let mut weight_to_current = T::sparse_zero();
let mut sigma_current = T::sparse_zero(); for i in 0..n {
if i != node && communities[i] == current_community {
let w = graph.get(node, i);
weight_to_current = weight_to_current + w;
sigma_current = sigma_current + degrees[i];
}
}
let k_i = degrees[node];
let remove_cost = weight_to_current - resolution * k_i * sigma_current / (two * m);
let mut best_community = current_community;
let mut best_delta = T::sparse_zero();
for &community in &neighbor_communities {
if community == current_community {
continue;
}
let mut weight_to_target = T::sparse_zero();
let mut sigma_target = T::sparse_zero();
for i in 0..n {
if communities[i] == community {
let w = graph.get(node, i);
weight_to_target = weight_to_target + w;
sigma_target = sigma_target + degrees[i];
}
}
let add_gain = weight_to_target - resolution * k_i * sigma_target / (two * m);
let delta = add_gain - remove_cost;
if delta > best_delta {
best_delta = delta;
best_community = community;
}
}
if best_community != current_community {
communities[node] = best_community;
improvement = true;
}
}
}
let community_map = renumber_communities(&communities);
let final_communities: Vec<usize> = communities.iter().map(|&c| community_map[&c]).collect();
let num_communities = community_map.len();
Ok((num_communities, final_communities))
}
fn renumber_communities(communities: &[usize]) -> HashMap<usize, usize> {
let unique_communities: HashSet<usize> = communities.iter().copied().collect();
let mut community_map = HashMap::new();
for (new_id, &old_id) in unique_communities.iter().enumerate() {
community_map.insert(old_id, new_id);
}
community_map
}
pub fn label_propagation<T, S>(graph: &S, max_iter: usize) -> SparseResult<(usize, Vec<usize>)>
where
T: Float + SparseElement + Debug + Copy + 'static,
S: SparseArray<T>,
{
let n = graph.shape().0;
if graph.shape().0 != graph.shape().1 {
return Err(SparseError::ValueError(
"Graph matrix must be square".to_string(),
));
}
let mut labels = (0..n).collect::<Vec<_>>();
let mut changed = true;
let mut iteration = 0;
while changed && iteration < max_iter {
changed = false;
iteration += 1;
for node in 0..n {
let mut label_counts: HashMap<usize, T> = HashMap::new();
for neighbor in 0..n {
let weight = graph.get(node, neighbor);
if !scirs2_core::SparseElement::is_zero(&weight) && neighbor != node {
let neighbor_label = labels[neighbor];
let count = label_counts
.entry(neighbor_label)
.or_insert(T::sparse_zero());
*count = *count + weight;
}
}
if label_counts.is_empty() {
continue;
}
let mut best_label = labels[node];
let mut best_count = T::sparse_zero();
for (&label, &count) in &label_counts {
if count > best_count {
best_count = count;
best_label = label;
}
}
if best_label != labels[node] {
labels[node] = best_label;
changed = true;
}
}
}
let community_map = renumber_communities(&labels);
let final_communities: Vec<usize> = labels.iter().map(|&c| community_map[&c]).collect();
let num_communities = community_map.len();
Ok((num_communities, final_communities))
}
pub fn modularity<T, S>(graph: &S, communities: &[usize]) -> SparseResult<T>
where
T: Float + SparseElement + Debug + Copy + std::iter::Sum + 'static,
S: SparseArray<T>,
{
let n = graph.shape().0;
if graph.shape().0 != graph.shape().1 {
return Err(SparseError::ValueError(
"Graph matrix must be square".to_string(),
));
}
if communities.len() != n {
return Err(SparseError::ValueError(
"Communities vector must match graph size".to_string(),
));
}
let two = T::from(2.0)
.ok_or_else(|| SparseError::ComputationError("Cannot convert 2.0".to_string()))?;
let mut sum_all_weights = T::sparse_zero();
let mut degrees = vec![T::sparse_zero(); n];
for i in 0..n {
for j in 0..n {
let weight = graph.get(i, j);
if !scirs2_core::SparseElement::is_zero(&weight) {
degrees[i] = degrees[i] + weight;
sum_all_weights = sum_all_weights + weight;
}
}
}
let m = sum_all_weights / two;
if scirs2_core::SparseElement::is_zero(&m) {
return Ok(T::sparse_zero());
}
let two_m = two * m;
let mut q = T::sparse_zero();
for i in 0..n {
for j in 0..n {
if communities[i] == communities[j] {
let aij = graph.get(i, j);
let kikj = degrees[i] * degrees[j];
q = q + (aij - kikj / two_m);
}
}
}
q = q / two_m;
Ok(q)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csr_array::CsrArray;
fn create_two_community_graph() -> CsrArray<f64> {
let rows = vec![
0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 2, 3, ];
let cols = vec![
1, 2, 0, 2, 0, 1, 4, 5, 3, 5, 3, 4, 3, 2, ];
let data = vec![1.0; 14];
CsrArray::from_triplets(&rows, &cols, &data, (6, 6), false).expect("Failed to create")
}
#[test]
fn test_louvain_communities() {
let graph = create_two_community_graph();
let (num_communities, communities) = louvain_communities(&graph, 1.0, 10).expect("Failed");
assert!(num_communities >= 2);
assert!(num_communities <= 3);
assert_eq!(communities.len(), 6);
}
#[test]
fn test_label_propagation() {
let graph = create_two_community_graph();
let (num_communities, communities) = label_propagation(&graph, 10).expect("Failed");
assert!(num_communities >= 1);
assert_eq!(communities.len(), 6);
}
#[test]
fn test_modularity() {
let graph = create_two_community_graph();
let communities = vec![0, 0, 0, 1, 1, 1];
let q = modularity(&graph, &communities).expect("Failed");
assert!(q > 0.0);
let random_communities = vec![0, 1, 0, 1, 0, 1];
let q_random = modularity(&graph, &random_communities).expect("Failed");
assert!(q > q_random);
}
#[test]
fn test_single_node_communities() {
let graph = create_two_community_graph();
let communities = vec![0, 1, 2, 3, 4, 5];
let q = modularity(&graph, &communities).expect("Failed");
assert!(q < 0.3);
}
#[test]
fn test_all_same_community() {
let graph = create_two_community_graph();
let communities = vec![0, 0, 0, 0, 0, 0];
let q = modularity(&graph, &communities).expect("Failed");
assert!(q.abs() < 0.1);
}
}