use vyre_foundation::ir::Program;
use crate::backend::{BackendError, DispatchConfig};
use crate::binding::BindingPlan;
use crate::program_walks::dispatch_element_count;
pub fn infer_dispatch_grid(
program: &Program,
inputs: &[Vec<u8>],
config: &DispatchConfig,
) -> Result<[u32; 3], BackendError> {
if let Some(grid) = config.grid_override {
return Ok(grid);
}
let plan = BindingPlan::from_program(program, inputs)?;
let element_count = dispatch_element_count(&plan.bindings);
infer_dispatch_grid_for_count(
element_count,
config
.workgroup_override
.unwrap_or(program.workgroup_size()),
)
}
pub fn auto_grid(
program: &Program,
backend: &dyn crate::backend::VyreBackend,
) -> Result<[u32; 3], BackendError> {
crate::validation::validate_program_for_backend(backend, program, &DispatchConfig::default())?;
let plan = BindingPlan::build(program)?;
let element_count = dispatch_element_count(&plan.bindings);
infer_dispatch_grid_for_count(element_count, program.workgroup_size())
}
pub fn infer_dispatch_grid_for_count(
element_count: u32,
workgroup: [u32; 3],
) -> Result<[u32; 3], BackendError> {
if workgroup.contains(&0) {
return Err(BackendError::new(
"workgroup dimensions must be non-zero. Fix: set Program::workgroup_size and DispatchConfig::workgroup_override to positive values.",
));
}
let count = u64::from(element_count.max(1));
if workgroup[1] == 1 && workgroup[2] == 1 {
return Ok([ceil_div_u64(count, u64::from(workgroup[0]))?, 1, 1]);
}
if workgroup[2] == 1 {
let side = ceil_sqrt_u64(count);
return Ok([
ceil_div_u64(side, u64::from(workgroup[0]))?,
ceil_div_u64(
u64::from(ceil_div_u64(count, side)?),
u64::from(workgroup[1]),
)?,
1,
]);
}
let side = ceil_cuberoot_u64(count);
let xy = side.saturating_mul(side).max(1);
Ok([
ceil_div_u64(side, u64::from(workgroup[0]))?,
ceil_div_u64(side, u64::from(workgroup[1]))?,
ceil_div_u64(u64::from(ceil_div_u64(count, xy)?), u64::from(workgroup[2]))?,
])
}
fn ceil_div_u64(value: u64, divisor: u64) -> Result<u32, BackendError> {
let divided = value.div_ceil(divisor).max(1);
u32::try_from(divided).map_err(|_| {
BackendError::new(
"inferred dispatch grid dimension overflowed u32. Fix: split the Program into smaller dispatches.",
)
})
}
fn ceil_sqrt_u64(value: u64) -> u64 {
let mut root = (value as f64).sqrt() as u64;
while root.saturating_mul(root) < value {
root += 1;
}
while root > 0 && (root - 1).saturating_mul(root - 1) >= value {
root -= 1;
}
root.max(1)
}
fn ceil_cuberoot_u64(value: u64) -> u64 {
let mut root = (value as f64).cbrt() as u64;
while root.saturating_mul(root).saturating_mul(root) < value {
root += 1;
}
while root > 0 && (root - 1).saturating_mul(root - 1).saturating_mul(root - 1) >= value {
root -= 1;
}
root.max(1)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TailMaskPolicy {
pub original_count: u32,
pub rounded_count: u32,
pub tail_lanes: u32,
}
impl TailMaskPolicy {
#[must_use]
pub fn is_aligned(&self) -> bool {
self.tail_lanes == 0
}
}
#[must_use]
pub fn coerce_to_pow2_with_tail_mask(element_count: u32) -> TailMaskPolicy {
if element_count == 0 {
return TailMaskPolicy {
original_count: 0,
rounded_count: 0,
tail_lanes: 0,
};
}
let rounded = next_pow2_u32_saturating(element_count);
TailMaskPolicy {
original_count: element_count,
rounded_count: rounded,
tail_lanes: rounded.saturating_sub(element_count),
}
}
fn next_pow2_u32_saturating(value: u32) -> u32 {
if value.is_power_of_two() {
return value;
}
if value > (1u32 << 31) {
return 1u32 << 31;
}
value.next_power_of_two()
}
#[cfg(test)]
mod n6_tests {
use super::*;
#[test]
fn already_pow2_is_identity_with_no_tail() {
let p = coerce_to_pow2_with_tail_mask(64);
assert_eq!(p.original_count, 64);
assert_eq!(p.rounded_count, 64);
assert_eq!(p.tail_lanes, 0);
assert!(p.is_aligned());
}
#[test]
fn non_pow2_rounds_up_and_reports_tail() {
let p = coerce_to_pow2_with_tail_mask(100);
assert_eq!(p.original_count, 100);
assert_eq!(p.rounded_count, 128);
assert_eq!(p.tail_lanes, 28);
assert!(!p.is_aligned());
}
#[test]
fn one_is_pow2_no_tail() {
let p = coerce_to_pow2_with_tail_mask(1);
assert_eq!(p.rounded_count, 1);
assert_eq!(p.tail_lanes, 0);
}
#[test]
fn zero_passes_through_with_no_tail() {
let p = coerce_to_pow2_with_tail_mask(0);
assert_eq!(p.rounded_count, 0);
assert_eq!(p.tail_lanes, 0);
assert!(p.is_aligned());
}
#[test]
fn large_value_below_2_31_rounds_normally() {
let p = coerce_to_pow2_with_tail_mask(1_000_000_000);
assert_eq!(p.rounded_count, 1u32 << 30);
assert_eq!(p.tail_lanes, (1u32 << 30) - 1_000_000_000);
}
#[test]
fn value_above_2_31_saturates_at_2_31() {
let p = coerce_to_pow2_with_tail_mask(u32::MAX);
assert_eq!(p.rounded_count, 1u32 << 31);
}
}