use crate::client::NexarClient;
use crate::collective::helpers::{CollectiveTag, collective_recv, collective_send, step_tag};
use crate::error::{NexarError, Result};
use crate::types::DataType;
pub(crate) async unsafe fn ring_allgather(
client: &NexarClient,
send_ptr: u64,
recv_ptr: u64,
count: usize,
dtype: DataType,
tag: CollectiveTag,
) -> Result<()> {
let world = client.world_size() as usize;
let rank = client.rank() as usize;
let elem_size = dtype.size_in_bytes();
let chunk_bytes = count * elem_size;
let total_bytes = chunk_bytes * world;
if world <= 1 {
let data = unsafe { client.adapter().stage_for_send(send_ptr, chunk_bytes)? };
unsafe { client.adapter().receive_to_device(&data, recv_ptr)? };
return Ok(());
}
let mut buf = vec![0u8; total_bytes];
let own_data = unsafe { client.adapter().stage_for_send(send_ptr, chunk_bytes)? };
buf[rank * chunk_bytes..(rank + 1) * chunk_bytes].copy_from_slice(&own_data);
let next = (rank + 1) % world;
let prev = (rank + world - 1) % world;
for step in 0..(world - 1) {
let send_idx = (rank + world - step) % world;
let recv_idx = (rank + world - step - 1) % world;
let send_data = buf[send_idx * chunk_bytes..(send_idx + 1) * chunk_bytes].to_vec();
let round_tag = step_tag(tag, step);
let (_, received) = tokio::try_join!(
collective_send(client, next as u32, &send_data, "allgather", round_tag),
collective_recv(client, prev as u32, "allgather", round_tag),
)?;
if received.len() != chunk_bytes {
return Err(NexarError::BufferSizeMismatch {
expected: chunk_bytes,
actual: received.len(),
});
}
buf[recv_idx * chunk_bytes..(recv_idx + 1) * chunk_bytes].copy_from_slice(&received);
}
unsafe { client.adapter().receive_to_device(&buf, recv_ptr)? };
Ok(())
}