use super::comm::{CommunicationError, Message, MessageTag};
use super::process::{Communicator, ProcessError};
use serde::{Deserialize, Serialize};
use std::ops::{Add, Mul};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum CollectiveError {
#[error("Communication error: {0}")]
Communication(#[from] CommunicationError),
#[error("Process error: {0}")]
Process(#[from] ProcessError),
#[error("Invalid root rank {root}, must be < {size}")]
InvalidRoot { root: usize, size: usize },
#[error("Data size mismatch: expected {expected}, got {actual}")]
SizeMismatch { expected: usize, actual: usize },
#[error("Collective operation failed: {0}")]
OperationFailed(String),
#[error("Timeout during collective operation")]
Timeout,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ReduceOp {
Sum,
Product,
Max,
Min,
And,
Or,
}
impl ReduceOp {
pub fn apply<T>(&self, a: T, b: T) -> T
where
T: Add<Output = T> + Mul<Output = T> + PartialOrd + Clone,
{
match self {
ReduceOp::Sum => a + b,
ReduceOp::Product => a * b,
ReduceOp::Max => {
if a > b {
a
} else {
b
}
}
ReduceOp::Min => {
if a < b {
a
} else {
b
}
}
_ => a, }
}
pub fn apply_slice<T>(&self, values: &[T]) -> Option<T>
where
T: Add<Output = T> + Mul<Output = T> + PartialOrd + Clone,
{
if values.is_empty() {
return None;
}
let mut result = values[0].clone();
for value in &values[1..] {
result = self.apply(result, value.clone());
}
Some(result)
}
}
const TAG_REDUCE: MessageTag = 1000;
const TAG_BROADCAST: MessageTag = 1001;
const TAG_GATHER: MessageTag = 1002;
const TAG_SCATTER: MessageTag = 1003;
const TAG_BARRIER: MessageTag = 1004;
pub async fn reduce<T>(
data: &[T],
op: ReduceOp,
root: usize,
comm: &Communicator,
) -> Result<Vec<T>, CollectiveError>
where
T: Serialize
+ for<'de> Deserialize<'de>
+ Clone
+ Add<Output = T>
+ Mul<Output = T>
+ PartialOrd
+ Send
+ 'static,
{
let rank = comm.rank();
let size = comm.size();
if root >= size {
return Err(CollectiveError::InvalidRoot { root, size });
}
if rank == root {
let all_data = [data.to_vec()];
for src in 0..size {
if src != root {
}
}
let mut result = data.to_vec();
for process_data in &all_data[1..] {
for (i, value) in process_data.iter().enumerate() {
if i < result.len() {
result[i] = op.apply(result[i].clone(), value.clone());
}
}
}
Ok(result)
} else {
Ok(Vec::new())
}
}
pub async fn allreduce<T>(
data: &[T],
op: ReduceOp,
comm: &Communicator,
) -> Result<Vec<T>, CollectiveError>
where
T: Serialize
+ for<'de> Deserialize<'de>
+ Clone
+ Add<Output = T>
+ Mul<Output = T>
+ PartialOrd
+ Send
+ 'static,
{
let reduced = reduce(data, op, 0, comm).await?;
let result = if comm.is_root() { reduced } else { vec![] };
Ok(result)
}
pub async fn broadcast<T>(
data: &mut [T],
root: usize,
comm: &Communicator,
) -> Result<(), CollectiveError>
where
T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static,
{
let rank = comm.rank();
let size = comm.size();
if root >= size {
return Err(CollectiveError::InvalidRoot { root, size });
}
if rank == root {
for dest in 0..size {
if dest != root {
}
}
} else {
}
Ok(())
}
pub async fn gather<T>(
data: &[T],
root: usize,
comm: &Communicator,
) -> Result<Vec<T>, CollectiveError>
where
T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static,
{
let rank = comm.rank();
let size = comm.size();
if root >= size {
return Err(CollectiveError::InvalidRoot { root, size });
}
if rank == root {
let mut result = Vec::new();
result.extend_from_slice(data);
for src in 0..size {
if src != root {
}
}
Ok(result)
} else {
Ok(Vec::new())
}
}
pub async fn allgather<T>(data: &[T], comm: &Communicator) -> Result<Vec<T>, CollectiveError>
where
T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static,
{
let gathered = gather(data, 0, comm).await?;
Ok(gathered)
}
pub async fn scatter<T>(
send_data: &[T],
root: usize,
comm: &Communicator,
) -> Result<Vec<T>, CollectiveError>
where
T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static,
{
let rank = comm.rank();
let size = comm.size();
if root >= size {
return Err(CollectiveError::InvalidRoot { root, size });
}
if rank == root {
let chunk_size = send_data.len() / size;
let remainder = send_data.len() % size;
let mut offset = 0;
for dest in 0..size {
let this_chunk_size = chunk_size + if dest < remainder { 1 } else { 0 };
let chunk = &send_data[offset..offset + this_chunk_size];
if dest != root {
}
offset += this_chunk_size;
}
let root_chunk_size = chunk_size + if root < remainder { 1 } else { 0 };
let root_offset = root * chunk_size + root.min(remainder);
Ok(send_data[root_offset..root_offset + root_chunk_size].to_vec())
} else {
Ok(Vec::new())
}
}
pub async fn allscatter<T>(send_data: &[T], comm: &Communicator) -> Result<Vec<T>, CollectiveError>
where
T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static,
{
let rank = comm.rank();
let size = comm.size();
let chunk_size = send_data.len() / size;
let mut result = Vec::new();
result.extend_from_slice(&send_data[rank * chunk_size..(rank + 1) * chunk_size]);
Ok(result)
}
pub async fn barrier(comm: &Communicator) -> Result<(), CollectiveError> {
comm.barrier().await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reduce_op_sum() {
let op = ReduceOp::Sum;
assert_eq!(op.apply(2.0, 3.0), 5.0);
assert_eq!(op.apply(10, 5), 15);
}
#[test]
fn test_reduce_op_product() {
let op = ReduceOp::Product;
assert_eq!(op.apply(2.0, 3.0), 6.0);
assert_eq!(op.apply(4, 5), 20);
}
#[test]
fn test_reduce_op_max() {
let op = ReduceOp::Max;
assert_eq!(op.apply(2.0, 3.0), 3.0);
assert_eq!(op.apply(10, 5), 10);
}
#[test]
fn test_reduce_op_min() {
let op = ReduceOp::Min;
assert_eq!(op.apply(2.0, 3.0), 2.0);
assert_eq!(op.apply(10, 5), 5);
}
#[test]
fn test_reduce_op_apply_slice() {
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
assert_eq!(ReduceOp::Sum.apply_slice(&values), Some(15.0));
assert_eq!(ReduceOp::Product.apply_slice(&values), Some(120.0));
assert_eq!(ReduceOp::Max.apply_slice(&values), Some(5.0));
assert_eq!(ReduceOp::Min.apply_slice(&values), Some(1.0));
}
#[test]
fn test_reduce_op_empty_slice() {
let values: Vec<f64> = vec![];
assert_eq!(ReduceOp::Sum.apply_slice(&values), None);
}
#[test]
fn test_collective_error_invalid_root() {
let err = CollectiveError::InvalidRoot { root: 5, size: 4 };
assert!(err.to_string().contains("Invalid root"));
}
#[test]
fn test_collective_error_size_mismatch() {
let err = CollectiveError::SizeMismatch {
expected: 10,
actual: 5,
};
assert!(err.to_string().contains("Data size mismatch"));
}
}