use crate::backend::{Backend, MockBackend, ReduceOp};
use axonml_tensor::Tensor;
use std::sync::Arc;
pub struct ProcessGroup {
backend: Arc<dyn Backend>,
ranks: Vec<usize>,
}
impl ProcessGroup {
pub fn new(backend: Arc<dyn Backend>) -> Self {
let world_size = backend.world_size();
Self {
backend,
ranks: (0..world_size).collect(),
}
}
pub fn with_ranks(backend: Arc<dyn Backend>, ranks: Vec<usize>) -> Self {
Self { backend, ranks }
}
#[must_use]
pub fn mock() -> Self {
Self::new(Arc::new(MockBackend::single()))
}
#[must_use]
pub fn backend(&self) -> &dyn Backend {
self.backend.as_ref()
}
#[must_use]
pub fn rank(&self) -> usize {
self.backend.rank()
}
#[must_use]
pub fn world_size(&self) -> usize {
self.backend.world_size()
}
#[must_use]
pub fn size(&self) -> usize {
self.ranks.len()
}
#[must_use]
pub fn ranks(&self) -> &[usize] {
&self.ranks
}
#[must_use]
pub fn contains(&self, rank: usize) -> bool {
self.ranks.contains(&rank)
}
pub fn barrier(&self) {
self.backend.barrier();
}
pub fn all_reduce_tensor(&self, tensor: &mut Tensor<f32>, op: ReduceOp) {
let mut data = tensor.to_vec();
self.backend.all_reduce(&mut data, op);
*tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
}
pub fn broadcast_tensor(&self, tensor: &mut Tensor<f32>, src: usize) {
let mut data = tensor.to_vec();
self.backend.broadcast(&mut data, src);
*tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
}
#[must_use]
pub fn all_gather_tensor(&self, send_tensor: &Tensor<f32>) -> Tensor<f32> {
let send_data = send_tensor.to_vec();
let mut recv_data = vec![0.0; send_data.len() * self.world_size()];
self.backend.all_gather(&send_data, &mut recv_data);
let mut new_shape = vec![self.world_size()];
new_shape.extend(send_tensor.shape());
Tensor::from_vec(recv_data, &new_shape).unwrap()
}
#[must_use]
pub fn reduce_scatter_tensor(&self, send_tensor: &Tensor<f32>, op: ReduceOp) -> Tensor<f32> {
let send_data = send_tensor.to_vec();
let chunk_size = send_data.len() / self.world_size();
let mut recv_data = vec![0.0; chunk_size];
self.backend.reduce_scatter(&send_data, &mut recv_data, op);
let original_shape = send_tensor.shape();
let mut new_shape = original_shape.to_vec();
if !new_shape.is_empty() {
new_shape[0] /= self.world_size();
}
Tensor::from_vec(recv_data, &new_shape).unwrap()
}
pub fn send_tensor(&self, tensor: &mut Tensor<f32>, dst: usize) {
let data = tensor.to_vec();
self.backend.send(&data, dst, 0);
}
#[must_use]
pub fn recv_tensor(&self, src: usize, shape: &[usize]) -> Tensor<f32> {
let size: usize = shape.iter().product();
let mut data = vec![0.0; size];
self.backend.recv(&mut data, src, 0);
Tensor::from_vec(data, shape).unwrap()
}
}
pub struct World {
default_group: ProcessGroup,
}
impl World {
pub fn init(backend: Arc<dyn Backend>) -> Self {
Self {
default_group: ProcessGroup::new(backend),
}
}
#[must_use]
pub fn mock() -> Self {
Self {
default_group: ProcessGroup::mock(),
}
}
#[must_use]
pub fn default_group(&self) -> &ProcessGroup {
&self.default_group
}
#[must_use]
pub fn rank(&self) -> usize {
self.default_group.rank()
}
#[must_use]
pub fn world_size(&self) -> usize {
self.default_group.world_size()
}
#[must_use]
pub fn is_main(&self) -> bool {
self.rank() == 0
}
pub fn barrier(&self) {
self.default_group.barrier();
}
#[must_use]
pub fn new_group(&self, ranks: Vec<usize>) -> ProcessGroup {
ProcessGroup::with_ranks(Arc::clone(&self.default_group.backend), ranks)
}
}
impl Clone for ProcessGroup {
fn clone(&self) -> Self {
Self {
backend: Arc::clone(&self.backend),
ranks: self.ranks.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_process_group_mock() {
let pg = ProcessGroup::mock();
assert_eq!(pg.rank(), 0);
assert_eq!(pg.world_size(), 1);
assert_eq!(pg.size(), 1);
}
#[test]
fn test_process_group_contains() {
let pg = ProcessGroup::mock();
assert!(pg.contains(0));
assert!(!pg.contains(1));
}
#[test]
fn test_world_mock() {
let world = World::mock();
assert_eq!(world.rank(), 0);
assert_eq!(world.world_size(), 1);
assert!(world.is_main());
}
#[test]
fn test_world_new_group() {
let world = World::mock();
let group = world.new_group(vec![0]);
assert_eq!(group.size(), 1);
}
#[test]
fn test_process_group_all_reduce_tensor() {
let backends = MockBackend::create_world(2);
let pg0 = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
pg0.all_reduce_tensor(&mut tensor, ReduceOp::Sum);
assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_process_group_broadcast_tensor() {
let pg = ProcessGroup::mock();
let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
pg.broadcast_tensor(&mut tensor, 0);
assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
}
#[test]
fn test_process_group_all_gather_tensor() {
let pg = ProcessGroup::mock();
let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
let gathered = pg.all_gather_tensor(&tensor);
assert_eq!(gathered.shape(), &[1, 2]);
}
#[test]
fn test_process_group_barrier() {
let pg = ProcessGroup::mock();
pg.barrier(); }
#[test]
fn test_world_barrier() {
let world = World::mock();
world.barrier(); }
#[test]
fn test_process_group_clone() {
let pg = ProcessGroup::mock();
let pg2 = pg.clone();
assert_eq!(pg.rank(), pg2.rank());
assert_eq!(pg.world_size(), pg2.world_size());
}
}