use crate::dispatch_buffers::{
checked_product_count, decode_u32_output_exact, ensure_input_slots, write_u32_slice_le_bytes,
write_zero_bytes,
};
use crate::optimizer::dispatcher::{DispatchError, OptimizerDispatcher};
use vyre_primitives::parsing::planar_rewrite::planar_rewrite_schedule;
#[cfg(test)]
use vyre_primitives::parsing::planar_rewrite::reference_planar_rewrite_schedule;
#[derive(Debug, Default)]
pub struct PlanarRewriteScheduleGpuScratch {
inputs: Vec<Vec<u8>>,
}
#[must_use]
#[cfg(test)]
pub fn schedule_disjoint_rewrites(candidates: &[u32], h: u32, w: u32, k: u32) -> Vec<u32> {
use crate::observability::{bump, planar_rewrite_pass_scheduler_calls};
bump(&planar_rewrite_pass_scheduler_calls);
assert!(k > 0, "Fix: rewrite footprint k must be > 0.");
reference_planar_rewrite_schedule(candidates, h, w, k)
}
pub fn schedule_disjoint_rewrites_via(
dispatcher: &impl OptimizerDispatcher,
candidates: &[u32],
h: u32,
w: u32,
k: u32,
) -> Result<Vec<u32>, DispatchError> {
let mut out = Vec::new();
schedule_disjoint_rewrites_via_into(dispatcher, candidates, h, w, k, &mut out)?;
Ok(out)
}
pub fn schedule_disjoint_rewrites_via_into(
dispatcher: &impl OptimizerDispatcher,
candidates: &[u32],
h: u32,
w: u32,
k: u32,
out: &mut Vec<u32>,
) -> Result<(), DispatchError> {
let mut scratch = PlanarRewriteScheduleGpuScratch::default();
schedule_disjoint_rewrites_via_with_scratch_into(
dispatcher,
candidates,
h,
w,
k,
&mut scratch,
out,
)
}
pub fn schedule_disjoint_rewrites_via_with_scratch_into(
dispatcher: &impl OptimizerDispatcher,
candidates: &[u32],
h: u32,
w: u32,
k: u32,
scratch: &mut PlanarRewriteScheduleGpuScratch,
out: &mut Vec<u32>,
) -> Result<(), DispatchError> {
use crate::observability::{bump, planar_rewrite_pass_scheduler_calls};
bump(&planar_rewrite_pass_scheduler_calls);
if k == 0 {
return Err(DispatchError::BadInputs(
"Fix: schedule_disjoint_rewrites_via requires k > 0.".to_string(),
));
}
let cells = checked_product_count(h, w, "h", "w", "schedule_disjoint_rewrites_via")?;
if candidates.len() != cells {
return Err(DispatchError::BadInputs(format!(
"Fix: schedule_disjoint_rewrites_via requires candidates.len() == h*w, got len={}, h={h}, w={w}, h*w={cells}.",
candidates.len()
)));
}
let program = planar_rewrite_schedule("candidates", "chosen", h, w, k);
let output_bytes = cells
.checked_mul(std::mem::size_of::<u32>())
.ok_or_else(|| {
DispatchError::BadInputs(format!(
"Fix: schedule_disjoint_rewrites_via output byte count overflows usize for {cells} cells."
))
})?;
ensure_input_slots(&mut scratch.inputs, 2);
write_u32_slice_le_bytes(&mut scratch.inputs[0], candidates);
write_zero_bytes(&mut scratch.inputs[1], output_bytes);
let outputs = dispatcher.dispatch(&program, &scratch.inputs, Some([1, 1, 1]))?;
if outputs.is_empty() {
return Err(DispatchError::BackendError(format!(
"Fix: schedule_disjoint_rewrites_via expected at least one output buffer, got {}.",
outputs.len()
)));
}
decode_u32_output_exact(&outputs[0], cells, "schedule_disjoint_rewrites_via", out)
}
#[must_use]
pub fn count_scheduled(schedule: &[u32]) -> u32 {
schedule.iter().filter(|&&v| v != 0).count() as u32
}
#[must_use]
pub fn batch_reduction_ratio(candidate_count: u32, scheduled_count: u32) -> f64 {
if scheduled_count == 0 {
return 0.0;
}
candidate_count as f64 / scheduled_count as f64
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dispatch_buffers::u32_slice_to_le_bytes;
use vyre_foundation::ir::Program;
#[test]
fn empty_grid_yields_no_schedule() {
let candidates = vec![0u32; 16];
let schedule = schedule_disjoint_rewrites(&candidates, 4, 4, 2);
assert_eq!(count_scheduled(&schedule), 0);
}
#[test]
fn full_grid_yields_disjoint_subset() {
let candidates = vec![1u32; 16];
let schedule = schedule_disjoint_rewrites(&candidates, 4, 4, 2);
let count = count_scheduled(&schedule);
assert!(count >= 1, "at least one rewrite must be schedulable");
assert!(count <= 4, "at most 4 disjoint k=2 rewrites in a 4x4 grid");
}
#[test]
fn batch_reduction_well_defined() {
assert_eq!(batch_reduction_ratio(0, 0), 0.0);
let r = batch_reduction_ratio(100, 4);
assert!((r - 25.0).abs() < 1e-9);
}
#[test]
fn k_one_allows_every_candidate() {
let candidates = vec![1u32, 1, 1, 1];
let schedule = schedule_disjoint_rewrites(&candidates, 2, 2, 1);
assert_eq!(count_scheduled(&schedule), 4);
}
struct PlanarDispatcher;
impl OptimizerDispatcher for PlanarDispatcher {
fn dispatch(
&self,
_program: &Program,
inputs: &[Vec<u8>],
grid_override: Option<[u32; 3]>,
) -> Result<Vec<Vec<u8>>, DispatchError> {
assert_eq!(grid_override, Some([1, 1, 1]));
assert_eq!(inputs.len(), 2);
let candidates = crate::hardware::dispatch_buffers::read_u32s(&inputs[0]);
let n = integer_sqrt(candidates.len());
let chosen = reference_planar_rewrite_schedule(&candidates, n as u32, n as u32, 2);
Ok(vec![u32_slice_to_le_bytes(&chosen)])
}
}
#[test]
fn schedule_disjoint_rewrites_via_dispatches_primitive() {
let candidates = vec![1u32; 16];
let via = schedule_disjoint_rewrites_via(&PlanarDispatcher, &candidates, 4, 4, 2).unwrap();
let reference = schedule_disjoint_rewrites(&candidates, 4, 4, 2);
assert_eq!(via, reference);
}
#[test]
fn schedule_disjoint_rewrites_via_with_scratch_reuses_dispatch_and_output_storage() {
let candidates = vec![1u32; 16];
let mut scratch = PlanarRewriteScheduleGpuScratch::default();
let mut out = Vec::with_capacity(16);
schedule_disjoint_rewrites_via_with_scratch_into(
&PlanarDispatcher,
&candidates,
4,
4,
2,
&mut scratch,
&mut out,
)
.unwrap();
let input_capacities = scratch.inputs.iter().map(Vec::capacity).collect::<Vec<_>>();
let out_capacity = out.capacity();
schedule_disjoint_rewrites_via_with_scratch_into(
&PlanarDispatcher,
&candidates,
4,
4,
2,
&mut scratch,
&mut out,
)
.unwrap();
assert_eq!(
scratch.inputs.iter().map(Vec::capacity).collect::<Vec<_>>(),
input_capacities
);
assert_eq!(out.capacity(), out_capacity);
assert_eq!(out, schedule_disjoint_rewrites(&candidates, 4, 4, 2));
}
#[test]
fn schedule_disjoint_rewrites_via_rejects_bad_shape() {
let err =
schedule_disjoint_rewrites_via(&PlanarDispatcher, &[1, 0, 1], 2, 2, 2).unwrap_err();
assert!(matches!(err, DispatchError::BadInputs(_)));
}
#[test]
fn release_via_path_does_not_call_cpu_or_reference_helpers() {
let source = include_str!("planar_rewrite_pass_scheduler.rs");
let start = source
.find("pub fn schedule_disjoint_rewrites_via")
.expect("Fix: via path marker must exist");
let end = source
.find("\n/// Convenience: count")
.expect("Fix: convenience marker must exist");
let release_path = &source[start..end];
assert!(!release_path.contains("reference_planar_rewrite_schedule"));
assert!(!release_path.contains("reference_"));
}
fn integer_sqrt(n: usize) -> usize {
let mut root = 0usize;
while root * root < n {
root += 1;
}
root
}
}