use std::ffi::c_void;
use std::sync::Arc;
use rustc_hash::FxHashSet;
use smallvec::SmallVec;
use vyre_driver::binding::BindingRole;
use vyre_driver::{BackendError, DispatchConfig};
use vyre_foundation::ir::Program;
use crate::backend::allocations::{DispatchAllocations, HostTransferAllocations};
use crate::backend::dispatch::CudaBackend;
use crate::backend::launch_params::launch_param_byte_len;
use crate::backend::module_cache::ModuleCacheKey;
use crate::backend::ordering::sort_unstable_by_key_if_needed;
use crate::backend::output_range::{cuda_output_readback, CudaOutputReadback};
use crate::backend::plan::CudaDispatchPlan;
use crate::backend::resident::{CudaResidentBuffer, ResidentViewCache};
use crate::backend::resident_dispatch::helpers::{
enqueue_optional_resident_h2d_copy, resident_required_handles,
};
use crate::backend::resident_dispatch_support::{
checked_resident_dispatch_capacity_mul, CudaResidentBatchDispatch,
};
use crate::backend::staging_reserve::{reserve_hash_set, reserve_smallvec};
impl CudaBackend {
#[allow(clippy::too_many_arguments)]
pub(crate) fn dispatch_resident_batch_async_concrete_with_ptx_key(
&self,
program: &Program,
batches: &[SmallVec<[CudaResidentBuffer; 8]>],
_config: &DispatchConfig,
ptx_src: &str,
module_key: ModuleCacheKey,
static_params_ptr: Option<u64>,
prepared: &CudaDispatchPlan,
) -> Result<CudaResidentBatchDispatch, BackendError> {
if batches.is_empty() {
return Err(BackendError::InvalidProgram {
fix:
"Fix: CUDA resident batch dispatch requires at least one resident handle tuple."
.into(),
});
}
self.warmup()?;
let required_handles = resident_required_handles(prepared)?;
let batch_handle_capacity = checked_resident_dispatch_capacity_mul(
batches.len(),
required_handles,
"batch handle",
)?;
let mut all_handles = SmallVec::<[CudaResidentBuffer; 32]>::new();
reserve_smallvec(
&mut all_handles,
batch_handle_capacity,
"resident batch all handles",
)?;
for (batch_index, handles) in batches.iter().enumerate() {
if handles.len() != required_handles {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA resident batch dispatch item {batch_index} expected {required_handles} resident buffer handle(s) but received {}.",
handles.len()
),
});
}
all_handles.extend(handles.iter().copied());
}
let param_bytes =
launch_param_byte_len(&prepared.launch.param_words, "resident batch dispatch")?;
let mut allocations =
DispatchAllocations::new(program.buffers().len(), Arc::clone(&self.transient_pool))?;
let mut host_transfers = HostTransferAllocations::with_capacity(
Arc::clone(&self.host_pool),
usize::from(static_params_ptr.is_none() && param_bytes != 0),
0,
)?;
let mut param_upload: Option<(u64, *const c_void, usize)> = None;
let params_ptr = match static_params_ptr {
Some(ptr) => ptr,
None if param_bytes == 0 => 0,
None => {
let (params_ptr, upload) = self.prepare_resident_param_upload(
&prepared.launch.param_words,
param_bytes,
"CUDA resident batch dispatch parameter bytes",
"CUDA resident batch dispatch parameter upload",
"resident batch dispatch parameter allocation byte count",
"resident batch dispatch parameter upload byte count",
&mut allocations,
&mut host_transfers,
)?;
param_upload = upload;
params_ptr
}
};
let func = self.resolve_launch_function(
ptx_src,
module_key,
&prepared.launch,
prepared.cooperative,
)?;
let mut output_handles_by_batch = SmallVec::<[SmallVec<[CudaResidentBuffer; 8]>; 8]>::new();
reserve_smallvec(
&mut output_handles_by_batch,
batches.len(),
"resident batch output handles",
)?;
let mut output_readbacks_by_batch =
SmallVec::<[SmallVec<[CudaOutputReadback; 8]>; 8]>::new();
reserve_smallvec(
&mut output_readbacks_by_batch,
batches.len(),
"resident batch output readbacks",
)?;
let mut launch_ptrs_by_batch = SmallVec::<[SmallVec<[u64; 8]>; 8]>::new();
reserve_smallvec(
&mut launch_ptrs_by_batch,
batches.len(),
"resident batch launch pointer groups",
)?;
let output_binding_count = prepared.output_binding_indices.len();
let total_output_entries = if output_binding_count == 0 {
0usize
} else {
checked_resident_dispatch_capacity_mul(
batches.len(),
output_binding_count,
"batch output-handle set",
)?
};
let seen_outputs_small = total_output_entries <= 8 && total_output_entries != 0;
let mut seen_output_handles_small = SmallVec::<[u64; 8]>::new();
reserve_smallvec(
&mut seen_output_handles_small,
total_output_entries.min(8),
"resident batch small output duplicate set",
)?;
let mut seen_output_handles = if !seen_outputs_small && total_output_entries != 0 {
let mut set = FxHashSet::default();
reserve_hash_set(
&mut set,
total_output_entries,
"resident batch output duplicate set",
)?;
Some(set)
} else {
None
};
for (batch_index, handles) in batches.iter().enumerate() {
let mut launch_ptrs = SmallVec::<[u64; 8]>::new();
reserve_smallvec(
&mut launch_ptrs,
prepared.bindings.bindings.len(),
"resident batch launch pointers",
)?;
let mut next_handle = 0usize;
let mut output_handles_by_index =
SmallVec::<[(usize, CudaResidentBuffer, CudaOutputReadback); 8]>::new();
reserve_smallvec(
&mut output_handles_by_index,
prepared.output_binding_indices.len(),
"resident batch output handles by index",
)?;
let mut resident_view_cache = ResidentViewCache::new();
reserve_smallvec(
&mut resident_view_cache,
handles.len(),
"resident batch dispatch view cache",
)?;
for binding in &prepared.bindings.bindings {
if binding.role == BindingRole::Shared {
continue;
}
let handle = handles[next_handle];
next_handle += 1;
let resident = self.resident_store.view_cached(
handle,
&mut resident_view_cache,
"resident batch dispatch view cache",
)?;
if let Some(expected) = binding.static_byte_len {
if resident.byte_len < expected {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA resident batch dispatch item {batch_index} binding `{}` expected at least {expected} bytes but handle {} has {} bytes.",
binding.name, handle.id, resident.byte_len
),
});
}
}
if resident.ptr == 0 {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA resident batch dispatch item {batch_index} binding `{}` resolved to a null device pointer; resident launch arguments must preserve descriptor order.",
binding.name
),
});
}
launch_ptrs.push(resident.ptr);
if let Some(output_index) = binding.output_index {
let full_byte_len = match binding.static_byte_len {
Some(len) => len,
None => resident.byte_len,
};
let readback = cuda_output_readback(
&program.buffers()[binding.buffer_index],
full_byte_len,
)?;
output_handles_by_index.push((output_index, handle, readback));
}
}
if output_handles_by_index.len() != prepared.output_binding_indices.len() {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA resident batch dispatch item {batch_index} expected {} output handle(s) but resolved {}.",
prepared.output_binding_indices.len(),
output_handles_by_index.len()
),
});
}
sort_unstable_by_key_if_needed(
output_handles_by_index.as_mut_slice(),
|(output_index, _, _)| *output_index,
);
let mut output_handles = SmallVec::<[CudaResidentBuffer; 8]>::new();
reserve_smallvec(
&mut output_handles,
output_handles_by_index.len(),
"resident batch output handles",
)?;
let mut output_readbacks = SmallVec::<[CudaOutputReadback; 8]>::new();
reserve_smallvec(
&mut output_readbacks,
output_handles_by_index.len(),
"resident batch output readbacks",
)?;
for (_, handle, readback) in output_handles_by_index {
if !seen_outputs_small {
if let Some(seen_output_handles) = seen_output_handles.as_mut() {
if !seen_output_handles.insert(handle.id) {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA resident batch dispatch cannot reuse output handle {} across submitted items; allocate one output resident buffer tuple per in-flight batch item so batched readback observes every result instead of the final overwrite.",
handle.id
),
});
}
}
} else {
if seen_output_handles_small.contains(&handle.id) {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA resident batch dispatch cannot reuse output handle {} across submitted items; allocate one output resident buffer tuple per in-flight batch item so batched readback observes every result instead of the final overwrite.",
handle.id
),
});
}
seen_output_handles_small.push(handle.id);
}
output_handles.push(handle);
output_readbacks.push(readback);
}
if output_handles.len() != prepared.output_binding_indices.len() {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA resident batch dispatch item {batch_index} expected {} output handle(s) but resolved {}.",
prepared.output_binding_indices.len(),
output_handles.len()
),
});
}
if output_handles.len() != output_readbacks.len() {
return Err(BackendError::InvalidProgram {
fix: "Fix: CUDA resident batch dispatch output handle/readback stream mismatch after reordering outputs."
.into(),
});
}
launch_ptrs_by_batch.push(launch_ptrs);
output_handles_by_batch.push(output_handles);
output_readbacks_by_batch.push(output_readbacks);
}
let resident_use = self.resident_store.mark_inflight(&all_handles)?;
let launch_resources = crate::stream::CudaLaunchResourceLease::acquire(
Arc::clone(&self.launch_resources),
false,
)?;
let stream_raw = launch_resources.stream_raw()?;
enqueue_optional_resident_h2d_copy(param_upload, stream_raw)?;
let mut kernel_args = SmallVec::<[*mut c_void; 8]>::new();
for mut launch_ptrs in launch_ptrs_by_batch {
let mut params_ref = params_ptr;
Self::kernel_args_into(&mut launch_ptrs, &mut params_ref, &mut kernel_args)?;
for _ in 0..prepared.fixpoint_iterations {
self.launch_prevalidated_function(
func,
&mut kernel_args,
&prepared.launch,
stream_raw,
false,
prepared.cooperative,
)?;
}
}
let event = self.launch_resources.acquire_event()?;
if let Err(error) = event.record(stream_raw) {
self.launch_resources.release_event(event);
return Err(error);
}
let (stream, _) = launch_resources.into_parts()?;
let pending = crate::stream::CudaPendingDispatch::new_resident_batch_pending(
Arc::clone(&self.ctx),
Arc::clone(&self.launch_resources),
event,
stream,
allocations,
resident_use,
host_transfers,
Arc::clone(&self.telemetry),
);
Ok(CudaResidentBatchDispatch {
pending,
output_handles: output_handles_by_batch,
output_readbacks: output_readbacks_by_batch,
})
}
}