vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::data_movement::{
    read_u32_words_le, DataMovementError, BYTES_BYTES_INPUTS, BYTES_OUTPUTS,
};
use crate::ops::OpSpec;

/// WGSL helper body used by conformance tooling to identify gather lowering.
pub const WGSL_SPELLING: &str = r"
pub fn data_movement_gather_byte(value: u32) -> u32 {
    return value & 0xffu;
}
";

/// Gather bytes by a little-endian u32 index buffer.
///
/// Each index addresses one byte. Use [`gather_elements`] when the logical
/// element width is larger than one byte.
///
/// # Errors
///
/// Returns [`DataMovementError`] when the index buffer is malformed or any
/// index is outside the input.
pub fn gather(data: &[u8], indices_le: &[u8]) -> Result<Vec<u8>, DataMovementError> {
    gather_elements(data, indices_le, 1)
}


/// Indexed byte gather operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct Gather;


/// Gather fixed-width elements by a little-endian u32 index buffer.
///
/// # Errors
///
/// Returns [`DataMovementError`] when the element size is zero, the data is not
/// element-aligned, the index buffer is malformed, or an index is outside the
/// element count.
pub fn gather_elements(
    data: &[u8],
    indices_le: &[u8],
    element_size: usize,
) -> Result<Vec<u8>, DataMovementError> {
    if element_size == 0 {
        return Err(DataMovementError::ZeroElementSize);
    }
    if data.len() % element_size != 0 {
        return Err(DataMovementError::MisalignedElements {
            data_len: data.len(),
            element_size,
        });
    }

    let indices = read_u32_words_le(indices_le)?;
    let element_count = data.len() / element_size;
    let requested =
        indices
            .len()
            .checked_mul(element_size)
            .ok_or(DataMovementError::OutputTooLarge {
                requested: usize::MAX,
                max: crate::ops::data_movement::MAX_OUTPUT_BYTES,
            })?;
    if requested > crate::ops::data_movement::MAX_OUTPUT_BYTES {
        return Err(DataMovementError::OutputTooLarge {
            requested,
            max: crate::ops::data_movement::MAX_OUTPUT_BYTES,
        });
    }

    let mut out = Vec::with_capacity(requested);
    for (position, &index) in indices.iter().enumerate() {
        let index_usize =
            usize::try_from(index).map_err(|_| DataMovementError::IndexOutOfBounds {
                position,
                index,
                element_count,
            })?;
        if index_usize >= element_count {
            return Err(DataMovementError::IndexOutOfBounds {
                position,
                index,
                element_count,
            });
        }
        let start = index_usize * element_size;
        out.extend_from_slice(&data[start..start + element_size]);
    }
    Ok(out)
}


impl Gather {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition(
        "data_movement.gather",
        BYTES_BYTES_INPUTS,
        BYTES_OUTPUTS,
        LAWS,
        Self::program,
    );

    /// Build a single-byte lowerable IR sketch for gather.
    #[must_use]
    pub fn program() -> Program {
        let idx = Expr::gid_x();
        let src = Expr::load("indices", idx.clone());
        Program::new(
            vec![
                BufferDecl::read("data", 0, DataType::Bytes),
                BufferDecl::read("indices", 1, DataType::U32),
                BufferDecl::output("out", 2, DataType::Bytes),
            ],
            [64, 1, 1],
            vec![
                Node::let_bind("idx", idx.clone()),
                Node::if_then(
                    Expr::and(
                        Expr::lt(idx.clone(), Expr::buf_len("out")),
                        Expr::lt(src.clone(), Expr::buf_len("data")),
                    ),
                    vec![Node::store("out", idx, Expr::load("data", src))],
                ),
            ],
        )
    }
}


pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];

// Unit tests.
// Unit tests extracted from `ops/data_movement/gather/kernel.rs`.

#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
    crate::ops::fixtures::run_committed_kats(
        include_str!("../fixtures/reference-vectors.toml"),
        |case| {
            if case.op.as_deref() != Some("gather") {
                return Ok(());
            }
            let actual = gather(
                &crate::ops::data_movement::hex_to_bytes(
                    case.input_hex.as_ref().ok_or("Fix: missing input_hex")?,
                )?,
                &crate::ops::data_movement::hex_to_bytes(
                    case.indices_hex.as_ref().ok_or("Fix: missing indices_hex")?,
                )?,
            );
            match (case.ok, actual) {
                (Some(true), Ok(bytes)) => assert_eq!(
                    bytes,
                    crate::ops::data_movement::hex_to_bytes(
                        case.expected_output_hex.as_ref().ok_or("Fix: missing expected_hex")?,
                    )?
                ),
                (Some(false), Err(DataMovementError::IndexOutOfBounds { .. })) => {}
                (Some(false), Err(error)) => {
                    return Err(format!("Fix: expected IndexOutOfBounds, got {error}"));
                }
                (Some(true), Err(error)) => {
                    return Err(format!("Fix: valid gather vector was rejected: {error}"));
                }
                (Some(false), Ok(bytes)) => {
                    return Err(format!(
                        "Fix: invalid gather vector should fail, produced {bytes:02x?}"
                    ));
                }
                (None, _) => return Err("Fix: missing ok field".into()),
            }
            Ok(())
        },
    )
}

#[test]
pub fn element_bound_is_checked_before_slicing() -> Result<(), String> {
    let indices = 2_u32.to_le_bytes();
    match gather_elements(&[1, 2, 3, 4], &indices, 2) {
        Err(error) => {
            assert_eq!(
                error,
                DataMovementError::IndexOutOfBounds {
                    position: 0,
                    index: 2,
                    element_count: 2
                }
            );
            Ok(())
        }
        Ok(bytes) => Err(format!(
            "Fix: out-of-bounds gather element should fail, produced {bytes:02x?}"
        )),
    }
}