use vyre_foundation::ir::Program;
use crate::backend::{BackendError, DispatchConfig};
use crate::binding::BindingPlan;
use crate::program_walks::dispatch_element_count_for_program;
pub fn infer_dispatch_grid(
program: &Program,
inputs: &[Vec<u8>],
config: &DispatchConfig,
) -> Result<[u32; 3], BackendError> {
if let Some(grid) = config.grid_override {
return Ok(grid);
}
let plan = BindingPlan::from_program(program, inputs)?;
let element_count = dispatch_element_count_for_program(program, &plan.bindings);
infer_dispatch_grid_for_count(
element_count,
config
.workgroup_override
.unwrap_or(program.workgroup_size()),
)
}
pub fn auto_grid(
program: &Program,
backend: &dyn crate::backend::VyreBackend,
) -> Result<[u32; 3], BackendError> {
crate::validation::validate_program_for_backend(backend, program, &DispatchConfig::default())?;
let plan = BindingPlan::build(program)?;
let element_count = dispatch_element_count_for_program(program, &plan.bindings);
infer_dispatch_grid_for_count(element_count, program.workgroup_size())
}
pub fn infer_dispatch_grid_for_count(
element_count: u32,
workgroup: [u32; 3],
) -> Result<[u32; 3], BackendError> {
if workgroup.contains(&0) {
return Err(BackendError::new(
"workgroup dimensions must be non-zero. Fix: set Program::workgroup_size and DispatchConfig::workgroup_override to positive values.",
));
}
let count = u64::from(element_count.max(1));
if workgroup[1] == 1 && workgroup[2] == 1 {
return Ok([ceil_div_u64(count, u64::from(workgroup[0]))?, 1, 1]);
}
if workgroup[2] == 1 {
let side = ceil_sqrt_u64(count);
return Ok([
ceil_div_u64(side, u64::from(workgroup[0]))?,
ceil_div_u64(
u64::from(ceil_div_u64(count, side)?),
u64::from(workgroup[1]),
)?,
1,
]);
}
let side = ceil_cuberoot_u64(count);
let xy = side.checked_mul(side).ok_or_else(|| {
BackendError::new(format!(
"3D dispatch-grid side {side} overflows u64 square during shape planning. Fix: split the Program before GPU launch planning."
))
})?;
Ok([
ceil_div_u64(side, u64::from(workgroup[0]))?,
ceil_div_u64(side, u64::from(workgroup[1]))?,
ceil_div_u64(u64::from(ceil_div_u64(count, xy)?), u64::from(workgroup[2]))?,
])
}
fn ceil_div_u64(value: u64, divisor: u64) -> Result<u32, BackendError> {
let divided = value.div_ceil(divisor).max(1);
u32::try_from(divided).map_err(|_| {
BackendError::new(
"inferred dispatch grid dimension overflowed u32. Fix: split the Program into smaller dispatches.",
)
})
}
fn ceil_sqrt_u64(value: u64) -> u64 {
if value <= 1 {
return 1;
}
let mut lo = 1_u64;
let mut hi = 1_u64 << 32;
while lo < hi {
let mid = lo + ((hi - lo) / 2);
match mid.checked_mul(mid) {
Some(square) if square < value => lo = mid + 1,
_ => hi = mid,
}
}
lo
}
fn ceil_cuberoot_u64(value: u64) -> u64 {
if value <= 1 {
return 1;
}
let mut lo = 1_u64;
let mut hi = 1_u64 << 22;
while lo < hi {
let mid = lo + ((hi - lo) / 2);
match checked_cube_u64(mid) {
Some(cube) if cube < value => lo = mid + 1,
_ => hi = mid,
}
}
lo
}
fn checked_cube_u64(value: u64) -> Option<u64> {
value.checked_mul(value)?.checked_mul(value)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TailMaskPolicy {
pub original_count: u32,
pub rounded_count: u32,
pub tail_lanes: u32,
}
impl TailMaskPolicy {
#[must_use]
pub fn is_aligned(&self) -> bool {
self.tail_lanes == 0
}
}
#[must_use]
pub fn coerce_to_pow2_with_tail_mask(element_count: u32) -> TailMaskPolicy {
match try_coerce_to_pow2_with_tail_mask(element_count) {
Ok(policy) => policy,
Err(_error) => TailMaskPolicy {
original_count: element_count,
rounded_count: element_count,
tail_lanes: 0,
},
}
}
pub fn try_coerce_to_pow2_with_tail_mask(
element_count: u32,
) -> Result<TailMaskPolicy, BackendError> {
if element_count == 0 {
return Ok(TailMaskPolicy {
original_count: 0,
rounded_count: 0,
tail_lanes: 0,
});
}
let rounded = next_pow2_u32_checked(element_count)?;
Ok(TailMaskPolicy {
original_count: element_count,
rounded_count: rounded,
tail_lanes: rounded - element_count,
})
}
fn next_pow2_u32_checked(value: u32) -> Result<u32, BackendError> {
if value.is_power_of_two() {
return Ok(value);
}
if value > (1u32 << 31) {
return Err(BackendError::new(format!(
"cannot round element_count={value} up to a power-of-two u32 grid without overflow. Fix: split the workload before grid-shape planning; do not silently saturate or fall back to an under-dispatching shape."
)));
}
Ok(value.next_power_of_two())
}
#[cfg(test)]
mod n6_tests {
use super::*;
#[test]
fn already_pow2_is_identity_with_no_tail() {
let p = coerce_to_pow2_with_tail_mask(64);
assert_eq!(p.original_count, 64);
assert_eq!(p.rounded_count, 64);
assert_eq!(p.tail_lanes, 0);
assert!(p.is_aligned());
}
#[test]
fn non_pow2_rounds_up_and_reports_tail() {
let p = coerce_to_pow2_with_tail_mask(100);
assert_eq!(p.original_count, 100);
assert_eq!(p.rounded_count, 128);
assert_eq!(p.tail_lanes, 28);
assert!(!p.is_aligned());
}
#[test]
fn one_is_pow2_no_tail() {
let p = coerce_to_pow2_with_tail_mask(1);
assert_eq!(p.rounded_count, 1);
assert_eq!(p.tail_lanes, 0);
}
#[test]
fn zero_passes_through_with_no_tail() {
let p = coerce_to_pow2_with_tail_mask(0);
assert_eq!(p.rounded_count, 0);
assert_eq!(p.tail_lanes, 0);
assert!(p.is_aligned());
}
#[test]
fn large_value_below_2_31_rounds_normally() {
let p = coerce_to_pow2_with_tail_mask(1_000_000_000);
assert_eq!(p.rounded_count, 1u32 << 30);
assert_eq!(p.tail_lanes, (1u32 << 30) - 1_000_000_000);
}
#[test]
fn value_above_2_31_errors_instead_of_saturating() {
let error = try_coerce_to_pow2_with_tail_mask(u32::MAX)
.expect_err("oversized power-of-two coercion must fail loudly");
let message = error.to_string();
assert!(
message.contains("Fix:"),
"oversized grid-shape error must be actionable"
);
}
#[test]
fn root_helpers_are_exact_at_large_boundaries() {
assert_eq!(ceil_sqrt_u64((1_u64 << 32) - 1), 65_536);
assert_eq!(ceil_sqrt_u64(1_u64 << 32), 65_536);
assert_eq!(ceil_cuberoot_u64(2_642_245_u64.pow(3)), 2_642_245);
assert_eq!(ceil_cuberoot_u64(2_642_245_u64.pow(3) - 1), 2_642_245);
}
#[test]
fn dispatch_grid_planning_uses_integer_roots_and_typed_errors() {
let source = include_str!("grid.rs");
let production = source
.split("#[cfg(test)]")
.next()
.expect("Fix: dispatch-grid production source must precede tests");
assert!(
!production.contains(" as f64")
&& !production.contains(".sqrt()")
&& !production.contains(".cbrt()"),
"Fix: dispatch-grid inference must use deterministic integer root arithmetic."
);
assert!(
production.contains("try_coerce_to_pow2_with_tail_mask")
&& !production.contains("panic!("),
"Fix: dispatch-grid planning should expose typed errors instead of production panics."
);
}
}