vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use vyre::ir::{self, DataType, Program};
use vyre::lower::wgsl;
use vyre::ops::compression::{lz4::Lz4Decompress, zstd::ZstdDecompress};
use vyre::ops::decode::{
    base64::Base64Decode, hex::HexDecode, unicode::UnicodeDecode, url::UrlDecode,
};
use vyre::ops::graph::{bfs::Bfs, dfs::Dfs, reachability::ReachabilityOp};
use vyre::ops::hash::{crc32, entropy, fnv1a32, rolling};
use vyre::ops::scan::PrefixSumInclusiveU32;
use vyre::ops::{AlgebraicLaw, Category, Compose, OpSpec};

fn assert_category_a(spec: &OpSpec) -> Program {
    assert!(matches!(spec.category(), Category::A));
    assert!(matches!(spec.compose(), Compose::Composition(_)));
    let program = spec
        .program()
        .expect("Category A specs must build programs");
    assert!(
        !program.entry().is_empty(),
        "{} must not return an empty program",
        spec.id()
    );
    program
}

fn assert_validates_and_lowers(spec: &OpSpec) {
    let program = assert_category_a(spec);
    let errors = ir::validate(&program);
    assert!(
        errors.is_empty(),
        "{} must validate cleanly: {errors:?}",
        spec.id()
    );
    wgsl::lower(&program).unwrap_or_else(|err| panic!("{} must lower: {err}", spec.id()));
}

fn assert_lowers(spec: &OpSpec) {
    let program = assert_category_a(spec);
    wgsl::lower(&program).unwrap_or_else(|err| panic!("{} must lower: {err}", spec.id()));
}

fn assert_laws(spec: &OpSpec, laws: &[AlgebraicLaw]) {
    for law in laws {
        assert!(
            spec.laws().contains(law),
            "{} must declare {law:?}",
            spec.id()
        );
    }
}

fn assert_has_bounded_law(spec: &OpSpec) {
    assert!(
        spec.laws()
            .iter()
            .any(|law| matches!(law, AlgebraicLaw::Bounded { .. })),
        "{} must declare a bounded law",
        spec.id()
    );
}

#[test]
fn graph_bfs_and_reachability_are_complete_category_a_programs() {
    assert_validates_and_lowers(&Bfs::SPEC);
    assert_validates_and_lowers(&ReachabilityOp::SPEC);
}

#[test]
fn graph_dfs_is_complete_category_a_program() {
    assert_validates_and_lowers(&Dfs::SPEC);
}

#[test]
fn decode_ops_are_non_empty_category_a_programs_that_lower() {
    for spec in [
        &Base64Decode::SPEC,
        &HexDecode::SPEC,
        &UrlDecode::SPEC,
        &UnicodeDecode::SPEC,
    ] {
        assert_eq!(spec.inputs(), &[DataType::Bytes]);
        assert_lowers(spec);
        assert_laws(spec, &[AlgebraicLaw::Bounded { lo: 0, hi: 255 }]);
        assert!(!spec.inlinable(), "{} must not be inlinable", spec.id());
    }
}

#[test]
fn hash_ops_are_non_empty_category_a_programs_that_lower() {
    for spec in [&fnv1a32::SPEC, &crc32::SPEC, &rolling::SPEC, &entropy::SPEC] {
        assert_eq!(spec.inputs(), &[DataType::Bytes]);
        assert_lowers(spec);
        assert_has_bounded_law(spec);
        assert!(!spec.inlinable(), "{} must not be inlinable", spec.id());
    }
}

#[test]
fn scan_prefix_sum_inclusive_validates_lowers_and_declares_identity() {
    assert_eq!(PrefixSumInclusiveU32::SPEC.inputs(), &[DataType::U32]);
    assert_validates_and_lowers(&PrefixSumInclusiveU32::SPEC);
    assert_laws(
        &PrefixSumInclusiveU32::SPEC,
        &[AlgebraicLaw::Identity { element: 0 }],
    );
    assert!(
        !PrefixSumInclusiveU32::SPEC.inlinable(),
        "{} must not be inlinable",
        PrefixSumInclusiveU32::SPEC.id()
    );
}

#[test]
fn compression_ops_validate_lower_and_declare_bounded_output() {
    for spec in [&Lz4Decompress::SPEC, &ZstdDecompress::SPEC] {
        assert_eq!(spec.inputs(), &[DataType::U32, DataType::U32]);
        assert_validates_and_lowers(spec);
        assert_laws(
            spec,
            &[AlgebraicLaw::Bounded {
                lo: 0,
                hi: u32::MAX,
            }],
        );
        assert!(!spec.inlinable(), "{} must not be inlinable", spec.id());
    }
}