use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::data_movement::{
read_u32_words_le, DataMovementError, BYTES_BYTES_INPUTS, BYTES_OUTPUTS,
};
use crate::ops::OpSpec;
pub const WGSL_SPELLING: &str = r"
pub fn data_movement_gather_byte(value: u32) -> u32 {
return value & 0xffu;
}
";
pub fn gather(data: &[u8], indices_le: &[u8]) -> Result<Vec<u8>, DataMovementError> {
gather_elements(data, indices_le, 1)
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Gather;
pub fn gather_elements(
data: &[u8],
indices_le: &[u8],
element_size: usize,
) -> Result<Vec<u8>, DataMovementError> {
if element_size == 0 {
return Err(DataMovementError::ZeroElementSize);
}
if data.len() % element_size != 0 {
return Err(DataMovementError::MisalignedElements {
data_len: data.len(),
element_size,
});
}
let indices = read_u32_words_le(indices_le)?;
let element_count = data.len() / element_size;
let requested =
indices
.len()
.checked_mul(element_size)
.ok_or(DataMovementError::OutputTooLarge {
requested: usize::MAX,
max: crate::ops::data_movement::MAX_OUTPUT_BYTES,
})?;
if requested > crate::ops::data_movement::MAX_OUTPUT_BYTES {
return Err(DataMovementError::OutputTooLarge {
requested,
max: crate::ops::data_movement::MAX_OUTPUT_BYTES,
});
}
let mut out = Vec::with_capacity(requested);
for (position, &index) in indices.iter().enumerate() {
let index_usize =
usize::try_from(index).map_err(|_| DataMovementError::IndexOutOfBounds {
position,
index,
element_count,
})?;
if index_usize >= element_count {
return Err(DataMovementError::IndexOutOfBounds {
position,
index,
element_count,
});
}
let start = index_usize * element_size;
out.extend_from_slice(&data[start..start + element_size]);
}
Ok(out)
}
impl Gather {
pub const SPEC: OpSpec = OpSpec::composition(
"data_movement.gather",
BYTES_BYTES_INPUTS,
BYTES_OUTPUTS,
LAWS,
Self::program,
);
#[must_use]
pub fn program() -> Program {
let idx = Expr::gid_x();
let src = Expr::load("indices", idx.clone());
Program::new(
vec![
BufferDecl::read("data", 0, DataType::Bytes),
BufferDecl::read("indices", 1, DataType::U32),
BufferDecl::output("out", 2, DataType::Bytes),
],
[64, 1, 1],
vec![
Node::let_bind("idx", idx.clone()),
Node::if_then(
Expr::and(
Expr::lt(idx.clone(), Expr::buf_len("out")),
Expr::lt(src.clone(), Expr::buf_len("data")),
),
vec![Node::store("out", idx, Expr::load("data", src))],
),
],
)
}
}
pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];
#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
crate::ops::fixtures::run_committed_kats(
include_str!("../fixtures/reference-vectors.toml"),
|case| {
if case.op.as_deref() != Some("gather") {
return Ok(());
}
let actual = gather(
&crate::ops::data_movement::hex_to_bytes(
case.input_hex.as_ref().ok_or("Fix: missing input_hex")?,
)?,
&crate::ops::data_movement::hex_to_bytes(
case.indices_hex.as_ref().ok_or("Fix: missing indices_hex")?,
)?,
);
match (case.ok, actual) {
(Some(true), Ok(bytes)) => assert_eq!(
bytes,
crate::ops::data_movement::hex_to_bytes(
case.expected_output_hex.as_ref().ok_or("Fix: missing expected_hex")?,
)?
),
(Some(false), Err(DataMovementError::IndexOutOfBounds { .. })) => {}
(Some(false), Err(error)) => {
return Err(format!("Fix: expected IndexOutOfBounds, got {error}"));
}
(Some(true), Err(error)) => {
return Err(format!("Fix: valid gather vector was rejected: {error}"));
}
(Some(false), Ok(bytes)) => {
return Err(format!(
"Fix: invalid gather vector should fail, produced {bytes:02x?}"
));
}
(None, _) => return Err("Fix: missing ok field".into()),
}
Ok(())
},
)
}
#[test]
pub fn element_bound_is_checked_before_slicing() -> Result<(), String> {
let indices = 2_u32.to_le_bytes();
match gather_elements(&[1, 2, 3, 4], &indices, 2) {
Err(error) => {
assert_eq!(
error,
DataMovementError::IndexOutOfBounds {
position: 0,
index: 2,
element_count: 2
}
);
Ok(())
}
Ok(bytes) => Err(format!(
"Fix: out-of-bounds gather element should fail, produced {bytes:02x?}"
)),
}
}