mod cuda;
mod memcpy;
mod nixl;
mod strategy;
use super::nixl::{IsMutable, NixlBlockDataImmutable, NixlBlockDataMutable, RemoteBlock};
use super::*;
use crate::block_manager::storage::{
nixl::{NixlRegisterableStorage, NixlStorage},
DeviceStorage, PinnedStorage, SystemStorage,
};
use cudarc::driver::CudaStream;
use std::ops::Range;
pub use crate::block_manager::storage::{CudaAccessible, Local, Remote};
pub use async_trait::async_trait;
pub trait Writable {}
pub trait Readable {}
pub trait Mutable: Readable + Writable {}
pub trait Immutable: Readable {}
#[derive(Debug)]
pub enum BlockTarget {
Source,
Destination,
}
#[derive(Debug, thiserror::Error)]
pub enum TransferError {
#[error("Builder configuration error: {0}")]
BuilderError(String),
#[error("Transfer execution failed: {0}")]
ExecutionError(String),
#[error("Incompatible block types provided: {0}")]
IncompatibleTypes(String),
#[error("Mismatched source/destination counts: {0} sources, {1} destinations")]
CountMismatch(usize, usize),
#[error("Block operation failed: {0}")]
BlockError(#[from] BlockError),
#[error("No blocks provided")]
NoBlocksProvided,
#[error("Mismatched {0:?} block set index: {1} != {2}")]
MismatchedBlockSetIndex(BlockTarget, usize, usize),
#[error("Mismatched {0:?} worker ID: {1} != {2}")]
MismatchedWorkerID(BlockTarget, usize, usize),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferStrategy {
Memcpy,
CudaAsyncH2D,
CudaAsyncD2H,
CudaAsyncD2D,
CudaBlockingH2D,
CudaBlockingD2H,
NixlWrite, NixlRead, Invalid,
}
pub trait WriteToStrategy<Target> {
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::Invalid
}
}
pub trait ReadFromStrategy<Source> {
fn read_from_strategy() -> TransferStrategy {
TransferStrategy::Invalid
}
}
impl<RB: ReadableBlock, WB: WritableBlock> WriteToStrategy<WB> for RB
where
<RB as ReadableBlock>::StorageType: Local + WriteToStrategy<<WB as WritableBlock>::StorageType>,
{
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
<<RB as ReadableBlock>::StorageType as WriteToStrategy<
<WB as WritableBlock>::StorageType,
>>::write_to_strategy()
}
}
impl<WB: WritableBlock, RB: ReadableBlock> ReadFromStrategy<RB> for WB
where
<RB as ReadableBlock>::StorageType: Remote,
<WB as WritableBlock>::StorageType: NixlRegisterableStorage,
{
#[inline(always)]
fn read_from_strategy() -> TransferStrategy {
TransferStrategy::NixlRead
}
}
pub trait WriteTo<Target> {
fn write_to(&self, dst: &mut Target, notify: Option<String>) -> Result<(), TransferError>;
}
impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for RB
where
RB: WriteToStrategy<WB> + Local,
{
fn write_to(&self, dst: &mut WB, notify: Option<String>) -> Result<(), TransferError> {
let ctx = self.transfer_context();
match Self::write_to_strategy() {
TransferStrategy::Memcpy => memcpy::copy_block(self, dst),
TransferStrategy::CudaAsyncH2D
| TransferStrategy::CudaAsyncD2H
| TransferStrategy::CudaAsyncD2D => {
cuda::copy_block(self, dst, ctx.stream().as_ref(), RB::write_to_strategy())
}
TransferStrategy::NixlWrite => Ok(nixl::write_block_to(self, dst, ctx, notify)?),
_ => Err(TransferError::IncompatibleTypes(format!(
"Unsupported copy strategy: {:?}",
RB::write_to_strategy()
))),
}
}
}
#[derive(Default)]
pub struct GetXferRequestBuilder<
'xfer,
Source: BlockDataProvider,
Target: BlockDataProviderMut + Local,
> {
_src: Option<&'xfer [Source]>,
_dst: Option<&'xfer [Target]>,
}
pub struct PutXferRequestBuilder<
'xfer,
Source: BlockDataProvider + Local,
Target: BlockDataProviderMut,
> {
_src: Option<&'xfer [Source]>,
_dst: Option<&'xfer [Target]>,
}
#[async_trait]
pub trait AsyncBlockTransferEngine<Source: BlockDataProvider, Target: BlockDataProviderMut + Local>
{
async fn execute(self) -> anyhow::Result<()>;
}
pub trait BlockTransferEngineV1<Source: BlockDataProvider, Target: BlockDataProviderMut> {
fn prepare(&mut self) -> Result<(), TransferError> {
Ok(())
}
fn execute(self) -> Result<(), TransferError>;
}
#[derive(Debug)]
pub struct TransferRequestPut<
'a,
Source: BlockDataProvider + Local,
Destination: BlockDataProviderMut,
> {
sources: &'a [Source],
destinations: &'a mut [Destination],
}
impl<Source> BlockTransferEngineV1<Source, RemoteBlock<IsMutable>>
for TransferRequestPut<'_, Source, RemoteBlock<IsMutable>>
where
Source: BlockDataProvider + Local, Source::StorageType: NixlRegisterableStorage,
{
fn execute(self) -> Result<(), TransferError> {
self.validate_counts()?;
tracing::info!("Executing NIXL PUT transfer request");
for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter_mut()) {
let src_data = src_block.block_data(private::PrivateToken);
let src_nixl_desc = src_data.as_block_descriptor()?;
let dst_data = dst_block.block_data_mut(private::PrivateToken);
let dst_nixl_desc = dst_data.as_block_descriptor_mut()?;
tracing::trace!(src_desc = ?src_nixl_desc, dst_desc = ?dst_nixl_desc, "NIXL PUT block");
}
Ok(())
}
}
impl<'a, Source, Destination> TransferRequestPut<'a, Source, Destination>
where
Source: BlockDataProvider + Local,
Destination: BlockDataProviderMut,
{
pub fn new(
sources: &'a [Source],
destinations: &'a mut [Destination],
) -> Result<Self, TransferError> {
let transfer_request = Self {
sources,
destinations,
};
transfer_request.validate_counts()?;
Ok(transfer_request)
}
pub fn validate_blocks(&self) -> Result<(), TransferError> {
let mut src_set = std::collections::HashSet::new();
let mut dst_set = std::collections::HashSet::new();
for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter()) {
let src_data = src_block.block_data(private::PrivateToken);
let dst_data = dst_block.block_data(private::PrivateToken);
src_set.insert((
src_data.block_set_idx,
src_data.block_idx,
src_data.worker_id,
));
dst_set.insert((
dst_data.block_set_idx,
dst_data.block_idx,
dst_data.worker_id,
));
}
if dst_set.len() != self.destinations.len() {
return Err(TransferError::BuilderError(
"Duplicate destination blocks".to_string(),
));
}
if !src_set.is_disjoint(&dst_set) {
return Err(TransferError::BuilderError(
"Duplicate one or more duplicate entries in source and destination list"
.to_string(),
));
}
Ok(())
}
fn validate_counts(&self) -> Result<(), TransferError> {
if self.sources.len() != self.destinations.len() {
Err(TransferError::CountMismatch(
self.sources.len(),
self.destinations.len(),
))
} else if self.sources.is_empty() {
Err(TransferError::BuilderError(
"Sources cannot be empty".to_string(),
))
} else if self.destinations.is_empty() {
Err(TransferError::BuilderError(
"Destinations cannot be empty".to_string(),
))
} else {
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn write_to_strategy() {
assert_eq!(
<SystemStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<SystemStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<SystemStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
TransferStrategy::CudaBlockingH2D
);
assert_eq!(
<SystemStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
);
assert_eq!(
<PinnedStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<PinnedStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<PinnedStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
TransferStrategy::CudaAsyncH2D
);
assert_eq!(
<PinnedStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
);
assert_eq!(
<DeviceStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
TransferStrategy::CudaBlockingD2H
);
assert_eq!(
<DeviceStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
TransferStrategy::CudaAsyncD2H
);
assert_eq!(
<DeviceStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
TransferStrategy::CudaAsyncD2D
);
assert_eq!(
<DeviceStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
);
}
}