use std::sync::{Arc, Mutex};
use smallvec::SmallVec;
use vyre_driver::accounting::checked_add_usize_lazy;
use vyre_driver::binding::BindingRole;
use vyre_driver::input_identity::{domain_separated_exact_input_key, ExactInputKey};
use vyre_driver::{backend::private, BackendError, DispatchConfig, LaunchPlan};
use vyre_foundation::ir::Program;
use crate::backend::allocations::DeviceAllocation;
use crate::backend::module_cache::PtxSourceCacheKey;
use crate::backend::{CachedCudaGraph, CudaBackend, CudaDispatchPlan, ModuleCacheKey};
use crate::device::CudaDeviceCaps;
mod compiled_dispatch;
mod materialized_cache;
mod static_params;
#[cfg(test)]
pub(crate) use materialized_cache::{
materialized_input_key, MAX_MATERIALIZED_OUTPUT_CACHE_BYTES_PER_PIPELINE,
};
pub(crate) use materialized_cache::{
MaterializedPipelineOutputCache, MaterializedPipelineOutputCacheEntry,
};
use static_params::upload_static_launch_params;
#[derive(Debug)]
pub(crate) struct CudaCompiledPipeline {
backend: CudaBackend,
program: Arc<Program>,
ptx_src: Arc<str>,
module_key: ModuleCacheKey,
prepared: CudaDispatchPlan,
compiled_config: DispatchConfig,
graph_cache: Mutex<SmallVec<[CachedCudaGraph; MAX_GRAPH_CACHE_ENTRIES_PER_PIPELINE]>>,
materialized_output_cache: Mutex<MaterializedPipelineOutputCache>,
static_params: DeviceAllocation,
id: String,
}
pub(crate) const MAX_GRAPH_CACHE_ENTRIES_PER_PIPELINE: usize = 32;
const CUDA_GRAPH_REPLAY_SMS_PER_LANE: usize = 8;
const CUDA_GRAPH_REPLAY_MIN_CONCURRENT_LANES: usize = 2;
const CUDA_GRAPH_REPLAY_VRAM_FRACTION_DENOMINATOR: u64 = 64;
const CUDA_COMPILED_PIPELINE_ID_DOMAIN: &[u8] = b"vyre.cuda.pipeline.compiled.v1";
fn cuda_compiled_pipeline_identity_key(
ptx_source_key: &[u8; 32],
module_key: &[u8; 32],
launch: &LaunchPlan,
) -> Result<ExactInputKey, BackendError> {
let element_count = launch.element_count.to_le_bytes();
let workgroup_x = launch.workgroup[0].to_le_bytes();
let workgroup_y = launch.workgroup[1].to_le_bytes();
let workgroup_z = launch.workgroup[2].to_le_bytes();
let grid_x = launch.grid[0].to_le_bytes();
let grid_y = launch.grid[1].to_le_bytes();
let grid_z = launch.grid[2].to_le_bytes();
domain_separated_exact_input_key(
CUDA_COMPILED_PIPELINE_ID_DOMAIN,
0,
0,
&[
ptx_source_key.as_slice(),
module_key.as_slice(),
element_count.as_slice(),
workgroup_x.as_slice(),
workgroup_y.as_slice(),
workgroup_z.as_slice(),
grid_x.as_slice(),
grid_y.as_slice(),
grid_z.as_slice(),
],
)
}
impl CudaCompiledPipeline {
pub(crate) fn new(
backend: CudaBackend,
program: Arc<Program>,
ptx_src: Arc<str>,
ptx_source_key: PtxSourceCacheKey,
module_key: ModuleCacheKey,
config: &DispatchConfig,
prepared: CudaDispatchPlan,
) -> Result<Self, BackendError> {
let _profiler_range =
crate::profiler::cuda_profiler_range(crate::profiler::CUDA_PIPELINE_COMPILE_RANGE);
let trace = crate::instrumentation::cuda_stage_trace_enabled();
let started = std::time::Instant::now();
if trace {
tracing::debug!(
"[cuda-pipeline] start entry={}",
program.entry_op_id.as_deref().unwrap_or("<anonymous>")
);
}
let digest = cuda_compiled_pipeline_identity_key(
ptx_source_key.as_bytes(),
&module_key.0,
&prepared.launch,
)?;
if trace {
tracing::debug!(
"[cuda-pipeline] +{}ms digest ready",
started.elapsed().as_millis()
);
}
let static_params = upload_static_launch_params(&backend, &prepared.launch.param_words)?;
if trace {
tracing::debug!(
"[cuda-pipeline] +{}ms static params ready bytes={}",
started.elapsed().as_millis(),
static_params.byte_len
);
}
Ok(Self {
backend,
program,
ptx_src,
module_key,
prepared,
compiled_config: config.clone(),
graph_cache: Mutex::new(SmallVec::new()),
materialized_output_cache: Mutex::new(MaterializedPipelineOutputCache::default()),
static_params,
id: format!("cuda:{}", blake3::Hash::from(digest).to_hex()),
})
}
}
#[cfg(test)]
mod identity_tests {
use super::*;
fn generated_key(seed: u32, salt: u32) -> [u8; 32] {
let mut out = [0_u8; 32];
let mut state = seed ^ salt ^ 0xC0DA_CAFE;
for (index, byte) in out.iter_mut().enumerate() {
state = state
.wrapping_mul(1_664_525)
.wrapping_add(1_013_904_223)
.rotate_left((index as u32) & 15);
*byte = (state >> ((index & 3) * 8)) as u8;
}
out
}
fn generated_launch(seed: u32) -> LaunchPlan {
LaunchPlan {
element_count: 1 + (seed % 4096),
workgroup: [
32 + (seed % 8) * 32,
1 + (seed.rotate_left(3) % 4),
1 + (seed.rotate_left(5) % 2),
],
grid: [
1 + (seed % 1024),
1 + (seed.rotate_left(7) % 16),
1 + (seed.rotate_left(11) % 8),
],
param_words: Vec::new(),
max_binding_alignment: std::mem::size_of::<u64>(),
}
}
#[test]
fn cuda_compiled_pipeline_identity_uses_shared_domain_separated_contract() {
for seed in 0_u32..2048 {
let ptx_key = generated_key(seed, 0x5054_5820);
let module_key = generated_key(seed, 0x4D4F_4420);
let launch = generated_launch(seed);
let key = cuda_compiled_pipeline_identity_key(&ptx_key, &module_key, &launch)
.expect("Fix: generated CUDA compiled pipeline key must fit");
let changed_ptx = cuda_compiled_pipeline_identity_key(
&generated_key(seed ^ 1, 0x5054_5820),
&module_key,
&launch,
)
.expect("Fix: generated CUDA compiled pipeline PTX variant must fit");
let changed_module = cuda_compiled_pipeline_identity_key(
&ptx_key,
&generated_key(seed ^ 1, 0x4D4F_4420),
&launch,
)
.expect("Fix: generated CUDA compiled pipeline module variant must fit");
let mut changed_launch = launch.clone();
changed_launch.grid[0] = changed_launch.grid[0].wrapping_add(1);
let changed_launch_key =
cuda_compiled_pipeline_identity_key(&ptx_key, &module_key, &changed_launch)
.expect("Fix: generated CUDA compiled pipeline launch variant must fit");
assert_ne!(key, changed_ptx);
assert_ne!(key, changed_module);
assert_ne!(key, changed_launch_key);
}
}
#[test]
fn cuda_compiled_pipeline_source_does_not_fork_blake3_tuple_hashing() {
let source = include_str!("pipeline.rs");
assert!(
source.contains("domain_separated_exact_input_key")
&& source.contains("cuda_compiled_pipeline_identity_key")
&& !source.contains(&["blake", "3::Hasher::new()"].concat()),
"Fix: CUDA compiled pipeline identity must use the shared domain-separated exact-input key instead of local BLAKE3 tuple hashing."
);
}
}
impl Drop for CudaCompiledPipeline {
fn drop(&mut self) {
self.backend
.transient_pool
.release(std::mem::take(&mut self.static_params));
}
}
impl private::Sealed for CudaCompiledPipeline {}
fn cuda_graph_replay_enabled() -> bool {
crate::instrumentation::cuda_graph_replay_enabled()
}
pub(crate) fn cuda_graph_lane_count_for_batch(
caps: &CudaDeviceCaps,
prepared: &CudaDispatchPlan,
batches: &[&[&[u8]]],
) -> Result<usize, BackendError> {
if batches.is_empty() {
return Ok(0);
}
let hardware_lanes = cuda_graph_hardware_lane_capacity(caps)?;
let shape_bytes = cuda_graph_shape_cached_bytes(prepared, batches[0])?;
let shape_bytes_u64 = u64::try_from(shape_bytes).map_err(|_| BackendError::InvalidProgram {
fix: "Fix: CUDA graph replay shape byte count exceeds u64; split the replay batch before lane planning.".to_string(),
})?;
let host_memory_budget_cap = u64::try_from(usize::MAX).map_err(|source| {
BackendError::InvalidProgram {
fix: format!(
"Fix: host usize::MAX cannot fit u64 while planning CUDA graph lanes: {source}; use a supported host pointer width."
),
}
})?;
let memory_budget_u64 = (caps.total_memory / CUDA_GRAPH_REPLAY_VRAM_FRACTION_DENOMINATOR)
.max(shape_bytes_u64)
.min(host_memory_budget_cap);
let memory_budget = usize::try_from(memory_budget_u64).map_err(|source| {
BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA graph replay memory budget {memory_budget_u64} cannot fit usize: {source}; split the replay batch before lane planning."
),
}
})?;
let memory_lanes = if shape_bytes == 0 {
MAX_GRAPH_CACHE_ENTRIES_PER_PIPELINE
} else {
(memory_budget / shape_bytes)
.max(1)
.min(MAX_GRAPH_CACHE_ENTRIES_PER_PIPELINE)
};
Ok(batches.len().min(hardware_lanes).min(memory_lanes).max(1))
}
fn cuda_graph_hardware_lane_capacity(caps: &CudaDeviceCaps) -> Result<usize, BackendError> {
if !caps.concurrent_kernels {
return Ok(1);
}
let sms = usize::try_from(caps.multi_processor_count_u32()).map_err(|source| {
BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA multiprocessor count cannot fit usize during graph lane planning: {source}; reject corrupt device capabilities before compiling graph replay."
),
}
});
let sms = sms?;
let lanes = sms.div_ceil(CUDA_GRAPH_REPLAY_SMS_PER_LANE);
Ok(lanes
.max(CUDA_GRAPH_REPLAY_MIN_CONCURRENT_LANES)
.min(MAX_GRAPH_CACHE_ENTRIES_PER_PIPELINE))
}
fn cuda_graph_shape_cached_bytes(
prepared: &CudaDispatchPlan,
inputs: &[&[u8]],
) -> Result<usize, BackendError> {
let mut bytes = bucketed_len(std::mem::size_of_val(
prepared.launch.param_words.as_slice(),
))?;
for binding in &prepared.bindings.bindings {
if binding.role == BindingRole::Shared {
continue;
}
let byte_len = binding
.input_index
.and_then(|input_index| inputs.get(input_index).map(|input| input.len()))
.or(binding.static_byte_len)
.ok_or_else(|| BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA graph replay shape cache found binding `{}` without a runtime input or static byte length. Preserve concrete binding byte lengths during dispatch planning instead of treating missing sizes as zero.",
binding.name
),
})?;
bytes = add_shape_bytes(bytes, bucketed_len(byte_len)?)?;
if binding.input_index.is_some() {
bytes = add_shape_bytes(bytes, bucketed_len(byte_len)?)?;
}
if binding.output_index.is_some() {
bytes = add_shape_bytes(bytes, bucketed_len(byte_len)?)?;
}
}
Ok(bytes)
}
fn add_shape_bytes(total: usize, component: usize) -> Result<usize, BackendError> {
checked_add_usize_lazy(total, component, || {
BackendError::InvalidProgram {
fix: "Fix: CUDA graph replay cached shape byte count overflowed; split the replay batch before graph-cache lane planning.".to_string(),
}
})
}
fn bucketed_len(byte_len: usize) -> Result<usize, BackendError> {
byte_len
.max(1)
.checked_next_power_of_two()
.ok_or_else(|| BackendError::InvalidProgram {
fix: "Fix: CUDA graph replay bucketed shape byte count overflowed; split the oversized input or disable graph replay for this shape.".to_string(),
})
}
#[cfg(test)]
mod tests;