use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::OpSpec;
use crate::ops::data_movement::{DataMovementError, BYTES_BYTES_INPUTS, BYTES_OUTPUTS};
pub const WGSL_SPELLING: &str = r"
pub fn data_movement_compact_keep(mask: u32) -> bool {
return mask != 0u;
}
";
pub fn compact(data: &[u8], mask: &[bool]) -> Result<Vec<u8>, DataMovementError> {
if mask.is_empty() {
return Ok(Vec::new());
}
if data.len() != mask.len() {
return Err(DataMovementError::MaskLengthMismatch {
data_len: data.len(),
mask_len: mask.len(),
});
}
Ok(data
.iter()
.zip(mask)
.filter_map(|(&byte, &keep)| keep.then_some(byte))
.collect())
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Compact;
pub fn compact_mask_bytes(data: &[u8], mask: &[u8]) -> Result<Vec<u8>, DataMovementError> {
let mask = mask.iter().map(|&byte| byte != 0).collect::<Vec<_>>();
compact(data, &mask)
}
impl Compact {
pub const SPEC: OpSpec = OpSpec::composition(
"data_movement.compact",
BYTES_BYTES_INPUTS,
BYTES_OUTPUTS,
LAWS,
Self::program,
);
#[must_use]
pub fn program() -> Program {
let idx = Expr::gid_x();
let keep = Expr::load("mask", idx.clone());
Program::new(
vec![
BufferDecl::read("data", 0, DataType::Bytes),
BufferDecl::read("mask", 1, DataType::Bytes),
BufferDecl::output("out", 2, DataType::Bytes),
],
[64, 1, 1],
vec![
Node::let_bind("idx", idx.clone()),
Node::if_then(
Expr::and(
Expr::lt(idx.clone(), Expr::buf_len("data")),
Expr::ne(keep, Expr::u32(0)),
),
vec![Node::store("out", idx.clone(), Expr::load("data", idx))],
),
],
)
}
}
pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];
#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
crate::ops::fixtures::run_committed_kats(
include_str!("../fixtures/reference-vectors.toml"),
|case| {
if case.op.as_deref() != Some("compact") {
return Ok(());
}
let actual = compact_mask_bytes(
&crate::ops::data_movement::hex_to_bytes(
case.input_hex.as_ref().ok_or("Fix: missing input_hex")?,
)?,
&crate::ops::data_movement::hex_to_bytes(
case.mask_hex.as_ref().ok_or("Fix: missing mask_hex")?,
)?,
);
match (case.ok, actual) {
(Some(true), Ok(bytes)) => assert_eq!(
bytes,
crate::ops::data_movement::hex_to_bytes(
case.expected_output_hex.as_ref().ok_or("Fix: missing expected_hex")?,
)?
),
(Some(false), Err(DataMovementError::MaskLengthMismatch { .. })) => {}
(Some(false), Err(error)) => {
return Err(format!("Fix: expected MaskLengthMismatch, got {error}"));
}
(Some(true), Err(error)) => {
return Err(format!("Fix: valid compact vector was rejected: {error}"));
}
(Some(false), Ok(bytes)) => {
return Err(format!(
"Fix: invalid compact vector should fail, produced {bytes:02x?}"
));
}
(None, _) => return Err("Fix: missing ok field".into()),
}
Ok(())
},
)
}
#[test]
pub fn empty_mask_drops_everything_even_for_nonempty_data() -> Result<(), DataMovementError> {
assert_eq!(compact_mask_bytes(&[1, 2, 3], &[])?, Vec::<u8>::new());
Ok(())
}