use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use scirs2_core::numeric::{Float, Zero};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::Arc;
use super::communication::{DistributedCommunicator, MessageTag};
use super::coordination::DistributedCoordinator;
use super::distribution::{DataDistribution, DistributionStrategy, MatrixPartitioner};
pub struct DistributedMatrix<T> {
local_data: Array2<T>,
distribution: DataDistribution,
communicator: Arc<DistributedCommunicator>,
coordinator: Arc<DistributedCoordinator>,
node_rank: usize,
config: super::DistributedConfig,
}
impl<T> DistributedMatrix<T>
where
T: Float + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
pub fn from_local(
localmatrix: Array2<T>,
config: super::DistributedConfig,
) -> LinalgResult<Self> {
let globalshape = localmatrix.dim();
let distribution = match config.distribution {
DistributionStrategy::RowWise => {
DataDistribution::row_wise(globalshape, config.num_nodes, config.node_rank)?
}
DistributionStrategy::ColumnWise => {
DataDistribution::column_wise(globalshape, config.num_nodes, config.node_rank)?
}
DistributionStrategy::BlockCyclic => {
DataDistribution::block_cyclic(
globalshape,
config.num_nodes,
config.node_rank,
(config.blocksize, config.blocksize),
)?
}
_ => {
return Err(LinalgError::NotImplemented(
"Distribution strategy not implemented".to_string()
));
}
};
let local_data = MatrixPartitioner::partition(&localmatrix.view(), &distribution)?;
let communicator = Arc::new(DistributedCommunicator::new(&config)?);
let coordinator = Arc::new(DistributedCoordinator::new(&config)?);
Ok(Self {
local_data,
distribution,
communicator,
coordinator,
node_rank: config.node_rank,
config,
})
}
pub fn from_distribution(
distribution: DataDistribution,
config: super::DistributedConfig,
) -> LinalgResult<Self> {
let local_data = Array2::zeros(distribution.localshape);
let communicator = Arc::new(DistributedCommunicator::new(&config)?);
let coordinator = Arc::new(DistributedCoordinator::new(&config)?);
Ok(Self {
local_data,
distribution,
communicator,
coordinator,
node_rank: config.node_rank,
config,
})
}
pub fn globalshape(&self) -> (usize, usize) {
self.distribution.globalshape
}
pub fn localshape(&self) -> (usize, usize) {
self.local_data.dim()
}
pub fn local_data(&self) -> &Array2<T> {
&self.local_data
}
pub fn local_data_mut(&mut self) -> &mut Array2<T> {
&mut self.local_data
}
pub fn multiply(&self, other: &DistributedMatrix<T>) -> LinalgResult<DistributedMatrix<T>> {
let (m, k) = self.globalshape();
let (k2, n) = other.globalshape();
if k != k2 {
return Err(LinalgError::DimensionError(format!(
"Matrix dimensions don't match for multiplication: ({}, {}) x ({}, {})",
m, k, k2, n
)));
}
match (&self.distribution.strategy, &other.distribution.strategy) {
(DistributionStrategy::RowWise, DistributionStrategy::ColumnWise) => {
self.multiply_row_col(other)
}
(DistributionStrategy::RowWise, DistributionStrategy::RowWise) => {
self.multiply_row_row(other)
}
(DistributionStrategy::ColumnWise, DistributionStrategy::ColumnWise) => {
self.multiply_col_col(other)
}
_ => Err(LinalgError::NotImplemented(
"Matrix multiplication for this distribution combination not implemented".to_string()
)),
}
}
pub fn add(&self, other: &DistributedMatrix<T>) -> LinalgResult<DistributedMatrix<T>> {
if self.globalshape() != other.globalshape() {
return Err(LinalgError::DimensionError(
"Matrix dimensions must match for addition".to_string()
));
}
if self.distribution.strategy != other.distribution.strategy {
return Err(LinalgError::InvalidInput(
"Matrices must have same distribution strategy for addition".to_string()
));
}
let local_result = &self.local_data + &other.local_data;
let mut result = DistributedMatrix::from_distribution(
self.distribution.clone(),
self.config.clone(),
)?;
result.local_data = local_result;
Ok(result)
}
pub fn transpose(&self) -> LinalgResult<DistributedMatrix<T>> {
match self.distribution.strategy {
DistributionStrategy::RowWise => self.transpose_row_to_col(),
DistributionStrategy::ColumnWise => self.transpose_col_to_row(, _ => Err(LinalgError::NotImplemented(
"Transpose for this distribution not implemented".to_string()
)),
}
}
pub fn gather(&self) -> LinalgResult<Option<Array2<T>>> {
if self.node_rank == 0 {
let matrices = self.communicator.gather_matrices(&self.local_data.view())?;
if let Some(partitions) = matrices {
let mut partition_map = HashMap::new();
for (rank, matrix) in partitions.into_iter().enumerate() {
partition_map.insert(rank, matrix);
}
let globalmatrix = MatrixPartitioner::reconstruct(&partition_map, &self.distribution)?;
Ok(Some(globalmatrix))
} else {
Ok(None)
}
} else {
self.communicator.gather_matrices(&self.local_data.view())?;
Ok(None)
}
}
pub fn broadcast_from_root(
globalmatrix: Option<Array2<T>>,
config: super::DistributedConfig,
) -> LinalgResult<DistributedMatrix<T>> {
let communicator = Arc::new(DistributedCommunicator::new(&config)?);
if config.node_rank == 0 {
let matrix = globalmatrix.ok_or_else(|| {
LinalgError::InvalidInput("Root node must provide matrix for broadcast".to_string())
})?;
communicator.broadcastmatrix(&matrix.view())?;
Self::from_local(matrix, config)
} else {
let matrix = communicator.recvmatrix(0, MessageTag::Data)?;
Self::from_local(matrix, config)
}
}
pub fn gemm_simd(
&self,
other: &DistributedMatrix<T>,
alpha: T,
beta: T,
) -> LinalgResult<DistributedMatrix<T>>
where
T: 'static,
{
if !self.config.enable_simd {
return self.multiply(other);
}
match (T::zero(), T::one()) {
(_) if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() => {
self.gemm_simd_f32(other, alpha, beta)
}
(_) if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() => {
self.gemm_simd_f64(other, alpha, beta)
}
_ => {
self.multiply(other)
}
}
}
fn multiply_row_col(&self, other: &DistributedMatrix<T>) -> LinalgResult<DistributedMatrix<T>> {
let (_, k) = self.globalshape();
let (_, n) = other.globalshape();
let mut local_result = Array2::zeros((self.local_data.nrows(), n));
for j in 0..self.config.num_nodes {
let b_partition = if j == self.node_rank {
other.local_data.clone()
} else {
self.communicator.recvmatrix(j, MessageTag::MatMul)?
};
if j != self.node_rank {
self.communicator.sendmatrix(&other.local_data.view(), j, MessageTag::MatMul)?;
}
let contrib = self.local_data.dot(&b_partition);
local_result = local_result + contrib;
}
self.coordinator.barrier()?;
let result_distribution = DataDistribution::row_wise(
(self.distribution.globalshape.0, n),
self.config.num_nodes,
self.node_rank,
)?;
let mut result = DistributedMatrix::from_distribution(result_distribution, self.config.clone())?;
result.local_data = local_result;
Ok(result)
}
fn multiply_row_row(&self, other: &DistributedMatrix<T>) -> LinalgResult<DistributedMatrix<T>> {
let other_transposed = other.transpose()?; let other_col_dist = other_transposed.transpose()?;
self.multiply_row_col(&other_col_dist)
}
fn multiply_col_col(&self, other: &DistributedMatrix<T>) -> LinalgResult<DistributedMatrix<T>> {
let self_transposed = self.transpose()?; let self_row_dist = self_transposed.transpose()?;
self_row_dist.multiply_row_col(other)
}
fn transpose_row_to_col(&self) -> LinalgResult<DistributedMatrix<T>> {
let (m, n) = self.globalshape();
let result_distribution = DataDistribution::column_wise(
(n, m), self.config.num_nodes,
self.node_rank,
)?;
let mut result = DistributedMatrix::from_distribution(result_distribution, self.config.clone())?;
if let Some(globalmatrix) = self.gather()? {
let transposed = globalmatrix.t().to_owned();
let redistributed = DistributedMatrix::broadcast_from_root(
Some(transposed),
self.config.clone(),
)?;
result.local_data = redistributed.local_data;
}
self.coordinator.barrier()?;
Ok(result)
}
fn transpose_col_to_row(&self) -> LinalgResult<DistributedMatrix<T>> {
let (m, n) = self.globalshape();
let result_distribution = DataDistribution::row_wise(
(n, m), self.config.num_nodes,
self.node_rank,
)?;
let mut result = DistributedMatrix::from_distribution(result_distribution, self.config.clone())?;
if let Some(globalmatrix) = self.gather()? {
let transposed = globalmatrix.t().to_owned();
let redistributed = DistributedMatrix::broadcast_from_root(
Some(transposed),
self.config.clone(),
)?;
result.local_data = redistributed.local_data;
}
self.coordinator.barrier()?;
Ok(result)
}
fn gemm_simd_f32(&self, other: &DistributedMatrix<T>, alpha: T, beta: T) -> LinalgResult<DistributedMatrix<T>> {
self.multiply(other)
}
fn gemm_simd_f64(&self, other: &DistributedMatrix<T>, alpha: T, beta: T) -> LinalgResult<DistributedMatrix<T>> {
self.multiply(other)
}
}
pub struct DistributedVector<T> {
local_data: Array1<T>,
global_length: usize,
distribution: VectorDistribution,
communicator: Arc<DistributedCommunicator>,
node_rank: usize,
config: super::DistributedConfig,
}
#[derive(Debug, Clone)]
pub struct VectorDistribution {
pub global_length: usize,
pub local_length: usize,
pub start_index: usize,
pub end_index: usize,
}
impl VectorDistribution {
pub fn new(_global_length: usize, num_nodes: usize, noderank: usize) -> Self {
let elements_per_node = _global_length / num_nodes;
let remainder = _global_length % num_nodes;
let start_index = if node_rank < remainder {
node_rank * (elements_per_node + 1)
} else {
node_rank * elements_per_node + remainder
};
let end_index = if node_rank < remainder {
start_index + elements_per_node + 1
} else {
start_index + elements_per_node
};
let local_length = end_index - start_index;
Self {
global_length,
local_length,
start_index,
end_index,
}
}
}
impl<T> DistributedVector<T>
where
T: Float + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
pub fn from_local(
local_vector: Array1<T>,
config: super::DistributedConfig,
) -> LinalgResult<Self> {
let global_length = local_vector.len();
let distribution = VectorDistribution::new(global_length, config.num_nodes, config.node_rank);
let local_data = local_vector.slice(scirs2_core::ndarray::s![distribution.start_index..distribution.end_index]).to_owned();
let communicator = Arc::new(DistributedCommunicator::new(&config)?);
Ok(Self {
local_data,
global_length,
distribution,
communicator,
node_rank: config.node_rank,
config,
})
}
pub fn global_length(&self) -> usize {
self.global_length
}
pub fn local_length(&self) -> usize {
self.local_data.len()
}
pub fn local_data(&self) -> &Array1<T> {
&self.local_data
}
pub fn local_data_mut(&mut self) -> &mut Array1<T> {
&mut self.local_data
}
pub fn dot(&self, other: &DistributedVector<T>) -> LinalgResult<T> {
if self.global_length != other.global_length {
return Err(LinalgError::DimensionError(
"Vector lengths must match for dot product".to_string()
));
}
let local_dot = self.local_data.dot(&other.local_data);
let global_dot = self.allreduce_sum(local_dot)?;
Ok(global_dot)
}
pub fn add(&self, other: &DistributedVector<T>) -> LinalgResult<DistributedVector<T>> {
if self.global_length != other.global_length {
return Err(LinalgError::DimensionError(
"Vector lengths must match for addition".to_string()
));
}
let local_result = &self.local_data + &other.local_data;
let mut result = Self::from_local(
Array1::zeros(self.global_length),
self.config.clone(),
)?;
result.local_data = local_result;
Ok(result)
}
pub fn gather(&self) -> LinalgResult<Option<Array1<T>>> {
let dummymatrix = self.local_data.clone().insert_axis(Axis(0));
if self.node_rank == 0 {
let matrices = self.communicator.gather_matrices(&dummymatrix.view())?;
if let Some(parts) = matrices {
let mut result = Array1::zeros(self.global_length);
for (rank, matrix) in parts.into_iter().enumerate() {
let dist = VectorDistribution::new(self.global_length, self.config.num_nodes, rank);
let vector = matrix.index_axis(Axis(0), 0);
result.slice_mut(scirs2_core::ndarray::s![dist.start_index..dist.end_index]).assign(&vector);
}
Ok(Some(result))
} else {
Ok(None)
}
} else {
self.communicator.gather_matrices(&dummymatrix.view())?;
Ok(None)
}
}
fn allreduce_sum(&self, localvalue: T) -> LinalgResult<T> {
let valuematrix = Array2::from_elem((1, 1), local_value);
if let Some(gathered) = self.communicator.gather_matrices(&valuematrix.view())? {
let total: T = gathered.iter().map(|m| m[[0, 0]]).fold(T::zero(), |acc, x| acc + x);
let resultmatrix = Array2::from_elem((1, 1), total);
self.communicator.broadcastmatrix(&resultmatrix.view())?;
Ok(total)
} else {
let resultmatrix = self.communicator.recvmatrix(0, MessageTag::Data)?;
Ok(resultmatrix[[0, 0]])
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::DistributedConfig;
#[test]
fn test_distributedmatrix_creation() {
let matrix = Array2::from_shape_fn((6, 4), |(i, j)| (i * 4 + j) as f64);
let config = DistributedConfig::default()
.with_num_nodes(2)
.with_node_rank(0)
.with_distribution(DistributionStrategy::RowWise);
let distmatrix = DistributedMatrix::from_local(matrix.clone(), config).expect("Operation failed");
assert_eq!(distmatrix.globalshape(), (6, 4));
assert_eq!(distmatrix.localshape().0, 3); assert_eq!(distmatrix.localshape().1, 4); }
#[test]
fn test_distributed_vector_creation() {
let vector = Array1::from_shape_fn(10, |i| i as f64);
let config = DistributedConfig::default()
.with_num_nodes(2)
.with_node_rank(0);
let dist_vector = DistributedVector::from_local(vector, config).expect("Operation failed");
assert_eq!(dist_vector.global_length(), 10);
assert_eq!(dist_vector.local_length(), 5); }
#[test]
fn test_vector_distribution() {
let dist = VectorDistribution::new(10, 3, 1);
assert_eq!(dist.global_length, 10);
assert_eq!(dist.start_index, 3);
assert_eq!(dist.end_index, 6);
assert_eq!(dist.local_length, 3);
}
#[test]
fn testmatrix_local_operations() {
let matrix1 = Array2::from_shape_fn((4, 4), |(i, j)| (i + j) as f64);
let matrix2 = Array2::from_shape_fn((4, 4), |(i, j)| (i * j) as f64);
let config = DistributedConfig::default();
let dist1 = DistributedMatrix::from_local(matrix1, config.clone()).expect("Operation failed");
let dist2 = DistributedMatrix::from_local(matrix2, config).expect("Operation failed");
let result = dist1.add(&dist2).expect("Operation failed");
assert_eq!(result.globalshape(), (4, 4));
}
}