vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::OpSpec;
use crate::ops::data_movement::{DataMovementError, BYTES_U32_INPUTS, TWO_BYTES_OUTPUTS};

// WGSL spelling note for `data_movement.partition`.

/// WGSL helper body used by conformance tooling to identify partition lowering.
pub const WGSL_SPELLING: &str = r"
pub fn data_movement_partition_nonzero(byte: u32) -> bool {
    return byte != 0u;
}
";

impl Partition {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition(
        "data_movement.partition",
        BYTES_U32_INPUTS,
        TWO_BYTES_OUTPUTS,
        LAWS,
        Self::program,
    );

    /// Build a lowerable IR sketch for the pass side of partition.
    #[must_use]
    pub fn program() -> Program {
        let idx = Expr::gid_x();
        let byte = Expr::load("data", idx.clone());
        Program::new(
            vec![
                BufferDecl::read("data", 0, DataType::Bytes),
                BufferDecl::uniform("predicate_id", 1, DataType::U32),
                BufferDecl::output("pass", 2, DataType::Bytes),
                BufferDecl::read_write("fail", 3, DataType::Bytes),
            ],
            [64, 1, 1],
            vec![
                Node::let_bind("idx", idx.clone()),
                Node::if_then(
                    Expr::lt(idx.clone(), Expr::buf_len("data")),
                    vec![Node::if_then_else(
                        Expr::ne(byte.clone(), Expr::u32(0)),
                        vec![Node::store("pass", idx.clone(), byte.clone())],
                        vec![Node::store("fail", idx, byte)],
                    )],
                ),
            ],
        )
    }
}

pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];

/// Partition bytes into pass and fail buffers according to `predicate_id`.
///
/// # Errors
///
/// Returns [`DataMovementError`] when the predicate id is not registered or
/// when either output would exceed the data movement output cap.
pub fn partition(data: &[u8], predicate_id: u32) -> Result<PartitionOutput, DataMovementError> {
    let predicate = predicate_by_id(predicate_id)?;
    if data.len() > crate::ops::data_movement::MAX_OUTPUT_BYTES {
        return Err(DataMovementError::OutputTooLarge {
            requested: data.len(),
            max: crate::ops::data_movement::MAX_OUTPUT_BYTES,
        });
    }

    let mut pass = Vec::with_capacity(data.len());
    let mut fail = Vec::with_capacity(data.len());
    for &byte in data {
        if predicate(byte) {
            pass.push(byte);
        } else {
            fail.push(byte);
        }
    }
    Ok(PartitionOutput { pass, fail })
}

/// Predicate byte partition operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct Partition;

/// Pair of byte buffers produced by `partition`.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PartitionOutput {
    /// Bytes that satisfied the predicate.
    pub pass: Vec<u8>,
    /// Bytes that did not satisfy the predicate.
    pub fail: Vec<u8>,
}

/// Predicate id that passes ASCII alphabetic bytes.
pub const PREDICATE_ASCII_ALPHA: u32 = 2;

/// Predicate id that passes ASCII decimal digit bytes.
pub const PREDICATE_ASCII_DIGIT: u32 = 1;

pub fn predicate_by_id(id: u32) -> Result<fn(u8) -> bool, DataMovementError> {
    match id {
        PREDICATE_NONZERO => Ok(|byte| byte != 0),
        PREDICATE_ASCII_DIGIT => Ok(|byte| byte.is_ascii_digit()),
        PREDICATE_ASCII_ALPHA => Ok(|byte| byte.is_ascii_alphabetic()),
        PREDICATE_EVEN => Ok(|byte| byte % 2 == 0),
        _ => Err(DataMovementError::UnknownPredicate { id }),
    }
}

/// Predicate id that passes even byte values.
pub const PREDICATE_EVEN: u32 = 3;

/// Predicate id that passes non-zero bytes.
pub const PREDICATE_NONZERO: u32 = 0;

// Unit tests.
// Unit tests extracted from `ops/data_movement/partition/kernel.rs`.

#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
    crate::ops::fixtures::run_committed_kats(
        include_str!("../fixtures/reference-vectors.toml"),
        |case| {
            if case.op.as_deref() != Some("partition") {
                return Ok(());
            }
            let actual = partition(
                &crate::ops::data_movement::hex_to_bytes(
                    case.input_hex.as_ref().ok_or("Fix: missing input_hex")?,
                )?,
                case.predicate_id.ok_or("Fix: missing predicate_id")?,
            );
            match (case.ok, actual) {
                (Some(true), Ok(output)) => {
                    assert_eq!(
                        output.pass,
                        crate::ops::data_movement::hex_to_bytes(
                            case.expected_pass_hex.as_ref().ok_or("Fix: missing expected_pass_hex")?,
                        )?
                    );
                    assert_eq!(
                        output.fail,
                        crate::ops::data_movement::hex_to_bytes(
                            case.expected_fail_hex.as_ref().ok_or("Fix: missing expected_fail_hex")?,
                        )?
                    );
                }
                (Some(false), Err(DataMovementError::UnknownPredicate { .. })) => {}
                (Some(false), Err(error)) => {
                    return Err(format!("Fix: expected UnknownPredicate, got {error}"));
                }
                (Some(true), Err(error)) => {
                    return Err(format!("Fix: valid partition vector was rejected: {error}"));
                }
                (Some(false), Ok(output)) => {
                    return Err(format!(
                        "Fix: invalid partition vector should fail, produced {output:?}"
                    ));
                }
                (None, _) => return Err("Fix: missing ok field".into()),
            }
            Ok(())
        },
    )
}