use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, ActorRef, Context, Props};
use cudarc::nccl::Comm;
use tokio::sync::oneshot;
use tracing::{info, warn};
use crate::completion::{CompletionStrategy, HostFnCompletion};
use crate::device::{DeviceActor, DeviceConfig, DeviceMsg};
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use crate::kernel::{AllReduceRequest, CollectiveActor, CollectiveMsg, ReduceOp};
#[derive(Debug, Clone)]
pub struct NcclWorldConfig {
pub device_ids: Vec<u32>,
pub root: usize,
}
impl NcclWorldConfig {
pub fn new(device_ids: Vec<u32>) -> Self {
Self {
device_ids,
root: 0,
}
}
}
pub enum NcclWorldMsg {
AllReduceF32 {
tensors: Vec<GpuRef<f32>>,
op: ReduceOp,
reply: oneshot::Sender<Result<(), GpuError>>,
},
ChildReady {
device_idx: usize,
device_ref: ActorRef<DeviceMsg>,
},
DeviceContextChanged {
device_idx: usize,
new_generation: u64,
},
}
pub struct NcclWorldActor {
config: NcclWorldConfig,
devices: Vec<Option<ActorRef<DeviceMsg>>>,
collectives: Vec<Option<ActorRef<CollectiveMsg>>>,
built: bool,
last_generation: Vec<u64>,
#[allow(dead_code)]
completion: Arc<dyn CompletionStrategy>,
}
impl NcclWorldActor {
pub fn props(config: NcclWorldConfig) -> Props<Self> {
Props::create(move || {
let n = config.device_ids.len();
NcclWorldActor {
config: config.clone(),
devices: (0..n).map(|_| None).collect(),
collectives: (0..n).map(|_| None).collect(),
built: false,
last_generation: vec![0; n],
completion: Arc::new(HostFnCompletion::new()),
}
})
}
async fn try_build_world(&mut self, ctx: &mut Context<Self>) {
if self.built {
return;
}
if self.devices.iter().any(|d| d.is_none()) {
return;
}
let mut snaps = Vec::with_capacity(self.devices.len());
for d in &self.devices {
let dref = d.as_ref().unwrap();
let (tx, rx) = oneshot::channel();
dref.tell(DeviceMsg::SnapshotContext { reply: tx });
match rx.await {
Ok(Some(c)) => snaps.push(c),
_ => {
warn!("NcclWorldActor: a device reported no context; aborting world-build");
return;
}
}
}
let mut streams = Vec::with_capacity(snaps.len());
for c in &snaps {
match c.new_stream() {
Ok(s) => streams.push(s),
Err(e) => {
warn!(error = %e, "NcclWorldActor: new_stream failed");
return;
}
}
}
let comms_res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
Comm::from_devices(streams.clone())
}));
let comms = match comms_res {
Ok(Ok(cs)) => cs,
Ok(Err(e)) => {
warn!(error = ?e, "NcclWorldActor: Comm::from_devices failed");
return;
}
Err(_) => {
warn!("NcclWorldActor: NCCL not loadable on this host");
return;
}
};
for (i, comm) in comms.into_iter().enumerate() {
let state = Arc::new(crate::device::DeviceState::new(self.config.device_ids[i]));
let comp: Arc<dyn CompletionStrategy> = Arc::new(HostFnCompletion::new());
let props = CollectiveActor::props_for_rank(comm, state, comp);
match ctx.spawn::<CollectiveActor>(props, &format!("nccl-{i}")) {
Ok(r) => self.collectives[i] = Some(r),
Err(e) => {
warn!(error = %e, "spawn CollectiveActor[{i}] failed");
return;
}
}
}
self.built = true;
info!(devices = self.devices.len(), "NcclWorldActor: world built");
}
fn dispatch_all_reduce_f32(
&self,
tensors: Vec<GpuRef<f32>>,
op: ReduceOp,
reply: oneshot::Sender<Result<(), GpuError>>,
) {
if !self.built {
let _ = reply.send(Err(GpuError::Unrecoverable(
"NcclWorldActor: world not built yet".into(),
)));
return;
}
if tensors.len() != self.config.device_ids.len() {
let _ = reply.send(Err(GpuError::Unrecoverable(format!(
"AllReduce: expected {} tensors, got {}",
self.config.device_ids.len(),
tensors.len()
))));
return;
}
for (i, t) in tensors.iter().enumerate() {
if let Some(d) = t.device_id() {
if d != self.config.device_ids[i] {
let _ = reply.send(Err(GpuError::Unrecoverable(format!(
"AllReduce: tensor[{i}] on device {d}, expected {}",
self.config.device_ids[i]
))));
return;
}
}
}
let collectives: Vec<_> = self
.collectives
.iter()
.map(|c| c.as_ref().unwrap().clone())
.collect();
tokio::spawn(async move {
let mut rxs = Vec::with_capacity(tensors.len());
for (c, t) in collectives.into_iter().zip(tensors) {
let (tx, rx) = oneshot::channel();
let op_clone = match op {
ReduceOp::Sum => ReduceOp::Sum,
ReduceOp::Prod => ReduceOp::Prod,
ReduceOp::Max => ReduceOp::Max,
ReduceOp::Min => ReduceOp::Min,
ReduceOp::Avg => ReduceOp::Avg,
};
c.tell(CollectiveMsg::Op(Box::new(AllReduceRequest::<f32> {
tensor: t,
op: op_clone,
reply: tx,
})));
rxs.push(rx);
}
let mut combined = Ok(());
for rx in rxs {
match rx.await {
Ok(Ok(())) => {}
Ok(Err(e)) => {
combined = Err(e);
break;
}
Err(_) => {
combined = Err(GpuError::Unrecoverable(
"AllReduce: a collective actor dropped its reply".into(),
));
break;
}
}
}
let _ = reply.send(combined);
});
}
}
#[async_trait]
impl Actor for NcclWorldActor {
type Msg = NcclWorldMsg;
async fn pre_start(&mut self, ctx: &mut Context<Self>) {
let world_ref = ctx.self_ref().clone();
for (i, &ord) in self.config.device_ids.iter().enumerate() {
let cfg = DeviceConfig::new(ord);
match ctx.spawn::<DeviceActor>(DeviceActor::props(cfg), &format!("dev-{i}")) {
Ok(r) => {
self.devices[i] = Some(r.clone());
let world = world_ref.clone();
let dr = r.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
world.tell(NcclWorldMsg::ChildReady {
device_idx: i,
device_ref: dr,
});
});
}
Err(e) => panic!("Unrecoverable: spawn DeviceActor[{i}]: {e}"),
}
}
}
async fn handle(&mut self, ctx: &mut Context<Self>, msg: NcclWorldMsg) {
match msg {
NcclWorldMsg::ChildReady {
device_idx,
device_ref,
} => {
self.devices[device_idx] = Some(device_ref.clone());
let world_ref = ctx.self_ref().clone();
let dr = device_ref.clone();
tokio::spawn(async move {
let watch_rx_res = dr
.ask_with(
move |tx| DeviceMsg::WatchGeneration { reply: tx },
std::time::Duration::from_secs(5),
)
.await;
let mut rx = match watch_rx_res {
Ok(rx) => rx,
Err(_) => return,
};
let mut last = *rx.borrow();
while rx.changed().await.is_ok() {
let gen = *rx.borrow();
if gen != last {
last = gen;
world_ref.tell(NcclWorldMsg::DeviceContextChanged {
device_idx,
new_generation: gen,
});
}
}
});
self.try_build_world(ctx).await;
}
NcclWorldMsg::DeviceContextChanged {
device_idx,
new_generation,
} => {
let prev = self.last_generation.get(device_idx).copied().unwrap_or(0);
if new_generation <= prev {
return;
}
self.last_generation[device_idx] = new_generation;
if !self.built {
return;
}
tracing::warn!(
device_idx,
new_generation,
"NcclWorldActor: device context rebuilt — tearing down NCCL world"
);
for c in self.collectives.iter_mut() {
if let Some(c) = c.take() {
c.stop();
}
}
self.built = false;
self.try_build_world(ctx).await;
}
NcclWorldMsg::AllReduceF32 { tensors, op, reply } => {
self.dispatch_all_reduce_f32(tensors, op, reply);
}
}
}
}