use crate::node::NodeIndex;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct EdgeKey {
pub source: NodeIndex,
pub target: NodeIndex,
}
impl EdgeKey {
pub fn new(source: NodeIndex, target: NodeIndex) -> Self {
Self { source, target }
}
pub fn new_undirected(a: NodeIndex, b: NodeIndex) -> Self {
if a < b {
Self {
source: a,
target: b,
}
} else {
Self {
source: b,
target: a,
}
}
}
}
#[derive(Debug, Clone)]
pub struct PartitionerConfig {
pub num_partitions: usize,
pub target_nodes_per_partition: Option<usize>,
pub properties: HashMap<String, String>,
}
impl PartitionerConfig {
pub fn new(num_partitions: usize) -> Self {
Self {
num_partitions,
target_nodes_per_partition: None,
properties: HashMap::new(),
}
}
pub fn with_target_nodes(mut self, target: usize) -> Self {
self.target_nodes_per_partition = Some(target);
self
}
pub fn with_property(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.properties.insert(key.into(), value.into());
self
}
}
pub type PartitionId = usize;
#[derive(Debug, Clone)]
pub struct Partition {
pub id: PartitionId,
pub nodes: Vec<NodeIndex>,
pub edges: Vec<usize>,
pub boundary_nodes: Vec<NodeIndex>,
pub edge_weights: HashMap<EdgeKey, f64>,
}
impl Partition {
pub fn new(id: PartitionId) -> Self {
Self {
id,
nodes: Vec::new(),
edges: Vec::new(),
boundary_nodes: Vec::new(),
edge_weights: HashMap::new(),
}
}
pub fn size(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn get_edge_weight(&self, source: NodeIndex, target: NodeIndex) -> Option<f64> {
self.edge_weights
.get(&EdgeKey::new(source, target))
.copied()
}
pub fn get_edge_weight_undirected(&self, a: NodeIndex, b: NodeIndex) -> Option<f64> {
self.edge_weights
.get(&EdgeKey::new_undirected(a, b))
.copied()
}
pub fn cache_edge_weights_from_graph<G, F>(&mut self, graph: &G, mut get_weight: F)
where
G: crate::vgi::VirtualGraph<NodeData = (), EdgeData = f64>,
F: FnMut(NodeIndex, NodeIndex, &f64) -> f64,
{
let max_node_idx = self.nodes.iter().map(|n| n.index()).max().unwrap_or(0);
let mut node_bitmap: Vec<bool> = vec![false; max_node_idx + 1];
for &node in &self.nodes {
node_bitmap[node.index()] = true;
}
for &node in &self.nodes {
for neighbor in graph.neighbors(node) {
if node_bitmap.get(neighbor.index()).copied().unwrap_or(false) {
let weight = graph.incident_edges(node).find_map(|edge_idx| {
if let Ok((u, v)) = graph.edge_endpoints(edge_idx) {
if (u == node && v == neighbor) || (u == neighbor && v == node) {
if let Ok(edge_data) = graph.get_edge(edge_idx) {
return Some(get_weight(u, v, edge_data));
}
}
}
None
});
if let Some(w) = weight {
let key = EdgeKey::new(node, neighbor);
self.edge_weights.insert(key, w);
let key_rev = EdgeKey::new(neighbor, node);
self.edge_weights.insert(key_rev, w);
}
}
}
}
}
}
pub trait Partitioner: Send + Sync {
fn name(&self) -> &'static str;
fn num_partitions(&self) -> usize;
fn partition_node(&self, node: NodeIndex) -> PartitionId;
fn partition_nodes(&self, nodes: &[NodeIndex]) -> HashMap<NodeIndex, PartitionId> {
nodes
.iter()
.map(|&node| (node, self.partition_node(node)))
.collect()
}
fn partition_graph<G>(&self, graph: &G) -> Vec<Partition>
where
G: crate::vgi::VirtualGraph,
{
let num_partitions = self.num_partitions();
let mut partitions: Vec<Partition> = (0..num_partitions).map(Partition::new).collect();
for node_ref in graph.nodes() {
let partition_id = self.partition_node(node_ref.index());
if partition_id < num_partitions {
partitions[partition_id].nodes.push(node_ref.index());
}
}
for edge_ref in graph.edges() {
let partition_id = self.partition_node(edge_ref.source());
if partition_id < num_partitions {
partitions[partition_id]
.edges
.push(edge_ref.index().index());
}
}
for partition in &mut partitions {
let max_node_idx = partition.nodes.iter().map(|n| n.index()).max().unwrap_or(0);
let mut partition_bitmap: Vec<bool> = vec![false; max_node_idx + 1];
for &node in &partition.nodes {
partition_bitmap[node.index()] = true;
}
for &node in &partition.nodes {
if graph.out_degree(node).is_ok() {
for neighbor in graph.neighbors(node) {
if !partition_bitmap.get(neighbor.index()).copied().unwrap_or(false) {
partition.boundary_nodes.push(node);
break;
}
}
}
}
}
partitions
}
fn partition_stats<G>(&self, graph: &G) -> PartitionStats
where
G: crate::vgi::VirtualGraph,
{
let partitions = self.partition_graph(graph);
let num_partitions = partitions.len();
let total_nodes: usize = partitions.iter().map(|p| p.size()).sum();
let min_partition_size = partitions.iter().map(|p| p.size()).min().unwrap_or(0);
let max_partition_size = partitions.iter().map(|p| p.size()).max().unwrap_or(0);
let avg_partition_size = if num_partitions > 0 {
total_nodes / num_partitions
} else {
0
};
let total_boundary_nodes: usize = partitions.iter().map(|p| p.boundary_nodes.len()).sum();
PartitionStats {
num_partitions,
total_nodes,
min_partition_size,
max_partition_size,
avg_partition_size,
total_boundary_nodes,
balance_ratio: if min_partition_size > 0 {
max_partition_size as f64 / min_partition_size as f64
} else {
f64::INFINITY
},
}
}
}
#[derive(Debug, Clone)]
pub struct PartitionStats {
pub num_partitions: usize,
pub total_nodes: usize,
pub min_partition_size: usize,
pub max_partition_size: usize,
pub avg_partition_size: usize,
pub total_boundary_nodes: usize,
pub balance_ratio: f64,
}
impl PartitionStats {
pub fn is_balanced(&self, threshold: f64) -> bool {
self.balance_ratio < threshold
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_partition_config() {
let config = PartitionerConfig::new(4)
.with_target_nodes(1000)
.with_property("strategy", "hash");
assert_eq!(config.num_partitions, 4);
assert_eq!(config.target_nodes_per_partition, Some(1000));
assert_eq!(config.properties.get("strategy"), Some(&"hash".to_string()));
}
#[test]
fn test_partition() {
let partition = Partition::new(0);
assert!(partition.is_empty());
assert_eq!(partition.size(), 0);
}
#[test]
fn test_partition_stats() {
let stats = PartitionStats {
num_partitions: 4,
total_nodes: 100,
min_partition_size: 20,
max_partition_size: 30,
avg_partition_size: 25,
total_boundary_nodes: 10,
balance_ratio: 1.5,
};
assert!(stats.is_balanced(2.0));
assert!(!stats.is_balanced(1.2));
}
}