use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::label::resolve_family";
#[must_use]
pub fn resolve_family(
node_tags: &str,
nodeset_out: &str,
node_count: u32,
family_mask: u32,
) -> Program {
let t = Expr::InvocationId { axis: 0 };
let words = node_count.div_ceil(32);
let body = vec![
Node::let_bind("tag", Expr::load(node_tags, t.clone())),
Node::if_then(
Expr::ne(
Expr::bitand(Expr::var("tag"), Expr::u32(family_mask)),
Expr::u32(0),
),
vec![
Node::let_bind("word_idx", Expr::shr(t.clone(), Expr::u32(5))),
Node::let_bind(
"bit",
Expr::shl(Expr::u32(1), Expr::bitand(t.clone(), Expr::u32(31))),
),
Node::let_bind(
"_",
Expr::atomic_or(nodeset_out, Expr::var("word_idx"), Expr::var("bit")),
),
],
),
];
Program::wrapped(
vec![
BufferDecl::storage(node_tags, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(node_count),
BufferDecl::storage(nodeset_out, 1, BufferAccess::ReadWrite, DataType::U32)
.with_count(words),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(node_count)),
body,
)]),
}],
)
}
#[must_use]
pub fn cpu_ref(node_tags: &[u32], family_mask: u32) -> Vec<u32> {
let mut out = Vec::new();
cpu_ref_into(node_tags, family_mask, &mut out);
out
}
pub fn cpu_ref_into(node_tags: &[u32], family_mask: u32, out: &mut Vec<u32>) {
let words = node_tags.len().div_ceil(32);
out.clear();
out.resize(words, 0);
for (v, tag) in node_tags.iter().enumerate() {
if (tag & family_mask) != 0 {
let word = v / 32;
let bit = 1u32 << (v % 32);
out[word] |= bit;
}
}
}
#[cfg(feature = "inventory-registry")]
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| resolve_family("tags", "nodeset", 4, 0b0010),
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![to_bytes(&[0x01, 0x02, 0x06, 0x04]), to_bytes(&[0])]]
}),
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![to_bytes(&[0b0110])]]
}),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_family_bit() {
assert_eq!(cpu_ref(&[0x01, 0x02, 0x06, 0x04], 0x02), vec![0b0110]);
}
#[test]
fn empty_family_yields_empty_nodeset() {
assert_eq!(cpu_ref(&[0x01, 0x02], 0x00), vec![0]);
}
#[test]
fn cpu_ref_into_reuses_nodeset_buffer() {
let mut out = Vec::with_capacity(4);
let ptr = out.as_ptr();
cpu_ref_into(&[0x01, 0x02, 0x06, 0x04], 0x02, &mut out);
assert_eq!(out, vec![0b0110]);
assert_eq!(out.as_ptr(), ptr);
}
}