use vyre_driver::launch_fusion::{
plan_launch_fusion, plan_launch_fusion_with_scratch, LaunchFusionError, LaunchFusionGroup,
LaunchFusionPlan, LaunchFusionScratch, LaunchFusionStage,
};
pub type CudaFusionStage = LaunchFusionStage;
pub type CudaLaunchFusionGroup = LaunchFusionGroup;
pub type CudaLaunchFusionPlan = LaunchFusionPlan;
pub type CudaLaunchFusionScratch = LaunchFusionScratch;
pub type CudaLaunchFusionError = LaunchFusionError;
pub fn plan_cuda_launch_fusion(
stages: &[CudaFusionStage],
max_group_bytes: u64,
) -> Result<CudaLaunchFusionPlan, CudaLaunchFusionError> {
plan_launch_fusion(stages, max_group_bytes)
}
pub fn plan_cuda_launch_fusion_with_scratch(
stages: &[CudaFusionStage],
max_group_bytes: u64,
scratch: &mut CudaLaunchFusionScratch,
) -> Result<CudaLaunchFusionPlan, CudaLaunchFusionError> {
plan_launch_fusion_with_scratch(stages, max_group_bytes, scratch)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cuda_launch_fusion_is_adapter_not_algorithm_fork() {
let production = include_str!("launch_fusion.rs")
.split("#[cfg(test)]")
.next()
.expect("Fix: CUDA launch fusion production source must precede tests.");
assert!(
production.contains("vyre_driver::launch_fusion"),
"Fix: CUDA launch fusion must delegate to the backend-neutral driver owner."
);
for forbidden in [
"FxHashSet",
"CudaStorageReserveFailure",
"CudaArithmeticOverflow",
"fn singleton_group_with_capacity",
"fn can_append_to_group",
"fn fused_required_bytes",
"fn stage_required_bytes",
"reserved_typed_vec",
"reserve_typed_hash_set",
] {
assert!(
!production.contains(forbidden),
"Fix: CUDA launch fusion must not carry local adjacent-stage fusion logic: {forbidden}."
);
}
}
#[test]
fn launch_fusion_groups_adjacent_compatible_stages() {
let plan = plan_cuda_launch_fusion(
&[
stage(1, 7, 64, 32, 8, false),
stage(2, 7, 32, 48, 8, false),
stage(3, 7, 48, 16, 8, false),
],
256,
)
.expect("Fix: compatible stages should fuse");
assert_eq!(plan.launch_count, 1);
assert_eq!(plan.avoided_launches, 2);
assert_eq!(plan.groups[0].stage_ids, vec![1, 2, 3]);
assert_eq!(plan.avoided_intermediate_bytes, 80);
}
#[test]
fn launch_fusion_splits_on_layout_host_boundary_and_budget() {
let plan = plan_cuda_launch_fusion(
&[
stage(1, 7, 64, 32, 8, false),
stage(2, 8, 32, 48, 8, false),
stage(3, 8, 48, 16, 8, true),
stage(4, 9, 16, 16, 8, false),
],
128,
)
.expect("Fix: incompatible stages should split deterministically");
assert_eq!(plan.launch_count, 4);
assert_eq!(plan.avoided_launches, 0);
assert_eq!(plan.groups[0].stage_ids, vec![1]);
assert_eq!(plan.groups[1].stage_ids, vec![2]);
assert_eq!(plan.groups[2].stage_ids, vec![3]);
assert_eq!(plan.groups[3].stage_ids, vec![4]);
}
#[test]
fn launch_fusion_rejects_invalid_inputs() {
assert_eq!(
plan_cuda_launch_fusion(&[stage(1, 7, 1, 1, 1, false)], 0)
.expect_err("zero budget should fail"),
CudaLaunchFusionError::ZeroBudget
);
assert_eq!(
plan_cuda_launch_fusion(
&[stage(1, 7, 1, 1, 1, false), stage(1, 7, 1, 1, 1, false),],
128,
)
.expect_err("duplicate stages should fail"),
CudaLaunchFusionError::DuplicateStage { id: 1 }
);
assert_eq!(
plan_cuda_launch_fusion(&[stage(9, 7, 64, 32, 64, false)], 128)
.expect_err("single over-budget stage should fail"),
CudaLaunchFusionError::StageOverBudget {
id: 9,
required_bytes: 160,
budget_bytes: 128,
}
);
}
#[test]
fn launch_fusion_reuses_caller_owned_duplicate_detection_scratch() {
let mut scratch = CudaLaunchFusionScratch::try_with_capacity(64)
.expect("Fix: fusion scratch should reserve");
let wide = (0..64)
.map(|id| stage(id, 7, 16, 16, 4, false))
.collect::<Vec<_>>();
let first = plan_cuda_launch_fusion_with_scratch(&wide, 8_192, &mut scratch)
.expect("Fix: wide compatible CUDA stages should fuse");
let id_capacity = scratch.id_capacity();
assert_eq!(first.launch_count, 1);
assert_eq!(first.avoided_launches, 63);
let second = plan_cuda_launch_fusion_with_scratch(
&[
stage(10, 7, 64, 32, 8, false),
stage(11, 8, 32, 48, 8, false),
],
512,
&mut scratch,
)
.expect("Fix: smaller incompatible CUDA stages should reuse duplicate-detection scratch");
assert_eq!(second.launch_count, 2);
assert!(scratch.id_capacity() >= id_capacity);
}
fn stage(
id: u32,
layout_hash: u64,
input_bytes: u64,
output_bytes: u64,
scratch_bytes: u64,
requires_host_materialization: bool,
) -> CudaFusionStage {
CudaFusionStage {
id,
layout_hash,
input_bytes,
output_bytes,
scratch_bytes,
requires_host_materialization,
}
}
}