use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::OpSpec;
use crate::ops::data_movement::{DataMovementError, BYTES_OUTPUTS, U32_U32_INPUTS};
pub const WGSL_SPELLING: &str = r"
pub fn data_movement_broadcast_value(value: u32) -> u32 {
return value;
}
";
pub fn broadcast(value: u32, count: u32) -> Result<Vec<u8>, DataMovementError> {
if count > MAX_BROADCAST_COUNT {
return Err(DataMovementError::OutputTooLarge {
requested: usize::try_from(count)
.unwrap_or(usize::MAX)
.saturating_mul(ELEMENT_SIZE),
max: crate::ops::data_movement::MAX_OUTPUT_BYTES,
});
}
let count = usize::try_from(count).map_err(|_| DataMovementError::OutputTooLarge {
requested: usize::MAX,
max: crate::ops::data_movement::MAX_OUTPUT_BYTES,
})?;
let mut out = Vec::with_capacity(count * ELEMENT_SIZE);
let bytes = value.to_le_bytes();
for _ in 0..count {
out.extend_from_slice(&bytes);
}
Ok(out)
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Broadcast;
pub const ELEMENT_SIZE: usize = core::mem::size_of::<u32>();
impl Broadcast {
pub const SPEC: OpSpec = OpSpec::composition(
"data_movement.broadcast",
U32_U32_INPUTS,
BYTES_OUTPUTS,
LAWS,
Self::program,
);
#[must_use]
pub fn program() -> Program {
let idx = Expr::gid_x();
Program::new(
vec![
BufferDecl::uniform("value", 0, DataType::U32),
BufferDecl::uniform("count", 1, DataType::U32),
BufferDecl::output("out", 2, DataType::U32),
],
[64, 1, 1],
vec![
Node::let_bind("idx", idx.clone()),
Node::if_then(
Expr::lt(idx.clone(), Expr::buf_len("out")),
vec![Node::store("out", idx, Expr::load("value", Expr::u32(0)))],
),
],
)
}
}
pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];
pub const MAX_BROADCAST_COUNT: u32 =
(crate::ops::data_movement::MAX_OUTPUT_BYTES / ELEMENT_SIZE) as u32;
#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
crate::ops::fixtures::run_committed_kats(
include_str!("../fixtures/reference-vectors.toml"),
|case| {
if case.op.as_deref() != Some("broadcast") {
return Ok(());
}
let actual = broadcast(
case.value.ok_or("Fix: missing value")?,
case.count.ok_or("Fix: missing count")?,
);
match (case.ok, actual) {
(Some(true), Ok(bytes)) => assert_eq!(
bytes,
crate::ops::data_movement::hex_to_bytes(
case.expected_output_hex.as_ref().ok_or("Fix: missing expected_hex")?,
)?
),
(Some(false), Err(DataMovementError::OutputTooLarge { .. })) => {}
(Some(false), Err(error)) => {
return Err(format!("Fix: expected OutputTooLarge, got {error}"));
}
(Some(true), Err(error)) => {
return Err(format!("Fix: valid broadcast vector was rejected: {error}"));
}
(Some(false), Ok(bytes)) => {
return Err(format!(
"Fix: invalid broadcast vector should fail, produced {bytes:02x?}"
));
}
(None, _) => return Err("Fix: missing ok field".into()),
}
Ok(())
},
)
}
#[test]
pub fn rejects_count_above_oom_cap() -> Result<(), String> {
match broadcast(7, MAX_BROADCAST_COUNT.saturating_add(1)) {
Err(DataMovementError::OutputTooLarge { .. }) => Ok(()),
Err(error) => Err(format!("Fix: expected OutputTooLarge, got {error}")),
Ok(bytes) => Err(format!(
"Fix: count above cap should fail, produced {bytes:02x?}"
)),
}
}