use smallvec::SmallVec;
use crate::backend::{CudaBackend, CudaResidentBuffer};
use crate::egraph_device_image::{
plan_cuda_egraph_device_upload_from_image_ref, CudaEGraphDeviceKernelView,
};
use crate::egraph_readback::{
cleanup_egraph_kernel_handles, decode_unique_equivalence_pairs, device_ptr_at,
download_structural_equivalence_output_ranges, read_u64_le,
upload_structural_equivalence_scratch,
};
use crate::CudaResidentEGraphDeviceImage;
use vyre_driver::BackendError;
use vyre_driver::LaunchPlan;
use vyre_foundation::optimizer::eqsat_gpu::GpuEGraphDeviceImage;
use super::{
args::EGraphStructuralKernelArgs,
plan_cuda_egraph_structural_equivalence_launch_artifact_from_plan, plan_cuda_egraph_signature_buckets,
CudaEGraphKernelLaunchConfig, CudaEGraphStructuralEquivalenceKernelPtx,
CudaEGraphStructuralEquivalenceKernelResult, CudaEGraphStructuralEquivalenceLaunchArtifact,
};
use super::ptx::cuda_egraph_structural_equivalence_kernel_ptx;
impl CudaBackend {
pub fn warm_egraph_structural_equivalence_kernel(
&self,
) -> Result<CudaEGraphStructuralEquivalenceKernelPtx, BackendError> {
self.warm_egraph_structural_equivalence_kernel_with_key()
.map(|(kernel, _)| kernel)
}
fn warm_egraph_structural_equivalence_kernel_with_key(
&self,
) -> Result<
(
CudaEGraphStructuralEquivalenceKernelPtx,
cudarc::driver::sys::CUfunction,
),
BackendError,
> {
let kernel = cuda_egraph_structural_equivalence_kernel_ptx(self.ptx_target_sm()).map_err(
|error| BackendError::InvalidProgram {
fix: error.to_string(),
},
)?;
let module_key = self.module_cache_key_for_raw_ptx_artifact(&kernel.source)?;
let function = self.module_for_ptx_with_key(&kernel.source, module_key)?;
if function.is_null() {
return Err(BackendError::InvalidProgram {
fix: "Fix: CUDA e-graph structural-equivalence kernel loaded but resolved a null `main` function. Inspect generated PTX entry metadata before launch.".to_string(),
});
}
Ok((kernel, function))
}
pub fn run_egraph_structural_equivalence_kernel(
&self,
image: CudaResidentEGraphDeviceImage,
artifact: &CudaEGraphStructuralEquivalenceLaunchArtifact,
) -> Result<CudaEGraphStructuralEquivalenceKernelResult, BackendError> {
let view = self.egraph_device_kernel_view(image)?;
let (_kernel, func) = self.warm_egraph_structural_equivalence_kernel_with_key()?;
let mut handles = SmallVec::<[CudaResidentBuffer; 4]>::new();
let result =
self.run_egraph_structural_equivalence_kernel_inner(view, artifact, func, &mut handles);
let cleanup = cleanup_egraph_kernel_handles(self, &handles);
match (result, cleanup) {
(Ok(result), Ok(())) => Ok(result),
(Err(error), Ok(())) => Err(error),
(Ok(_), Err(error)) | (Err(_), Err(error)) => Err(error),
}
}
pub fn discover_egraph_structural_equivalences(
&self,
image: GpuEGraphDeviceImage,
config: CudaEGraphKernelLaunchConfig,
) -> Result<CudaEGraphStructuralEquivalenceKernelResult, BackendError> {
let upload_plan =
plan_cuda_egraph_device_upload_from_image_ref(&image).map_err(|error| {
BackendError::InvalidProgram {
fix: error.to_string(),
}
})?;
let resident = self.upload_egraph_device_image_borrowed_plan(upload_plan)?;
let result = (|| {
let view = self.egraph_device_kernel_view(resident)?;
let signature_plan =
plan_cuda_egraph_signature_buckets(&image, view, config).map_err(|error| {
BackendError::InvalidProgram {
fix: error.to_string(),
}
})?;
let artifact =
plan_cuda_egraph_structural_equivalence_launch_artifact_from_plan(signature_plan)
.map_err(|error| BackendError::InvalidProgram {
fix: error.to_string(),
})?;
self.run_egraph_structural_equivalence_kernel(resident, &artifact)
})();
let cleanup = self.free_resident(resident.handle());
match (result, cleanup) {
(Ok(result), Ok(())) => Ok(result),
(Err(error), Ok(())) => Err(error),
(Ok(_), Err(error)) | (Err(_), Err(error)) => Err(error),
}
}
fn run_egraph_structural_equivalence_kernel_inner(
&self,
view: CudaEGraphDeviceKernelView,
artifact: &CudaEGraphStructuralEquivalenceLaunchArtifact,
func: cudarc::driver::sys::CUfunction,
handles: &mut SmallVec<[CudaResidentBuffer; 4]>,
) -> Result<CudaEGraphStructuralEquivalenceKernelResult, BackendError> {
let scratch = upload_structural_equivalence_scratch(self, artifact)?;
handles.push(scratch.handle);
let scratch_base_ptr = self.resident_device_ptr(scratch.handle)?;
let bucket_words_ptr = device_ptr_at(
scratch_base_ptr,
scratch.bucket_words_offset,
"bucket words",
)?;
let bucket_rows_ptr =
device_ptr_at(scratch_base_ptr, scratch.bucket_rows_offset, "bucket rows")?;
let output_pairs_ptr = device_ptr_at(
scratch_base_ptr,
scratch.output_pairs_offset,
"output pairs",
)?;
let output_count_ptr = device_ptr_at(
scratch_base_ptr,
scratch.output_count_offset,
"output count",
)?;
let stream = crate::stream::CudaStream::non_blocking()?;
let mut kernel_args = SmallVec::<[*mut std::ffi::c_void; 8]>::new();
for wave in &artifact.pair_waves {
let launch = LaunchPlan {
element_count: u32::try_from(wave.pair_count).map_err(|error| {
BackendError::InvalidProgram {
fix: format!(
"Fix: CUDA e-graph structural-equivalence wave pair count {} does not fit u32 launch accounting: {error}. Split the wave before launch.",
wave.pair_count
),
}
})?,
workgroup: [wave.threads_per_block, 1, 1],
grid: [wave.blocks, 1, 1],
param_words: Vec::new(),
max_binding_alignment: std::mem::size_of::<u64>(),
};
let mut args = EGraphStructuralKernelArgs {
row_eclass_ids_ptr: view.row_eclass_ids_ptr(),
row_language_op_ids_ptr: view.row_language_op_ids_ptr(),
row_children_offsets_ptr: view.row_children_offsets_ptr(),
row_children_lens_ptr: view.row_children_lens_ptr(),
row_signatures_ptr: view.row_signatures_ptr(),
children_ptr: view.children_ptr(),
bucket_words_ptr,
bucket_rows_ptr,
output_pairs_ptr,
output_count_ptr,
bucket_index: wave.bucket_index,
first_pair: wave.first_pair,
pair_count: wave.pair_count,
};
args.write_kernel_args_into(&mut kernel_args)?;
self.launch_resolved_function(
func,
&mut kernel_args,
&launch,
stream.raw(),
false,
false,
)?;
}
stream.synchronize()?;
let (count_bytes, pair_bytes) =
download_structural_equivalence_output_ranges(self, &scratch)?;
let count_bytes = count_bytes
.get(..8)
.ok_or_else(|| BackendError::InvalidProgram {
fix: "Fix: CUDA e-graph fused scratch readback did not contain the 8-byte structural equivalence counter.".to_string(),
})?;
let device_reported_count = read_u64_le(count_bytes, "structural equivalence count")?;
let planned_capacity = artifact.output.max_equivalences;
let capped_count = device_reported_count.min(planned_capacity);
let pair_bytes = pair_bytes
.get(..scratch.output_pairs_bytes)
.ok_or_else(|| BackendError::InvalidProgram {
fix: "Fix: CUDA e-graph fused scratch readback did not contain the planned structural equivalence output-pair range.".to_string(),
})?;
let (emitted_pair_count, unique) =
decode_unique_equivalence_pairs(&pair_bytes, capped_count)?;
Ok(CudaEGraphStructuralEquivalenceKernelResult {
emitted_pair_count,
unique,
device_reported_count,
overflowed_output_capacity: device_reported_count > planned_capacity,
})
}
}