vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::DataType;
use crate::ops::Backend;
use crate::ops::{IntrinsicDescriptor, OpSpec};

// WGSL lowering source for `hash.sha256`.

/// Shared WGSL helpers and SHA-256 compression primitives.
pub const WGSL: &str = concat!(
    include_str!("wgsl_shaders/words.wgsl"),
    "\n",
    include_str!("wgsl_shaders/sha256.wgsl"),
);

pub const INPUTS: &[DataType] = &[DataType::Bytes];

pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];

pub const OUTPUTS: &[DataType] = &[const { DataType::U32 }; 8];

/// Compute SHA-256 and return eight big-endian digest words.
#[must_use]
pub fn sha256(input: &[u8]) -> [u32; 8] {
    crate::ops::hash::reference::sha256::sha256_words(input)
}

/// Declarative operation specification for `hash.sha256`.
pub const SPEC: OpSpec = OpSpec::intrinsic(
    "hash.sha256",
    INPUTS,
    OUTPUTS,
    LAWS,
    wgsl_only,
    IntrinsicDescriptor::new(
        "hash.sha256.wgsl",
        "wgsl-sha256-compress",
        crate::ops::hash::cpu_refs::sha256,
    ),
);

pub fn wgsl_only(backend: &Backend) -> bool {
    matches!(backend, Backend::Wgsl)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_sha256_boundaries() {
        let empty_input = b"";
        let expected_empty: [u32; 8] = [
            0xe3b0c442, 0x98fc1c14, 0x9afbf4c8, 0x996fb924, 0x27ae41e4, 0x649b934c, 0xa495991b,
            0x7852b855,
        ];
        assert_eq!(sha256(empty_input), expected_empty);

        let abc_input = b"abc";
        let expected_abc: [u32; 8] = [
            0xba7816bf, 0x8f01cfea, 0x414140de, 0x5dae2223, 0xb00361a3, 0x96177a9c, 0xb410ff61,
            0xf20015ad,
        ];
        assert_eq!(sha256(abc_input), expected_abc);

        let block_input = b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq";
        let expected_block: [u32; 8] = [
            0x248d6a61, 0xd20638b8, 0xe5c02693, 0x0c3e6039, 0xa33ce459, 0x64ff2167, 0xf6ecedd4,
            0x19db06c1,
        ];
        assert_eq!(sha256(block_input), expected_block);

        let multi_block_input = b"abcdefghbcdefghicdefghijdefghijkefghijklfghijklmghijklmnhijklmnoijklmnopjklmnopqklmnopqrlmnopqrsmnopqrstnopqrstu";
        let expected_multi: [u32; 8] = [
            0xcf5b16a7, 0x78af8380, 0x036ce59e, 0x7b049237, 0x0b249b11, 0xe8f07a51, 0xafac4503,
            0x7afee9d1,
        ];
        assert_eq!(sha256(multi_block_input), expected_multi);
    }
}