use crate::backend::ReduceOp;
use crate::process_group::ProcessGroup;
use axonml_tensor::Tensor;
pub fn all_reduce_sum(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
pg.all_reduce_tensor(tensor, ReduceOp::Sum);
}
pub fn all_reduce_mean(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
pg.all_reduce_tensor(tensor, ReduceOp::Average);
}
pub fn all_reduce_min(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
pg.all_reduce_tensor(tensor, ReduceOp::Min);
}
pub fn all_reduce_max(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
pg.all_reduce_tensor(tensor, ReduceOp::Max);
}
pub fn all_reduce_product(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
pg.all_reduce_tensor(tensor, ReduceOp::Product);
}
pub fn broadcast(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
broadcast_from(tensor, 0, pg);
}
pub fn broadcast_from(tensor: &mut Tensor<f32>, src: usize, pg: &ProcessGroup) {
pg.broadcast_tensor(tensor, src);
}
#[must_use]
pub fn all_gather(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
pg.all_gather_tensor(tensor)
}
#[must_use]
pub fn reduce_scatter_sum(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
pg.reduce_scatter_tensor(tensor, ReduceOp::Sum)
}
#[must_use]
pub fn reduce_scatter_mean(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
pg.reduce_scatter_tensor(tensor, ReduceOp::Average)
}
pub fn barrier(pg: &ProcessGroup) {
pg.barrier();
}
#[must_use]
pub fn is_main_process(pg: &ProcessGroup) -> bool {
pg.rank() == 0
}
#[must_use]
pub fn world_size(pg: &ProcessGroup) -> usize {
pg.world_size()
}
#[must_use]
pub fn rank(pg: &ProcessGroup) -> usize {
pg.rank()
}
#[must_use]
pub fn scatter_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
let shape = tensor.shape();
if dim >= shape.len() {
return tensor.clone();
}
let world_size = pg.world_size();
let rank = pg.rank();
let dim_size = shape[dim];
if dim_size % world_size != 0 {
return tensor.clone();
}
let chunk_size = dim_size / world_size;
let start = rank * chunk_size;
let end = start + chunk_size;
if shape.len() == 1 && dim == 0 {
let data = tensor.to_vec();
let chunk = data[start..end].to_vec();
return Tensor::from_vec(chunk, &[chunk_size]).unwrap();
}
if shape.len() == 2 && dim == 0 {
let data = tensor.to_vec();
let cols = shape[1];
let mut chunk = Vec::with_capacity(chunk_size * cols);
for row in start..end {
let row_start = row * cols;
let row_end = row_start + cols;
chunk.extend_from_slice(&data[row_start..row_end]);
}
return Tensor::from_vec(chunk, &[chunk_size, cols]).unwrap();
}
tensor.clone()
}
#[must_use]
pub fn gather_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
let gathered = pg.all_gather_tensor(tensor);
let world_size = pg.world_size();
let shape = tensor.shape();
if shape.len() == 1 && dim == 0 {
let data = gathered.to_vec();
return Tensor::from_vec(data, &[shape[0] * world_size]).unwrap();
}
gathered
}
pub fn sync_gradients(gradients: &mut [Tensor<f32>], pg: &ProcessGroup) {
for grad in gradients.iter_mut() {
all_reduce_mean(grad, pg);
}
}
pub fn sync_gradient(gradient: &mut Tensor<f32>, pg: &ProcessGroup) {
all_reduce_mean(gradient, pg);
}
pub fn ring_all_reduce(data: &mut [f32], pg: &ProcessGroup, op: ReduceOp) {
let world_size = pg.world_size();
if world_size == 1 {
return;
}
pg.backend().all_reduce(data, op);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_all_reduce_sum() {
let pg = ProcessGroup::mock();
let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
all_reduce_sum(&mut tensor, &pg);
assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_all_reduce_mean() {
let pg = ProcessGroup::mock();
let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
all_reduce_mean(&mut tensor, &pg);
assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_all_reduce_min() {
let pg = ProcessGroup::mock();
let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
all_reduce_min(&mut tensor, &pg);
assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_all_reduce_max() {
let pg = ProcessGroup::mock();
let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
all_reduce_max(&mut tensor, &pg);
assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_broadcast() {
let pg = ProcessGroup::mock();
let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
broadcast(&mut tensor, &pg);
assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
}
#[test]
fn test_broadcast_from() {
let pg = ProcessGroup::mock();
let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
broadcast_from(&mut tensor, 0, &pg);
assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
}
#[test]
fn test_all_gather() {
let pg = ProcessGroup::mock();
let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
let gathered = all_gather(&tensor, &pg);
assert_eq!(gathered.shape(), &[1, 2]);
}
#[test]
fn test_reduce_scatter_sum() {
let pg = ProcessGroup::mock();
let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
let scattered = reduce_scatter_sum(&tensor, &pg);
assert_eq!(scattered.shape(), &[2]);
}
#[test]
fn test_barrier() {
let pg = ProcessGroup::mock();
barrier(&pg); }
#[test]
fn test_is_main_process() {
let pg = ProcessGroup::mock();
assert!(is_main_process(&pg));
}
#[test]
fn test_world_size() {
let pg = ProcessGroup::mock();
assert_eq!(world_size(&pg), 1);
}
#[test]
fn test_rank() {
let pg = ProcessGroup::mock();
assert_eq!(rank(&pg), 0);
}
#[test]
fn test_scatter_tensor_1d() {
let pg = ProcessGroup::mock();
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
let scattered = scatter_tensor(&tensor, 0, &pg);
assert_eq!(scattered.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_gather_tensor() {
let pg = ProcessGroup::mock();
let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
let gathered = gather_tensor(&tensor, 0, &pg);
assert_eq!(gathered.to_vec(), vec![1.0, 2.0]);
}
#[test]
fn test_sync_gradients() {
let pg = ProcessGroup::mock();
let mut grads = vec![
Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
];
sync_gradients(&mut grads, &pg);
assert_eq!(grads[0].to_vec(), vec![1.0, 2.0]);
assert_eq!(grads[1].to_vec(), vec![3.0, 4.0]);
}
#[test]
fn test_sync_gradient() {
let pg = ProcessGroup::mock();
let mut grad = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
sync_gradient(&mut grad, &pg);
assert_eq!(grad.to_vec(), vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_ring_all_reduce() {
let pg = ProcessGroup::mock();
let mut data = vec![1.0, 2.0, 3.0];
ring_all_reduce(&mut data, &pg, ReduceOp::Sum);
assert_eq!(data, vec![1.0, 2.0, 3.0]);
}
}