use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, Context, Props};
pub use cudarc::nccl::ReduceOp;
use cudarc::nccl::{group_end, group_start, Comm};
use tokio::sync::oneshot;
use crate::completion::CompletionStrategy;
use crate::device::DeviceState;
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
const LIB: &str = "nccl";
pub enum CollectiveMsg {
AllReduceF32 {
tensor: GpuRef<f32>,
op: ReduceOp,
reply: oneshot::Sender<Result<(), GpuError>>,
},
BroadcastF32 {
data: GpuRef<f32>,
root: usize,
reply: oneshot::Sender<Result<(), GpuError>>,
},
BeginGroup {
reply: oneshot::Sender<Result<(), GpuError>>,
},
EndGroup {
reply: oneshot::Sender<Result<(), GpuError>>,
},
}
pub struct CollectiveActor {
inner: CollectiveInner,
}
struct SendComm(Comm);
unsafe impl Send for SendComm {}
unsafe impl Sync for SendComm {}
#[allow(dead_code)]
enum CollectiveInner {
Real {
comm: SendComm,
state: Arc<DeviceState>,
completion: Arc<dyn CompletionStrategy>,
},
Mock,
}
impl CollectiveActor {
pub fn props_for_rank(
comm: Comm,
state: Arc<DeviceState>,
completion: Arc<dyn CompletionStrategy>,
) -> Props<Self> {
use parking_lot::Mutex;
let comm_slot = Arc::new(Mutex::new(Some(SendComm(comm))));
Props::create(move || {
let comm = comm_slot
.lock()
.take()
.expect("Unrecoverable: CollectiveActor restart with consumed Comm — NcclWorldActor must rebuild the world");
CollectiveActor {
inner: CollectiveInner::Real {
comm,
state: state.clone(),
completion: completion.clone(),
},
}
})
}
pub fn mock_props() -> Props<Self> {
Props::create(|| CollectiveActor {
inner: CollectiveInner::Mock,
})
}
}
#[async_trait]
impl Actor for CollectiveActor {
type Msg = CollectiveMsg;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: CollectiveMsg) {
match &self.inner {
CollectiveInner::Mock => mock_reply(msg),
CollectiveInner::Real { comm, state, .. } => {
if let Some(dev) = msg_device_id(&msg) {
if dev != state.device_id() {
let _ = msg_reply_err(
msg,
GpuError::Unrecoverable(format!(
"collective tensor on wrong device: expected {}, got {}",
state.device_id(),
dev
)),
);
return;
}
}
handle_real(comm, msg);
}
}
}
}
fn msg_device_id(msg: &CollectiveMsg) -> Option<u32> {
match msg {
CollectiveMsg::AllReduceF32 { tensor, .. } => tensor.device_id(),
CollectiveMsg::BroadcastF32 { data, .. } => data.device_id(),
_ => None,
}
}
fn mock_reply(msg: CollectiveMsg) {
let err = || GpuError::Unrecoverable("CollectiveActor in mock mode".into());
msg_reply_err(msg, err());
}
fn msg_reply_err(msg: CollectiveMsg, e: GpuError) {
match msg {
CollectiveMsg::AllReduceF32 { reply, .. } => {
let _ = reply.send(Err(e));
}
CollectiveMsg::BroadcastF32 { reply, .. } => {
let _ = reply.send(Err(e));
}
CollectiveMsg::BeginGroup { reply } => {
let _ = reply.send(Err(e));
}
CollectiveMsg::EndGroup { reply } => {
let _ = reply.send(Err(e));
}
}
}
fn handle_real(comm: &SendComm, msg: CollectiveMsg) {
match msg {
CollectiveMsg::AllReduceF32 { tensor, op, reply } => {
let slice = match tensor.access() {
Ok(s) => s.clone(),
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let mut owned = match Arc::try_unwrap(slice) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"AllReduce tensor has multiple live references".into(),
)));
return;
}
};
let res =
comm.0
.all_reduce_in_place(&mut owned, &op)
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("all_reduce: {e:?}"),
});
let _ = reply.send(res.map(|_| ()));
drop(owned);
}
CollectiveMsg::BroadcastF32 { data, root, reply } => {
let slice = match data.access() {
Ok(s) => s.clone(),
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let owned = match Arc::try_unwrap(slice) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"Broadcast data has multiple live references".into(),
)));
return;
}
};
let _ = root;
let _ = comm;
let _ = reply.send(Err(GpuError::LibraryError {
lib: LIB,
msg: "BroadcastF32: F4 skeleton; needs separate send/recv buffers".into(),
}));
drop(owned);
}
CollectiveMsg::BeginGroup { reply } => {
let res = group_start()
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("group_start: {e:?}"),
})
.map(|_| ());
let _ = reply.send(res);
}
CollectiveMsg::EndGroup { reply } => {
let res = group_end()
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("group_end: {e:?}"),
})
.map(|_| ());
let _ = reply.send(res);
}
}
}