use crate::error::{CoreError, CoreResult, ErrorContext};
use std::sync::{Arc, Barrier, Mutex};
fn lock_or_err<T>(m: &Mutex<T>) -> CoreResult<std::sync::MutexGuard<'_, T>> {
m.lock().map_err(|_| {
CoreError::ComputationError(ErrorContext::new("collective: mutex lock poisoned"))
})
}
pub struct AllReduceOp {
n_workers: usize,
barrier: Arc<Barrier>,
shared_buffer: Arc<Mutex<Vec<f64>>>,
}
impl AllReduceOp {
pub fn new(n_workers: usize) -> Self {
assert!(n_workers > 0, "n_workers must be > 0");
Self {
n_workers,
barrier: Arc::new(Barrier::new(n_workers)),
shared_buffer: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn barrier(&self) -> &Arc<Barrier> {
&self.barrier
}
pub fn n_workers(&self) -> usize {
self.n_workers
}
pub fn all_reduce_sum(&self, local_data: &[f64], worker_id: usize) -> CoreResult<Vec<f64>> {
self.validate_args(local_data, worker_id)?;
let chunk_len = local_data.len();
{
let mut buf = lock_or_err(&self.shared_buffer)?;
let required = self.n_workers * chunk_len;
if buf.len() < required {
buf.resize(required, 0.0_f64);
}
let start = worker_id * chunk_len;
buf[start..start + chunk_len].copy_from_slice(local_data);
}
self.barrier.wait();
let buf = lock_or_err(&self.shared_buffer)?;
let result = self.reduce_sum_from_buf(&buf, chunk_len)?;
drop(buf);
self.barrier.wait();
Ok(result)
}
pub fn all_reduce_mean(&self, local_data: &[f64], worker_id: usize) -> CoreResult<Vec<f64>> {
let sum = self.all_reduce_sum(local_data, worker_id)?;
let n = self.n_workers as f64;
Ok(sum.into_iter().map(|v| v / n).collect())
}
pub fn all_reduce_max(&self, local_data: &[f64], worker_id: usize) -> CoreResult<Vec<f64>> {
self.validate_args(local_data, worker_id)?;
let chunk_len = local_data.len();
{
let mut buf = lock_or_err(&self.shared_buffer)?;
let required = self.n_workers * chunk_len;
if buf.len() < required {
buf.resize(required, f64::NEG_INFINITY);
}
let start = worker_id * chunk_len;
buf[start..start + chunk_len].copy_from_slice(local_data);
}
self.barrier.wait();
let buf = lock_or_err(&self.shared_buffer)?;
let result = self.reduce_max_from_buf(&buf, chunk_len)?;
drop(buf);
self.barrier.wait();
Ok(result)
}
fn validate_args(&self, local_data: &[f64], worker_id: usize) -> CoreResult<()> {
if worker_id >= self.n_workers {
return Err(CoreError::ValueError(ErrorContext::new(format!(
"worker_id {worker_id} >= n_workers {}",
self.n_workers
))));
}
if local_data.is_empty() {
return Err(CoreError::ValueError(ErrorContext::new(
"local_data must not be empty",
)));
}
Ok(())
}
fn reduce_sum_from_buf(&self, buf: &[f64], chunk_len: usize) -> CoreResult<Vec<f64>> {
let mut result = vec![0.0_f64; chunk_len];
for w in 0..self.n_workers {
let start = w * chunk_len;
let end = start + chunk_len;
if end > buf.len() {
return Err(CoreError::ComputationError(ErrorContext::new(
"shared buffer smaller than expected during reduce",
)));
}
for (acc, &val) in result.iter_mut().zip(buf[start..end].iter()) {
*acc += val;
}
}
Ok(result)
}
fn reduce_max_from_buf(&self, buf: &[f64], chunk_len: usize) -> CoreResult<Vec<f64>> {
let mut result = vec![f64::NEG_INFINITY; chunk_len];
for w in 0..self.n_workers {
let start = w * chunk_len;
let end = start + chunk_len;
if end > buf.len() {
return Err(CoreError::ComputationError(ErrorContext::new(
"shared buffer smaller than expected during reduce",
)));
}
for (acc, &val) in result.iter_mut().zip(buf[start..end].iter()) {
if val > *acc {
*acc = val;
}
}
}
Ok(result)
}
}
pub fn broadcast<T: Clone + Send>(data: T, src_worker: usize, n_workers: usize) -> Vec<T> {
assert!(n_workers > 0, "n_workers must be > 0");
assert!(
src_worker < n_workers,
"src_worker {src_worker} >= n_workers {n_workers}"
);
(0..n_workers).map(|_| data.clone()).collect()
}
pub fn scatter<T: Clone + Send>(data: &[T], n_workers: usize) -> Vec<Vec<T>> {
assert!(n_workers > 0, "n_workers must be > 0");
let total = data.len();
let base = total / n_workers;
let remainder = total % n_workers;
let mut result = Vec::with_capacity(n_workers);
let mut offset = 0_usize;
for w in 0..n_workers {
let this_len = if w < remainder { base + 1 } else { base };
result.push(data[offset..offset + this_len].to_vec());
offset += this_len;
}
result
}
pub fn gather<T: Clone + Send>(
local_data: Vec<T>,
worker_id: usize,
n_workers: usize,
shared: Arc<Mutex<Vec<Option<Vec<T>>>>>,
barrier: Arc<Barrier>,
) -> CoreResult<Option<Vec<T>>> {
if worker_id >= n_workers {
return Err(CoreError::ValueError(ErrorContext::new(format!(
"worker_id {worker_id} >= n_workers {n_workers}"
))));
}
{
let mut buf = shared.lock().map_err(|_| {
CoreError::ComputationError(ErrorContext::new("gather: shared mutex poisoned"))
})?;
if buf.len() != n_workers {
return Err(CoreError::ValueError(ErrorContext::new(format!(
"shared buffer has length {} but n_workers is {n_workers}",
buf.len()
))));
}
buf[worker_id] = Some(local_data);
}
barrier.wait();
if worker_id == 0 {
let buf = shared.lock().map_err(|_| {
CoreError::ComputationError(ErrorContext::new("gather: shared mutex poisoned on read"))
})?;
let mut gathered: Vec<T> = Vec::new();
for slot in buf.iter() {
match slot {
Some(chunk) => gathered.extend(chunk.iter().cloned()),
None => {
return Err(CoreError::ComputationError(ErrorContext::new(
"gather: a worker slot was empty after barrier",
)));
}
}
}
Ok(Some(gathered))
} else {
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Barrier, Mutex};
use std::thread;
#[test]
fn test_all_reduce_sum_4_workers() {
let n = 4_usize;
let op = Arc::new(AllReduceOp::new(n));
let handles: Vec<_> = (0..n)
.map(|id| {
let op_ref = Arc::clone(&op);
thread::spawn(move || {
let local: Vec<f64> = vec![
(id + 1) as f64,
(id + 1) as f64 * 2.0,
(id + 1) as f64 * 3.0,
];
op_ref.all_reduce_sum(&local, id)
})
})
.collect();
for h in handles {
let result = h
.join()
.expect("thread panicked")
.expect("all_reduce_sum failed");
assert_eq!(result, vec![10.0, 20.0, 30.0]);
}
}
#[test]
fn test_all_reduce_sum_single_element() {
let n = 4_usize;
let op = Arc::new(AllReduceOp::new(n));
let handles: Vec<_> = (0..n)
.map(|id| {
let op_ref = Arc::clone(&op);
thread::spawn(move || op_ref.all_reduce_sum(&[1.0_f64], id))
})
.collect();
for h in handles {
let r = h.join().expect("panic").expect("error");
assert_eq!(r, vec![4.0]);
}
}
#[test]
fn test_all_reduce_mean_4_workers() {
let n = 4_usize;
let op = Arc::new(AllReduceOp::new(n));
let handles: Vec<_> = (0..n)
.map(|id| {
let op_ref = Arc::clone(&op);
thread::spawn(move || op_ref.all_reduce_mean(&[4.0_f64], id))
})
.collect();
for h in handles {
let r = h.join().expect("panic").expect("error");
let diff = (r[0] - 4.0_f64).abs();
assert!(diff < 1e-10, "expected 4.0, got {}", r[0]);
}
}
#[test]
fn test_all_reduce_mean_heterogeneous() {
let n = 4_usize;
let op = Arc::new(AllReduceOp::new(n));
let inputs: Vec<f64> = vec![1.0, 3.0, 5.0, 7.0];
let handles: Vec<_> = (0..n)
.map(|id| {
let op_ref = Arc::clone(&op);
let val = inputs[id];
thread::spawn(move || op_ref.all_reduce_mean(&[val], id))
})
.collect();
for h in handles {
let r = h.join().expect("panic").expect("error");
let diff = (r[0] - 4.0_f64).abs();
assert!(diff < 1e-10, "expected 4.0, got {}", r[0]);
}
}
#[test]
fn test_all_reduce_max() {
let n = 4_usize;
let op = Arc::new(AllReduceOp::new(n));
let inputs: Vec<Vec<f64>> = vec![
vec![1.0, 9.0],
vec![3.0, 2.0],
vec![7.0, 5.0],
vec![4.0, 8.0],
];
let handles: Vec<_> = (0..n)
.map(|id| {
let op_ref = Arc::clone(&op);
let local = inputs[id].clone();
thread::spawn(move || op_ref.all_reduce_max(&local, id))
})
.collect();
for h in handles {
let r = h.join().expect("panic").expect("error");
assert_eq!(r, vec![7.0, 9.0]);
}
}
#[test]
fn test_broadcast_copies_to_all() {
let copies = broadcast(42_u32, 0, 5);
assert_eq!(copies.len(), 5);
assert!(copies.iter().all(|&v| v == 42));
}
#[test]
fn test_broadcast_vec_cloned() {
let copies = broadcast(vec![1.0_f64, 2.0], 0, 3);
assert_eq!(copies.len(), 3);
for c in &copies {
assert_eq!(c, &vec![1.0_f64, 2.0]);
}
}
#[test]
fn test_scatter_even_split() {
let chunks = scatter(&[10, 20, 30, 40, 50, 60], 3);
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0], vec![10, 20]);
assert_eq!(chunks[1], vec![30, 40]);
assert_eq!(chunks[2], vec![50, 60]);
}
#[test]
fn test_scatter_uneven_split() {
let chunks = scatter(&[1, 2, 3, 4, 5, 6, 7], 3);
assert_eq!(chunks[0], vec![1, 2, 3]);
assert_eq!(chunks[1], vec![4, 5]);
assert_eq!(chunks[2], vec![6, 7]);
}
#[test]
fn test_scatter_single_element_each() {
let chunks = scatter(&[100_u8, 200, 150], 3);
assert_eq!(chunks[0], vec![100_u8]);
assert_eq!(chunks[1], vec![200_u8]);
assert_eq!(chunks[2], vec![150_u8]);
}
#[test]
fn test_scatter_more_workers_than_elements() {
let chunks = scatter(&[1_i32, 2], 4);
assert_eq!(chunks[0], vec![1_i32]);
assert_eq!(chunks[1], vec![2_i32]);
assert_eq!(chunks[2], Vec::<i32>::new());
assert_eq!(chunks[3], Vec::<i32>::new());
}
#[test]
fn test_gather_root_collects() {
let n = 4_usize;
let barrier = Arc::new(Barrier::new(n));
let shared: Arc<Mutex<Vec<Option<Vec<u32>>>>> =
Arc::new(Mutex::new(vec![None; n]));
let handles: Vec<_> = (0..n)
.map(|id| {
let b = Arc::clone(&barrier);
let s = Arc::clone(&shared);
thread::spawn(move || {
gather(vec![id as u32 * 10, id as u32 * 10 + 1], id, n, s, b)
})
})
.collect();
let mut root_data: Option<Vec<u32>> = None;
for (id, h) in handles.into_iter().enumerate() {
let res = h.join().expect("panic").expect("gather error");
if id == 0 {
root_data = res;
} else {
assert!(res.is_none(), "non-root worker should return None");
}
}
let gathered = root_data.expect("root must have data");
assert_eq!(gathered.len(), 8);
assert_eq!(gathered[0], 0);
assert_eq!(gathered[1], 1);
assert_eq!(gathered[2], 10);
}
#[test]
fn test_all_reduce_invalid_worker_id() {
let op = AllReduceOp::new(2);
let op1 = AllReduceOp::new(1);
let err = op1.all_reduce_sum(&[1.0], 99);
assert!(err.is_err());
let _ = op; }
#[test]
fn test_all_reduce_empty_local_data() {
let op = AllReduceOp::new(1);
let err = op.all_reduce_sum(&[], 0);
assert!(err.is_err());
}
}