use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::graph::backdoor_descendants_check";
#[must_use]
pub fn backdoor_descendants_check(
candidate_z: &str,
descendants_of_x: &str,
out_violation: &str,
n: u32,
) -> Program {
if n == 0 {
return crate::invalid_output_program(
OP_ID,
out_violation,
DataType::U32,
format!("Fix: backdoor_descendants_check requires n > 0, got {n}."),
);
}
let t = Expr::InvocationId { axis: 0 };
let body = vec![Node::if_then(
Expr::eq(t.clone(), Expr::u32(0)),
vec![
Node::let_bind("violated", Expr::u32(0)),
Node::loop_for(
"k",
Expr::u32(0),
Expr::u32(n),
vec![Node::if_then(
Expr::and(
Expr::ne(Expr::load(candidate_z, Expr::var("k")), Expr::u32(0)),
Expr::ne(Expr::load(descendants_of_x, Expr::var("k")), Expr::u32(0)),
),
vec![Node::assign("violated", Expr::u32(1))],
)],
),
Node::store(out_violation, Expr::u32(0), Expr::var("violated")),
],
)];
Program::wrapped(
vec![
BufferDecl::storage(candidate_z, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(n),
BufferDecl::storage(descendants_of_x, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(n),
BufferDecl::storage(out_violation, 2, BufferAccess::ReadWrite, DataType::U32)
.with_count(1),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn backdoor_descendants_check_cpu(candidate_z: &[u32], descendants_of_x: &[u32]) -> bool {
candidate_z
.iter()
.zip(descendants_of_x.iter())
.any(|(&z, &d)| z != 0 && d != 0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_disjoint_z_passes() {
let z = vec![1, 1, 0, 0];
let d = vec![0, 0, 1, 1];
assert!(!backdoor_descendants_check_cpu(&z, &d));
}
#[test]
fn cpu_overlap_violates() {
let z = vec![1, 0, 1, 0];
let d = vec![0, 0, 1, 1]; assert!(backdoor_descendants_check_cpu(&z, &d));
}
#[test]
fn cpu_empty_z_never_violates() {
let z = vec![0, 0, 0, 0];
let d = vec![1, 1, 1, 1];
assert!(!backdoor_descendants_check_cpu(&z, &d));
}
#[test]
fn cpu_empty_descendants_never_violates() {
let z = vec![1, 1, 1, 1];
let d = vec![0, 0, 0, 0];
assert!(!backdoor_descendants_check_cpu(&z, &d));
}
#[test]
fn cpu_mismatched_inputs_only_check_complete_pairs() {
assert!(!backdoor_descendants_check_cpu(&[1], &[]));
assert!(backdoor_descendants_check_cpu(&[0, 1], &[0, 1, 1]));
}
#[test]
fn ir_program_buffer_layout() {
let p = backdoor_descendants_check("z", "d", "v", 8);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["z", "d", "v"]);
assert_eq!(p.buffers[0].count(), 8);
assert_eq!(p.buffers[1].count(), 8);
assert_eq!(p.buffers[2].count(), 1);
}
#[test]
fn zero_n_traps() {
let p = backdoor_descendants_check("z", "d", "v", 0);
assert!(p.stats().trap());
}
}