use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array2, ArrayView2, Axis};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DistributionStrategy {
RowWise,
ColumnWise,
BlockCyclic,
LoadBased,
Custom,
}
#[derive(Debug, Clone)]
pub struct DataDistribution {
pub strategy: DistributionStrategy,
pub globalshape: (usize, usize),
pub localshape: (usize, usize),
pub owned_indices: IndexRange,
pub index_map: HashMap<(usize, usize), usize>,
pub blocksize: (usize, usize),
}
#[derive(Debug, Clone)]
pub struct IndexRange {
pub rows: (usize, usize),
pub columns: (usize, usize),
}
impl IndexRange {
pub fn new(_row_start: usize, row_end: usize, col_start: usize, colend: usize) -> Self {
Self {
rows: (_row_start, row_end),
columns: (col_start, col_end),
}
}
pub fn contains(&self, row: usize, col: usize) -> bool {
row >= self.rows.0 && row < self.rows.1 && col >= self.columns.0 && col < self.columns.1
}
pub fn num_rows(&self) -> usize {
self.rows.1 - self.rows.0
}
pub fn num_cols(&self) -> usize {
self.columns.1 - self.columns.0
}
pub fn size(&self) -> usize {
self.num_rows() * self.num_cols()
}
}
impl DataDistribution {
pub fn row_wise(
globalshape: (usize, usize),
num_nodes: usize,
node_rank: usize,
) -> LinalgResult<Self> {
let (global_rows, global_cols) = globalshape;
let rows_per_node = global_rows / num_nodes;
let remainder = global_rows % num_nodes;
let start_row = if node_rank < remainder {
node_rank * (rows_per_node + 1)
} else {
node_rank * rows_per_node + remainder
};
let end_row = if node_rank < remainder {
start_row + rows_per_node + 1
} else {
start_row + rows_per_node
};
let local_rows = end_row - start_row;
let localshape = (local_rows, global_cols);
let owned_indices = IndexRange::new(start_row, end_row, 0, global_cols);
let mut index_map = HashMap::new();
for node in 0..num_nodes {
let node_start = if node < remainder {
node * (rows_per_node + 1)
} else {
node * rows_per_node + remainder
};
let node_end = if node < remainder {
node_start + rows_per_node + 1
} else {
node_start + rows_per_node
};
for row in node_start..node_end {
for col in 0..global_cols {
index_map.insert((row, col), node);
}
}
}
Ok(Self {
strategy: DistributionStrategy::RowWise,
globalshape,
localshape,
owned_indices,
index_map,
blocksize: (1, global_cols),
})
}
pub fn column_wise(
globalshape: (usize, usize),
num_nodes: usize,
node_rank: usize,
) -> LinalgResult<Self> {
let (global_rows, global_cols) = globalshape;
let cols_per_node = global_cols / num_nodes;
let remainder = global_cols % num_nodes;
let start_col = if node_rank < remainder {
node_rank * (cols_per_node + 1)
} else {
node_rank * cols_per_node + remainder
};
let end_col = if node_rank < remainder {
start_col + cols_per_node + 1
} else {
start_col + cols_per_node
};
let local_cols = end_col - start_col;
let localshape = (global_rows, local_cols);
let owned_indices = IndexRange::new(0, global_rows, start_col, end_col);
let mut index_map = HashMap::new();
for node in 0..num_nodes {
let node_start = if node < remainder {
node * (cols_per_node + 1)
} else {
node * cols_per_node + remainder
};
let node_end = if node < remainder {
node_start + cols_per_node + 1
} else {
node_start + cols_per_node
};
for row in 0..global_rows {
for col in node_start..node_end {
index_map.insert((row, col), node);
}
}
}
Ok(Self {
strategy: DistributionStrategy::ColumnWise,
globalshape,
localshape,
owned_indices,
index_map,
blocksize: (global_rows, 1),
})
}
pub fn block_cyclic(
globalshape: (usize, usize),
num_nodes: usize,
node_rank: usize,
blocksize: (usize, usize),
) -> LinalgResult<Self> {
let (global_rows, global_cols) = globalshape;
let (block_rows, block_cols) = blocksize;
let grid_rows = (global_rows + block_rows - 1) / block_rows;
let grid_cols = (global_cols + block_cols - 1) / block_cols;
let proc_grid_rows = (num_nodes as f64).sqrt() as usize;
let proc_grid_cols = (num_nodes + proc_grid_rows - 1) / proc_grid_rows;
let proc_row = node_rank / proc_grid_cols;
let proc_col = node_rank % proc_grid_cols;
let mut owned_blocks = Vec::new();
for grid_r in (proc_row..grid_rows).step_by(proc_grid_rows) {
for grid_c in (proc_col..grid_cols).step_by(proc_grid_cols) {
owned_blocks.push((grid_r, grid_c));
}
}
let local_rows = owned_blocks.iter()
.map(|(gr_)| {
let start_row = gr * block_rows;
let end_row = ((gr + 1) * block_rows).min(global_rows);
end_row - start_row
})
.sum::<usize>();
let local_cols = owned_blocks.iter()
.map(|(_, gc)| {
let start_col = gc * block_cols;
let end_col = ((gc + 1) * block_cols).min(global_cols);
end_col - start_col
})
.max()
.unwrap_or(0);
let localshape = (local_rows, local_cols);
let owned_indices = if !owned_blocks.is_empty() {
let (first_gr, first_gc) = owned_blocks[0];
let start_row = first_gr * block_rows;
let end_row = ((first_gr + 1) * block_rows).min(global_rows);
let start_col = first_gc * block_cols;
let end_col = ((first_gc + 1) * block_cols).min(global_cols);
IndexRange::new(start_row, end_row, start_col, end_col)
} else {
IndexRange::new(0, 0, 0, 0)
};
let mut index_map = HashMap::new();
for &(gr, gc) in &owned_blocks {
let start_row = gr * block_rows;
let end_row = ((gr + 1) * block_rows).min(global_rows);
let start_col = gc * block_cols;
let end_col = ((gc + 1) * block_cols).min(global_cols);
for row in start_row..end_row {
for col in start_col..end_col {
index_map.insert((row, col), node_rank);
}
}
}
Ok(Self {
strategy: DistributionStrategy::BlockCyclic,
globalshape,
localshape,
owned_indices,
index_map,
blocksize,
})
}
pub fn get_owner(&self, row: usize, col: usize) -> Option<usize> {
self.index_map.get(&(row, col)).copied()
}
pub fn owns(&self, row: usize, col: usize) -> bool {
self.owned_indices.contains(row, col)
}
pub fn global_to_local(&self, global_row: usize, globalcol: usize) -> Option<(usize, usize)> {
if !self.owns(global_row, global_col) {
return None;
}
match self.strategy {
DistributionStrategy::RowWise => {
let local_row = global_row - self.owned_indices.rows.0;
Some((local_row, global_col))
},
DistributionStrategy::ColumnWise => {
let local_col = global_col - self.owned_indices.columns.0;
Some((global_row, local_col))
},
DistributionStrategy::BlockCyclic => {
let local_row = global_row - self.owned_indices.rows.0;
let local_col = global_col - self.owned_indices.columns.0;
Some((local_row, local_col))
}_ => None,
}
}
pub fn local_to_global(&self, local_row: usize, localcol: usize) -> (usize, usize) {
match self.strategy {
DistributionStrategy::RowWise => {
(local_row + self.owned_indices.rows.0, local_col)
},
DistributionStrategy::ColumnWise => {
(local_row, local_col + self.owned_indices.columns.0)
},
DistributionStrategy::BlockCyclic => {
(local_row + self.owned_indices.rows.0, local_col + self.owned_indices.columns.0)
}_ => (local_row, local_col),
}
}
}
pub struct LoadBalancer {
node_capabilities: HashMap<usize, f64>,
performance_history: HashMap<usize, Vec<f64>>,
workload_distribution: HashMap<usize, f64>,
}
impl LoadBalancer {
pub fn new(config: &super::DistributedConfig) -> LinalgResult<Self> {
let mut node_capabilities = HashMap::new();
for rank in 0.._config.num_nodes {
node_capabilities.insert(rank, 1.0);
}
Ok(Self {
node_capabilities,
performance_history: HashMap::new(),
workload_distribution: HashMap::new(),
})
}
pub fn update_capability(&mut self, node_rank: usize, performancemetric: f64) {
self.performance_history
.entry(node_rank)
.or_insert_with(Vec::new)
.push(performance_metric);
let history = self.performance_history.get_mut(&node_rank).expect("Operation failed");
if history.len() > 10 {
history.drain(0..history.len() - 10);
}
let avg_performance = history.iter().sum::<f64>() / history.len() as f64;
self.node_capabilities.insert(node_rank, avg_performance);
}
pub fn calculate_workload_distribution(&self, totalwork: f64) -> HashMap<usize, f64> {
let total_capability: f64 = self.node_capabilities.values().sum();
if total_capability == 0.0 {
return HashMap::new();
}
let mut distribution = HashMap::new();
for (&node_rank, &capability) in &self.node_capabilities {
let work_fraction = capability / total_capability;
distribution.insert(node_rank, total_work * work_fraction);
}
distribution
}
pub fn suggest_redistribution(&self) -> Option<RedistributionPlan> {
let workloads: Vec<f64> = self.workload_distribution.values().cloned().collect();
if workloads.is_empty() {
return None;
}
let avg_workload = workloads.iter().sum::<f64>() / workloads.len() as f64;
let max_workload = workloads.iter().cloned().fold(0.0, f64::max);
let min_workload = workloads.iter().cloned().fold(f64::INFINITY, f64::min);
let imbalance_ratio = (max_workload - min_workload) / avg_workload;
if imbalance_ratio > 0.2 {
let mut overloaded = Vec::new();
let mut underloaded = Vec::new();
for (&node_rank, &workload) in &self.workload_distribution {
if workload > avg_workload * 1.1 {
overloaded.push(node_rank);
} else if workload < avg_workload * 0.9 {
underloaded.push(node_rank);
}
}
Some(RedistributionPlan {
from_nodes: overloaded,
to_nodes: underloaded,
suggested_strategy: DistributionStrategy::LoadBased,
imbalance_ratio,
})
} else {
None
}
}
pub fn record_workload(&mut self, noderank: usize, workload: f64) {
self.workload_distribution.insert(node_rank, workload);
}
pub fn get_efficiency(&self) -> f64 {
if self.workload_distribution.is_empty() {
return 1.0;
}
let workloads: Vec<f64> = self.workload_distribution.values().cloned().collect();
let avg_workload = workloads.iter().sum::<f64>() / workloads.len() as f64;
if avg_workload == 0.0 {
return 1.0;
}
let variance: f64 = workloads
.iter()
.map(|w| (w - avg_workload).powi(2))
.sum::<f64>() / workloads.len() as f64;
let coefficient_of_variation = variance.sqrt() / avg_workload;
1.0 / (1.0 + coefficient_of_variation)
}
}
#[derive(Debug, Clone)]
pub struct RedistributionPlan {
pub from_nodes: Vec<usize>,
pub to_nodes: Vec<usize>,
pub suggested_strategy: DistributionStrategy,
pub imbalance_ratio: f64,
}
pub struct MatrixPartitioner;
impl MatrixPartitioner {
pub fn partition<T>(
matrix: &ArrayView2<T>,
distribution: &DataDistribution,
) -> LinalgResult<Array2<T>>
where
T: Clone,
{
let (global_rows, global_cols) = matrix.dim();
if (global_rows, global_cols) != distribution.globalshape {
return Err(LinalgError::DimensionError(format!(
"Matrix shape {:?} doesn't match distribution shape {:?}",
(global_rows, global_cols),
distribution.globalshape
)));
}
let IndexRange { rows, columns } = &distribution.owned_indices;
let localmatrix = matrix.slice(scirs2_core::ndarray::s![rows.0..rows.1, columns.0..columns.1]);
Ok(localmatrix.to_owned())
}
pub fn reconstruct<T>(
partitions: &HashMap<usize, Array2<T>>,
distribution: &DataDistribution,
) -> LinalgResult<Array2<T>>
where
T: Clone + Default,
{
let mut globalmatrix = Array2::default(distribution.globalshape);
for (&node_rank, partition) in partitions {
if let Some(range) = Self::get_node_range(node_rank, distribution) {
let target_slice = globalmatrix.slice_mut(scirs2_core::ndarray::s![
range.rows.0..range.rows.1,
range.columns.0..range.columns.1
]);
if target_slice.shape() == partition.shape() {
target_slice.assign(partition);
}
}
}
Ok(globalmatrix)
}
fn get_node_range(_noderank: usize, distribution: &DataDistribution) -> Option<IndexRange> {
Some(distribution.owned_indices.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_index_range() {
let range = IndexRange::new(10, 20, 5, 15);
assert_eq!(range.num_rows(), 10);
assert_eq!(range.num_cols(), 10);
assert_eq!(range.size(), 100);
assert!(range.contains(15, 10));
assert!(!range.contains(5, 10));
assert!(!range.contains(15, 20));
}
#[test]
fn test_row_wise_distribution() {
let distribution = DataDistribution::row_wise((100, 50), 4, 1).expect("Operation failed");
assert_eq!(distribution.strategy, DistributionStrategy::RowWise);
assert_eq!(distribution.globalshape, (100, 50));
assert_eq!(distribution.localshape.1, 50);
assert!(distribution.owns(30, 10));
assert!(!distribution.owns(10, 10));
}
#[test]
fn test_column_wise_distribution() {
let distribution = DataDistribution::column_wise((100, 50), 4, 2).expect("Operation failed");
assert_eq!(distribution.strategy, DistributionStrategy::ColumnWise);
assert_eq!(distribution.globalshape, (100, 50));
assert_eq!(distribution.localshape.0, 100);
assert!(distribution.localshape.1 > 0);
}
#[test]
fn test_load_balancer() {
use super::super::DistributedConfig;
let config = DistributedConfig::default().with_num_nodes(3);
let mut balancer = LoadBalancer::new(&config).expect("Operation failed");
balancer.update_capability(0, 1.0);
balancer.update_capability(1, 2.0);
balancer.update_capability(2, 0.5);
let distribution = balancer.calculate_workload_distribution(100.0);
assert!(distribution.len() == 3);
assert!(distribution[&1] > distribution[&0]); assert!(distribution[&2] < distribution[&0]); }
#[test]
fn testmatrix_partitioner() {
let matrix = Array2::from_shape_fn((10, 8), |(i, j)| (i * 8 + j) as f64);
let distribution = DataDistribution::row_wise((10, 8), 2, 0).expect("Operation failed");
let partition = MatrixPartitioner::partition(&matrix.view(), &distribution).expect("Operation failed");
assert_eq!(partition.nrows(), 5); assert_eq!(partition.ncols(), 8); }
}