use crate::error::{ClusteringError, Result as ClusterResult};
use std::collections::HashSet;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BigClamInit {
Random,
SpectralWarmStart,
}
#[derive(Debug, Clone)]
pub struct BigClamConfig {
pub n_communities: usize,
pub max_iter: usize,
pub learning_rate: f64,
pub min_membership: f64,
pub reg_lambda: f64,
pub init: BigClamInit,
pub seed: u64,
pub tol: f64,
}
impl Default for BigClamConfig {
fn default() -> Self {
Self {
n_communities: 5,
max_iter: 100,
learning_rate: 0.005,
min_membership: 1e-3,
reg_lambda: 0.01,
init: BigClamInit::Random,
seed: 42,
tol: 1e-4,
}
}
}
#[derive(Debug, Clone)]
pub struct MembershipMatrix {
pub memberships: Vec<Vec<f64>>,
pub n_nodes: usize,
pub n_communities: usize,
}
impl MembershipMatrix {
pub fn new(n_nodes: usize, n_communities: usize) -> Self {
Self {
memberships: vec![vec![0.0; n_communities]; n_nodes],
n_nodes,
n_communities,
}
}
pub fn community_members(&self, community: usize, threshold: f64) -> Vec<usize> {
self.memberships
.iter()
.enumerate()
.filter(|(_, row)| row.get(community).copied().unwrap_or(0.0) >= threshold)
.map(|(i, _)| i)
.collect()
}
pub fn node_communities(&self, node: usize, threshold: f64) -> Vec<usize> {
match self.memberships.get(node) {
None => Vec::new(),
Some(row) => row
.iter()
.enumerate()
.filter(|(_, &v)| v >= threshold)
.map(|(c, _)| c)
.collect(),
}
}
pub fn to_hard_partition(&self) -> Vec<usize> {
self.memberships
.iter()
.map(|row| {
row.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(c, _)| c)
.unwrap_or(0)
})
.collect()
}
}
pub struct BigClam {
config: BigClamConfig,
}
impl BigClam {
pub fn new(config: BigClamConfig) -> Self {
Self { config }
}
pub fn fit(&self, adj: &[Vec<usize>]) -> ClusterResult<MembershipMatrix> {
let n = adj.len();
if n == 0 {
return Err(ClusteringError::InvalidInput(
"Adjacency list is empty".to_string(),
));
}
let k = self.config.n_communities;
if k == 0 {
return Err(ClusteringError::InvalidInput(
"n_communities must be ≥ 1".to_string(),
));
}
let neighbour_sets: Vec<HashSet<usize>> = adj
.iter()
.map(|row| row.iter().copied().collect())
.collect();
let mut f = self.init_f(n, k, adj);
let lr = self.config.learning_rate;
let lambda = self.config.reg_lambda;
let tol = self.config.tol;
let min_m = self.config.min_membership;
let mut community_sums = compute_community_sums(&f, k);
for _iter in 0..self.config.max_iter {
let f_old = f.clone();
for u in 0..n {
let f_u_old = f[u].clone();
for c in 0..k {
let mut pos_grad = 0.0;
for &v in &adj[u] {
if v == u {
continue;
}
let dp = dot_product(&f[u], &f[v]);
let exp_neg = (-dp).exp();
let denom = (1.0 - exp_neg).max(1e-10);
pos_grad += f[v][c] * exp_neg / denom;
}
let mut nbr_sum_c = 0.0;
for &v in &adj[u] {
if v != u {
nbr_sum_c += f[v][c];
}
}
let neg_grad = community_sums[c] - f[u][c] - nbr_sum_c;
let reg = lambda * f[u][c];
let grad = pos_grad - neg_grad - reg;
let new_val = (f[u][c] + lr * grad).max(0.0);
f[u][c] = if new_val < min_m { 0.0 } else { new_val };
}
for c in 0..k {
community_sums[c] += f[u][c] - f_u_old[c];
}
}
if has_converged(&f, &f_old, tol) {
break;
}
let _ = &neighbour_sets;
}
Ok(MembershipMatrix {
memberships: f,
n_nodes: n,
n_communities: k,
})
}
pub fn log_likelihood(&self, adj: &[Vec<usize>], f: &[Vec<f64>]) -> f64 {
let n = adj.len();
let mut edge_set: HashSet<(usize, usize)> = HashSet::new();
for (u, neighbours) in adj.iter().enumerate() {
for &v in neighbours {
if u < v {
edge_set.insert((u, v));
}
}
}
let mut ll = 0.0;
for u in 0..n {
for v in (u + 1)..n {
let dp = dot_product(&f[u], &f[v]);
if edge_set.contains(&(u, v)) {
let prob = (1.0_f64 - (-dp).exp()).max(1e-15);
ll += prob.ln();
} else {
ll -= dp;
}
}
}
ll
}
fn init_f(&self, n: usize, k: usize, adj: &[Vec<usize>]) -> Vec<Vec<f64>> {
match self.config.init {
BigClamInit::Random => self.init_random(n, k),
BigClamInit::SpectralWarmStart => {
if n < 2 {
self.init_random(n, k)
} else {
self.init_spectral(n, k, adj)
}
}
}
}
fn init_random(&self, n: usize, k: usize) -> Vec<Vec<f64>> {
let mut state = self.config.seed.wrapping_add(1);
let mut rand_f64 = move || -> f64 {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
(state as f64) / (u64::MAX as f64)
};
(0..n)
.map(|_| (0..k).map(|_| rand_f64() * 0.5 + 0.01).collect())
.collect()
}
fn init_spectral(&self, n: usize, k: usize, adj: &[Vec<usize>]) -> Vec<Vec<f64>> {
let deg: Vec<f64> = adj.iter().map(|row| row.len() as f64).collect();
let inv_sqrt_n = 1.0 / (n as f64).sqrt();
let mut x: Vec<f64> = vec![inv_sqrt_n; n];
for _ in 0..2 {
let mut y = vec![0.0f64; n];
for (u, neighbours) in adj.iter().enumerate() {
let d_u = deg[u].max(1.0).sqrt();
for &v in neighbours {
let d_v = deg[v].max(1.0).sqrt();
y[u] += x[v] / (d_u * d_v);
}
}
let norm = y.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-15);
for val in &mut y {
*val /= norm;
}
x = y;
}
let mut state = self.config.seed.wrapping_add(0xDEAD_BEEF);
let mut rand_normal = move || -> f64 {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
let u1 = (state as f64) / (u64::MAX as f64);
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
let u2 = (state as f64) / (u64::MAX as f64);
let u1 = u1.max(1e-15);
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
};
let mut proj: Vec<f64> = (0..k).map(|_| rand_normal().abs() + 0.1).collect();
let pnorm = proj.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-15);
for v in &mut proj {
*v /= pnorm;
}
(0..n)
.map(|u| {
(0..k)
.map(|c| {
let base = (x[u] * proj[c]).abs();
(base + 0.01).max(0.01)
})
.collect()
})
.collect()
}
}
#[inline]
pub fn dot_product(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
fn compute_community_sums(f: &[Vec<f64>], k: usize) -> Vec<f64> {
let mut sums = vec![0.0f64; k];
for row in f {
for (c, &v) in row.iter().enumerate() {
sums[c] += v;
}
}
sums
}
fn has_converged(f_new: &[Vec<f64>], f_old: &[Vec<f64>], tol: f64) -> bool {
let mut diff_sq = 0.0f64;
let mut old_sq = 0.0f64;
for (row_new, row_old) in f_new.iter().zip(f_old.iter()) {
for (&a, &b) in row_new.iter().zip(row_old.iter()) {
diff_sq += (a - b) * (a - b);
old_sq += b * b;
}
}
diff_sq / (old_sq + 1e-15) < tol * tol
}
#[cfg(test)]
mod tests {
use super::*;
fn two_triangles() -> Vec<Vec<usize>> {
vec![
vec![1, 2],
vec![0, 2],
vec![0, 1],
vec![4, 5],
vec![3, 5],
vec![3, 4],
]
}
fn two_cliques_bridge() -> Vec<Vec<usize>> {
let mut adj = vec![vec![]; 7];
for &(u, v) in &[(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)] {
adj[u].push(v);
adj[v].push(u);
}
for &(u, v) in &[(3, 4), (3, 5), (3, 6), (4, 5), (4, 6), (5, 6)] {
adj[u].push(v);
adj[v].push(u);
}
adj
}
#[test]
fn test_bigclam_membership_nonnegative() {
let adj = two_triangles();
let config = BigClamConfig {
n_communities: 2,
max_iter: 10,
..Default::default()
};
let mm = BigClam::new(config).fit(&adj).expect("fit should succeed");
for row in &mm.memberships {
for &v in row {
assert!(v >= 0.0, "membership must be non-negative, got {v}");
}
}
}
#[test]
fn test_bigclam_on_two_cliques() {
let adj = two_cliques_bridge();
let config = BigClamConfig {
n_communities: 2,
max_iter: 80,
learning_rate: 0.01,
..Default::default()
};
let mm = BigClam::new(config).fit(&adj).expect("fit should succeed");
assert_eq!(mm.n_nodes, 7);
assert_eq!(mm.n_communities, 2);
let node3 = &mm.memberships[3];
let max_m = node3.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
assert!(max_m > 0.0, "bridge node should have positive membership");
}
#[test]
fn test_bigclam_convergence() {
let adj = two_triangles();
let config = BigClamConfig {
n_communities: 2,
max_iter: 200,
tol: 1e-6,
..Default::default()
};
BigClam::new(config).fit(&adj).expect("fit should succeed");
}
#[test]
fn test_bigclam_n_communities() {
let adj = two_triangles();
for k in 1..=4 {
let config = BigClamConfig {
n_communities: k,
max_iter: 5,
..Default::default()
};
let mm = BigClam::new(config).fit(&adj).expect("fit should succeed");
assert_eq!(mm.n_communities, k);
for row in &mm.memberships {
assert_eq!(row.len(), k);
}
}
}
#[test]
fn test_membership_matrix_community_members() {
let mut mm = MembershipMatrix::new(4, 2);
mm.memberships[0][0] = 0.9;
mm.memberships[1][0] = 0.8;
mm.memberships[2][1] = 0.7;
mm.memberships[3][0] = 0.1; let members = mm.community_members(0, 0.5);
assert!(members.contains(&0));
assert!(members.contains(&1));
assert!(!members.contains(&2));
assert!(!members.contains(&3));
}
#[test]
fn test_membership_matrix_to_hard_partition() {
let mut mm = MembershipMatrix::new(3, 2);
mm.memberships[0] = vec![0.8, 0.2];
mm.memberships[1] = vec![0.1, 0.9];
mm.memberships[2] = vec![0.5, 0.5];
let hard = mm.to_hard_partition();
assert_eq!(hard[0], 0);
assert_eq!(hard[1], 1);
assert!(hard[2] == 0 || hard[2] == 1);
}
#[test]
fn test_bigclam_log_likelihood_improves() {
let adj = two_cliques_bridge();
let bc = BigClam::new(BigClamConfig {
n_communities: 2,
max_iter: 1,
..Default::default()
});
let f_init = bc.fit(&adj).expect("fit should succeed").memberships;
let ll_early = bc.log_likelihood(&adj, &f_init);
let bc2 = BigClam::new(BigClamConfig {
n_communities: 2,
max_iter: 100,
..Default::default()
});
let f_trained = bc2.fit(&adj).expect("fit should succeed").memberships;
let ll_trained = bc2.log_likelihood(&adj, &f_trained);
assert!(
ll_trained >= ll_early - 0.5,
"trained LL {ll_trained} should not be much worse than early LL {ll_early}"
);
}
#[test]
fn test_bigclam_spectral_init() {
let adj = two_cliques_bridge();
let config = BigClamConfig {
n_communities: 2,
max_iter: 20,
init: BigClamInit::SpectralWarmStart,
..Default::default()
};
let mm = BigClam::new(config)
.fit(&adj)
.expect("spectral init fit should succeed");
for row in &mm.memberships {
for &v in row {
assert!(v >= 0.0);
}
}
}
#[test]
fn test_bigclam_empty_graph_error() {
let adj: Vec<Vec<usize>> = vec![];
let result = BigClam::new(BigClamConfig::default()).fit(&adj);
assert!(result.is_err());
}
#[test]
fn test_bigclam_zero_communities_error() {
let adj = two_triangles();
let config = BigClamConfig {
n_communities: 0,
..Default::default()
};
let result = BigClam::new(config).fit(&adj);
assert!(result.is_err());
}
}