pub trait Communicator: Send + Sync {
// Required methods
fn world_size(&self) -> usize;
fn rank(&self) -> usize;
unsafe fn all_reduce(
&self,
ptr: u64,
count: usize,
dtype: DType,
op: ReduceOp,
) -> Result<()>;
unsafe fn broadcast(
&self,
ptr: u64,
count: usize,
dtype: DType,
root: usize,
) -> Result<()>;
unsafe fn all_gather(
&self,
send_ptr: u64,
recv_ptr: u64,
count: usize,
dtype: DType,
) -> Result<()>;
unsafe fn reduce_scatter(
&self,
send_ptr: u64,
recv_ptr: u64,
count: usize,
dtype: DType,
op: ReduceOp,
) -> Result<()>;
unsafe fn send(
&self,
ptr: u64,
count: usize,
dtype: DType,
dest: usize,
tag: u32,
) -> Result<()>;
unsafe fn recv(
&self,
ptr: u64,
count: usize,
dtype: DType,
src: usize,
tag: u32,
) -> Result<()>;
fn sync(&self) -> Result<()>;
fn barrier(&self) -> Result<()>;
// Provided methods
fn split(
&self,
_color: u32,
_key: u32,
) -> Result<Option<Box<dyn Communicator>>> { ... }
fn as_stream_sync(&self) -> Option<&dyn StreamSyncOps> { ... }
}Expand description
Multi-device collective communication
Operates on device pointers (u64) + element count + DType, matching
NCCL’s and MPI’s native calling conventions. The u64 pointer is the
same abstraction as Runtime::allocate() / Runtime::deallocate().
DType provides unambiguous type information so backends can dispatch
to the correct reduction unit (e.g., f16 vs bf16 vs i16 are all 2 bytes
but require different hardware reduction units).
§Safety
All pointer-based methods are unsafe fn because passing an invalid u64
(dangling, wrong device, wrong provenance) causes undefined behavior.
Callers MUST ensure:
- NCCL: pointers are GPU device pointers from the same CUDA context
- MPI: pointers are valid host pointers
- Pointer provenance matches the communicator backend
- Buffers remain allocated until
sync()orbarrier()
Higher-level wrappers (boostr’s distributed patterns) accept Tensor<R>
and extract pointers internally, providing a safe public API.
§Drop contract
Dropping with pending non-blocking operations attempts best-effort sync
with a bounded timeout. On failure the destructor logs the error
(via tracing::error!) and proceeds — it never panics.
§Thread safety
Send + Sync so it can be stored in Arc. If multiple threads call
send()/recv() concurrently, submission order is implementation-defined.
For deterministic ordering, serialize submissions externally.
Required Methods§
Sourcefn world_size(&self) -> usize
fn world_size(&self) -> usize
Number of participants
Sourceunsafe fn all_reduce(
&self,
ptr: u64,
count: usize,
dtype: DType,
op: ReduceOp,
) -> Result<()>
unsafe fn all_reduce( &self, ptr: u64, count: usize, dtype: DType, op: ReduceOp, ) -> Result<()>
AllReduce in-place: reduce across all ranks, result on all ranks.
Completion semantics are implementation-defined. On NCCL the operation
is non-blocking (stream-ordered). Portable code must call sync()
before reading the result buffer.
§Safety
ptr must be a valid device pointer with at least count elements of dtype.
Sourceunsafe fn broadcast(
&self,
ptr: u64,
count: usize,
dtype: DType,
root: usize,
) -> Result<()>
unsafe fn broadcast( &self, ptr: u64, count: usize, dtype: DType, root: usize, ) -> Result<()>
Broadcast from root rank to all other ranks.
§Safety
ptr must be a valid device pointer with at least count elements of dtype.
Sourceunsafe fn all_gather(
&self,
send_ptr: u64,
recv_ptr: u64,
count: usize,
dtype: DType,
) -> Result<()>
unsafe fn all_gather( &self, send_ptr: u64, recv_ptr: u64, count: usize, dtype: DType, ) -> Result<()>
AllGather: each rank contributes count elements, result is
count * world_size elements on all ranks.
§Safety
send_ptrmust point to at leastcountelementsrecv_ptrmust point to at leastcount * world_sizeelements
Sourceunsafe fn reduce_scatter(
&self,
send_ptr: u64,
recv_ptr: u64,
count: usize,
dtype: DType,
op: ReduceOp,
) -> Result<()>
unsafe fn reduce_scatter( &self, send_ptr: u64, recv_ptr: u64, count: usize, dtype: DType, op: ReduceOp, ) -> Result<()>
ReduceScatter: reduce + scatter. Each rank gets a different slice of the reduced result.
§Safety
send_ptrmust point to at leastcount * world_sizeelementsrecv_ptrmust point to at leastcountelements
Sourceunsafe fn send(
&self,
ptr: u64,
count: usize,
dtype: DType,
dest: usize,
tag: u32,
) -> Result<()>
unsafe fn send( &self, ptr: u64, count: usize, dtype: DType, dest: usize, tag: u32, ) -> Result<()>
Point-to-point send to a specific rank (non-blocking).
The send buffer must NOT be modified or deallocated until sync().
tag is used for message matching on MPI. On NCCL, tag is accepted
but ignored (stream-ordered submission determines matching).
§Safety
ptr must be a valid device pointer with at least count elements of dtype.
Sourceunsafe fn recv(
&self,
ptr: u64,
count: usize,
dtype: DType,
src: usize,
tag: u32,
) -> Result<()>
unsafe fn recv( &self, ptr: u64, count: usize, dtype: DType, src: usize, tag: u32, ) -> Result<()>
Point-to-point receive from a specific rank (non-blocking).
The recv buffer contains valid data only after sync() or barrier().
§Safety
ptr must be a valid device pointer with at least count elements of dtype.
Provided Methods§
Sourcefn split(&self, _color: u32, _key: u32) -> Result<Option<Box<dyn Communicator>>>
fn split(&self, _color: u32, _key: u32) -> Result<Option<Box<dyn Communicator>>>
Split this communicator into sub-communicators by color and key.
All ranks must call split() collectively. Ranks with the same color
end up in the same sub-communicator, ordered by key.
Returns None for backends that don’t support splitting (e.g., NCCL
without ncclCommSplit, or the no-op communicator).
Sourcefn as_stream_sync(&self) -> Option<&dyn StreamSyncOps>
fn as_stream_sync(&self) -> Option<&dyn StreamSyncOps>
Downcast to StreamSyncOps if this communicator supports CUDA
stream/event synchronization for compute-communication overlap.
Returns None by default. Backends with separate communication
streams (e.g., NCCL) override this to return Some(self).