use crate::{Error, ErrorKind, Graph};
use rayon::prelude::*;
use std::collections::HashSet;
use std::fmt::{Debug, Display};
use std::hash::Hash;
pub fn group_degree_centrality<T, A>(
graph: &Graph<T, A>,
group: &HashSet<T>,
normalized: bool,
) -> Result<f64, Error>
where
T: Hash + Eq + Clone + Ord + Debug + Display + Send + Sync,
A: Clone + Send + Sync,
{
let mut all_nodes: Vec<T> = graph
.get_all_nodes()
.iter()
.map(|n| n.name.clone())
.collect();
all_nodes.sort();
let all_nodes_set: HashSet<T> = all_nodes.iter().cloned().collect();
let mut group_vec: Vec<T> = group.iter().cloned().collect();
group_vec.sort();
let group_set: HashSet<T> = group_vec.iter().cloned().collect();
let missing_nodes: Vec<T> = group_set.difference(&all_nodes_set).cloned().collect();
if !missing_nodes.is_empty() {
return Err(Error {
kind: ErrorKind::NodeNotFound,
message: format!("The node(s) {:?} are not in the graph", missing_nodes),
});
}
let n = all_nodes.len();
let c = group_set.len();
if c == 0 {
return Err(Error {
kind: ErrorKind::InvalidArgument,
message: "Group cannot be empty".to_string(),
});
}
if c >= n {
return Err(Error {
kind: ErrorKind::InvalidArgument,
message: "Group cannot contain all nodes".to_string(),
});
}
let mut non_group_nodes: Vec<T> = all_nodes_set.difference(&group_set).cloned().collect();
non_group_nodes.sort();
let non_group_set: HashSet<T> = non_group_nodes.iter().cloned().collect();
let connected_non_group_nodes = if group_vec.len() > 10 && rayon::current_num_threads() > 1 {
calculate_connected_non_group_nodes_parallel(graph, &group_vec, &non_group_set)?
} else {
calculate_connected_non_group_nodes_sequential(graph, &group_vec, &non_group_set)?
};
let degree_centrality = if normalized {
let total_non_group_nodes = n - c;
if total_non_group_nodes == 0 {
0.0
} else {
connected_non_group_nodes.len() as f64 / total_non_group_nodes as f64
}
} else {
connected_non_group_nodes.len() as f64
};
Ok(degree_centrality)
}
fn calculate_connected_non_group_nodes_parallel<T, A>(
graph: &Graph<T, A>,
group_nodes: &[T],
non_group_set: &HashSet<T>,
) -> Result<HashSet<T>, Error>
where
T: Hash + Eq + Clone + Ord + Debug + Display + Send + Sync,
A: Clone + Send + Sync,
{
let connected_sets: Result<Vec<HashSet<T>>, Error> = group_nodes
.par_iter()
.map(|node| get_connected_non_group_nodes_for_node(graph, node, non_group_set))
.collect();
let connected_sets = connected_sets?;
let mut all_connected = HashSet::new();
for set in connected_sets {
all_connected.extend(set);
}
Ok(all_connected)
}
fn calculate_connected_non_group_nodes_sequential<T, A>(
graph: &Graph<T, A>,
group_nodes: &[T],
non_group_set: &HashSet<T>,
) -> Result<HashSet<T>, Error>
where
T: Hash + Eq + Clone + Ord + Debug + Display + Send + Sync,
A: Clone + Send + Sync,
{
let mut connected_non_group_nodes = HashSet::new();
for node in group_nodes {
let node_connected = get_connected_non_group_nodes_for_node(graph, node, non_group_set)?;
connected_non_group_nodes.extend(node_connected);
}
Ok(connected_non_group_nodes)
}
fn get_connected_non_group_nodes_for_node<T, A>(
graph: &Graph<T, A>,
node: &T,
non_group_set: &HashSet<T>,
) -> Result<HashSet<T>, Error>
where
T: Hash + Eq + Clone + Ord + Debug + Display + Send + Sync,
A: Clone + Send + Sync,
{
let neighbors = if graph.specs.directed {
graph.get_successor_nodes(node.clone())
} else {
graph.get_neighbor_nodes(node.clone())
};
match neighbors {
Ok(neighbors) => {
let connected_non_group: HashSet<T> = neighbors
.iter()
.filter(|neighbor| non_group_set.contains(&neighbor.name))
.map(|neighbor| neighbor.name.clone())
.collect();
Ok(connected_non_group)
}
Err(_) => Ok(HashSet::new()), }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Edge, Graph, GraphSpecs};
#[test]
fn test_group_degree_basic_undirected() {
let mut graph = Graph::<i32, ()>::new(GraphSpecs::undirected_create_missing());
graph
.add_edges(vec![
Edge::new(0, 1),
Edge::new(1, 2),
Edge::new(2, 3),
Edge::new(3, 4),
])
.unwrap();
let mut group = HashSet::new();
group.insert(1);
group.insert(2);
let centrality = group_degree_centrality(&graph, &group, false).unwrap();
assert_eq!(centrality, 2.0);
}
#[test]
fn test_group_degree_basic_directed() {
let mut graph = Graph::<i32, ()>::new(GraphSpecs::directed_create_missing());
graph
.add_edges(vec![
Edge::new(0, 1),
Edge::new(1, 2),
Edge::new(2, 3),
Edge::new(3, 4),
])
.unwrap();
let mut group = HashSet::new();
group.insert(1);
group.insert(2);
let centrality = group_degree_centrality(&graph, &group, false).unwrap();
assert_eq!(centrality, 1.0);
}
#[test]
fn test_group_degree_normalization() {
let mut graph = Graph::<i32, ()>::new(GraphSpecs::undirected_create_missing());
graph
.add_edges(vec![Edge::new(0, 1), Edge::new(1, 2), Edge::new(2, 3)])
.unwrap();
let mut group = HashSet::new();
group.insert(1);
let unnormalized = group_degree_centrality(&graph, &group, false).unwrap();
let normalized = group_degree_centrality(&graph, &group, true).unwrap();
assert_eq!(unnormalized, 2.0);
assert!((normalized - (2.0 / 3.0)).abs() < 1e-10);
}
#[test]
fn test_invalid_group_node() {
let mut graph = Graph::<i32, ()>::new(GraphSpecs::undirected_create_missing());
graph.add_edges(vec![Edge::new(0, 1)]).unwrap();
let mut group = HashSet::new();
group.insert(999);
let result = group_degree_centrality(&graph, &group, false);
assert!(result.is_err());
}
#[test]
fn test_empty_group() {
let mut graph = Graph::<i32, ()>::new(GraphSpecs::undirected_create_missing());
graph.add_edges(vec![Edge::new(0, 1)]).unwrap();
let group = HashSet::new();
let result = group_degree_centrality(&graph, &group, false);
assert!(result.is_err());
}
#[test]
fn test_group_contains_all_nodes() {
let mut graph = Graph::<i32, ()>::new(GraphSpecs::undirected_create_missing());
graph.add_edges(vec![Edge::new(0, 1)]).unwrap();
let mut group = HashSet::new();
group.insert(0);
group.insert(1);
let result = group_degree_centrality(&graph, &group, false);
assert!(result.is_err());
}
#[test]
fn test_deterministic_behavior() {
let mut graph = Graph::<i32, ()>::new(GraphSpecs::undirected_create_missing());
graph
.add_edges(vec![
Edge::new(0, 1),
Edge::new(1, 2),
Edge::new(2, 3),
Edge::new(3, 4),
Edge::new(0, 3),
Edge::new(1, 4),
])
.unwrap();
let mut group = HashSet::new();
group.insert(0);
group.insert(1);
let mut results = Vec::new();
for _ in 0..5 {
let centrality = group_degree_centrality(&graph, &group, false).unwrap();
results.push(centrality);
}
let first_result = results[0];
for &result in &results[1..] {
assert!(
(result - first_result).abs() < 1e-15,
"Non-deterministic behavior detected: first={}, other={}, diff={}",
first_result,
result,
(result - first_result).abs()
);
}
let mut norm_results = Vec::new();
for _ in 0..3 {
let centrality = group_degree_centrality(&graph, &group, true).unwrap();
norm_results.push(centrality);
}
let first_norm = norm_results[0];
for &result in &norm_results[1..] {
assert!(
(result - first_norm).abs() < 1e-15,
"Non-deterministic behavior in normalized version"
);
}
}
#[test]
fn test_group_degree_parallel_threshold() {
let mut graph = Graph::<i32, ()>::new(GraphSpecs::undirected_create_missing());
for i in 0..15 {
graph.add_edge(Edge::new(i, i + 1)).unwrap();
}
let mut group = HashSet::new();
for i in 0..12 {
group.insert(i);
}
let centrality = group_degree_centrality(&graph, &group, false).unwrap();
assert!(centrality > 0.0);
}
}