use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::collections::{HashMap, HashSet, VecDeque};
use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use crate::error::{ClusteringError, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Graph<F: Float> {
pub n_nodes: usize,
pub adjacency: Vec<Vec<(usize, F)>>,
pub node_features: Option<Array2<F>>,
}
impl<
F: Float
+ FromPrimitive
+ Debug
+ ScalarOperand
+ std::iter::Sum
+ std::cmp::Eq
+ std::hash::Hash
+ 'static,
> Graph<F>
{
pub fn new(_nnodes: usize) -> Self {
Self {
n_nodes: _nnodes,
adjacency: vec![Vec::new(); _nnodes],
node_features: None,
}
}
pub fn from_adjacencymatrix(_adjacencymatrix: ArrayView2<F>) -> Result<Self> {
let n_nodes = _adjacencymatrix.shape()[0];
if _adjacencymatrix.shape()[1] != n_nodes {
return Err(ClusteringError::InvalidInput(
"Adjacency _matrix must be square".to_string(),
));
}
let mut graph = Self::new(n_nodes);
for i in 0..n_nodes {
for j in 0..n_nodes {
let weight = _adjacencymatrix[[i, j]];
if weight > F::zero() && i != j {
graph.add_edge(i, j, weight)?;
}
}
}
Ok(graph)
}
pub fn from_knngraph(data: ArrayView2<F>, k: usize) -> Result<Self> {
let n_samples = data.shape()[0];
let mut graph = Self::new(n_samples);
graph.node_features = Some(data.to_owned());
for i in 0..n_samples {
let mut distances: Vec<(usize, F)> = Vec::new();
for j in 0..n_samples {
if i != j {
let dist = euclidean_distance(data.row(i), data.row(j));
distances.push((j, dist));
}
}
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
for &(neighbor_idx, distance) in distances.iter().take(k) {
let similarity = F::one() / (F::one() + distance);
graph.add_edge(i, neighbor_idx, similarity)?;
}
}
Ok(graph)
}
pub fn add_edge(&mut self, node1: usize, node2: usize, weight: F) -> Result<()> {
if node1 >= self.n_nodes || node2 >= self.n_nodes {
return Err(ClusteringError::InvalidInput(
"Node index out of bounds".to_string(),
));
}
if node1 != node2 {
self.adjacency[node1].push((node2, weight));
self.adjacency[node2].push((node1, weight)); }
Ok(())
}
pub fn degree(&self, node: usize) -> usize {
if node < self.n_nodes {
self.adjacency[node].len()
} else {
0
}
}
pub fn weighted_degree(&self, node: usize) -> F {
if node < self.n_nodes {
self.adjacency[node].iter().map(|(_, weight)| *weight).sum()
} else {
F::zero()
}
}
pub fn neighbor_s(&self, node: usize) -> &[(usize, F)] {
if node < self.n_nodes {
&self.adjacency[node]
} else {
&[]
}
}
pub fn modularity(&self, communities: &[usize]) -> F {
let total_weight = self.total_edge_weight();
if total_weight == F::zero() {
return F::zero();
}
let mut modularity = F::zero();
for i in 0..self.n_nodes {
for j in 0..self.n_nodes {
if communities[i] == communities[j] {
let edge_weight = self.get_edge_weight(i, j);
let degree_i = self.weighted_degree(i);
let degree_j = self.weighted_degree(j);
let expected = degree_i * degree_j
/ (F::from(2.0).expect("Failed to convert constant to float")
* total_weight);
modularity = modularity + edge_weight - expected;
}
}
}
modularity / (F::from(2.0).expect("Failed to convert constant to float") * total_weight)
}
fn get_edge_weight(&self, node1: usize, node2: usize) -> F {
if node1 < self.n_nodes {
for &(neighbor_, weight) in &self.adjacency[node1] {
if neighbor_ == node2 {
return weight;
}
}
}
F::zero()
}
fn total_edge_weight(&self) -> F {
let mut total = F::zero();
for node in 0..self.n_nodes {
for &(_, weight) in &self.adjacency[node] {
total = total + weight;
}
}
total / F::from(2.0).expect("Failed to convert constant to float") }
}
#[allow(dead_code)]
pub fn louvain<F>(graph: &Graph<F>, resolution: f64, max_iterations: usize) -> Result<Array1<usize>>
where
F: Float
+ FromPrimitive
+ Debug
+ ScalarOperand
+ std::iter::Sum
+ std::cmp::Eq
+ std::hash::Hash
+ 'static,
f64: From<F>,
{
let n_nodes = graph.n_nodes;
let mut communities: Array1<usize> = Array1::from_iter(0..n_nodes);
let mut improved = true;
let mut iteration = 0;
while improved && iteration < max_iterations {
improved = false;
iteration += 1;
for node in 0..n_nodes {
let current_community = communities[node];
let mut best_community = current_community;
let mut best_gain = F::zero();
let mut candidate_communities = HashSet::new();
candidate_communities.insert(current_community);
for &(neighbor_id, _weight) in graph.neighbor_s(node) {
candidate_communities.insert(communities[neighbor_id]);
}
for &candidate_community in &candidate_communities {
if candidate_community != current_community {
let gain = modularity_gain(
graph,
&communities,
node,
current_community,
candidate_community,
resolution,
);
if gain > best_gain {
best_gain = gain;
best_community = candidate_community;
}
}
}
if best_community != current_community && best_gain > F::zero() {
communities[node] = best_community;
improved = true;
}
}
}
Ok(communities)
}
#[allow(dead_code)]
fn modularity_gain<F>(
graph: &Graph<F>,
communities: &Array1<usize>,
node: usize,
from_community: usize,
to_community: usize,
resolution: f64,
) -> F
where
F: Float
+ FromPrimitive
+ Debug
+ ScalarOperand
+ std::iter::Sum
+ std::cmp::Eq
+ std::hash::Hash
+ 'static,
f64: From<F>,
{
let total_weight = graph.total_edge_weight();
if total_weight == F::zero() {
return F::zero();
}
let node_degree = graph.weighted_degree(node);
let resolution_f = F::from(resolution).expect("Failed to convert to float");
let mut edges_to_target = F::zero();
let mut edges_from_source = F::zero();
for &(neighbor_, weight) in graph.neighbor_s(node) {
if communities[neighbor_] == to_community {
edges_to_target = edges_to_target + weight;
}
if communities[neighbor_] == from_community && neighbor_ != node {
edges_from_source = edges_from_source + weight;
}
}
let target_community_weight = calculate_community_weight(graph, communities, to_community);
let source_community_weight = calculate_community_weight(graph, communities, from_community);
let gain_to = edges_to_target
- resolution_f * node_degree * target_community_weight
/ (F::from(2.0).expect("Failed to convert constant to float") * total_weight);
let loss_from = edges_from_source
- resolution_f * node_degree * (source_community_weight - node_degree)
/ (F::from(2.0).expect("Failed to convert constant to float") * total_weight);
gain_to - loss_from
}
#[allow(dead_code)]
fn calculate_community_weight<F>(
graph: &Graph<F>,
communities: &Array1<usize>,
community: usize,
) -> F
where
F: Float
+ FromPrimitive
+ Debug
+ ScalarOperand
+ std::iter::Sum
+ std::cmp::Eq
+ std::hash::Hash
+ 'static,
{
let mut weight = F::zero();
for node in 0..graph.n_nodes {
if communities[node] == community {
weight = weight + graph.weighted_degree(node);
}
}
weight
}
#[allow(dead_code)]
pub fn label_propagation<F>(
graph: &Graph<F>,
max_iterations: usize,
tolerance: f64,
) -> Result<Array1<usize>>
where
F: Float
+ FromPrimitive
+ Debug
+ ScalarOperand
+ std::iter::Sum
+ std::cmp::Eq
+ std::hash::Hash
+ 'static,
f64: From<F>,
{
let n_nodes = graph.n_nodes;
let mut labels: Array1<usize> = Array1::from_iter(0..n_nodes);
let tolerance_f = F::from(tolerance).expect("Failed to convert to float");
for _iteration in 0..max_iterations {
let mut new_labels = labels.clone();
let mut changed_nodes = 0;
let mut node_order: Vec<usize> = (0..n_nodes).collect();
node_order.sort_by_key(|&i| i * 17 % n_nodes);
for &node in &node_order {
let mut label_weights: HashMap<usize, F> = HashMap::new();
for &(neighbor_, weight) in graph.neighbor_s(node) {
let label = labels[neighbor_];
let entry = label_weights.entry(label).or_insert(F::zero());
*entry = *entry + weight;
}
if let Some((&best_label_, _)) = label_weights
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).expect("Operation failed"))
{
if best_label_ != labels[node] {
new_labels[node] = best_label_;
changed_nodes += 1;
}
}
}
labels = new_labels;
let change_ratio = changed_nodes as f64 / n_nodes as f64;
if change_ratio < tolerance {
break;
}
}
let unique_labels: HashSet<usize> = labels.iter().cloned().collect();
let label_mapping: HashMap<usize, usize> = unique_labels
.into_iter()
.enumerate()
.map(|(new_label, old_label)| (old_label, new_label))
.collect();
for label in labels.iter_mut() {
*label = label_mapping[label];
}
Ok(labels)
}
#[allow(dead_code)]
pub fn girvan_newman<F>(graph: &Graph<F>, ncommunities: usize) -> Result<Array1<usize>>
where
F: Float
+ FromPrimitive
+ Debug
+ ScalarOperand
+ std::iter::Sum
+ std::cmp::Eq
+ std::hash::Hash
+ 'static,
{
if ncommunities > graph.n_nodes {
return Err(ClusteringError::InvalidInput(
"Number of _communities cannot exceed number of nodes".to_string(),
));
}
let mut workinggraph = graph.clone();
let mut _communities = find_connected_components(&workinggraph);
while count_communities(&_communities) < ncommunities && has_edges(&workinggraph) {
let edge_betweenness = calculate_edge_betweenness(&workinggraph)?;
if let Some((max_edge_, _)) = edge_betweenness
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).expect("Operation failed"))
{
remove_edge(&mut workinggraph, max_edge_.0, max_edge_.1);
_communities = find_connected_components(&workinggraph);
} else {
break; }
}
Ok(Array1::from_vec(_communities))
}
#[allow(dead_code)]
fn calculate_edge_betweenness<F>(graph: &Graph<F>) -> Result<HashMap<(usize, usize), f64>>
where
F: Float
+ FromPrimitive
+ Debug
+ ScalarOperand
+ std::iter::Sum
+ std::cmp::Eq
+ std::hash::Hash
+ 'static,
{
let mut edge_betweenness = HashMap::new();
for node in 0..graph.n_nodes {
for &(neighbor_, _) in graph.neighbor_s(node) {
if node < neighbor_ {
edge_betweenness.insert((node, neighbor_), 0.0);
}
}
}
for source in 0..graph.n_nodes {
for target in (source + 1)..graph.n_nodes {
let paths = find_all_shortest_paths(graph, source, target);
if !paths.is_empty() {
let contribution = 1.0 / paths.len() as f64;
for path in paths {
for i in 0..(path.len() - 1) {
let (u, v) = if path[i] < path[i + 1] {
(path[i], path[i + 1])
} else {
(path[i + 1], path[i])
};
*edge_betweenness.entry((u, v)).or_insert(0.0) += contribution;
}
}
}
}
}
Ok(edge_betweenness)
}
#[allow(dead_code)]
fn find_all_shortest_paths<F>(graph: &Graph<F>, source: usize, target: usize) -> Vec<Vec<usize>>
where
F: Float
+ FromPrimitive
+ Debug
+ ScalarOperand
+ std::iter::Sum
+ std::cmp::Eq
+ std::hash::Hash
+ 'static,
{
let mut distances = vec![None; graph.n_nodes];
let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); graph.n_nodes];
let mut queue = VecDeque::new();
distances[source] = Some(0);
queue.push_back(source);
while let Some(current) = queue.pop_front() {
let current_dist = distances[current].expect("Operation failed");
for &(neighbor_, _) in graph.neighbor_s(current) {
if distances[neighbor_].is_none() {
distances[neighbor_] = Some(current_dist + 1);
predecessors[neighbor_].push(current);
queue.push_back(neighbor_);
} else if distances[neighbor_] == Some(current_dist + 1) {
predecessors[neighbor_].push(current);
}
}
}
if distances[target].is_none() {
return Vec::new(); }
let mut paths = Vec::new();
let mut current_paths = vec![vec![target]];
while !current_paths.is_empty() {
let mut next_paths = Vec::new();
for path in current_paths {
let last_node = path[path.len() - 1];
if last_node == source {
let mut complete_path = path.clone();
complete_path.reverse();
paths.push(complete_path);
} else {
for &pred in &predecessors[last_node] {
let mut new_path = path.clone();
new_path.push(pred);
next_paths.push(new_path);
}
}
}
current_paths = next_paths;
}
paths
}
#[allow(dead_code)]
fn remove_edge<F>(graph: &mut Graph<F>, node1: usize, node2: usize)
where
F: Float
+ FromPrimitive
+ Debug
+ ScalarOperand
+ std::iter::Sum
+ std::cmp::Eq
+ std::hash::Hash
+ 'static,
{
graph.adjacency[node1].retain(|(neighbor_, _)| *neighbor_ != node2);
graph.adjacency[node2].retain(|(neighbor_, _)| *neighbor_ != node1);
}
#[allow(dead_code)]
fn has_edges<F>(graph: &Graph<F>) -> bool
where
F: Float
+ FromPrimitive
+ Debug
+ ScalarOperand
+ std::iter::Sum
+ std::cmp::Eq
+ std::hash::Hash
+ 'static,
{
graph
.adjacency
.iter()
.any(|neighbor_s| !neighbor_s.is_empty())
}
#[allow(dead_code)]
fn find_connected_components<F>(graph: &Graph<F>) -> Vec<usize>
where
F: Float
+ FromPrimitive
+ Debug
+ ScalarOperand
+ std::iter::Sum
+ std::cmp::Eq
+ std::hash::Hash
+ 'static,
{
let mut visited = vec![false; graph.n_nodes];
let mut components = vec![0; graph.n_nodes];
let mut component_id = 0;
for node in 0..graph.n_nodes {
if !visited[node] {
dfs_component(graph, node, component_id, &mut visited, &mut components);
component_id += 1;
}
}
components
}
#[allow(dead_code)]
fn dfs_component<F>(
graph: &Graph<F>,
node: usize,
component_id: usize,
visited: &mut [bool],
components: &mut [usize],
) where
F: Float
+ FromPrimitive
+ Debug
+ ScalarOperand
+ std::iter::Sum
+ std::cmp::Eq
+ std::hash::Hash
+ 'static,
{
visited[node] = true;
components[node] = component_id;
for &(neighbor_, _) in graph.neighbor_s(node) {
if !visited[neighbor_] {
dfs_component(graph, neighbor_, component_id, visited, components);
}
}
}
#[allow(dead_code)]
fn count_communities(communities: &[usize]) -> usize {
let mut unique: HashSet<usize> = HashSet::new();
for &community in communities {
unique.insert(community);
}
unique.len()
}
#[allow(dead_code)]
fn euclidean_distance<F>(a: ArrayView1<F>, b: ArrayView1<F>) -> F
where
F: Float + std::iter::Sum + 'static,
{
let diff = &a.to_owned() - &b.to_owned();
diff.dot(&diff).sqrt()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphClusteringConfig {
pub algorithm: GraphClusteringAlgorithm,
pub max_iterations: usize,
pub tolerance: f64,
pub resolution: f64,
pub ncommunities: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GraphClusteringAlgorithm {
Louvain,
LabelPropagation,
GirvanNewman,
}
impl Default for GraphClusteringConfig {
fn default() -> Self {
Self {
algorithm: GraphClusteringAlgorithm::Louvain,
max_iterations: 100,
tolerance: 1e-6,
resolution: 1.0,
ncommunities: None,
}
}
}
#[allow(dead_code)]
pub fn graph_clustering<F>(
graph: &Graph<F>,
config: &GraphClusteringConfig,
) -> Result<Array1<usize>>
where
F: Float
+ FromPrimitive
+ Debug
+ ScalarOperand
+ std::iter::Sum
+ std::cmp::Eq
+ std::hash::Hash
+ 'static,
f64: From<F>,
{
match config.algorithm {
GraphClusteringAlgorithm::Louvain => {
louvain(graph, config.resolution, config.max_iterations)
}
GraphClusteringAlgorithm::LabelPropagation => {
label_propagation(graph, config.max_iterations, config.tolerance)
}
GraphClusteringAlgorithm::GirvanNewman => {
let ncommunities = config.ncommunities.unwrap_or(2);
girvan_newman(graph, ncommunities)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
}