use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::OpSpec;
use crate::ops::data_movement::{DataMovementError, BYTES_U32_INPUTS, TWO_BYTES_OUTPUTS};
pub const WGSL_SPELLING: &str = r"
pub fn data_movement_partition_nonzero(byte: u32) -> bool {
return byte != 0u;
}
";
impl Partition {
pub const SPEC: OpSpec = OpSpec::composition(
"data_movement.partition",
BYTES_U32_INPUTS,
TWO_BYTES_OUTPUTS,
LAWS,
Self::program,
);
#[must_use]
pub fn program() -> Program {
let idx = Expr::gid_x();
let byte = Expr::load("data", idx.clone());
Program::new(
vec![
BufferDecl::read("data", 0, DataType::Bytes),
BufferDecl::uniform("predicate_id", 1, DataType::U32),
BufferDecl::output("pass", 2, DataType::Bytes),
BufferDecl::read_write("fail", 3, DataType::Bytes),
],
[64, 1, 1],
vec![
Node::let_bind("idx", idx.clone()),
Node::if_then(
Expr::lt(idx.clone(), Expr::buf_len("data")),
vec![Node::if_then_else(
Expr::ne(byte.clone(), Expr::u32(0)),
vec![Node::store("pass", idx.clone(), byte.clone())],
vec![Node::store("fail", idx, byte)],
)],
),
],
)
}
}
pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];
pub fn partition(data: &[u8], predicate_id: u32) -> Result<PartitionOutput, DataMovementError> {
let predicate = predicate_by_id(predicate_id)?;
if data.len() > crate::ops::data_movement::MAX_OUTPUT_BYTES {
return Err(DataMovementError::OutputTooLarge {
requested: data.len(),
max: crate::ops::data_movement::MAX_OUTPUT_BYTES,
});
}
let mut pass = Vec::with_capacity(data.len());
let mut fail = Vec::with_capacity(data.len());
for &byte in data {
if predicate(byte) {
pass.push(byte);
} else {
fail.push(byte);
}
}
Ok(PartitionOutput { pass, fail })
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Partition;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PartitionOutput {
pub pass: Vec<u8>,
pub fail: Vec<u8>,
}
pub const PREDICATE_ASCII_ALPHA: u32 = 2;
pub const PREDICATE_ASCII_DIGIT: u32 = 1;
pub fn predicate_by_id(id: u32) -> Result<fn(u8) -> bool, DataMovementError> {
match id {
PREDICATE_NONZERO => Ok(|byte| byte != 0),
PREDICATE_ASCII_DIGIT => Ok(|byte| byte.is_ascii_digit()),
PREDICATE_ASCII_ALPHA => Ok(|byte| byte.is_ascii_alphabetic()),
PREDICATE_EVEN => Ok(|byte| byte % 2 == 0),
_ => Err(DataMovementError::UnknownPredicate { id }),
}
}
pub const PREDICATE_EVEN: u32 = 3;
pub const PREDICATE_NONZERO: u32 = 0;
#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
crate::ops::fixtures::run_committed_kats(
include_str!("../fixtures/reference-vectors.toml"),
|case| {
if case.op.as_deref() != Some("partition") {
return Ok(());
}
let actual = partition(
&crate::ops::data_movement::hex_to_bytes(
case.input_hex.as_ref().ok_or("Fix: missing input_hex")?,
)?,
case.predicate_id.ok_or("Fix: missing predicate_id")?,
);
match (case.ok, actual) {
(Some(true), Ok(output)) => {
assert_eq!(
output.pass,
crate::ops::data_movement::hex_to_bytes(
case.expected_pass_hex.as_ref().ok_or("Fix: missing expected_pass_hex")?,
)?
);
assert_eq!(
output.fail,
crate::ops::data_movement::hex_to_bytes(
case.expected_fail_hex.as_ref().ok_or("Fix: missing expected_fail_hex")?,
)?
);
}
(Some(false), Err(DataMovementError::UnknownPredicate { .. })) => {}
(Some(false), Err(error)) => {
return Err(format!("Fix: expected UnknownPredicate, got {error}"));
}
(Some(true), Err(error)) => {
return Err(format!("Fix: valid partition vector was rejected: {error}"));
}
(Some(false), Ok(output)) => {
return Err(format!(
"Fix: invalid partition vector should fail, produced {output:?}"
));
}
(None, _) => return Err("Fix: missing ok field".into()),
}
Ok(())
},
)
}