use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, Context, Props};
pub use cudarc::nccl::ReduceOp;
use cudarc::nccl::{group_end, group_start, Comm};
use tokio::sync::oneshot;
use crate::completion::CompletionStrategy;
use crate::device::DeviceState;
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx};
pub mod all_to_all;
pub mod allgather;
pub mod allreduce;
pub mod broadcast;
pub mod capabilities;
pub mod custom_op;
pub mod group;
pub mod p2p;
pub mod reduce;
pub mod reduce_scatter;
pub use all_to_all::{AllToAllRequest, AllToAllvRequest};
pub use allgather::AllGatherRequest;
pub use allreduce::AllReduceRequest;
pub use broadcast::BroadcastRequest;
pub use capabilities::{probe_capabilities, NcclCapabilities};
pub use custom_op::PreMulSumOp;
pub use group::GroupGuard;
pub use p2p::{RecvRequest, SendRequest};
pub use reduce::ReduceRequest;
pub use reduce_scatter::ReduceScatterRequest;
pub(crate) const LIB: &str = "nccl";
pub trait NcclReduceSupported: cudarc::nccl::NcclType + Copy + Send + Sync + 'static {
fn dispatch_dtype() -> crate::kernel::dispatch::DispatchDType;
}
macro_rules! impl_nccl_reduce_supported {
($t:ty, $kind:ident) => {
impl NcclReduceSupported for $t {
fn dispatch_dtype() -> crate::kernel::dispatch::DispatchDType {
crate::kernel::dispatch::DispatchDType::$kind
}
}
};
}
impl_nccl_reduce_supported!(f32, F32);
impl_nccl_reduce_supported!(f64, F64);
impl_nccl_reduce_supported!(i8, I8);
impl_nccl_reduce_supported!(u8, U8);
impl_nccl_reduce_supported!(i32, I32);
impl_nccl_reduce_supported!(u32, U32);
impl_nccl_reduce_supported!(i64, I64);
impl_nccl_reduce_supported!(u64, U64);
#[cfg(feature = "f16")]
impl_nccl_reduce_supported!(half::f16, F16);
#[cfg(feature = "f16")]
impl_nccl_reduce_supported!(half::bf16, Bf16);
pub enum CollectiveMsg {
Op(Box<dyn CollectiveDispatch>),
BeginGroup {
reply: oneshot::Sender<Result<(), GpuError>>,
},
EndGroup {
reply: oneshot::Sender<Result<(), GpuError>>,
},
QueryCapabilities {
reply: oneshot::Sender<NcclCapabilities>,
},
#[deprecated(
note = "use CollectiveMsg::Op(Box::new(AllReduceRequest::<f32> { ... })) instead"
)]
AllReduceF32 {
tensor: GpuRef<f32>,
op: ReduceOp,
reply: oneshot::Sender<Result<(), GpuError>>,
},
#[deprecated(
note = "use CollectiveMsg::Op(Box::new(BroadcastRequest::<f32> { ... })) instead"
)]
BroadcastF32 {
data: GpuRef<f32>,
root: usize,
reply: oneshot::Sender<Result<(), GpuError>>,
},
}
pub struct CollectiveActor {
inner: CollectiveInner,
}
pub(crate) struct SendComm(pub(crate) Comm);
unsafe impl Send for SendComm {}
unsafe impl Sync for SendComm {}
#[allow(dead_code)]
enum CollectiveInner {
Real {
comm: SendComm,
state: Arc<DeviceState>,
completion: Arc<dyn CompletionStrategy>,
},
Mock,
}
impl CollectiveActor {
pub fn props_for_rank(
comm: Comm,
state: Arc<DeviceState>,
completion: Arc<dyn CompletionStrategy>,
) -> Props<Self> {
use parking_lot::Mutex;
let comm_slot = Arc::new(Mutex::new(Some(SendComm(comm))));
Props::create(move || {
let comm = comm_slot
.lock()
.take()
.expect("Unrecoverable: CollectiveActor restart with consumed Comm — NcclWorldActor must rebuild the world");
CollectiveActor {
inner: CollectiveInner::Real {
comm,
state: state.clone(),
completion: completion.clone(),
},
}
})
}
pub fn mock_props() -> Props<Self> {
Props::create(|| CollectiveActor {
inner: CollectiveInner::Mock,
})
}
}
#[async_trait]
impl Actor for CollectiveActor {
type Msg = CollectiveMsg;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: CollectiveMsg) {
match (&self.inner, msg) {
(CollectiveInner::Mock, msg) => mock_reply(msg),
(
CollectiveInner::Real {
comm,
state,
completion,
},
CollectiveMsg::Op(boxed),
) => {
if let Some(dev) = boxed.device_id() {
if dev != state.device_id() {
tracing::warn!(
expected = state.device_id(),
got = dev,
"collective op on wrong device"
);
}
}
let ctx = CollectiveDispatchCtx {
comm: &comm.0,
state,
completion,
};
boxed.dispatch(&ctx);
}
(CollectiveInner::Real { comm, .. }, msg) => {
handle_legacy(comm, msg);
}
}
}
}
fn mock_reply(msg: CollectiveMsg) {
let err = || GpuError::Unrecoverable("CollectiveActor in mock mode".into());
match msg {
CollectiveMsg::Op(boxed) => {
tracing::warn!(
dtype = ?boxed.dtype_kind(),
"CollectiveActor mock: dropping boxed op without reply"
);
drop(boxed);
}
CollectiveMsg::BeginGroup { reply } => {
let _ = reply.send(Err(err()));
}
CollectiveMsg::EndGroup { reply } => {
let _ = reply.send(Err(err()));
}
CollectiveMsg::QueryCapabilities { reply } => {
let _ = reply.send(NcclCapabilities::zeroed());
}
#[allow(deprecated)]
CollectiveMsg::AllReduceF32 { reply, .. } => {
let _ = reply.send(Err(err()));
}
#[allow(deprecated)]
CollectiveMsg::BroadcastF32 { reply, .. } => {
let _ = reply.send(Err(err()));
}
}
}
#[allow(deprecated)]
fn handle_legacy(comm: &SendComm, msg: CollectiveMsg) {
match msg {
CollectiveMsg::Op(_) => unreachable!("Op handled in handle()"),
CollectiveMsg::BeginGroup { reply } => {
let res = group_start()
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("group_start: {e:?}"),
})
.map(|_| ());
let _ = reply.send(res);
}
CollectiveMsg::EndGroup { reply } => {
let res = group_end()
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("group_end: {e:?}"),
})
.map(|_| ());
let _ = reply.send(res);
}
CollectiveMsg::QueryCapabilities { reply } => {
let _ = reply.send(probe_capabilities());
}
CollectiveMsg::AllReduceF32 { tensor, op, reply } => {
let req = AllReduceRequest::<f32> { tensor, op, reply };
let dummy_state = Arc::new(crate::device::DeviceState::new(0));
let dummy_comp: Arc<dyn CompletionStrategy> =
Arc::new(crate::completion::HostFnCompletion::new());
let ctx = CollectiveDispatchCtx {
comm: &comm.0,
state: &dummy_state,
completion: &dummy_comp,
};
Box::new(req).dispatch(&ctx);
}
CollectiveMsg::BroadcastF32 { data, root, reply } => {
let req = BroadcastRequest::<f32> { data, root, reply };
let dummy_state = Arc::new(crate::device::DeviceState::new(0));
let dummy_comp: Arc<dyn CompletionStrategy> =
Arc::new(crate::completion::HostFnCompletion::new());
let ctx = CollectiveDispatchCtx {
comm: &comm.0,
state: &dummy_state,
completion: &dummy_comp,
};
Box::new(req).dispatch(&ctx);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::DeviceState;
use std::sync::Arc as StdArc;
#[test]
#[allow(deprecated)]
fn deprecated_allreduce_f32_alias_still_constructs() {
let (tx, _rx) = oneshot::channel::<Result<(), GpuError>>();
let state = StdArc::new(DeviceState::new(0));
let _ = state;
let _ = tx; let _ = std::mem::size_of::<CollectiveMsg>();
let _ = std::any::TypeId::of::<CollectiveMsg>();
}
}