use crate::{CsrGraph, NodeId};
use anyhow::Result;
use aprender::graph::Graph as AprenderGraph;
#[derive(Debug, Clone)]
pub struct CommunityDetectionResult {
pub communities: Vec<Vec<NodeId>>,
pub num_communities: usize,
pub modularity: f64,
}
impl CommunityDetectionResult {
#[must_use]
pub fn get_community(&self, node: NodeId) -> Option<usize> {
for (comm_id, community) in self.communities.iter().enumerate() {
if community.contains(&node) {
return Some(comm_id);
}
}
None
}
#[must_use]
pub fn get_community_nodes(&self, comm_id: usize) -> Option<&[NodeId]> {
self.communities.get(comm_id).map(Vec::as_slice)
}
#[must_use]
pub fn community_size(&self, comm_id: usize) -> Option<usize> {
self.communities.get(comm_id).map(Vec::len)
}
}
pub fn louvain(graph: &CsrGraph) -> Result<CommunityDetectionResult> {
let aprender_graph = convert_to_aprender(graph);
let communities = aprender_graph.louvain();
let modularity = aprender_graph.modularity(&communities);
let converted_communities: Vec<Vec<NodeId>> = communities
.into_iter()
.map(|community| {
community
.into_iter()
.filter_map(|node_id| u32::try_from(node_id).ok().map(NodeId))
.collect()
})
.filter(|community: &Vec<NodeId>| !community.is_empty()) .collect();
let num_communities = converted_communities.len();
Ok(CommunityDetectionResult {
communities: converted_communities,
num_communities,
modularity,
})
}
fn convert_to_aprender(graph: &CsrGraph) -> AprenderGraph {
let mut edges = Vec::new();
for (src, targets, weights) in graph.iter_adjacency() {
for (dst, weight) in targets.iter().zip(weights.iter()) {
edges.push((src.0 as usize, *dst as usize, f64::from(*weight)));
}
}
AprenderGraph::from_weighted_edges(&edges, false)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_louvain_empty_graph() {
let graph = CsrGraph::new();
let result = louvain(&graph).unwrap();
assert_eq!(result.num_communities, 0);
assert_eq!(result.communities.len(), 0);
}
#[test]
fn test_louvain_single_triangle() {
let mut graph = CsrGraph::new();
graph.add_edge(NodeId(0), NodeId(1), 1.0).unwrap();
graph.add_edge(NodeId(1), NodeId(2), 1.0).unwrap();
graph.add_edge(NodeId(2), NodeId(0), 1.0).unwrap();
let result = louvain(&graph).unwrap();
assert_eq!(
result.num_communities, 1,
"Triangle should form 1 community"
);
let comm_0 = result.get_community(NodeId(0));
let comm_1 = result.get_community(NodeId(1));
let comm_2 = result.get_community(NodeId(2));
assert_eq!(comm_0, comm_1);
assert_eq!(comm_1, comm_2);
}
#[test]
fn test_louvain_two_triangles_connected() {
let mut graph = CsrGraph::new();
graph.add_edge(NodeId(0), NodeId(1), 1.0).unwrap();
graph.add_edge(NodeId(1), NodeId(2), 1.0).unwrap();
graph.add_edge(NodeId(2), NodeId(0), 1.0).unwrap();
graph.add_edge(NodeId(3), NodeId(4), 1.0).unwrap();
graph.add_edge(NodeId(4), NodeId(5), 1.0).unwrap();
graph.add_edge(NodeId(5), NodeId(3), 1.0).unwrap();
graph.add_edge(NodeId(2), NodeId(3), 1.0).unwrap();
let result = louvain(&graph).unwrap();
assert!(
result.num_communities >= 1,
"Should find at least 1 community"
);
assert!(
result.modularity > 0.0,
"Modularity should be positive for community structure"
);
}
#[test]
fn test_louvain_disconnected_components() {
let mut graph = CsrGraph::new();
graph.add_edge(NodeId(0), NodeId(1), 1.0).unwrap();
graph.add_edge(NodeId(1), NodeId(2), 1.0).unwrap();
graph.add_edge(NodeId(2), NodeId(0), 1.0).unwrap();
graph.add_edge(NodeId(3), NodeId(4), 1.0).unwrap();
graph.add_edge(NodeId(4), NodeId(5), 1.0).unwrap();
graph.add_edge(NodeId(5), NodeId(3), 1.0).unwrap();
let result = louvain(&graph).unwrap();
assert_eq!(
result.num_communities, 2,
"Should find 2 communities for 2 disconnected components"
);
let comm_0 = result.get_community(NodeId(0));
let comm_3 = result.get_community(NodeId(3));
assert!(comm_0.is_some() && comm_3.is_some());
assert_ne!(
comm_0, comm_3,
"Disconnected components should be in different communities"
);
}
#[test]
fn test_community_detection_result_api() {
let result = CommunityDetectionResult {
communities: vec![
vec![NodeId(0), NodeId(1), NodeId(2)],
vec![NodeId(3), NodeId(4)],
],
num_communities: 2,
modularity: 0.42,
};
assert_eq!(result.get_community(NodeId(0)), Some(0));
assert_eq!(result.get_community(NodeId(3)), Some(1));
assert_eq!(result.get_community(NodeId(99)), None);
assert_eq!(
result.get_community_nodes(0),
Some(&[NodeId(0), NodeId(1), NodeId(2)] as &[NodeId])
);
assert_eq!(
result.get_community_nodes(1),
Some(&[NodeId(3), NodeId(4)] as &[NodeId])
);
assert_eq!(result.get_community_nodes(2), None);
assert_eq!(result.community_size(0), Some(3));
assert_eq!(result.community_size(1), Some(2));
assert_eq!(result.community_size(2), None);
}
#[test]
fn test_louvain_all_nodes_assigned() {
let mut graph = CsrGraph::new();
graph.add_edge(NodeId(0), NodeId(1), 1.0).unwrap();
graph.add_edge(NodeId(1), NodeId(2), 1.0).unwrap();
graph.add_edge(NodeId(2), NodeId(3), 1.0).unwrap();
graph.add_edge(NodeId(3), NodeId(4), 1.0).unwrap();
let result = louvain(&graph).unwrap();
let mut assigned_nodes = std::collections::HashSet::new();
for community in &result.communities {
for &node in community {
assigned_nodes.insert(node);
}
}
assert_eq!(assigned_nodes.len(), 5);
assert!(assigned_nodes.contains(&NodeId(0)));
assert!(assigned_nodes.contains(&NodeId(1)));
assert!(assigned_nodes.contains(&NodeId(2)));
assert!(assigned_nodes.contains(&NodeId(3)));
assert!(assigned_nodes.contains(&NodeId(4)));
}
}