use vyre_driver::BackendError;
use crate::backend::allocations::{DeviceAllocation, HostTransferAllocations};
use crate::backend::copy::aligned_async_copy_len;
use crate::backend::launch_params::launch_param_byte_len;
use crate::backend::CudaBackend;
use crate::numeric::CUDA_NUMERIC;
pub(crate) fn upload_static_launch_params(
backend: &CudaBackend,
param_words: &[u32],
) -> Result<DeviceAllocation, BackendError> {
if param_words.is_empty() {
return Ok(DeviceAllocation::default());
}
let param_bytes = launch_param_byte_len(param_words, "compiled-pipeline static")?;
backend.validate_transient_allocation_memory_budget(
param_bytes,
"CUDA compiled-pipeline static parameter bytes",
"CUDA compiled-pipeline static parameter upload",
)?;
let transfer_bytes = aligned_async_copy_len(param_bytes)?;
let allocation = backend.transient_pool.acquire(transfer_bytes)?;
backend
.telemetry
.record_transient_allocation_bytes(CUDA_NUMERIC.usize_to_u64(
allocation.byte_len,
"static launch parameter allocation byte count",
)?);
let mut host_transfers =
HostTransferAllocations::with_capacity(std::sync::Arc::clone(&backend.host_pool), 1, 0)?;
let upload_result = (|| {
let stream = backend.launch_resources.acquire_stream()?;
let result = (|| {
let param_host_ptr = host_transfers.push_u32_words_padded(param_words, transfer_bytes)?;
unsafe {
crate::backend::copy::h2d_async_checked(
allocation.ptr,
param_host_ptr,
transfer_bytes,
stream.raw(),
)?;
}
let result = stream.synchronize();
if result.is_ok() {
backend.telemetry.record_sync_point();
}
result
})();
backend.launch_resources.release_stream(stream);
result
})();
if let Err(err) = upload_result {
backend.transient_pool.release(allocation);
return Err(err);
}
backend.telemetry.record_host_to_device_bytes(
CUDA_NUMERIC.usize_to_u64(param_bytes, "static launch parameter upload byte count")?,
);
backend.telemetry.record_host_upload_operations(1);
backend.telemetry.record_param_upload_bytes(
CUDA_NUMERIC.usize_to_u64(param_bytes, "static launch parameter upload byte count")?,
);
Ok(allocation)
}