use serde::{Deserialize, Serialize};
use std::sync::OnceLock;
use crate::block_manager::v2::physical::{
layout::LayoutConfig,
transfer::{
PhysicalLayout, TransferOptions, TransportManager, executor::execute_transfer,
nixl_agent::NixlAgent,
},
};
static GDS_SUPPORTED: OnceLock<bool> = OnceLock::new();
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct TransferCapabilities {
pub allow_gds: bool,
pub allow_gpu_rdma: bool,
}
impl TransferCapabilities {
pub fn new() -> Self {
Self::default()
}
pub fn all_enabled() -> Self {
Self {
allow_gds: true,
allow_gpu_rdma: true,
}
}
pub fn with_gds(mut self, enabled: bool) -> Self {
self.allow_gds = enabled;
self
}
fn test_gds_transfer(&self) -> anyhow::Result<()> {
let agent = NixlAgent::require_backends("agent", &["GDS_MT"])?;
let config = LayoutConfig::builder()
.num_blocks(1)
.num_layers(1)
.outer_dim(1)
.page_size(1)
.inner_dim(4096)
.build()?;
let src = PhysicalLayout::builder(agent.clone())
.with_config(config.clone())
.fully_contiguous()
.allocate_device(0)
.build()?;
let dst = PhysicalLayout::builder(agent.clone())
.with_config(config)
.fully_contiguous()
.allocate_disk(None)
.build()?;
let src_blocks = vec![0];
let dst_blocks = vec![0];
let ctx = TransportManager::builder()
.worker_id(0)
.nixl_agent(agent)
.cuda_device_id(0)
.build()?;
execute_transfer(
&src,
&dst,
&src_blocks,
&dst_blocks,
TransferOptions::default(),
ctx.context(),
)?;
Ok(())
}
pub fn with_gds_if_supported(mut self) -> Self {
self.allow_gds = *GDS_SUPPORTED.get_or_init(|| self.test_gds_transfer().is_ok());
self
}
pub fn with_gpu_rdma(mut self, enabled: bool) -> Self {
self.allow_gpu_rdma = enabled;
self
}
pub fn allows_device_disk_direct(&self) -> bool {
self.allow_gds
}
pub fn allows_device_remote_direct(&self) -> bool {
self.allow_gpu_rdma
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_capabilities() {
let caps = TransferCapabilities::default();
assert!(!caps.allow_gds);
assert!(!caps.allow_gpu_rdma);
assert!(!caps.allows_device_disk_direct());
assert!(!caps.allows_device_remote_direct());
}
#[test]
fn test_all_enabled() {
let caps = TransferCapabilities::all_enabled();
assert!(caps.allow_gds);
assert!(caps.allow_gpu_rdma);
assert!(caps.allows_device_disk_direct());
assert!(caps.allows_device_remote_direct());
}
#[test]
fn test_builder_pattern() {
let caps = TransferCapabilities::new()
.with_gds(true)
.with_gpu_rdma(false);
assert!(caps.allow_gds);
assert!(!caps.allow_gpu_rdma);
}
#[test]
fn test_selective_enablement() {
let caps = TransferCapabilities::new().with_gds(true);
assert!(caps.allows_device_disk_direct());
assert!(!caps.allows_device_remote_direct());
let caps = TransferCapabilities::new().with_gpu_rdma(true);
assert!(!caps.allows_device_disk_direct());
assert!(caps.allows_device_remote_direct());
}
}