use std::collections::HashMap;
use infomap_rs::{Infomap, Network};
pub trait EdgeWeightStrategy {
type Item;
fn edge_weight(&self, a: &Self::Item, b: &Self::Item) -> Option<f64>;
}
#[derive(Debug, Clone)]
pub struct ClusteringConfig {
pub min_community_size: usize,
pub max_community_size: usize,
pub seed: u64,
}
impl Default for ClusteringConfig {
fn default() -> Self {
Self {
min_community_size: 2,
max_community_size: usize::MAX,
seed: 42,
}
}
}
#[derive(Debug, Clone)]
pub struct Community<T> {
pub members: Vec<T>,
pub module_id: usize,
}
pub fn cluster_with_infomap<T, S>(
items: &[T],
strategy: &S,
config: &ClusteringConfig,
) -> Vec<Community<T>>
where
T: Clone,
S: EdgeWeightStrategy<Item = T>,
{
let n = items.len();
if n < 2 {
return Vec::new();
}
let mut network = Network::with_capacity(n);
network.ensure_capacity(n);
let mut has_edges = false;
for i in 0..n {
for j in (i + 1)..n {
if let Some(weight) = strategy.edge_weight(&items[i], &items[j]) {
network.add_edge(i, j, weight);
network.add_edge(j, i, weight);
has_edges = true;
}
}
}
if !has_edges {
return Vec::new();
}
let result = Infomap::new(&network).seed(config.seed).run();
let mut modules: HashMap<usize, Vec<usize>> = HashMap::new();
for (node_idx, &module_id) in result.assignments.iter().enumerate() {
if node_idx < n {
modules.entry(module_id).or_default().push(node_idx);
}
}
let mut communities = Vec::new();
for (module_id, indices) in modules {
if indices.len() < config.min_community_size {
continue;
}
let mut members: Vec<T> = indices.iter().map(|&i| items[i].clone()).collect();
if members.len() > config.max_community_size {
members.truncate(config.max_community_size);
}
communities.push(Community {
members,
module_id,
});
}
communities
}
pub struct CosineStrategy {
pub threshold: f64,
}
impl CosineStrategy {
pub fn new(threshold: f64) -> Self {
Self { threshold }
}
}
#[derive(Debug, Clone)]
pub struct EmbeddingItem {
pub id: String,
pub embedding: Vec<f32>,
}
impl EdgeWeightStrategy for CosineStrategy {
type Item = EmbeddingItem;
fn edge_weight(&self, a: &EmbeddingItem, b: &EmbeddingItem) -> Option<f64> {
let sim = crate::embeddings::EmbeddingProvider::cosine_similarity(
&a.embedding,
&b.embedding,
) as f64;
if sim >= self.threshold {
Some(sim)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_strategy_basic() {
let items = vec![
EmbeddingItem { id: "a".into(), embedding: vec![1.0, 0.0, 0.0] },
EmbeddingItem { id: "b".into(), embedding: vec![0.95, 0.1, 0.0] },
EmbeddingItem { id: "c".into(), embedding: vec![0.0, 1.0, 0.0] },
EmbeddingItem { id: "d".into(), embedding: vec![0.1, 0.95, 0.0] },
];
let strategy = CosineStrategy::new(0.3);
let config = ClusteringConfig {
min_community_size: 2,
..Default::default()
};
let communities = cluster_with_infomap(&items, &strategy, &config);
assert_eq!(communities.len(), 2, "Expected 2 communities, got {}", communities.len());
}
#[test]
fn test_empty_input() {
let items: Vec<EmbeddingItem> = vec![];
let strategy = CosineStrategy::new(0.3);
let config = ClusteringConfig::default();
let communities = cluster_with_infomap(&items, &strategy, &config);
assert!(communities.is_empty());
}
#[test]
fn test_no_edges_above_threshold() {
let items = vec![
EmbeddingItem { id: "a".into(), embedding: vec![1.0, 0.0, 0.0] },
EmbeddingItem { id: "b".into(), embedding: vec![0.0, 1.0, 0.0] },
];
let strategy = CosineStrategy::new(0.9);
let config = ClusteringConfig::default();
let communities = cluster_with_infomap(&items, &strategy, &config);
assert!(communities.is_empty());
}
#[test]
fn test_custom_strategy() {
struct AlwaysConnect;
impl EdgeWeightStrategy for AlwaysConnect {
type Item = String;
fn edge_weight(&self, _a: &String, _b: &String) -> Option<f64> {
Some(1.0)
}
}
let items = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let config = ClusteringConfig {
min_community_size: 2,
..Default::default()
};
let communities = cluster_with_infomap(&items, &AlwaysConnect, &config);
assert_eq!(communities.len(), 1);
assert_eq!(communities[0].members.len(), 3);
}
#[test]
fn test_min_community_size_filter() {
let items = vec![
EmbeddingItem { id: "a".into(), embedding: vec![1.0, 0.0] },
EmbeddingItem { id: "b".into(), embedding: vec![0.99, 0.01] },
];
let strategy = CosineStrategy::new(0.3);
let config = ClusteringConfig {
min_community_size: 5, ..Default::default()
};
let communities = cluster_with_infomap(&items, &strategy, &config);
assert!(communities.is_empty());
}
}