use ahash::AHashMap;
use crate::progress::ProgressCallback;
use crate::{errors::SqliteGraphError, graph::SqliteGraph};
pub fn label_propagation(
graph: &SqliteGraph,
max_iterations: usize,
) -> Result<Vec<Vec<i64>>, SqliteGraphError> {
let all_ids = graph.all_entity_ids()?;
if all_ids.is_empty() {
return Ok(Vec::new());
}
let mut labels: AHashMap<i64, i64> = all_ids.iter().map(|&id| (id, id)).collect();
let mut node_order: Vec<i64> = all_ids.clone();
node_order.sort();
for _iteration in 0..max_iterations {
let mut any_changed = false;
for &node in &node_order {
let mut label_counts: AHashMap<i64, usize> = AHashMap::new();
for &neighbor in &graph.fetch_outgoing(node)? {
let neighbor_label = labels.get(&neighbor).unwrap_or(&neighbor);
*label_counts.entry(*neighbor_label).or_insert(0) += 1;
}
for &neighbor in &graph.fetch_incoming(node)? {
let neighbor_label = labels.get(&neighbor).unwrap_or(&neighbor);
*label_counts.entry(*neighbor_label).or_insert(0) += 1;
}
if let Some((&_most_frequent_label, _)) = label_counts
.iter()
.max_by_key(|(_, count)| *count)
.map(|(label, count)| (label, *count))
{
let max_count = *label_counts.values().max().unwrap_or(&0);
let best_label = label_counts
.iter()
.filter(|(_, count)| **count == max_count)
.map(|(&label, _)| label)
.min()
.unwrap_or(node);
if let Some(current_label) = labels.get(&node) {
if *current_label != best_label {
labels.insert(node, best_label);
any_changed = true;
}
}
}
}
if !any_changed {
break;
}
}
let mut communities_map: AHashMap<i64, Vec<i64>> = AHashMap::new();
for (node, label) in &labels {
communities_map.entry(*label).or_default().push(*node);
}
let mut communities: Vec<Vec<i64>> = communities_map.into_values().collect();
for community in &mut communities {
community.sort();
}
communities.sort_by(|a, b| a.first().cmp(&b.first()));
Ok(communities)
}
pub fn louvain_communities(
graph: &SqliteGraph,
max_iterations: usize,
) -> Result<Vec<Vec<i64>>, SqliteGraphError> {
let all_ids = graph.all_entity_ids()?;
if all_ids.is_empty() {
return Ok(Vec::new());
}
let mut total_edges = 0usize;
let mut degrees: AHashMap<i64, usize> = AHashMap::new();
for &id in &all_ids {
let out_count = graph.fetch_outgoing(id)?.len();
let in_count = graph.fetch_incoming(id)?.len();
let degree = out_count + in_count;
degrees.insert(id, degree);
total_edges += degree;
}
let m = total_edges as f64 / 2.0;
if m == 0.0 {
let mut communities: Vec<Vec<i64>> = all_ids.iter().map(|&id| vec![id]).collect();
communities.sort();
return Ok(communities);
}
let mut communities: AHashMap<i64, i64> = all_ids.iter().map(|&id| (id, id)).collect();
let mut node_order: Vec<i64> = all_ids.clone();
node_order.sort();
for _iteration in 0..max_iterations {
let mut any_moved = false;
for &node in &node_order {
let current_community = *communities.get(&node).unwrap_or(&node);
let node_degree = *degrees.get(&node).unwrap_or(&0) as f64;
let mut community_connections: AHashMap<i64, f64> = AHashMap::new();
for &neighbor in &graph.fetch_outgoing(node)? {
let neighbor_community = *communities.get(&neighbor).unwrap_or(&neighbor);
*community_connections
.entry(neighbor_community)
.or_insert(0.0) += 1.0;
}
for &neighbor in &graph.fetch_incoming(node)? {
let neighbor_community = *communities.get(&neighbor).unwrap_or(&neighbor);
*community_connections
.entry(neighbor_community)
.or_insert(0.0) += 1.0;
}
let mut best_community = current_community;
let mut best_delta = 0.0f64;
for (&target_community, &edges_to_community) in &community_connections {
if target_community == current_community {
continue;
}
let community_degree: f64 = communities
.iter()
.filter(|(_, comm)| **comm == target_community)
.map(|(&node, _)| *degrees.get(&node).unwrap_or(&0) as f64)
.sum();
let delta =
(2.0 * edges_to_community - node_degree * community_degree / m) / (2.0 * m);
if delta > best_delta {
best_delta = delta;
best_community = target_community;
}
}
if best_community != current_community {
communities.insert(node, best_community);
any_moved = true;
}
}
if !any_moved {
break;
}
}
let mut communities_map: AHashMap<i64, Vec<i64>> = AHashMap::new();
for (node, community) in &communities {
communities_map.entry(*community).or_default().push(*node);
}
let mut result: Vec<Vec<i64>> = communities_map.into_values().collect();
for community in &mut result {
community.sort();
}
result.sort_by(|a, b| a.first().cmp(&b.first()));
Ok(result)
}
pub fn louvain_communities_with_progress<F>(
graph: &SqliteGraph,
max_iterations: usize,
progress: &F,
) -> Result<Vec<Vec<i64>>, SqliteGraphError>
where
F: ProgressCallback,
{
let all_ids = graph.all_entity_ids()?;
if all_ids.is_empty() {
progress.on_complete();
return Ok(Vec::new());
}
let mut total_edges = 0usize;
let mut degrees: AHashMap<i64, usize> = AHashMap::new();
for &id in &all_ids {
let out_count = graph.fetch_outgoing(id)?.len();
let in_count = graph.fetch_incoming(id)?.len();
let degree = out_count + in_count;
degrees.insert(id, degree);
total_edges += degree;
}
let m = total_edges as f64 / 2.0;
if m == 0.0 {
progress.on_complete();
let mut communities: Vec<Vec<i64>> = all_ids.iter().map(|&id| vec![id]).collect();
communities.sort();
return Ok(communities);
}
let mut communities: AHashMap<i64, i64> = all_ids.iter().map(|&id| (id, id)).collect();
let mut node_order: Vec<i64> = all_ids.clone();
node_order.sort();
for iteration in 0..max_iterations {
progress.on_progress(
iteration + 1,
None,
&format!("Louvain pass {}", iteration + 1),
);
let mut any_moved = false;
for &node in &node_order {
let current_community = *communities.get(&node).unwrap_or(&node);
let node_degree = *degrees.get(&node).unwrap_or(&0) as f64;
let mut community_connections: AHashMap<i64, f64> = AHashMap::new();
for &neighbor in &graph.fetch_outgoing(node)? {
let neighbor_community = *communities.get(&neighbor).unwrap_or(&neighbor);
*community_connections
.entry(neighbor_community)
.or_insert(0.0) += 1.0;
}
for &neighbor in &graph.fetch_incoming(node)? {
let neighbor_community = *communities.get(&neighbor).unwrap_or(&neighbor);
*community_connections
.entry(neighbor_community)
.or_insert(0.0) += 1.0;
}
let mut best_community = current_community;
let mut best_delta = 0.0f64;
for (&target_community, &edges_to_community) in &community_connections {
if target_community == current_community {
continue;
}
let community_degree: f64 = communities
.iter()
.filter(|(_, comm)| **comm == target_community)
.map(|(&node, _)| *degrees.get(&node).unwrap_or(&0) as f64)
.sum();
let delta =
(2.0 * edges_to_community - node_degree * community_degree / m) / (2.0 * m);
if delta > best_delta {
best_delta = delta;
best_community = target_community;
}
}
if best_community != current_community {
communities.insert(node, best_community);
any_moved = true;
}
}
if !any_moved {
break;
}
}
progress.on_complete();
let mut communities_map: AHashMap<i64, Vec<i64>> = AHashMap::new();
for (node, community) in &communities {
communities_map.entry(*community).or_default().push(*node);
}
let mut result: Vec<Vec<i64>> = communities_map.into_values().collect();
for community in &mut result {
community.sort();
}
result.sort_by(|a, b| a.first().cmp(&b.first()));
Ok(result)
}