use crate::egraph_device_image::CudaEGraphDeviceKernelView;
use crate::numeric::CUDA_NUMERIC;
use vyre_foundation::optimizer::eqsat_gpu::GpuEGraphDeviceImage;
use super::{
CudaEGraphKernelLaunchConfig, CudaEGraphKernelPass, CudaEGraphKernelPlanError,
CudaEGraphKernelWave, CudaEGraphSignatureBucketPlan, CudaEGraphSignaturePairWave,
};
pub fn cuda_egraph_signature_pair_rows(
plan: &CudaEGraphSignatureBucketPlan,
bucket_index: u32,
pair_ordinal: u64,
) -> Result<(u32, u32), CudaEGraphKernelPlanError> {
let Some(bucket) = plan.buckets.get(bucket_index as usize) else {
return Err(CudaEGraphKernelPlanError::SignaturePairOrdinalOutOfBounds {
bucket_index,
pair_ordinal,
candidate_pair_count: 0,
});
};
if pair_ordinal >= bucket.candidate_pair_count {
return Err(CudaEGraphKernelPlanError::SignaturePairOrdinalOutOfBounds {
bucket_index,
pair_ordinal,
candidate_pair_count: bucket.candidate_pair_count,
});
}
let row_count = u64::from(bucket.row_count);
let mut lo = 0_u64;
let mut hi = row_count - 1;
while lo < hi {
let mid = lo + ((hi - lo) / 2);
let next_start = signature_pairs_before_row(mid + 1, row_count)?;
if next_start <= pair_ordinal {
lo = mid + 1;
} else {
hi = mid;
}
}
let local_left = lo;
let row_pair_base = signature_pairs_before_row(local_left, row_count)?;
let local_right = local_left
.checked_add(1)
.and_then(|value| value.checked_add(pair_ordinal - row_pair_base))
.ok_or(CudaEGraphKernelPlanError::CountOverflow {
field: "signature pair local right row",
})?;
let base = bucket.first_bucket_row as usize;
let bucket_end = base.checked_add(bucket.row_count as usize).ok_or(
CudaEGraphKernelPlanError::CountOverflow {
field: "signature bucket row range end",
},
)?;
if bucket_end > plan.bucket_rows.len() {
return Err(CudaEGraphKernelPlanError::SignatureBucketRowsOutOfBounds {
bucket_index,
first_bucket_row: base,
row_count: bucket.row_count as usize,
bucket_rows_len: plan.bucket_rows.len(),
});
}
let left = plan.bucket_rows[base + local_left as usize];
let right = plan.bucket_rows[base + local_right as usize];
Ok((left, right))
}
pub(super) fn validate_image_view_matches(
image: &GpuEGraphDeviceImage,
view: CudaEGraphDeviceKernelView,
) -> Result<(), CudaEGraphKernelPlanError> {
if image.layout().row_count() != view.row_count() {
return Err(CudaEGraphKernelPlanError::ImageViewMismatch {
field: "row count",
image: image.layout().row_count(),
view: view.row_count(),
});
}
if image.layout().child_count() != view.child_count() {
return Err(CudaEGraphKernelPlanError::ImageViewMismatch {
field: "child count",
image: image.layout().child_count(),
view: view.child_count(),
});
}
if image.layout().eclass_group_count() != view.eclass_group_count() {
return Err(CudaEGraphKernelPlanError::ImageViewMismatch {
field: "eclass group count",
image: image.layout().eclass_group_count(),
view: view.eclass_group_count(),
});
}
Ok(())
}
pub(super) fn packed_rows_structurally_equal(
image: &GpuEGraphDeviceImage,
left_row: u32,
right_row: u32,
) -> Result<bool, CudaEGraphKernelPlanError> {
let left = left_row as usize;
let right = right_row as usize;
let row_count = image.layout().row_count();
if left >= row_count {
return Err(CudaEGraphKernelPlanError::ImageColumnOutOfBounds {
column: "rows",
row: left_row,
start: left,
end: left.saturating_add(1),
len: row_count,
});
}
if right >= row_count {
return Err(CudaEGraphKernelPlanError::ImageColumnOutOfBounds {
column: "rows",
row: right_row,
start: right,
end: right.saturating_add(1),
len: row_count,
});
}
if image.row_signatures()[left] != image.row_signatures()[right] {
return Ok(false);
}
if image.row_language_op_ids()[left] != image.row_language_op_ids()[right] {
return Ok(false);
}
if image.row_children_lens()[left] != image.row_children_lens()[right] {
return Ok(false);
}
let left_children = packed_row_children(image, left_row)?;
let right_children = packed_row_children(image, right_row)?;
Ok(left_children == right_children)
}
pub(super) fn packed_row_children(
image: &GpuEGraphDeviceImage,
row: u32,
) -> Result<&[u32], CudaEGraphKernelPlanError> {
let row_index = row as usize;
let start = image.row_children_offsets()[row_index] as usize;
let len = image.row_children_lens()[row_index] as usize;
let end = start
.checked_add(len)
.ok_or(CudaEGraphKernelPlanError::CountOverflow {
field: "packed row child span end",
})?;
let children = image.children();
if end > children.len() {
return Err(CudaEGraphKernelPlanError::ImageColumnOutOfBounds {
column: "children",
row,
start,
end,
len: children.len(),
});
}
Ok(&children[start..end])
}
pub(super) fn append_pass_waves(
waves: &mut Vec<CudaEGraphKernelWave>,
total_items: &mut u64,
total_blocks: &mut u64,
pass: CudaEGraphKernelPass,
item_count: u64,
config: CudaEGraphKernelLaunchConfig,
) -> Result<(), CudaEGraphKernelPlanError> {
if item_count == 0 {
return Ok(());
}
let items_per_wave = u64::from(config.threads_per_block)
.checked_mul(u64::from(config.max_blocks_per_launch))
.ok_or(CudaEGraphKernelPlanError::CountOverflow {
field: "items per launch wave",
})?;
let mut first_item = 0_u64;
while first_item < item_count {
let remaining = item_count - first_item;
let wave_items = remaining.min(items_per_wave);
let blocks = ceil_div_u64(wave_items, u64::from(config.threads_per_block))?;
let blocks =
u32::try_from(blocks).map_err(|_| CudaEGraphKernelPlanError::CountOverflow {
field: "blocks per launch wave",
})?;
waves.push(CudaEGraphKernelWave {
pass,
first_item,
item_count: wave_items,
blocks,
threads_per_block: config.threads_per_block,
});
*total_items = total_items.checked_add(wave_items).ok_or(
CudaEGraphKernelPlanError::CountOverflow {
field: "total logical items",
},
)?;
*total_blocks = total_blocks.checked_add(u64::from(blocks)).ok_or(
CudaEGraphKernelPlanError::CountOverflow {
field: "total blocks",
},
)?;
first_item =
first_item
.checked_add(wave_items)
.ok_or(CudaEGraphKernelPlanError::CountOverflow {
field: "next wave first item",
})?;
}
Ok(())
}
pub(super) fn append_signature_pair_waves(
pair_waves: &mut Vec<CudaEGraphSignaturePairWave>,
total_blocks: &mut u64,
bucket_index: u32,
pair_count: u64,
config: CudaEGraphKernelLaunchConfig,
) -> Result<(), CudaEGraphKernelPlanError> {
let items_per_wave = u64::from(config.threads_per_block)
.checked_mul(u64::from(config.max_blocks_per_launch))
.ok_or(CudaEGraphKernelPlanError::CountOverflow {
field: "items per signature pair launch wave",
})?;
let mut first_pair = 0_u64;
while first_pair < pair_count {
let remaining = pair_count - first_pair;
let wave_pairs = remaining.min(items_per_wave);
let blocks = ceil_div_u64(wave_pairs, u64::from(config.threads_per_block))?;
let blocks =
u32::try_from(blocks).map_err(|_| CudaEGraphKernelPlanError::CountOverflow {
field: "blocks per signature pair launch wave",
})?;
pair_waves.push(CudaEGraphSignaturePairWave {
bucket_index,
first_pair,
pair_count: wave_pairs,
blocks,
threads_per_block: config.threads_per_block,
});
*total_blocks = total_blocks.checked_add(u64::from(blocks)).ok_or(
CudaEGraphKernelPlanError::CountOverflow {
field: "signature pair total blocks",
},
)?;
first_pair =
first_pair
.checked_add(wave_pairs)
.ok_or(CudaEGraphKernelPlanError::CountOverflow {
field: "next signature pair first item",
})?;
}
Ok(())
}
pub(super) fn wave_count_for(
item_count: u64,
config: CudaEGraphKernelLaunchConfig,
) -> Result<u64, CudaEGraphKernelPlanError> {
if item_count == 0 {
return Ok(0);
}
let items_per_wave = u64::from(config.threads_per_block)
.checked_mul(u64::from(config.max_blocks_per_launch))
.ok_or(CudaEGraphKernelPlanError::CountOverflow {
field: "items per launch wave",
})?;
ceil_div_u64(item_count, items_per_wave)
}
pub(super) fn ceil_div_u64(numerator: u64, denominator: u64) -> Result<u64, CudaEGraphKernelPlanError> {
if denominator == 0 {
return Err(CudaEGraphKernelPlanError::CountOverflow {
field: "ceil division denominator",
});
}
if numerator == 0 {
return Ok(0);
}
numerator
.checked_add(denominator - 1)
.ok_or(CudaEGraphKernelPlanError::CountOverflow {
field: "ceil division numerator",
})
.map(|value| value / denominator)
}
pub(super) fn unordered_pair_count(item_count: u64) -> Result<u64, CudaEGraphKernelPlanError> {
item_count
.checked_mul(item_count.saturating_sub(1))
.and_then(|count| count.checked_div(2))
.ok_or(CudaEGraphKernelPlanError::CountOverflow {
field: "unordered pair count",
})
}
pub(super) fn signature_pairs_before_row(
local_row: u64,
row_count: u64,
) -> Result<u64, CudaEGraphKernelPlanError> {
local_row
.checked_mul(
row_count
.checked_mul(2)
.and_then(|value| value.checked_sub(local_row))
.and_then(|value| value.checked_sub(1))
.ok_or(CudaEGraphKernelPlanError::CountOverflow {
field: "signature pair row width",
})?,
)
.and_then(|value| value.checked_div(2))
.ok_or(CudaEGraphKernelPlanError::CountOverflow {
field: "signature pairs before row",
})
}
pub(crate) fn usize_to_u64(value: usize, field: &'static str) -> Result<u64, CudaEGraphKernelPlanError> {
CUDA_NUMERIC
.usize_to_u64(value, field)
.map_err(|_| CudaEGraphKernelPlanError::CountOverflow { field })
}