use super::traits::CommunityDetection;
use crate::error::{Error, Result};
use graphops::{leiden_seeded, leiden_weighted_seeded, Graph, PetgraphRef, WeightedGraph};
use petgraph::graph::UnGraph;
use petgraph::visit::EdgeRef;
#[derive(Debug, Clone)]
pub struct Leiden {
resolution: f64,
#[allow(dead_code)]
max_iter: usize,
#[allow(dead_code)]
min_gain: f64,
seed: u64,
}
impl Leiden {
pub fn new() -> Self {
Self {
resolution: 1.0,
max_iter: 100,
min_gain: 1e-7,
seed: 42,
}
}
pub fn with_resolution(mut self, resolution: f64) -> Self {
self.resolution = resolution;
self
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn with_min_gain(mut self, min_gain: f64) -> Self {
self.min_gain = min_gain;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
#[deprecated(since = "0.2.0", note = "Use with_max_iter instead")]
pub fn with_refinement(self, n: usize) -> Self {
self.with_max_iter(n)
}
pub fn detect_weighted<N>(&self, graph: &UnGraph<N, f32>) -> Result<Vec<usize>> {
let n = graph.node_count();
if n == 0 {
return Err(Error::EmptyInput);
}
if graph.edge_count() == 0 {
return Ok((0..n).collect());
}
let adapter = F32WeightedAdapter::from_graph(graph);
Ok(leiden_weighted_seeded(&adapter, self.resolution, self.seed))
}
}
impl Default for Leiden {
fn default() -> Self {
Self::new()
}
}
impl CommunityDetection for Leiden {
fn detect<N, E>(&self, graph: &UnGraph<N, E>) -> Result<Vec<usize>> {
let n = graph.node_count();
if n == 0 {
return Err(Error::EmptyInput);
}
if graph.edge_count() == 0 {
return Ok((0..n).collect());
}
let adapter = PetgraphRef::from_graph(graph);
Ok(leiden_seeded(&adapter, self.resolution, self.seed))
}
fn resolution(&self) -> f64 {
self.resolution
}
}
struct F32WeightedAdapter {
adj: Vec<Vec<(usize, f64)>>,
}
impl F32WeightedAdapter {
fn from_graph<N>(graph: &UnGraph<N, f32>) -> Self {
let n = graph.node_count();
let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
for edge in graph.edge_references() {
let i = edge.source().index();
let j = edge.target().index();
let w = *edge.weight() as f64;
adj[i].push((j, w));
adj[j].push((i, w));
}
Self { adj }
}
}
impl Graph for F32WeightedAdapter {
fn node_count(&self) -> usize {
self.adj.len()
}
fn neighbors(&self, node: usize) -> Vec<usize> {
self.adj[node].iter().map(|(v, _)| *v).collect()
}
}
impl WeightedGraph for F32WeightedAdapter {
fn edge_weight(&self, source: usize, target: usize) -> f64 {
self.adj[source]
.iter()
.find(|(v, _)| *v == target)
.map(|(_, w)| *w)
.unwrap_or(0.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_leiden_basic() {
let mut graph = UnGraph::<(), ()>::new_undirected();
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
let _ = graph.add_edge(n0, n1, ());
let _ = graph.add_edge(n1, n2, ());
let _ = graph.add_edge(n0, n2, ());
let leiden = Leiden::new();
let communities = leiden.detect(&graph).unwrap();
assert_eq!(communities[0], communities[1]);
assert_eq!(communities[1], communities[2]);
}
#[test]
fn test_leiden_two_cliques() {
let mut graph = UnGraph::<(), ()>::new_undirected();
let a0 = graph.add_node(());
let a1 = graph.add_node(());
let a2 = graph.add_node(());
let _ = graph.add_edge(a0, a1, ());
let _ = graph.add_edge(a1, a2, ());
let _ = graph.add_edge(a0, a2, ());
let b0 = graph.add_node(());
let b1 = graph.add_node(());
let b2 = graph.add_node(());
let _ = graph.add_edge(b0, b1, ());
let _ = graph.add_edge(b1, b2, ());
let _ = graph.add_edge(b0, b2, ());
let _ = graph.add_edge(a2, b0, ());
let leiden = Leiden::new();
let communities = leiden.detect(&graph).unwrap();
assert_eq!(communities.len(), 6);
assert_eq!(communities[0], communities[1]);
assert_eq!(communities[1], communities[2]);
assert_eq!(communities[3], communities[4]);
assert_eq!(communities[4], communities[5]);
assert_ne!(communities[0], communities[3]);
}
#[test]
fn test_leiden_disconnected_within_community() {
let mut graph = UnGraph::<(), ()>::new_undirected();
let a = graph.add_node(());
let b = graph.add_node(());
let c = graph.add_node(());
let d = graph.add_node(());
let e = graph.add_node(());
let _ = graph.add_edge(a, b, ());
let _ = graph.add_edge(b, c, ());
let _ = graph.add_edge(d, e, ());
let leiden = Leiden::new();
let communities = leiden.detect(&graph).unwrap();
assert_eq!(communities[0], communities[1]);
assert_eq!(communities[1], communities[2]);
assert_eq!(communities[3], communities[4]);
assert_ne!(communities[0], communities[3]);
}
#[test]
fn test_leiden_empty_graph() {
let graph = UnGraph::<(), ()>::new_undirected();
let leiden = Leiden::new();
let result = leiden.detect(&graph);
assert!(result.is_err());
}
#[test]
fn test_leiden_single_node() {
let mut graph = UnGraph::<(), ()>::new_undirected();
let _ = graph.add_node(());
let leiden = Leiden::new();
let communities = leiden.detect(&graph).unwrap();
assert_eq!(communities.len(), 1);
assert_eq!(communities[0], 0);
}
#[test]
fn test_leiden_resolution_parameter() {
let mut graph = UnGraph::<(), ()>::new_undirected();
for _ in 0..10 {
let _ = graph.add_node(());
}
for i in 0..9 {
let n1 = petgraph::graph::NodeIndex::new(i);
let n2 = petgraph::graph::NodeIndex::new(i + 1);
let _ = graph.add_edge(n1, n2, ());
}
let low_res = Leiden::new().with_resolution(0.5);
let high_res = Leiden::new().with_resolution(2.0);
let comms_low = low_res.detect(&graph).unwrap();
let comms_high = high_res.detect(&graph).unwrap();
assert_eq!(comms_low.len(), 10);
assert_eq!(comms_high.len(), 10);
assert!(!comms_low.is_empty());
assert!(!comms_high.is_empty());
}
#[test]
fn test_leiden_connectivity_guarantee() {
use std::collections::{HashMap, HashSet, VecDeque};
let mut graph = UnGraph::<(), ()>::new_undirected();
for _ in 0..20 {
let _ = graph.add_node(());
}
for i in 0..15 {
let n1 = petgraph::graph::NodeIndex::new(i);
let n2 = petgraph::graph::NodeIndex::new(i + 1);
let _ = graph.add_edge(n1, n2, ());
}
let _ = graph.add_edge(
petgraph::graph::NodeIndex::new(0),
petgraph::graph::NodeIndex::new(5),
(),
);
let _ = graph.add_edge(
petgraph::graph::NodeIndex::new(10),
petgraph::graph::NodeIndex::new(15),
(),
);
let leiden = Leiden::new();
let communities = leiden.detect(&graph).unwrap();
let mut by_community: HashMap<usize, Vec<usize>> = HashMap::new();
for (node, &comm) in communities.iter().enumerate() {
by_community.entry(comm).or_default().push(node);
}
for (_comm, nodes) in by_community {
if nodes.len() <= 1 {
continue;
}
let node_set: HashSet<usize> = nodes.iter().copied().collect();
let mut adj: HashMap<usize, Vec<usize>> = HashMap::new();
for edge in graph.edge_references() {
let i = edge.source().index();
let j = edge.target().index();
if node_set.contains(&i) && node_set.contains(&j) {
adj.entry(i).or_default().push(j);
adj.entry(j).or_default().push(i);
}
}
let start = nodes[0];
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(start);
while let Some(node) = queue.pop_front() {
if !visited.insert(node) {
continue;
}
if let Some(neighbors) = adj.get(&node) {
for &n in neighbors {
if !visited.contains(&n) {
queue.push_back(n);
}
}
}
}
assert_eq!(
visited.len(),
nodes.len(),
"Community is not fully connected!"
);
}
}
}