vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::OpSpec;
use crate::ops::data_movement::{DataMovementError, BYTES_BYTES_INPUTS, BYTES_OUTPUTS};

// WGSL spelling note for `data_movement.compact`.

/// WGSL helper body used by conformance tooling to identify compact lowering.
pub const WGSL_SPELLING: &str = r"
pub fn data_movement_compact_keep(mask: u32) -> bool {
    return mask != 0u;
}
";

/// Drop bytes whose corresponding mask entry is false.
///
/// # Errors
///
/// Returns [`DataMovementError`] when a non-empty mask does not contain exactly
/// one entry per input byte. An empty mask always produces empty output.
pub fn compact(data: &[u8], mask: &[bool]) -> Result<Vec<u8>, DataMovementError> {
    if mask.is_empty() {
        return Ok(Vec::new());
    }
    if data.len() != mask.len() {
        return Err(DataMovementError::MaskLengthMismatch {
            data_len: data.len(),
            mask_len: mask.len(),
        });
    }

    Ok(data
        .iter()
        .zip(mask)
        .filter_map(|(&byte, &keep)| keep.then_some(byte))
        .collect())
}

/// Masked byte compaction operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct Compact;

/// Drop bytes using a byte-encoded mask where zero is false and non-zero is true.
///
/// # Errors
///
/// Returns [`DataMovementError`] when the mask length is invalid.
pub fn compact_mask_bytes(data: &[u8], mask: &[u8]) -> Result<Vec<u8>, DataMovementError> {
    let mask = mask.iter().map(|&byte| byte != 0).collect::<Vec<_>>();
    compact(data, &mask)
}

impl Compact {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition(
        "data_movement.compact",
        BYTES_BYTES_INPUTS,
        BYTES_OUTPUTS,
        LAWS,
        Self::program,
    );

    /// Build a lowerable IR sketch that preserves kept bytes in-place.
    #[must_use]
    pub fn program() -> Program {
        let idx = Expr::gid_x();
        let keep = Expr::load("mask", idx.clone());
        Program::new(
            vec![
                BufferDecl::read("data", 0, DataType::Bytes),
                BufferDecl::read("mask", 1, DataType::Bytes),
                BufferDecl::output("out", 2, DataType::Bytes),
            ],
            [64, 1, 1],
            vec![
                Node::let_bind("idx", idx.clone()),
                Node::if_then(
                    Expr::and(
                        Expr::lt(idx.clone(), Expr::buf_len("data")),
                        Expr::ne(keep, Expr::u32(0)),
                    ),
                    vec![Node::store("out", idx.clone(), Expr::load("data", idx))],
                ),
            ],
        )
    }
}

pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];

// Unit tests.
// Unit tests extracted from `ops/data_movement/compact/kernel.rs`.

#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
    crate::ops::fixtures::run_committed_kats(
        include_str!("../fixtures/reference-vectors.toml"),
        |case| {
            if case.op.as_deref() != Some("compact") {
                return Ok(());
            }
            let actual = compact_mask_bytes(
                &crate::ops::data_movement::hex_to_bytes(
                    case.input_hex.as_ref().ok_or("Fix: missing input_hex")?,
                )?,
                &crate::ops::data_movement::hex_to_bytes(
                    case.mask_hex.as_ref().ok_or("Fix: missing mask_hex")?,
                )?,
            );
            match (case.ok, actual) {
                (Some(true), Ok(bytes)) => assert_eq!(
                    bytes,
                    crate::ops::data_movement::hex_to_bytes(
                        case.expected_output_hex.as_ref().ok_or("Fix: missing expected_hex")?,
                    )?
                ),
                (Some(false), Err(DataMovementError::MaskLengthMismatch { .. })) => {}
                (Some(false), Err(error)) => {
                    return Err(format!("Fix: expected MaskLengthMismatch, got {error}"));
                }
                (Some(true), Err(error)) => {
                    return Err(format!("Fix: valid compact vector was rejected: {error}"));
                }
                (Some(false), Ok(bytes)) => {
                    return Err(format!(
                        "Fix: invalid compact vector should fail, produced {bytes:02x?}"
                    ));
                }
                (None, _) => return Err("Fix: missing ok field".into()),
            }
            Ok(())
        },
    )
}

#[test]
pub fn empty_mask_drops_everything_even_for_nonempty_data() -> Result<(), DataMovementError> {
    assert_eq!(compact_mask_bytes(&[1, 2, 3], &[])?, Vec::<u8>::new());
    Ok(())
}