use crate::block_manager::v2::memory::StorageKind;
use super::TransferCapabilities;
use crate::block_manager::v2::physical::{layout::PhysicalLayout, transfer::TransferContext};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferStrategy {
Memcpy,
CudaAsyncH2D,
CudaAsyncD2H,
CudaAsyncD2D,
CudaBlockingH2D,
CudaBlockingD2H,
NixlRead,
NixlWrite,
NixlWriteFlipped,
NixlReadFlipped,
Invalid,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TransferPlan {
Direct(TransferStrategy),
TwoHop {
first: TransferStrategy,
bounce_location: StorageKind,
second: TransferStrategy,
},
}
pub(crate) fn select_strategy(
src: &PhysicalLayout,
dst: &PhysicalLayout,
ctx: &TransferContext,
) -> anyhow::Result<TransferPlan> {
let is_src_local = src.nixl_metadata().agent_name() == ctx.nixl_agent().name();
let is_dst_local = dst.nixl_metadata().agent_name() == ctx.nixl_agent().name();
if !is_src_local && !is_dst_local {
return Err(anyhow::anyhow!(
"Both src and dst are remote - this is not supported."
));
}
if is_src_local && is_dst_local {
return Ok(select_direct_strategy(
src.location(),
dst.location(),
false,
ctx.capabilities(),
));
}
select_remote_strategy_v2(
src.location(),
is_src_local,
dst.location(),
is_dst_local,
ctx.capabilities(),
)
}
fn select_direct_strategy(
src: StorageKind,
dst: StorageKind,
dst_is_remote: bool,
capabilities: &TransferCapabilities,
) -> TransferPlan {
use StorageKind::*;
use TransferStrategy::*;
if dst_is_remote {
return select_remote_strategy(src, capabilities);
}
match (src, dst) {
(System, System) | (System, Pinned) | (Pinned, System) | (Pinned, Pinned) => {
TransferPlan::Direct(Memcpy)
}
(System, Device(_)) => TransferPlan::Direct(CudaBlockingH2D),
(Pinned, Device(_)) => TransferPlan::Direct(CudaAsyncH2D),
(Device(_), System) => TransferPlan::Direct(CudaBlockingD2H),
(Device(_), Pinned) => TransferPlan::Direct(CudaAsyncD2H),
(Device(_), Device(_)) => TransferPlan::Direct(CudaAsyncD2D),
(System, Disk(_)) | (Pinned, Disk(_)) => TransferPlan::Direct(NixlWrite),
(Disk(_), System) | (Disk(_), Pinned) => TransferPlan::Direct(NixlReadFlipped),
(Disk(_), Disk(_)) => TransferPlan::TwoHop {
first: NixlReadFlipped,
bounce_location: Pinned,
second: NixlWrite,
},
(Device(_), Disk(_)) => {
if capabilities.allows_device_disk_direct() {
TransferPlan::Direct(NixlWrite)
} else {
TransferPlan::TwoHop {
first: CudaAsyncD2H,
bounce_location: Pinned,
second: NixlWrite,
}
}
}
(Disk(_), Device(_)) => {
if capabilities.allows_device_disk_direct() {
TransferPlan::Direct(NixlRead)
} else {
TransferPlan::TwoHop {
first: NixlReadFlipped,
bounce_location: Pinned,
second: CudaAsyncH2D,
}
}
}
}
}
fn select_remote_strategy(src: StorageKind, capabilities: &TransferCapabilities) -> TransferPlan {
use StorageKind::*;
use TransferStrategy::*;
match src {
System | Pinned => TransferPlan::Direct(NixlWrite),
Device(_) => {
if capabilities.allows_device_remote_direct() {
TransferPlan::Direct(NixlWrite)
} else {
TransferPlan::TwoHop {
first: CudaAsyncD2H,
bounce_location: Pinned,
second: NixlWrite,
}
}
}
Disk(_) => TransferPlan::TwoHop {
first: NixlWrite,
bounce_location: Pinned,
second: NixlWrite,
},
}
}
fn select_remote_strategy_v2(
src: StorageKind,
is_src_local: bool,
dst: StorageKind,
is_dst_local: bool,
capabilities: &TransferCapabilities,
) -> anyhow::Result<TransferPlan> {
if matches!(src, StorageKind::Disk(_)) | matches!(dst, StorageKind::Disk(_)) {
return Err(anyhow::anyhow!(
"Neither local nor remote disk transfers are supported over NIXL at this time."
));
}
if !capabilities.allow_gpu_rdma
&& (matches!(src, StorageKind::Device(_)) || matches!(dst, StorageKind::Device(_)))
{
return Err(anyhow::anyhow!(
"GPU RDMA is disabled - this transfer requires GPU RDMA."
));
}
if is_src_local && !is_dst_local {
return Ok(TransferPlan::Direct(TransferStrategy::NixlWrite));
}
if is_dst_local && !is_src_local {
return Ok(TransferPlan::Direct(TransferStrategy::NixlReadFlipped));
}
unreachable!("Both src and dst are remote - this is not supported.");
}
#[cfg(test)]
mod tests {
use super::*;
fn default_caps() -> TransferCapabilities {
TransferCapabilities::default()
}
#[test]
fn test_host_to_host_transfers() {
let caps = default_caps();
assert_eq!(
select_direct_strategy(StorageKind::System, StorageKind::System, false, &caps),
TransferPlan::Direct(TransferStrategy::Memcpy)
);
assert_eq!(
select_direct_strategy(StorageKind::System, StorageKind::Pinned, false, &caps),
TransferPlan::Direct(TransferStrategy::Memcpy)
);
assert_eq!(
select_direct_strategy(StorageKind::Pinned, StorageKind::System, false, &caps),
TransferPlan::Direct(TransferStrategy::Memcpy)
);
assert_eq!(
select_direct_strategy(StorageKind::Pinned, StorageKind::Pinned, false, &caps),
TransferPlan::Direct(TransferStrategy::Memcpy)
);
}
#[test]
fn test_host_to_device_transfers() {
let caps = default_caps();
assert_eq!(
select_direct_strategy(StorageKind::System, StorageKind::Device(0), false, &caps),
TransferPlan::Direct(TransferStrategy::CudaBlockingH2D)
);
assert_eq!(
select_direct_strategy(StorageKind::Pinned, StorageKind::Device(0), false, &caps),
TransferPlan::Direct(TransferStrategy::CudaAsyncH2D)
);
}
#[test]
fn test_device_to_host_transfers() {
let caps = default_caps();
assert_eq!(
select_direct_strategy(StorageKind::Device(0), StorageKind::System, false, &caps),
TransferPlan::Direct(TransferStrategy::CudaBlockingD2H)
);
assert_eq!(
select_direct_strategy(StorageKind::Device(0), StorageKind::Pinned, false, &caps),
TransferPlan::Direct(TransferStrategy::CudaAsyncD2H)
);
}
#[test]
fn test_device_to_device_transfers() {
let caps = default_caps();
assert_eq!(
select_direct_strategy(StorageKind::Device(0), StorageKind::Device(1), false, &caps),
TransferPlan::Direct(TransferStrategy::CudaAsyncD2D)
);
assert_eq!(
select_direct_strategy(StorageKind::Device(3), StorageKind::Device(3), false, &caps),
TransferPlan::Direct(TransferStrategy::CudaAsyncD2D)
);
}
#[test]
fn test_disk_to_host_transfers() {
let caps = default_caps();
assert_eq!(
select_direct_strategy(StorageKind::Disk(42), StorageKind::System, false, &caps),
TransferPlan::Direct(TransferStrategy::NixlReadFlipped)
);
assert_eq!(
select_direct_strategy(StorageKind::Disk(42), StorageKind::Pinned, false, &caps),
TransferPlan::Direct(TransferStrategy::NixlReadFlipped)
);
}
#[test]
fn test_host_to_disk_transfers() {
let caps = default_caps();
assert_eq!(
select_direct_strategy(StorageKind::System, StorageKind::Disk(42), false, &caps),
TransferPlan::Direct(TransferStrategy::NixlWrite)
);
assert_eq!(
select_direct_strategy(StorageKind::Pinned, StorageKind::Disk(42), false, &caps),
TransferPlan::Direct(TransferStrategy::NixlWrite)
);
}
#[test]
fn test_device_to_disk_without_gds() {
let caps = default_caps(); let plan =
select_direct_strategy(StorageKind::Device(0), StorageKind::Disk(42), false, &caps);
match plan {
TransferPlan::TwoHop {
first,
bounce_location,
second,
} => {
assert_eq!(first, TransferStrategy::CudaAsyncD2H);
assert_eq!(bounce_location, StorageKind::Pinned);
assert_eq!(second, TransferStrategy::NixlWrite);
}
_ => panic!("Expected TwoHop plan"),
}
}
#[test]
fn test_disk_to_device_without_gds() {
let caps = default_caps(); let plan =
select_direct_strategy(StorageKind::Disk(42), StorageKind::Device(0), false, &caps);
match plan {
TransferPlan::TwoHop {
first,
bounce_location,
second,
} => {
assert_eq!(first, TransferStrategy::NixlReadFlipped);
assert_eq!(bounce_location, StorageKind::Pinned);
assert_eq!(second, TransferStrategy::CudaAsyncH2D);
}
_ => panic!("Expected TwoHop plan"),
}
}
#[test]
fn test_device_to_disk_with_gds() {
let caps = TransferCapabilities::default().with_gds(true);
assert_eq!(
select_direct_strategy(StorageKind::Device(0), StorageKind::Disk(42), false, &caps),
TransferPlan::Direct(TransferStrategy::NixlWrite)
);
}
#[test]
fn test_disk_to_device_with_gds() {
let caps = TransferCapabilities::default().with_gds(true);
assert_eq!(
select_direct_strategy(StorageKind::Disk(42), StorageKind::Device(0), false, &caps),
TransferPlan::Direct(TransferStrategy::NixlRead)
);
}
#[test]
fn test_host_to_remote() {
let caps = default_caps();
assert_eq!(
select_direct_strategy(StorageKind::System, StorageKind::System, true, &caps),
TransferPlan::Direct(TransferStrategy::NixlWrite)
);
assert_eq!(
select_direct_strategy(StorageKind::Pinned, StorageKind::Pinned, true, &caps),
TransferPlan::Direct(TransferStrategy::NixlWrite)
);
}
#[test]
fn test_device_to_remote_without_rdma() {
let caps = default_caps(); let plan = select_direct_strategy(StorageKind::Device(0), StorageKind::System, true, &caps);
match plan {
TransferPlan::TwoHop {
first,
bounce_location,
second,
} => {
assert_eq!(first, TransferStrategy::CudaAsyncD2H);
assert_eq!(bounce_location, StorageKind::Pinned);
assert_eq!(second, TransferStrategy::NixlWrite);
}
_ => panic!("Expected TwoHop plan"),
}
}
#[test]
fn test_device_to_remote_with_rdma() {
let caps = TransferCapabilities::default().with_gpu_rdma(true);
assert_eq!(
select_direct_strategy(StorageKind::Device(0), StorageKind::Device(0), true, &caps),
TransferPlan::Direct(TransferStrategy::NixlWrite)
);
}
#[test]
fn test_disk_to_remote() {
let caps = default_caps();
let plan = select_direct_strategy(StorageKind::Disk(42), StorageKind::System, true, &caps);
match plan {
TransferPlan::TwoHop {
first,
bounce_location,
second,
} => {
assert_eq!(first, TransferStrategy::NixlWrite);
assert_eq!(bounce_location, StorageKind::Pinned);
assert_eq!(second, TransferStrategy::NixlWrite);
}
_ => panic!("Expected TwoHop plan"),
}
}
}