vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::OpSpec;
use crate::ops::data_movement::{DataMovementError, BYTES_OUTPUTS, U32_U32_INPUTS};

// WGSL spelling note for `data_movement.broadcast`.

/// WGSL helper body used by conformance tooling to identify broadcast lowering.
pub const WGSL_SPELLING: &str = r"
pub fn data_movement_broadcast_value(value: u32) -> u32 {
    return value;
}
";

/// Fill a byte buffer with `count` little-endian copies of `value`.
///
/// # Errors
///
/// Returns [`DataMovementError`] when `count` would allocate more than the
/// 64 MiB data movement output cap.
pub fn broadcast(value: u32, count: u32) -> Result<Vec<u8>, DataMovementError> {
    if count > MAX_BROADCAST_COUNT {
        return Err(DataMovementError::OutputTooLarge {
            requested: usize::try_from(count)
                .unwrap_or(usize::MAX)
                .saturating_mul(ELEMENT_SIZE),
            max: crate::ops::data_movement::MAX_OUTPUT_BYTES,
        });
    }

    let count = usize::try_from(count).map_err(|_| DataMovementError::OutputTooLarge {
        requested: usize::MAX,
        max: crate::ops::data_movement::MAX_OUTPUT_BYTES,
    })?;
    let mut out = Vec::with_capacity(count * ELEMENT_SIZE);
    let bytes = value.to_le_bytes();
    for _ in 0..count {
        out.extend_from_slice(&bytes);
    }
    Ok(out)
}

/// U32 broadcast operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct Broadcast;

pub const ELEMENT_SIZE: usize = core::mem::size_of::<u32>();

impl Broadcast {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition(
        "data_movement.broadcast",
        U32_U32_INPUTS,
        BYTES_OUTPUTS,
        LAWS,
        Self::program,
    );

    /// Build a lowerable IR program for repeated u32 writes.
    #[must_use]
    pub fn program() -> Program {
        let idx = Expr::gid_x();
        Program::new(
            vec![
                BufferDecl::uniform("value", 0, DataType::U32),
                BufferDecl::uniform("count", 1, DataType::U32),
                BufferDecl::output("out", 2, DataType::U32),
            ],
            [64, 1, 1],
            vec![
                Node::let_bind("idx", idx.clone()),
                Node::if_then(
                    Expr::lt(idx.clone(), Expr::buf_len("out")),
                    vec![Node::store("out", idx, Expr::load("value", Expr::u32(0)))],
                ),
            ],
        )
    }
}

pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];

/// Maximum number of u32 values accepted by one broadcast call.
pub const MAX_BROADCAST_COUNT: u32 =
    (crate::ops::data_movement::MAX_OUTPUT_BYTES / ELEMENT_SIZE) as u32;

// Unit tests.
// Unit tests extracted from `ops/data_movement/broadcast/kernel.rs`.

#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
    crate::ops::fixtures::run_committed_kats(
        include_str!("../fixtures/reference-vectors.toml"),
        |case| {
            if case.op.as_deref() != Some("broadcast") {
                return Ok(());
            }
            let actual = broadcast(
                case.value.ok_or("Fix: missing value")?,
                case.count.ok_or("Fix: missing count")?,
            );
            match (case.ok, actual) {
                (Some(true), Ok(bytes)) => assert_eq!(
                    bytes,
                    crate::ops::data_movement::hex_to_bytes(
                        case.expected_output_hex.as_ref().ok_or("Fix: missing expected_hex")?,
                    )?
                ),
                (Some(false), Err(DataMovementError::OutputTooLarge { .. })) => {}
                (Some(false), Err(error)) => {
                    return Err(format!("Fix: expected OutputTooLarge, got {error}"));
                }
                (Some(true), Err(error)) => {
                    return Err(format!("Fix: valid broadcast vector was rejected: {error}"));
                }
                (Some(false), Ok(bytes)) => {
                    return Err(format!(
                        "Fix: invalid broadcast vector should fail, produced {bytes:02x?}"
                    ));
                }
                (None, _) => return Err("Fix: missing ok field".into()),
            }
            Ok(())
        },
    )
}

#[test]
pub fn rejects_count_above_oom_cap() -> Result<(), String> {
    match broadcast(7, MAX_BROADCAST_COUNT.saturating_add(1)) {
        Err(DataMovementError::OutputTooLarge { .. }) => Ok(()),
        Err(error) => Err(format!("Fix: expected OutputTooLarge, got {error}")),
        Ok(bytes) => Err(format!(
            "Fix: count above cap should fail, produced {bytes:02x?}"
        )),
    }
}