use std::ffi::c_void;
use std::sync::Arc;
use cudarc::driver::sys::CUstream;
use smallvec::SmallVec;
use vyre_driver::accounting::checked_add_usize_lazy;
use vyre_driver::binding::BindingRole;
use vyre_driver::transfer_accounting::TransferAccountingPolicy;
use vyre_driver::{BackendError, DispatchConfig, OutputBuffers, PendingDispatch, VyreBackend};
use vyre_foundation::ir::Program;
use crate::numeric::CUDA_NUMERIC;
use crate::CUDA_BACKEND_ID;
use super::allocations::{DispatchAllocations, HostTransferAllocations};
use super::copy::aligned_async_copy_len;
use super::dispatch::CudaBackend;
use super::launch_params::launch_param_byte_len;
use super::module_cache::ModuleCacheKey;
use super::output_range::cuda_output_readback;
use super::plan::CudaDispatchPlan;
use super::staging_reserve::{reserve_smallvec, reserved_vec};
#[derive(Clone, Copy)]
struct HostUpload {
dst: u64,
src: *const c_void,
byte_len: usize,
}
#[derive(Clone, Copy)]
struct DeviceClear {
dst: u64,
byte_len: usize,
}
struct CudaReadyPending {
outputs: Vec<Vec<u8>>,
}
const CUDA_HOST_TRANSFER_ACCOUNTING: TransferAccountingPolicy =
TransferAccountingPolicy::new("CUDA", "split the dispatch into bounded chunks");
impl vyre_driver::backend::private::Sealed for CudaReadyPending {}
impl PendingDispatch for CudaReadyPending {
fn is_ready(&self) -> bool {
true
}
fn await_result(self: Box<Self>) -> Result<Vec<Vec<u8>>, BackendError> {
Ok(self.outputs)
}
}
struct GridSyncSplitCudaBackend<'a>(&'a CudaBackend);
impl vyre_driver::backend::private::Sealed for GridSyncSplitCudaBackend<'_> {}
impl VyreBackend for GridSyncSplitCudaBackend<'_> {
fn id(&self) -> &'static str {
CUDA_BACKEND_ID
}
fn dispatch(
&self,
program: &Program,
inputs: &[Vec<u8>],
config: &DispatchConfig,
) -> Result<Vec<Vec<u8>>, BackendError> {
let mut borrowed_inputs = SmallVec::<[&[u8]; 8]>::new();
reserve_smallvec(&mut borrowed_inputs, inputs.len(), "grid-sync CUDA input")?;
borrowed_inputs.extend(inputs.iter().map(Vec::as_slice));
self.0
.dispatch_borrowed_async(program, &borrowed_inputs, config)?
.await_result()
}
fn dispatch_borrowed(
&self,
program: &Program,
inputs: &[&[u8]],
config: &DispatchConfig,
) -> Result<Vec<Vec<u8>>, BackendError> {
self.0
.dispatch_borrowed_async(program, inputs, config)?
.await_result()
}
fn dispatch_borrowed_into(
&self,
program: &Program,
inputs: &[&[u8]],
config: &DispatchConfig,
outputs: &mut OutputBuffers,
) -> Result<(), BackendError> {
self.0
.dispatch_borrowed_async(program, inputs, config)?
.await_result_into(outputs)
}
fn supports_grid_sync(&self) -> bool {
self.0.supports_grid_sync()
}
fn max_workgroup_size(&self) -> [u32; 3] {
self.0.max_block_dim()
}
fn max_compute_invocations_per_workgroup(&self) -> u32 {
self.0.max_threads_per_block()
}
fn max_compute_workgroups_per_dimension(&self) -> u32 {
self.0.max_grid_dim()[0]
}
}
fn add_transfer_bytes(total: &mut u64, bytes: usize, label: &str) -> Result<(), BackendError> {
CUDA_HOST_TRANSFER_ACCOUNTING.add_bytes(total, bytes, label)
}
fn add_transfer_operation(total: &mut u64, label: &str) -> Result<(), BackendError> {
CUDA_HOST_TRANSFER_ACCOUNTING.add_operation(total, label)
}
impl CudaBackend {
fn dispatch_borrowed_with_grid_sync_split(
&self,
program: &Program,
inputs: &[&[u8]],
config: &DispatchConfig,
) -> Result<Vec<Vec<u8>>, BackendError> {
let adapter = GridSyncSplitCudaBackend(self);
vyre_driver::grid_sync::dispatch_with_grid_sync_split(&adapter, program, inputs, config)
}
pub fn dispatch_borrowed(
&self,
program: &Program,
inputs: &[&[u8]],
config: &DispatchConfig,
) -> Result<Vec<Vec<u8>>, BackendError> {
if vyre_driver::grid_sync::contains_grid_sync(program) && !self.supports_grid_sync() {
return self.dispatch_borrowed_with_grid_sync_split(program, inputs, config);
}
self.dispatch_borrowed_async(program, inputs, config)?
.await_result()
}
pub fn dispatch_async(
&self,
program: &Program,
inputs: &[Vec<u8>],
config: &DispatchConfig,
) -> Result<Box<dyn PendingDispatch>, BackendError> {
let mut borrowed_inputs = SmallVec::<[&[u8]; 8]>::new();
reserve_smallvec(&mut borrowed_inputs, inputs.len(), "borrowed input")?;
for input in inputs {
borrowed_inputs.push(input.as_slice());
}
self.dispatch_borrowed_async(program, &borrowed_inputs, config)
}
pub fn dispatch_borrowed_async(
&self,
program: &Program,
inputs: &[&[u8]],
config: &DispatchConfig,
) -> Result<Box<dyn PendingDispatch>, BackendError> {
let lowered_program =
vyre_foundation::transform::collectives::lower_single_rank_collectives(program)
.map_err(|error| BackendError::InvalidProgram {
fix: error.to_string(),
})?;
let program = lowered_program.as_ref().unwrap_or(program);
if vyre_driver::grid_sync::contains_grid_sync(program) && !self.supports_grid_sync() {
let outputs = self.dispatch_borrowed_with_grid_sync_split(program, inputs, config)?;
return Ok(Box::new(CudaReadyPending { outputs }));
}
let trace = crate::instrumentation::cuda_stage_trace_enabled();
let start = std::time::Instant::now();
if trace {
tracing::debug!(
"[cuda-trace] dispatch_borrowed_async start buffers={} inputs={}",
program.buffers().len(),
inputs.len()
);
}
let prepared = self.prepare_host_dispatch(program, inputs, config)?;
if trace {
tracing::debug!(
"[cuda-trace] +{}ms prepare_host_dispatch",
start.elapsed().as_millis()
);
}
let (ptx_src, ptx_source_key) = self.ptx_for_program_cached_with_key(program, config)?;
if trace {
tracing::debug!(
"[cuda-trace] +{}ms ptx_for_program_cached bytes={}",
start.elapsed().as_millis(),
ptx_src.len()
);
}
let module_key = self.module_cache_key_for_ptx_source_key(ptx_source_key)?;
self.dispatch_prepared_borrowed_async_with_ptx_key(
program, inputs, &ptx_src, module_key, &prepared,
)
}
pub fn dispatch_borrowed_timed(
&self,
program: &Program,
inputs: &[&[u8]],
config: &DispatchConfig,
) -> Result<vyre_driver::TimedDispatchResult, BackendError> {
let lowered_program =
vyre_foundation::transform::collectives::lower_single_rank_collectives(program)
.map_err(|error| BackendError::InvalidProgram {
fix: error.to_string(),
})?;
let program = lowered_program.as_ref().unwrap_or(program);
let prepared = self.prepare_host_dispatch(program, inputs, config)?;
let (ptx_src, ptx_source_key) = self.ptx_for_program_cached_with_key(program, config)?;
let module_key = self.module_cache_key_for_ptx_source_key(ptx_source_key)?;
self.dispatch_prepared_borrowed_timed_with_ptx_key(
program, inputs, config, &ptx_src, module_key, &prepared,
)
}
pub(crate) fn dispatch_borrowed_async_with_ptx_key(
&self,
program: &Program,
inputs: &[&[u8]],
config: &DispatchConfig,
ptx_src: &str,
module_key: ModuleCacheKey,
) -> Result<Box<dyn PendingDispatch>, BackendError> {
let prepared = self.prepare_host_dispatch(program, inputs, config)?;
self.dispatch_prepared_borrowed_async_with_ptx_key(
program, inputs, ptx_src, module_key, &prepared,
)
}
pub(crate) fn dispatch_borrowed_timed_with_ptx_key(
&self,
program: &Program,
inputs: &[&[u8]],
config: &DispatchConfig,
ptx_src: &str,
module_key: ModuleCacheKey,
) -> Result<vyre_driver::TimedDispatchResult, BackendError> {
let prepared = self.prepare_host_dispatch(program, inputs, config)?;
self.dispatch_prepared_borrowed_timed_with_ptx_key(
program, inputs, config, ptx_src, module_key, &prepared,
)
}
pub(crate) fn dispatch_prepared_borrowed_async_with_ptx_key(
&self,
program: &Program,
inputs: &[&[u8]],
ptx_src: &str,
module_key: ModuleCacheKey,
prepared: &CudaDispatchPlan,
) -> Result<Box<dyn PendingDispatch>, BackendError> {
Ok(Box::new(self.dispatch_borrowed_async_with_ptx_concrete(
program, inputs, ptx_src, module_key, false, prepared,
)?))
}
pub(crate) fn dispatch_prepared_borrowed_timed_with_ptx_key(
&self,
program: &Program,
inputs: &[&[u8]],
config: &DispatchConfig,
ptx_src: &str,
module_key: ModuleCacheKey,
prepared: &CudaDispatchPlan,
) -> Result<vyre_driver::TimedDispatchResult, BackendError> {
let started = std::time::Instant::now();
let enqueue_started = std::time::Instant::now();
let pending = self.dispatch_borrowed_async_with_ptx_concrete(
program, inputs, ptx_src, module_key, true, prepared,
)?;
let enqueue_ns =
CUDA_NUMERIC.elapsed_nanos_u64(enqueue_started, "host-dispatch enqueue latency")?;
let wait_started = std::time::Instant::now();
let (outputs, device_ns) = pending.await_timed_result()?;
let wait_ns = CUDA_NUMERIC.elapsed_nanos_u64(wait_started, "host-dispatch wait latency")?;
if let Some(measured_device_ns) = device_ns {
let _accepted = vyre_driver::launch::record_launch_measurement(
program,
config,
self.launch_limits(),
prepared.launch.element_count,
prepared.launch.workgroup,
measured_device_ns,
);
}
let wall_ns = CUDA_NUMERIC.elapsed_nanos_u64(started, "host-dispatch wall latency")?;
self.telemetry
.record_timed_dispatch(wall_ns, device_ns, Some(enqueue_ns), Some(wait_ns));
Ok(vyre_driver::TimedDispatchResult {
outputs,
wall_ns,
device_ns,
enqueue_ns: Some(enqueue_ns),
wait_ns: Some(wait_ns),
})
}
fn dispatch_borrowed_async_with_ptx_concrete(
&self,
program: &Program,
inputs: &[&[u8]],
ptx_src: &str,
module_key: ModuleCacheKey,
capture_timing: bool,
prepared: &CudaDispatchPlan,
) -> Result<crate::stream::CudaPendingDispatch, BackendError> {
let _profiler_range =
crate::profiler::cuda_profiler_range(crate::profiler::CUDA_HOST_DISPATCH_RANGE);
if prepared
.bindings
.bindings
.iter()
.any(|binding| binding.role == BindingRole::Persistent)
{
return Err(BackendError::UnsupportedFeature {
name: "cuda_persistent_memory_binding".to_string(),
backend: crate::CUDA_BACKEND_ID.to_string(),
});
}
let trace = crate::instrumentation::cuda_stage_trace_enabled();
let start = std::time::Instant::now();
self.warmup()?;
if trace {
tracing::debug!("[cuda-trace] +{}ms warmup", start.elapsed().as_millis());
}
self.validate_transient_dispatch_memory_budget(prepared, inputs, "CUDA host dispatch")?;
let buffers = program.buffers();
let mut allocations =
DispatchAllocations::new(buffers.len(), Arc::clone(&self.transient_pool))?;
let (transfer_capacity, output_capacity) = host_transfer_capacities(prepared)?;
let mut host_transfers = HostTransferAllocations::with_capacity(
Arc::clone(&self.host_pool),
transfer_capacity,
output_capacity,
)?;
let mut host_uploads = SmallVec::<[HostUpload; 8]>::new();
reserve_smallvec(
&mut host_uploads,
host_upload_batch_capacity(prepared)?,
"host upload",
)?;
let mut device_clears = SmallVec::<[DeviceClear; 8]>::new();
reserve_smallvec(
&mut device_clears,
prepared.bindings.bindings.len(),
"device clear",
)?;
let mut upload_bytes = 0_u64;
let mut upload_operations = 0_u64;
for binding in &prepared.bindings.bindings {
if binding.role == BindingRole::Shared {
continue;
}
let byte_len = match binding.input_index {
Some(input_index) => inputs[input_index].len(),
None => binding.static_byte_len.ok_or_else(|| BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA output `{}` needs a static byte length before launch; set BufferDecl::with_count or output_byte_range.",
binding.name
),
})?,
};
let allocation_byte_len = aligned_async_copy_len(byte_len)?;
let allocation = self.transient_pool.acquire(allocation_byte_len)?;
self.telemetry.record_transient_allocation_bytes(
CUDA_NUMERIC
.usize_to_u64(allocation.byte_len, "transient allocation byte count")?,
);
let dev_ptr = allocation.ptr;
allocations.set_ptr(binding.buffer_index, allocation);
if let Some(input_index) = binding.input_index {
let input = inputs[input_index];
let copy_byte_len = aligned_async_copy_len(input.len())?;
let host_ptr = host_transfers.push_upload_padded(input, copy_byte_len)?;
add_transfer_bytes(&mut upload_bytes, input.len(), "host upload")?;
if !input.is_empty() {
add_transfer_operation(&mut upload_operations, "host upload")?;
}
host_uploads.push(HostUpload {
dst: dev_ptr,
src: host_ptr,
byte_len: copy_byte_len,
});
} else if byte_len != 0 {
device_clears.push(DeviceClear {
dst: dev_ptr,
byte_len: allocation.byte_len,
});
}
}
let param_bytes = launch_param_byte_len(&prepared.launch.param_words, "host dispatch")?;
let params_buf_ptr = if param_bytes == 0 {
0
} else {
let param_copy_bytes = aligned_async_copy_len(param_bytes)?;
let params_allocation = self.transient_pool.acquire(param_copy_bytes)?;
self.telemetry
.record_transient_allocation_bytes(CUDA_NUMERIC.usize_to_u64(
params_allocation.byte_len,
"parameter allocation byte count",
)?);
let params_buf_ptr = params_allocation.ptr;
let param_host_ptr = host_transfers
.push_u32_words_padded(&prepared.launch.param_words, param_copy_bytes)?;
host_uploads.push(HostUpload {
dst: params_buf_ptr,
src: param_host_ptr,
byte_len: param_copy_bytes,
});
add_transfer_bytes(&mut upload_bytes, param_bytes, "parameter upload")?;
add_transfer_operation(&mut upload_operations, "parameter upload")?;
self.telemetry.record_param_upload_bytes(
CUDA_NUMERIC.usize_to_u64(param_bytes, "parameter upload byte count")?,
);
allocations.set_params(params_allocation);
params_buf_ptr
};
let launch_resources = crate::stream::CudaLaunchResourceLease::acquire(
Arc::clone(&self.launch_resources),
capture_timing,
)?;
let stream_raw = launch_resources.stream_raw()?;
if trace {
tracing::debug!(
"[cuda-trace] +{}ms stream/events",
start.elapsed().as_millis()
);
}
enqueue_host_uploads_async(&host_uploads, stream_raw)?;
self.telemetry.record_host_to_device_bytes(upload_bytes);
self.telemetry
.record_host_upload_operations(upload_operations);
enqueue_device_clears_async(&device_clears, stream_raw)?;
if trace {
tracing::debug!(
"[cuda-trace] +{}ms alloc/upload/clear",
start.elapsed().as_millis()
);
}
if let Some((start_event, _)) = launch_resources.timing_events()? {
start_event.record(stream_raw)?;
}
let func = self.resolve_launch_function(
ptx_src,
module_key,
&prepared.launch,
prepared.cooperative,
)?;
if trace {
tracing::debug!(
"[cuda-trace] +{}ms resolve_launch_function",
start.elapsed().as_millis()
);
}
let mut ptr_values = SmallVec::<[u64; 8]>::new();
reserve_smallvec(
&mut ptr_values,
prepared.bindings.bindings.len(),
"kernel pointer argument",
)?;
for binding in &prepared.bindings.bindings {
if binding.role == BindingRole::Shared {
continue;
}
let ptr = allocations.ptr(binding.buffer_index);
if ptr == 0 {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA launch binding `{}` has no device allocation; argument order must match the lowered kernel descriptor.",
binding.name
),
});
}
ptr_values.push(ptr);
}
if trace {
tracing::debug!(
"[cuda-trace] +{}ms host args ptr_values={:x?} params=0x{params_buf_ptr:x} words={:?} grid={:?} workgroup={:?} element_count={}",
start.elapsed().as_millis(),
ptr_values,
prepared.launch.param_words,
prepared.launch.grid,
prepared.launch.workgroup,
prepared.launch.element_count
);
}
let mut params_ref = params_buf_ptr;
let mut kernel_args = Self::kernel_args(&mut ptr_values, &mut params_ref)?;
for _ in 0..prepared.fixpoint_iterations {
self.launch_prevalidated_function(
func,
&mut kernel_args,
&prepared.launch,
stream_raw,
false,
prepared.cooperative,
)?;
}
if trace {
tracing::debug!("[cuda-trace] +{}ms launch", start.elapsed().as_millis());
}
let mut readback_bytes = 0_u64;
let mut readback_operations = 0_u64;
for &binding_index in &prepared.output_binding_indices {
let binding = &prepared.bindings.bindings[binding_index];
let full_byte_len = match binding.static_byte_len {
Some(len) => len,
None => match binding.input_index {
Some(input_index) => inputs[input_index].len(),
None => {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA output `{}` needs a static byte length before readback.",
binding.name
),
});
}
},
};
let readback = cuda_output_readback(&buffers[binding.buffer_index], full_byte_len)?;
let allocation_byte_len = allocations.byte_len(binding.buffer_index);
let padded_readback_len = aligned_async_copy_len(readback.byte_len)?;
let readback_end = readback
.device_offset
.checked_add(padded_readback_len)
.unwrap_or(usize::MAX);
let copy_byte_len = if readback_end <= allocation_byte_len {
padded_readback_len
} else {
readback.byte_len
};
let out_ptr = host_transfers.push_output_padded(readback.byte_len, copy_byte_len)?;
if readback.byte_len != 0 {
add_transfer_bytes(&mut readback_bytes, readback.byte_len, "output readback")?;
add_transfer_operation(&mut readback_operations, "output readback")?;
let base_ptr = allocations.ptr(binding.buffer_index);
let device_ptr = vyre_driver::accounting::checked_add_u64_usize_offset_lazy(
base_ptr,
readback.device_offset,
|| {
BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA host dispatch readback device offset {} for output `{}` does not fit CUdeviceptr arithmetic.",
readback.device_offset, binding.name
),
}
},
|| {
BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA host dispatch readback pointer overflowed for output `{}` at device_ptr={base_ptr} offset={}. Rebuild the program with a valid output byte range or split the output buffer.",
binding.name, readback.device_offset
),
}
},
)?;
unsafe {
super::copy::d2h_async_checked(
out_ptr,
device_ptr,
copy_byte_len,
stream_raw,
)?;
}
}
}
self.telemetry
.record_device_to_host_readback(readback_bytes);
self.telemetry
.record_device_readback_operations(readback_operations);
if let Some((_, end_event)) = launch_resources.timing_events()? {
end_event.record(stream_raw)?;
}
let event = self.launch_resources.acquire_event()?;
if let Err(error) = event.record(stream_raw) {
self.launch_resources.release_event(event);
return Err(error);
}
if trace {
tracing::debug!(
"[cuda-trace] +{}ms readback/event",
start.elapsed().as_millis()
);
}
let (stream, timing_events) = launch_resources.into_parts()?;
let output_storage = reserved_vec(prepared.output_binding_indices.len(), "pending output")?;
if let Some((start_event, end_event)) = timing_events {
Ok(crate::stream::CudaPendingDispatch::new_with_timing(
Arc::clone(&self.ctx),
Arc::clone(&self.launch_resources),
event,
stream,
allocations,
None,
Some(host_transfers),
output_storage,
start_event,
end_event,
Arc::clone(&self.telemetry),
))
} else {
Ok(crate::stream::CudaPendingDispatch::new(
Arc::clone(&self.ctx),
Arc::clone(&self.launch_resources),
event,
stream,
allocations,
None,
Some(host_transfers),
output_storage,
Arc::clone(&self.telemetry),
))
}
}
pub fn dispatch(
&self,
program: &Program,
inputs: &[Vec<u8>],
config: &DispatchConfig,
) -> Result<Vec<Vec<u8>>, BackendError> {
let lowered_program =
vyre_foundation::transform::collectives::lower_single_rank_collectives(program)
.map_err(|error| BackendError::InvalidProgram {
fix: error.to_string(),
})?;
let program = lowered_program.as_ref().unwrap_or(program);
let required = vyre_foundation::program_caps::scan(program);
let validation_caps = self.program_validation_caps();
vyre_foundation::program_caps::check_backend_capabilities(
validation_caps.backend_id,
validation_caps.supports_subgroup_ops,
validation_caps.supports_f16,
validation_caps.supports_bf16,
validation_caps.supports_indirect_dispatch,
validation_caps.supports_trap_propagation,
validation_caps.supports_distributed_collectives,
validation_caps.max_workgroup_size,
&required,
)
.map_err(|error| BackendError::InvalidProgram {
fix: error.to_string(),
})?;
if vyre_driver::grid_sync::contains_grid_sync(program) && !self.supports_grid_sync() {
let mut borrowed_inputs = SmallVec::<[&[u8]; 8]>::new();
reserve_smallvec(
&mut borrowed_inputs,
inputs.len(),
"grid-sync CUDA dispatch input",
)?;
borrowed_inputs.extend(inputs.iter().map(Vec::as_slice));
return self.dispatch_borrowed_with_grid_sync_split(program, &borrowed_inputs, config);
}
self.dispatch_async(program, inputs, config)?.await_result()
}
}
#[inline]
fn host_transfer_capacities(prepared: &CudaDispatchPlan) -> Result<(usize, usize), BackendError> {
let output_capacity = prepared.output_binding_indices.len();
let upload_capacity = host_upload_batch_capacity(prepared)?;
let transfer_capacity = checked_add_usize_lazy(upload_capacity, output_capacity, || {
BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA host transfer capacity overflowed usize for {upload_capacity} upload slot(s) plus {output_capacity} output slot(s); split the dispatch."
),
}
})?;
Ok((transfer_capacity, output_capacity))
}
#[inline]
fn host_upload_batch_capacity(prepared: &CudaDispatchPlan) -> Result<usize, BackendError> {
let input_slots = prepared.bindings.input_indices.len();
checked_add_usize_lazy(
input_slots,
usize::from(!prepared.launch.param_words.is_empty()),
|| {
BackendError::InvalidProgram {
fix: "Fix: CUDA host upload batch capacity overflowed usize while adding the params upload slot; split the dispatch."
.to_string(),
}
},
)
}
#[inline]
fn enqueue_host_uploads_async(
uploads: &[HostUpload],
stream: CUstream,
) -> Result<(), BackendError> {
for upload in uploads {
if upload.byte_len == 0 {
continue;
}
unsafe {
super::copy::h2d_async_checked(upload.dst, upload.src, upload.byte_len, stream)?;
}
}
Ok(())
}
#[inline]
fn enqueue_device_clears_async(
clears: &[DeviceClear],
stream: CUstream,
) -> Result<(), BackendError> {
for clear in clears {
unsafe {
super::copy::memset_d8_async_checked(clear.dst, 0, clear.byte_len, stream)?;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::{host_transfer_capacities, host_upload_batch_capacity};
use crate::backend::CudaDispatchPlan;
use smallvec::smallvec;
use std::sync::Arc;
use vyre_driver::binding::{Binding, BindingPlan, BindingRole};
use vyre_driver::LaunchPlan;
#[test]
fn host_upload_batch_capacity_counts_inputs_once_plus_params() {
let plan = CudaDispatchPlan {
bindings: BindingPlan {
bindings: vec![
Binding {
name: Arc::from("a"),
binding: 0,
buffer_index: 0,
role: BindingRole::Input,
element_size: 4,
preferred_alignment: 4,
element_count: 16,
static_byte_len: Some(64),
input_index: Some(0),
output_index: None,
},
Binding {
name: Arc::from("b"),
binding: 1,
buffer_index: 1,
role: BindingRole::InputOutput,
element_size: 4,
preferred_alignment: 4,
element_count: 16,
static_byte_len: Some(64),
input_index: Some(1),
output_index: Some(0),
},
Binding {
name: Arc::from("out"),
binding: 2,
buffer_index: 2,
role: BindingRole::Output,
element_size: 4,
preferred_alignment: 4,
element_count: 16,
static_byte_len: Some(64),
input_index: None,
output_index: Some(1),
},
],
input_indices: vec![0, 1],
output_indices: vec![1, 2],
shared_indices: Vec::new(),
},
output_binding_indices: smallvec![1, 2],
launch: LaunchPlan::new(),
cooperative: false,
fixpoint_iterations: 1,
};
assert_eq!(
host_upload_batch_capacity(&plan).expect("Fix: capacity must fit"),
2,
"zero-byte launch params must not reserve a fake H2D upload slot"
);
assert_eq!(
host_transfer_capacities(&plan).expect("Fix: capacity must fit"),
(4, 2),
"pinned-host transfer storage must reserve inputs + outputs only when params are empty"
);
let mut plan_with_params = plan;
plan_with_params.launch.param_words.push(7);
assert_eq!(
host_upload_batch_capacity(&plan_with_params).expect("Fix: capacity must fit"),
3,
"non-empty launch params must reserve one H2D upload slot"
);
assert_eq!(
host_transfer_capacities(&plan_with_params).expect("Fix: capacity must fit"),
(5, 2),
"pinned-host transfer storage must reserve inputs + params + outputs when params exist"
);
}
}