use crate::client::NexarClient;
use crate::collective::allreduce::ring_allreduce;
use crate::collective::helpers::CollectiveTag;
use crate::error::Result;
use crate::types::{DataType, IoVec, ReduceOp};
pub(crate) async unsafe fn allreduce_bucketed(
client: &NexarClient,
entries: &[(u64, usize)],
dtype: DataType,
op: ReduceOp,
tag: CollectiveTag,
) -> Result<()> {
if entries.is_empty() {
return Ok(());
}
if !client.adapter().supports_host_offload() {
return Err(crate::error::NexarError::CollectiveFailed {
operation: "allreduce_bucketed",
rank: client.rank(),
reason:
"bucketed allreduce requires a host-offload capable adapter (e.g. CpuAdapter); \
GPU users should use nexar-nccl's on-device bucketed operations"
.into(),
});
}
let elem_size = dtype.size_in_bytes();
let regions: Vec<IoVec> = entries
.iter()
.map(|&(ptr, count)| IoVec {
ptr,
len: count * elem_size,
})
.collect();
let flat = unsafe { client.adapter().stage_for_send_iov(®ions)? };
let total_count: usize = entries.iter().map(|&(_, c)| c).sum();
let total_bytes = total_count * elem_size;
let mut buf = flat;
debug_assert_eq!(buf.len(), total_bytes);
let buf_ptr = buf.as_mut_ptr() as u64;
unsafe {
ring_allreduce(client, buf_ptr, total_count, dtype, op, tag).await?;
}
unsafe { client.adapter().receive_to_device_iov(&buf, ®ions)? };
Ok(())
}