use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::parsing::planar_rewrite_schedule";
#[must_use]
pub fn planar_rewrite_schedule(candidates: &str, chosen: &str, h: u32, w: u32, k: u32) -> Program {
if h == 0 || w == 0 {
return crate::invalid_output_program(
OP_ID,
chosen,
DataType::U32,
format!("Fix: planar_rewrite_schedule requires h > 0 and w > 0, got h={h}, w={w}."),
);
}
if k == 0 {
return crate::invalid_output_program(
OP_ID,
chosen,
DataType::U32,
format!("Fix: planar_rewrite_schedule requires k > 0, got {k}."),
);
}
let cells = h * w;
let t = Expr::InvocationId { axis: 0 };
let body = vec![Node::if_then(
Expr::eq(t.clone(), Expr::u32(0)),
vec![
Node::loop_for(
"init",
Expr::u32(0),
Expr::u32(cells),
vec![Node::store(chosen, Expr::var("init"), Expr::u32(0))],
),
Node::loop_for(
"r",
Expr::u32(0),
Expr::u32(h),
vec![Node::loop_for(
"c",
Expr::u32(0),
Expr::u32(w),
vec![
Node::let_bind(
"addr",
Expr::add(Expr::mul(Expr::var("r"), Expr::u32(w)), Expr::var("c")),
),
Node::if_then(
Expr::ne(Expr::load(candidates, Expr::var("addr")), Expr::u32(0)),
vec![
Node::let_bind("conflict", Expr::u32(0)),
Node::loop_for(
"di",
Expr::u32(0),
Expr::u32(k),
vec![Node::loop_for(
"dj",
Expr::u32(0),
Expr::u32(k),
vec![Node::if_then(
Expr::and(
Expr::ge(Expr::var("r"), Expr::var("di")),
Expr::ge(Expr::var("c"), Expr::var("dj")),
),
vec![Node::if_then(
Expr::ne(
Expr::load(
chosen,
Expr::add(
Expr::mul(
Expr::sub(
Expr::var("r"),
Expr::var("di"),
),
Expr::u32(w),
),
Expr::sub(
Expr::var("c"),
Expr::var("dj"),
),
),
),
Expr::u32(0),
),
vec![Node::assign("conflict", Expr::u32(1))],
)],
)],
)],
),
Node::if_then(
Expr::eq(Expr::var("conflict"), Expr::u32(0)),
vec![Node::store(chosen, Expr::var("addr"), Expr::u32(1))],
),
],
),
],
)],
),
],
)];
Program::wrapped(
vec![
BufferDecl::storage(candidates, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(cells),
BufferDecl::storage(chosen, 1, BufferAccess::ReadWrite, DataType::U32)
.with_count(cells),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn planar_rewrite_schedule_cpu(candidates: &[u32], h: u32, w: u32, k: u32) -> Vec<u32> {
let h = h as usize;
let w = w as usize;
let k = k as usize;
let mut chosen = vec![0u32; h * w];
for r in 0..h {
for c in 0..w {
let addr = r * w + c;
if candidates.get(addr).copied().unwrap_or(0) == 0 {
continue;
}
let mut conflict = false;
for di in 0..k {
for dj in 0..k {
if di > r || dj > c {
continue;
}
if chosen[(r - di) * w + (c - dj)] != 0 {
conflict = true;
break;
}
}
if conflict {
break;
}
}
if !conflict {
chosen[addr] = 1;
}
}
}
chosen
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_no_candidates_no_chosen() {
let cands = vec![0u32; 16];
let chosen = planar_rewrite_schedule_cpu(&cands, 4, 4, 2);
for v in chosen {
assert_eq!(v, 0);
}
}
#[test]
fn cpu_isolated_candidate_is_chosen() {
let mut cands = vec![0u32; 16];
cands[5] = 1; let chosen = planar_rewrite_schedule_cpu(&cands, 4, 4, 2);
assert_eq!(chosen[5], 1);
}
#[test]
fn cpu_overlapping_candidates_only_first_chosen() {
let mut cands = vec![0u32; 9];
cands[0] = 1;
cands[1] = 1;
let chosen = planar_rewrite_schedule_cpu(&cands, 3, 3, 2);
assert_eq!(chosen[0], 1);
assert_eq!(chosen[1], 0);
}
#[test]
fn cpu_widely_spaced_candidates_all_chosen() {
let mut cands = vec![0u32; 25];
cands[0] = 1; cands[4] = 1; cands[20] = 1; cands[24] = 1; let chosen = planar_rewrite_schedule_cpu(&cands, 5, 5, 2);
assert_eq!(chosen[0], 1);
assert_eq!(chosen[4], 1);
assert_eq!(chosen[20], 1);
assert_eq!(chosen[24], 1);
}
#[test]
fn cpu_short_candidate_buffer_treats_missing_cells_as_zero() {
let cands = vec![1u32];
let chosen = planar_rewrite_schedule_cpu(&cands, 2, 2, 1);
assert_eq!(chosen, vec![1, 0, 0, 0]);
}
#[test]
fn cpu_dense_candidates_alternate_chosen() {
let cands = vec![1u32; 16];
let chosen = planar_rewrite_schedule_cpu(&cands, 4, 4, 2);
let total: u32 = chosen.iter().sum();
assert!(total >= 4);
for r in 0..4 {
for c in 0..4 {
if chosen[r * 4 + c] == 0 {
continue;
}
for di in 0..2 {
for dj in 0..2 {
if (di == 0 && dj == 0) || di > r || dj > c {
continue;
}
assert_eq!(chosen[(r - di) * 4 + (c - dj)], 0);
}
}
}
}
}
#[test]
fn ir_program_buffer_layout() {
let p = planar_rewrite_schedule("c", "ch", 4, 4, 2);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["c", "ch"]);
assert_eq!(p.buffers[0].count(), 16);
assert_eq!(p.buffers[1].count(), 16);
}
#[test]
fn zero_h_traps() {
let p = planar_rewrite_schedule("c", "ch", 0, 4, 2);
assert!(p.stats().trap());
}
#[test]
fn zero_k_traps() {
let p = planar_rewrite_schedule("c", "ch", 4, 4, 0);
assert!(p.stats().trap());
}
}