use crate::client::NexarClient;
use crate::collective::helpers::{
ChunkLayout, CollectiveTag, collective_recv, collective_send, step_tag,
};
use crate::error::{NexarError, Result};
use crate::reduce::reduce_slice;
use crate::types::{DataType, ReduceOp};
pub(crate) async unsafe fn ring_allreduce(
client: &NexarClient,
ptr: u64,
count: usize,
dtype: DataType,
op: ReduceOp,
tag: CollectiveTag,
) -> Result<()> {
let world = client.world_size() as usize;
let total_bytes = count * dtype.size_in_bytes();
let cfg = client.config();
if total_bytes >= cfg.large_msg_bytes {
unsafe {
crate::collective::pipelined_allreduce::pipelined_ring_allreduce(
client, ptr, count, dtype, op, tag,
)
.await
}
} else if world <= cfg.ring_max_world {
unsafe { ring_allreduce_impl(client, ptr, count, dtype, op, tag).await }
} else if count < world {
unsafe { ring_allreduce_impl(client, ptr, count, dtype, op, tag).await }
} else {
unsafe { halving_doubling_allreduce(client, ptr, count, dtype, op, tag).await }
}
}
async unsafe fn ring_allreduce_impl(
client: &NexarClient,
ptr: u64,
count: usize,
dtype: DataType,
op: ReduceOp,
tag: CollectiveTag,
) -> Result<()> {
let world = client.world_size() as usize;
let rank = client.rank() as usize;
if world <= 1 {
return Ok(());
}
let elem_size = dtype.size_in_bytes();
let total_bytes = count * elem_size;
let data = unsafe { client.adapter().stage_for_send(ptr, total_bytes)? };
let mut buf = data;
let layout = ChunkLayout::new(count, world);
let next = (rank + 1) % world;
let prev = (rank + world - 1) % world;
for step in 0..(world - 1) {
let send_idx = (rank + world - step) % world;
let send_off = layout.offsets[send_idx] * elem_size;
let send_len = layout.chunk_count(send_idx) * elem_size;
let recv_idx = (rank + world - step - 1) % world;
let recv_off = layout.offsets[recv_idx] * elem_size;
let recv_count = layout.chunk_count(recv_idx);
let recv_len = recv_count * elem_size;
let send_snapshot = buf[send_off..send_off + send_len].to_vec();
let round_tag = step_tag(tag, step);
let (_, received) = tokio::try_join!(
collective_send(client, next as u32, &send_snapshot, "allreduce", round_tag),
collective_recv(client, prev as u32, "allreduce", round_tag),
)?;
if received.len() != recv_len {
return Err(NexarError::BufferSizeMismatch {
expected: recv_len,
actual: received.len(),
});
}
let dst_slice = &mut buf[recv_off..recv_off + recv_len];
reduce_slice(dst_slice, &received, recv_count, dtype, op)?;
}
for step in 0..(world - 1) {
let send_idx = (rank + world + 1 - step) % world;
let send_off = layout.offsets[send_idx] * elem_size;
let send_len = layout.chunk_count(send_idx) * elem_size;
let recv_idx = (rank + world - step) % world;
let recv_off = layout.offsets[recv_idx] * elem_size;
let recv_len = layout.chunk_count(recv_idx) * elem_size;
let send_snapshot = buf[send_off..send_off + send_len].to_vec();
let round_tag = step_tag(tag, world - 1 + step);
let (_, received) = tokio::try_join!(
collective_send(client, next as u32, &send_snapshot, "allreduce", round_tag),
collective_recv(client, prev as u32, "allreduce", round_tag),
)?;
if received.len() != recv_len {
return Err(NexarError::BufferSizeMismatch {
expected: recv_len,
actual: received.len(),
});
}
buf[recv_off..recv_off + recv_len].copy_from_slice(&received);
}
unsafe { client.adapter().receive_to_device(&buf, ptr)? };
Ok(())
}
async unsafe fn halving_doubling_allreduce(
client: &NexarClient,
ptr: u64,
count: usize,
dtype: DataType,
op: ReduceOp,
tag: CollectiveTag,
) -> Result<()> {
let world = client.world_size() as usize;
let rank = client.rank() as usize;
if world <= 1 {
return Ok(());
}
let elem_size = dtype.size_in_bytes();
let total_bytes = count * elem_size;
let data = unsafe { client.adapter().stage_for_send(ptr, total_bytes)? };
let mut buf = data;
let p2 = world.next_power_of_two() >> if world.is_power_of_two() { 0 } else { 1 };
let excess = world - p2;
let mut virtual_rank: Option<usize> = None;
let mut step_counter = 0usize;
if rank < excess {
let partner = rank + p2;
let received = collective_recv(
client,
partner as u32,
"allreduce",
step_tag(tag, step_counter),
)
.await?;
if received.len() != total_bytes {
return Err(NexarError::BufferSizeMismatch {
expected: total_bytes,
actual: received.len(),
});
}
reduce_slice(&mut buf, &received, count, dtype, op)?;
virtual_rank = Some(rank);
} else if rank >= p2 {
let partner = rank - p2;
collective_send(
client,
partner as u32,
&buf,
"allreduce",
step_tag(tag, step_counter),
)
.await?;
} else {
virtual_rank = Some(rank);
}
step_counter += 1;
if let Some(vrank) = virtual_rank {
let log2 = p2.trailing_zeros() as usize;
let mut slice_start = 0usize;
let mut slice_len = count;
for round in 0..log2 {
let partner_vrank = vrank ^ (1 << round);
let partner_real = partner_vrank;
let half = slice_len / 2;
let half_rem = slice_len - half;
let (send_start, send_len, keep_start, keep_len) = if vrank < partner_vrank {
(slice_start + half, half_rem, slice_start, half)
} else {
(slice_start, half, slice_start + half, half_rem)
};
let send_off = send_start * elem_size;
let send_bytes = send_len * elem_size;
let keep_off = keep_start * elem_size;
let keep_bytes = keep_len * elem_size;
let send_data = buf[send_off..send_off + send_bytes].to_vec();
let round_tag = step_tag(tag, step_counter + round);
let (_, received) = tokio::try_join!(
collective_send(
client,
partner_real as u32,
&send_data,
"allreduce",
round_tag
),
collective_recv(client, partner_real as u32, "allreduce", round_tag),
)?;
if received.len() != keep_bytes {
return Err(NexarError::BufferSizeMismatch {
expected: keep_bytes,
actual: received.len(),
});
}
let dst = &mut buf[keep_off..keep_off + keep_bytes];
reduce_slice(dst, &received, keep_len, dtype, op)?;
slice_start = keep_start;
slice_len = keep_len;
}
let ag_base = step_counter + log2;
for round in (0..log2).rev() {
let partner_vrank = vrank ^ (1 << round);
let partner_real = partner_vrank;
let send_off = slice_start * elem_size;
let send_bytes = slice_len * elem_size;
let send_data = buf[send_off..send_off + send_bytes].to_vec();
let round_tag = step_tag(tag, ag_base + (log2 - 1 - round));
let (_, received) = tokio::try_join!(
collective_send(
client,
partner_real as u32,
&send_data,
"allreduce",
round_tag
),
collective_recv(client, partner_real as u32, "allreduce", round_tag),
)?;
let (new_start, new_len) = if vrank < partner_vrank {
let recv_start = slice_start + slice_len;
let recv_len = received.len() / elem_size;
let recv_off = recv_start * elem_size;
buf[recv_off..recv_off + received.len()].copy_from_slice(&received);
(slice_start, slice_len + recv_len)
} else {
let recv_len = received.len() / elem_size;
let recv_start = slice_start - recv_len;
let recv_off = recv_start * elem_size;
buf[recv_off..recv_off + received.len()].copy_from_slice(&received);
(recv_start, recv_len + slice_len)
};
slice_start = new_start;
slice_len = new_len;
}
}
let final_tag = step_tag(tag, 999); if rank < excess {
let partner = rank + p2;
collective_send(client, partner as u32, &buf, "allreduce", final_tag).await?;
} else if rank >= p2 {
let partner = rank - p2;
let received = collective_recv(client, partner as u32, "allreduce", final_tag).await?;
if received.len() != total_bytes {
return Err(NexarError::BufferSizeMismatch {
expected: total_bytes,
actual: received.len(),
});
}
buf.copy_from_slice(&received);
}
unsafe { client.adapter().receive_to_device(&buf, ptr)? };
Ok(())
}