use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const LITERAL_TRUE: u32 = 1;
pub const LITERAL_FALSE: u32 = 2;
pub const AND_NODE: u32 = 3;
pub const OR_NODE: u32 = 4;
pub const OP_ID: &str = "vyre-primitives::graph::ddnnf_evaluate";
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn ddnnf_evaluate(
node_kinds: &str,
node_var: &str,
child_offsets: &str,
child_counts: &str,
children: &str,
var_assignments: &str,
out: &str,
n_nodes: u32,
n_children: u32,
n_vars: u32,
) -> Program {
if n_nodes == 0 {
return crate::invalid_output_program(
OP_ID,
out,
DataType::U32,
format!("Fix: ddnnf_evaluate requires n_nodes > 0, got {n_nodes}."),
);
}
if n_vars == 0 {
return crate::invalid_output_program(
OP_ID,
out,
DataType::U32,
format!("Fix: ddnnf_evaluate requires n_vars > 0, got {n_vars}."),
);
}
let lane = Expr::InvocationId { axis: 0 };
let child_index = Expr::add(Expr::var("child_base"), Expr::var("k"));
let body = vec![Node::if_then(
Expr::lt(lane.clone(), Expr::u32(n_nodes)),
vec![
Node::let_bind("kind", Expr::load(node_kinds, lane.clone())),
Node::let_bind("var_id", Expr::load(node_var, lane.clone())),
Node::let_bind("child_base", Expr::load(child_offsets, lane.clone())),
Node::let_bind("child_count", Expr::load(child_counts, lane.clone())),
Node::if_then(
Expr::eq(Expr::var("kind"), Expr::u32(LITERAL_TRUE)),
vec![
Node::let_bind(
"assigned_true",
Expr::load(var_assignments, Expr::var("var_id")),
),
Node::store(
out,
lane.clone(),
Expr::select(
Expr::or(
Expr::eq(Expr::var("assigned_true"), Expr::u32(1)),
Expr::eq(Expr::var("assigned_true"), Expr::u32(u32::MAX)),
),
Expr::u32(1),
Expr::u32(0),
),
),
],
),
Node::if_then(
Expr::eq(Expr::var("kind"), Expr::u32(LITERAL_FALSE)),
vec![
Node::let_bind(
"assigned_false",
Expr::load(var_assignments, Expr::var("var_id")),
),
Node::store(
out,
lane.clone(),
Expr::select(
Expr::or(
Expr::eq(Expr::var("assigned_false"), Expr::u32(0)),
Expr::eq(Expr::var("assigned_false"), Expr::u32(u32::MAX)),
),
Expr::u32(1),
Expr::u32(0),
),
),
],
),
Node::if_then(
Expr::eq(Expr::var("kind"), Expr::u32(AND_NODE)),
vec![
Node::let_bind("acc_and", Expr::u32(1)),
Node::loop_for(
"k",
Expr::u32(0),
Expr::var("child_count"),
vec![
Node::let_bind("child_node", Expr::load(children, child_index.clone())),
Node::assign(
"acc_and",
Expr::mul(
Expr::var("acc_and"),
Expr::load(out, Expr::var("child_node")),
),
),
],
),
Node::store(out, lane.clone(), Expr::var("acc_and")),
],
),
Node::if_then(
Expr::eq(Expr::var("kind"), Expr::u32(OR_NODE)),
vec![
Node::let_bind("acc_or", Expr::u32(0)),
Node::loop_for(
"kk",
Expr::u32(0),
Expr::var("child_count"),
vec![
Node::let_bind(
"or_child_node",
Expr::load(
children,
Expr::add(Expr::var("child_base"), Expr::var("kk")),
),
),
Node::assign(
"acc_or",
Expr::add(
Expr::var("acc_or"),
Expr::load(out, Expr::var("or_child_node")),
),
),
],
),
Node::store(out, lane.clone(), Expr::var("acc_or")),
],
),
Node::if_then(
Expr::and(
Expr::and(
Expr::ne(Expr::var("kind"), Expr::u32(LITERAL_TRUE)),
Expr::ne(Expr::var("kind"), Expr::u32(LITERAL_FALSE)),
),
Expr::and(
Expr::ne(Expr::var("kind"), Expr::u32(AND_NODE)),
Expr::ne(Expr::var("kind"), Expr::u32(OR_NODE)),
),
),
vec![Node::store(out, lane.clone(), Expr::u32(0))],
),
],
)];
Program::wrapped(
vec![
BufferDecl::storage(node_kinds, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_nodes),
BufferDecl::storage(node_var, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_nodes),
BufferDecl::storage(child_offsets, 2, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_nodes),
BufferDecl::storage(child_counts, 3, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_nodes),
BufferDecl::storage(children, 4, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_children.max(1)),
BufferDecl::storage(var_assignments, 5, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_vars),
BufferDecl::storage(out, 6, BufferAccess::ReadWrite, DataType::U32).with_count(n_nodes),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn ddnnf_evaluate_cpu(
nodes: &[(u32, u32, u32)],
node_var: &[u32],
children: &[u32],
var_assignments: &[u32],
topo_order: &[u32],
) -> Vec<u32> {
let n_nodes = nodes.len();
let mut out = vec![0u32; n_nodes];
for &node in topo_order {
let i = node as usize;
let (kind, co, cc) = nodes[i];
match kind {
LITERAL_TRUE => {
let v = node_var[i] as usize;
let assigned = var_assignments[v];
out[i] = if assigned == 1 || assigned == u32::MAX {
1
} else {
0
};
}
LITERAL_FALSE => {
let v = node_var[i] as usize;
let assigned = var_assignments[v];
out[i] = if assigned == 0 || assigned == u32::MAX {
1
} else {
0
};
}
AND_NODE => {
let mut acc = 1u32;
for k in 0..cc as usize {
let cn = children[co as usize + k] as usize;
acc = acc.wrapping_mul(out[cn]);
}
out[i] = acc;
}
OR_NODE => {
let mut acc = 0u32;
for k in 0..cc as usize {
let cn = children[co as usize + k] as usize;
acc = acc.wrapping_add(out[cn]);
}
out[i] = acc;
}
_ => {
out[i] = 0;
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_single_true_literal_with_assigned_var() {
let nodes = vec![(LITERAL_TRUE, 0, 0)];
let node_var = vec![0];
let children = vec![];
let assigns = vec![1];
let order = vec![0];
let out = ddnnf_evaluate_cpu(&nodes, &node_var, &children, &assigns, &order);
assert_eq!(out[0], 1);
}
#[test]
fn cpu_single_true_literal_with_unset_var() {
let nodes = vec![(LITERAL_TRUE, 0, 0)];
let node_var = vec![0];
let children = vec![];
let assigns = vec![u32::MAX];
let order = vec![0];
let out = ddnnf_evaluate_cpu(&nodes, &node_var, &children, &assigns, &order);
assert_eq!(out[0], 1);
}
#[test]
fn cpu_and_of_two_literals() {
let nodes = vec![(LITERAL_TRUE, 0, 0), (LITERAL_TRUE, 0, 0), (AND_NODE, 0, 2)];
let node_var = vec![0, 1, 0];
let children = vec![0, 1];
let assigns = vec![u32::MAX; 2];
let order = vec![0, 1, 2];
let out = ddnnf_evaluate_cpu(&nodes, &node_var, &children, &assigns, &order);
assert_eq!(out[2], 1);
}
#[test]
fn cpu_or_of_two_literals_counts_both() {
let nodes = vec![(LITERAL_TRUE, 0, 0), (LITERAL_TRUE, 0, 0), (OR_NODE, 0, 2)];
let node_var = vec![0, 1, 0];
let children = vec![0, 1];
let assigns = vec![u32::MAX; 2];
let order = vec![0, 1, 2];
let out = ddnnf_evaluate_cpu(&nodes, &node_var, &children, &assigns, &order);
assert_eq!(out[2], 2);
}
#[test]
fn cpu_partial_assignment_constrains_count() {
let nodes = vec![(LITERAL_TRUE, 0, 0), (LITERAL_TRUE, 0, 0), (OR_NODE, 0, 2)];
let node_var = vec![0, 1, 0];
let children = vec![0, 1];
let assigns = vec![1, 0]; let order = vec![0, 1, 2];
let out = ddnnf_evaluate_cpu(&nodes, &node_var, &children, &assigns, &order);
assert_eq!(out[2], 1);
}
#[test]
fn gpu_program_builder_exposes_ddnnf_buffers() {
let program = ddnnf_evaluate(
"kinds",
"node_var",
"child_offsets",
"child_counts",
"children",
"assignments",
"out",
3,
2,
2,
);
assert_eq!(program.buffers().len(), 7);
assert_eq!(program.workgroup_size(), [256, 1, 1]);
assert!(
program
.entry()
.iter()
.any(|node| matches!(node, vyre_foundation::ir::Node::Region { generator, .. } if generator.as_str() == OP_ID))
);
}
#[test]
fn gpu_program_builder_rejects_empty_node_count_with_trap_program() {
let program = ddnnf_evaluate(
"kinds",
"node_var",
"child_offsets",
"child_counts",
"children",
"assignments",
"out",
0,
0,
1,
);
assert_eq!(program.buffers().len(), 1);
assert!(
program
.entry()
.iter()
.any(|node| matches!(node, vyre_foundation::ir::Node::Region { body, .. } if body.iter().any(|inner| matches!(inner, vyre_foundation::ir::Node::Trap { .. }))))
);
}
}