use vyre_foundation::ir::Program;
use crate::binding::Binding;
use crate::program_walks::{
dispatch_element_count, dispatch_param_words_into, infer_dispatch_grid_for_count,
};
use crate::validation::{validate_launch_geometry, LaunchGeometryLimits};
use crate::{BackendError, DispatchConfig};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct LaunchPlan {
pub element_count: u32,
pub workgroup: [u32; 3],
pub grid: [u32; 3],
pub param_words: Vec<u32>,
pub max_binding_alignment: usize,
}
impl LaunchPlan {
#[must_use]
pub fn new() -> Self {
Self {
element_count: 1,
workgroup: [1, 1, 1],
grid: [1, 1, 1],
param_words: Vec::new(),
max_binding_alignment: 1,
}
}
pub fn from_bindings(
program: &Program,
bindings: &[Binding],
config: &DispatchConfig,
limits: LaunchGeometryLimits,
) -> Result<Self, BackendError> {
let mut plan = Self::new();
plan.prepare_into(program, bindings, config, limits)?;
Ok(plan)
}
pub fn prepare_into(
&mut self,
program: &Program,
bindings: &[Binding],
config: &DispatchConfig,
limits: LaunchGeometryLimits,
) -> Result<(), BackendError> {
let workgroup = config
.workgroup_override
.unwrap_or(program.workgroup_size());
let element_count = launch_element_count(bindings, workgroup, config, limits)?;
let grid = match config.grid_override {
Some(grid) => grid,
None => {
if workgroup[1] != 1 || workgroup[2] != 1 {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: backend `{}` requires DispatchConfig::grid_override for non-1D workgroups. \
workgroup={:?} has no unambiguous default grid; set grid_override to the logical [x, y, z] you want.",
limits.backend, workgroup,
),
});
}
infer_dispatch_grid_for_count(element_count, workgroup)?
}
};
validate_launch_geometry(workgroup, grid, limits)?;
self.element_count = element_count;
self.workgroup = workgroup;
self.grid = grid;
self.max_binding_alignment = bindings
.iter()
.map(|binding| binding.preferred_alignment)
.max()
.unwrap_or(1);
dispatch_param_words_into(bindings, element_count, &mut self.param_words);
Ok(())
}
}
impl Default for LaunchPlan {
fn default() -> Self {
Self::new()
}
}
fn launch_element_count(
bindings: &[Binding],
workgroup: [u32; 3],
config: &DispatchConfig,
limits: LaunchGeometryLimits,
) -> Result<u32, BackendError> {
let inferred = dispatch_element_count(bindings);
let Some(grid) = config.grid_override else {
return Ok(inferred);
};
if workgroup.contains(&0) || grid.contains(&0) {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: {} grid_override and workgroup dimensions must all be non-zero.",
limits.backend
),
});
}
grid[0]
.checked_mul(workgroup[0])
.filter(|count| *count != 0)
.ok_or_else(|| BackendError::InvalidProgram {
fix: format!(
"Fix: {} grid_override.x * workgroup_size.x must fit in u32.",
limits.backend
),
})
}
#[must_use]
pub fn program_vsa_fingerprint(program: &Program) -> Vec<u32> {
program_vsa_fingerprint_words(program).to_vec()
}
#[must_use]
pub fn program_vsa_fingerprint_words(program: &Program) -> [u32; 8] {
let fingerprint = program.fingerprint();
let mut words = [0_u32; 8];
for (word, chunk) in words.iter_mut().zip(fingerprint.chunks_exact(4)) {
*word = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
words
}
#[cfg(test)]
mod tests {
use super::*;
use crate::binding::BindingRole;
use vyre_foundation::ir::Program;
#[test]
fn launch_plan_prepare_into_reuses_param_words() {
let program = Program::wrapped(vec![], [64, 1, 1], vec![]);
let bindings = vec![Binding {
name: std::sync::Arc::from("input"),
binding: 0,
buffer_index: 0,
role: BindingRole::Input,
element_size: 4,
preferred_alignment: 64,
element_count: 7,
static_byte_len: Some(28),
input_index: Some(0),
output_index: None,
}];
let limits = LaunchGeometryLimits {
backend: "test",
max_threads_per_block: 1024,
max_block_dim: [1024, 1024, 64],
max_grid_dim: [u32::MAX, u32::MAX, u32::MAX],
};
let mut plan = LaunchPlan {
param_words: Vec::with_capacity(8),
..LaunchPlan::new()
};
let ptr = plan.param_words.as_ptr();
plan.prepare_into(&program, &bindings, &DispatchConfig::default(), limits)
.unwrap();
assert_eq!(plan.element_count, 7);
assert_eq!(plan.grid, [1, 1, 1]);
assert_eq!(plan.param_words, vec![7, 7]);
assert_eq!(plan.max_binding_alignment, 64);
assert_eq!(plan.param_words.as_ptr(), ptr);
}
}