#[cfg(all(feature = "testing-cuda", feature = "testing-nixl"))]
mod local_transfers;
use super::{NixlAgent, PhysicalLayout};
use crate::block_manager::v2::physical::layout::{
LayoutConfig,
builder::{HasConfig, NoLayout, NoMemory, PhysicalLayoutBuilder},
};
pub fn standard_config(num_blocks: usize) -> LayoutConfig {
LayoutConfig::builder()
.num_blocks(num_blocks)
.num_layers(2)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap()
}
pub fn builder(num_blocks: usize) -> PhysicalLayoutBuilder<HasConfig, NoLayout, NoMemory> {
let agent = create_test_agent("test_agent");
let config = standard_config(num_blocks);
PhysicalLayout::builder(agent).with_config(config)
}
pub fn create_test_agent(name: &str) -> NixlAgent {
NixlAgent::require_backends(name, &[]).expect("Failed to require backends")
}
#[cfg(feature = "testing-cuda")]
pub(crate) mod cuda {
use anyhow::Result;
use cudarc::driver::sys::CUdevice_attribute_enum;
use cudarc::driver::{CudaContext, CudaStream, LaunchConfig, PushKernelArg};
use cudarc::nvrtc::{CompileOptions, compile_ptx_with_opts};
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use std::time::{Duration, Instant};
pub const SLEEP_KERNEL_SRC: &str = r#"
extern "C" __global__ void sleep_kernel(unsigned long long min_cycles) {
const unsigned long long start = clock64();
while ((clock64() - start) < min_cycles) {
asm volatile("");
}
}
"#;
pub struct CudaSleep {
function: cudarc::driver::CudaFunction,
cycles_per_ms: f64,
}
impl CudaSleep {
pub fn for_context(cuda_ctx: &Arc<CudaContext>) -> Result<Arc<Self>> {
static INSTANCES: OnceLock<parking_lot::Mutex<HashMap<usize, Arc<CudaSleep>>>> =
OnceLock::new();
let instances = INSTANCES.get_or_init(|| parking_lot::Mutex::new(HashMap::new()));
let device_ordinal = cuda_ctx.ordinal();
{
let instances_guard = instances.lock();
if let Some(instance) = instances_guard.get(&device_ordinal) {
return Ok(Arc::clone(instance));
}
}
let instance = Arc::new(Self::new(cuda_ctx)?);
let mut instances_guard = instances.lock();
instances_guard
.entry(device_ordinal)
.or_insert_with(|| Arc::clone(&instance));
Ok(instance)
}
fn new(cuda_ctx: &Arc<CudaContext>) -> Result<Self> {
let major = cuda_ctx
.attribute(CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)?;
let minor = cuda_ctx
.attribute(CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR)?;
let mut compile_opts = CompileOptions {
name: Some("sleep_kernel.cu".into()),
..Default::default()
};
compile_opts
.options
.push(format!("--gpu-architecture=compute_{}{}", major, minor));
let ptx = compile_ptx_with_opts(SLEEP_KERNEL_SRC, compile_opts)?;
let module = cuda_ctx.load_module(ptx)?;
let function = module.load_function("sleep_kernel")?;
let clock_rate_khz =
cuda_ctx.attribute(CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_CLOCK_RATE)? as u64;
let stream = cuda_ctx.new_stream()?;
let warm_cycles = clock_rate_khz.saturating_mul(10).max(1);
Self::launch_kernel(&function, &stream, warm_cycles)?;
stream.synchronize()?;
let desired_delay = Duration::from_millis(600);
let mut target_cycles = clock_rate_khz.saturating_mul(50).max(1); let mut actual_duration = Duration::ZERO;
for _ in 0..8 {
let start = Instant::now();
Self::launch_kernel(&function, &stream, target_cycles)?;
stream.synchronize()?;
actual_duration = start.elapsed();
if actual_duration >= desired_delay {
break;
}
target_cycles = target_cycles.saturating_mul(2);
}
let cycles_per_ms = if actual_duration.as_millis() > 0 {
(target_cycles as f64 / actual_duration.as_millis() as f64) * 1.2
} else {
clock_rate_khz as f64 };
Ok(Self {
function,
cycles_per_ms,
})
}
fn launch_kernel(
function: &cudarc::driver::CudaFunction,
stream: &Arc<CudaStream>,
cycles: u64,
) -> Result<()> {
let launch_cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
let mut launch = stream.launch_builder(function);
unsafe {
launch.arg(&cycles);
launch.launch(launch_cfg)?;
}
Ok(())
}
pub fn launch(&self, duration: Duration, stream: &Arc<CudaStream>) -> Result<()> {
let target_cycles = (duration.as_millis() as f64 * self.cycles_per_ms) as u64;
Self::launch_kernel(&self.function, stream, target_cycles)
}
}
}