use super::*;
use super::zmq::*;
use futures::future::try_join_all;
use nixl_sys::NixlDescriptor;
use utils::*;
use BlockTransferPool::*;
use crate::block_manager::{
BasicMetadata, Storage,
block::{
Block, BlockDataProvider, BlockDataProviderMut, ReadableBlock, WritableBlock,
data::local::LocalBlockData,
locality,
transfer::{TransferContext, WriteTo, WriteToStrategy},
},
connector::scheduler::{SchedulingDecision, TransferSchedulerClient},
offload::max_transfer_batch_size,
storage::{DeviceStorage, DiskStorage, Local, PinnedStorage},
};
use anyhow::Result;
use async_trait::async_trait;
use std::{any::Any, sync::Arc};
#[cfg(feature = "nccl")]
use cudarc::nccl::sys::ncclComm_t;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum TransferMode {
#[default]
Sharded,
Replicated,
}
#[cfg(feature = "nccl")]
#[derive(Clone, Copy)]
pub struct NcclCommHandle(ncclComm_t);
#[cfg(feature = "nccl")]
impl NcclCommHandle {
pub unsafe fn new(comm: ncclComm_t) -> Self {
Self(comm)
}
pub fn as_raw(&self) -> ncclComm_t {
self.0
}
}
#[cfg(feature = "nccl")]
unsafe impl Send for NcclCommHandle {}
#[cfg(feature = "nccl")]
unsafe impl Sync for NcclCommHandle {}
#[cfg(feature = "nccl")]
#[derive(Clone, Copy)]
struct NcclConfigInner {
comm: NcclCommHandle,
rank: i32,
world_size: i32,
}
#[derive(Clone, Copy, Default)]
pub struct NcclConfig {
#[cfg(feature = "nccl")]
inner: Option<NcclConfigInner>,
#[cfg(not(feature = "nccl"))]
_phantom: (),
}
impl NcclConfig {
pub fn disabled() -> Self {
Self::default()
}
#[cfg(feature = "nccl")]
pub unsafe fn enabled(comm: ncclComm_t, rank: i32, world_size: i32) -> Self {
unsafe {
assert!(
world_size > 0 && (0..world_size).contains(&rank),
"NCCL topology invariant violated: required 0 <= rank < world_size, world_size > 0; got rank={}, world_size={}",
rank,
world_size
);
Self {
inner: Some(NcclConfigInner {
comm: NcclCommHandle::new(comm),
rank,
world_size,
}),
}
}
}
pub fn is_enabled(&self) -> bool {
#[cfg(feature = "nccl")]
{
self.inner.is_some()
}
#[cfg(not(feature = "nccl"))]
{
false
}
}
pub fn rank(&self) -> i32 {
#[cfg(feature = "nccl")]
{
self.inner.as_ref().expect("NCCL not enabled").rank
}
#[cfg(not(feature = "nccl"))]
{
panic!("NCCL feature not enabled")
}
}
pub fn world_size(&self) -> i32 {
#[cfg(feature = "nccl")]
{
self.inner.as_ref().expect("NCCL not enabled").world_size
}
#[cfg(not(feature = "nccl"))]
{
panic!("NCCL feature not enabled")
}
}
#[cfg(feature = "nccl")]
pub fn comm(&self) -> NcclCommHandle {
self.inner.as_ref().expect("NCCL not enabled").comm
}
}
type LocalBlock<S, M> = Block<S, locality::Local, M>;
type LocalBlockDataList<S> = Vec<LocalBlockData<S>>;
#[derive(Clone, Debug)]
pub struct ConnectorTransferBatcher {
max_batch_size: usize,
}
impl ConnectorTransferBatcher {
pub fn new() -> Self {
Self {
max_batch_size: max_transfer_batch_size(),
}
}
pub async fn execute_batched_transfer(
&self,
handler: &BlockTransferHandler,
request: BlockTransferRequest,
) -> Result<()> {
if handler.transfer_mode() == TransferMode::Replicated {
return handler.execute_transfer_direct(request).await;
}
let blocks = request.blocks();
let num_blocks = blocks.len();
if num_blocks <= self.max_batch_size {
return handler.execute_transfer_direct(request).await;
}
let batches = blocks.chunks(self.max_batch_size);
let batch_futures: Vec<_> = batches
.map(|batch| {
let batch_request = BlockTransferRequest {
from_pool: *request.from_pool(),
to_pool: *request.to_pool(),
blocks: batch.to_vec(),
connector_req: None,
};
handler.execute_transfer_direct(batch_request)
})
.collect();
tracing::debug!("Executing {} batches concurrently", batch_futures.len());
match try_join_all(batch_futures).await {
Ok(_) => Ok(()),
Err(e) => {
tracing::error!("Batched connector transfer failed: {}", e);
Err(e)
}
}
}
}
#[derive(Clone)]
pub struct BlockTransferHandler {
device: Option<LocalBlockDataList<DeviceStorage>>,
host: Option<LocalBlockDataList<PinnedStorage>>,
disk: Option<LocalBlockDataList<DiskStorage>>,
context: Arc<TransferContext>,
scheduler_client: Option<TransferSchedulerClient>,
batcher: ConnectorTransferBatcher,
transfer_mode: TransferMode,
#[cfg(feature = "nccl")]
nccl_config: NcclConfig,
}
impl BlockTransferHandler {
pub fn new(
device_blocks: Option<Vec<LocalBlock<DeviceStorage, BasicMetadata>>>,
host_blocks: Option<Vec<LocalBlock<PinnedStorage, BasicMetadata>>>,
disk_blocks: Option<Vec<LocalBlock<DiskStorage, BasicMetadata>>>,
context: Arc<TransferContext>,
scheduler_client: Option<TransferSchedulerClient>,
nccl_config: NcclConfig,
) -> Result<Self> {
let transfer_mode = if nccl_config.is_enabled() {
TransferMode::Replicated
} else {
TransferMode::Sharded
};
Ok(Self {
device: Self::get_local_data(device_blocks),
host: Self::get_local_data(host_blocks),
disk: Self::get_local_data(disk_blocks),
context,
scheduler_client,
batcher: ConnectorTransferBatcher::new(),
transfer_mode,
#[cfg(feature = "nccl")]
nccl_config,
})
}
pub fn transfer_mode(&self) -> TransferMode {
self.transfer_mode
}
fn get_local_data<S: Storage>(
blocks: Option<Vec<LocalBlock<S, BasicMetadata>>>,
) -> Option<LocalBlockDataList<S>> {
blocks.map(|blocks| {
blocks
.into_iter()
.map(|b| {
let block_data = b.block_data() as &dyn Any;
block_data
.downcast_ref::<LocalBlockData<S>>()
.unwrap()
.clone()
})
.collect()
})
}
async fn begin_transfer<Source, Target>(
&self,
source_pool_list: &Option<LocalBlockDataList<Source>>,
target_pool_list: &Option<LocalBlockDataList<Target>>,
request: BlockTransferRequest,
) -> Result<tokio::sync::oneshot::Receiver<()>>
where
Source: Storage + NixlDescriptor,
Target: Storage + NixlDescriptor,
LocalBlockData<Source>:
ReadableBlock<StorageType = Source> + Local + WriteToStrategy<LocalBlockData<Target>>,
LocalBlockData<Target>: WritableBlock<StorageType = Target>,
LocalBlockData<Source>: BlockDataProvider<Locality = locality::Local>,
LocalBlockData<Target>: BlockDataProviderMut<Locality = locality::Local>,
{
let Some(source_pool_list) = source_pool_list else {
return Err(anyhow::anyhow!("Source pool manager not initialized"));
};
let Some(target_pool_list) = target_pool_list else {
return Err(anyhow::anyhow!("Target pool manager not initialized"));
};
let source_idxs = request.blocks().iter().map(|(from, _)| *from);
let target_idxs = request.blocks().iter().map(|(_, to)| *to);
let sources: Vec<LocalBlockData<Source>> = source_idxs
.map(|idx| source_pool_list[idx].clone())
.collect();
let mut targets: Vec<LocalBlockData<Target>> = target_idxs
.map(|idx| target_pool_list[idx].clone())
.collect();
match sources.write_to(&mut targets, self.context.clone()) {
Ok(channel) => Ok(channel),
Err(e) => {
tracing::error!("Failed to write to blocks: {:?}", e);
Err(e.into())
}
}
}
pub async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()> {
self.batcher.execute_batched_transfer(self, request).await
}
pub async fn execute_transfer_direct(&self, request: BlockTransferRequest) -> Result<()> {
match self.transfer_mode {
TransferMode::Sharded => self.execute_transfer_spmd_sharded(request).await,
#[cfg(feature = "nccl")]
TransferMode::Replicated => self.execute_transfer_spmd_replicated(request).await,
#[cfg(not(feature = "nccl"))]
TransferMode::Replicated => {
Err(anyhow::anyhow!("Replicated mode requires NCCL feature"))
}
}
}
async fn execute_transfer_spmd_sharded(&self, request: BlockTransferRequest) -> Result<()> {
tracing::debug!(
"Performing sharded transfer of {} blocks from {:?} to {:?}",
request.blocks().len(),
request.from_pool(),
request.to_pool()
);
tracing::debug!("request: {request:#?}");
let notify = match (request.from_pool(), request.to_pool()) {
(Device, Host) => self.begin_transfer(&self.device, &self.host, request).await,
(Device, Disk) => self.begin_transfer(&self.device, &self.disk, request).await,
(Host, Device) => self.begin_transfer(&self.host, &self.device, request).await,
(Host, Disk) => self.begin_transfer(&self.host, &self.disk, request).await,
(Disk, Device) => self.begin_transfer(&self.disk, &self.device, request).await,
_ => {
return Err(anyhow::anyhow!("Invalid transfer type."));
}
}?;
notify.await?;
Ok(())
}
#[cfg(feature = "nccl")]
async fn execute_transfer_spmd_replicated(&self, request: BlockTransferRequest) -> Result<()> {
assert!(
self.nccl_config.is_enabled(),
"NCCL config required for replicated mode"
);
let rank = self.nccl_config.rank();
let is_rank0 = rank == 0;
let use_bcast = request.to_pool() == &Device && request.from_pool() != &Device;
if use_bcast {
tracing::info!(
"NCCL replicated transfer: {} blocks from {:?} to {:?}, rank={}, \
rank0 will load from storage then broadcast to all GPUs",
request.blocks().len(),
request.from_pool(),
request.to_pool(),
rank
);
} else {
tracing::debug!(
"Replicated transfer: {} blocks from {:?} to {:?} (rank={}, bcast={})",
request.blocks().len(),
request.from_pool(),
request.to_pool(),
rank,
use_bcast
);
}
if request.from_pool() == &Device && request.to_pool() == &Device {
return self.execute_transfer_spmd_sharded(request).await;
}
if !is_rank0 && !use_bcast {
return Ok(());
}
if is_rank0 {
let notify = match (request.from_pool(), request.to_pool()) {
(Device, Host) => {
self.begin_transfer(&self.device, &self.host, request.clone())
.await
}
(Device, Disk) => {
self.begin_transfer(&self.device, &self.disk, request.clone())
.await
}
(Host, Device) => {
self.begin_transfer(&self.host, &self.device, request.clone())
.await
}
(Host, Disk) => {
self.begin_transfer(&self.host, &self.disk, request.clone())
.await
}
(Disk, Device) => {
self.begin_transfer(&self.disk, &self.device, request.clone())
.await
}
_ => {
return Err(anyhow::anyhow!("Invalid transfer type."));
}
}?;
notify.await?;
}
if use_bcast {
self.broadcast_device_blocks(&request).await?;
}
Ok(())
}
#[cfg(feature = "nccl")]
async fn broadcast_device_blocks(&self, request: &BlockTransferRequest) -> Result<()> {
use crate::block_manager::block::transfer::{NcclGroup, bcast_block};
let device_blocks = self
.device
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Device blocks required for broadcast"))?;
let stream = self.context.stream().cu_stream();
let comm = self.nccl_config.comm();
let dst_indices: Vec<usize> = request.blocks().iter().map(|(_, to)| *to).collect();
let rank = self.nccl_config.rank();
let world_size = self.nccl_config.world_size();
tracing::info!(
"NCCL broadcast starting: rank={}/{}, num_blocks={}, block_indices={:?}",
rank,
world_size,
dst_indices.len(),
dst_indices
);
let group = unsafe { NcclGroup::new()? };
for &block_idx in &dst_indices {
let block = &device_blocks[block_idx];
unsafe {
bcast_block(block, 0, comm.as_raw(), stream)?;
}
}
group.end()?; drop(group);
let (tx, rx) = tokio::sync::oneshot::channel();
self.context.cuda_event(tx)?;
rx.await
.map_err(|_| anyhow::anyhow!("CUDA event channel closed"))?;
tracing::info!(
"NCCL broadcast completed: rank={}/{}, num_blocks={}",
rank,
world_size,
dst_indices.len()
);
Ok(())
}
}
#[async_trait]
impl Handler for BlockTransferHandler {
async fn handle(&self, mut message: MessageHandle) -> Result<()> {
if message.data.len() != 1 {
return Err(anyhow::anyhow!(
"Block transfer request must have exactly one data element"
));
}
let mut request: BlockTransferRequest = serde_json::from_slice(&message.data[0])?;
let result = if let Some(req) = request.connector_req.take() {
let operation_id = req.uuid;
tracing::debug!(
request_id = %req.request_id,
operation_id = %operation_id,
"scheduling transfer"
);
let client = self
.scheduler_client
.as_ref()
.expect("scheduler client is required")
.clone();
let handle = client.schedule_transfer(req).await?;
assert_eq!(handle.scheduler_decision(), SchedulingDecision::Execute);
match self.execute_transfer(request).await {
Ok(_) => {
handle.mark_complete(Ok(())).await;
Ok(())
}
Err(e) => {
handle.mark_complete(Err(anyhow::anyhow!("{}", e))).await;
Err(e)
}
}
} else {
self.execute_transfer(request).await
};
message.ack().await?;
result
}
}