use vyre::ir::{self, DataType, Program};
use vyre::lower::wgsl;
use vyre::ops::compression::{lz4::Lz4Decompress, zstd::ZstdDecompress};
use vyre::ops::decode::{
base64::Base64Decode, hex::HexDecode, unicode::UnicodeDecode, url::UrlDecode,
};
use vyre::ops::graph::{bfs::Bfs, dfs::Dfs, reachability::ReachabilityOp};
use vyre::ops::hash::{crc32, entropy, fnv1a32, rolling};
use vyre::ops::scan::PrefixSumInclusiveU32;
use vyre::ops::{AlgebraicLaw, Category, Compose, OpSpec};
fn assert_category_a(spec: &OpSpec) -> Program {
assert!(matches!(spec.category(), Category::A));
assert!(matches!(spec.compose(), Compose::Composition(_)));
let program = spec
.program()
.expect("Category A specs must build programs");
assert!(
!program.entry().is_empty(),
"{} must not return an empty program",
spec.id()
);
program
}
fn assert_validates_and_lowers(spec: &OpSpec) {
let program = assert_category_a(spec);
let errors = ir::validate(&program);
assert!(
errors.is_empty(),
"{} must validate cleanly: {errors:?}",
spec.id()
);
wgsl::lower(&program).unwrap_or_else(|err| panic!("{} must lower: {err}", spec.id()));
}
fn assert_lowers(spec: &OpSpec) {
let program = assert_category_a(spec);
wgsl::lower(&program).unwrap_or_else(|err| panic!("{} must lower: {err}", spec.id()));
}
fn assert_laws(spec: &OpSpec, laws: &[AlgebraicLaw]) {
for law in laws {
assert!(
spec.laws().contains(law),
"{} must declare {law:?}",
spec.id()
);
}
}
fn assert_has_bounded_law(spec: &OpSpec) {
assert!(
spec.laws()
.iter()
.any(|law| matches!(law, AlgebraicLaw::Bounded { .. })),
"{} must declare a bounded law",
spec.id()
);
}
#[test]
fn graph_bfs_and_reachability_are_complete_category_a_programs() {
assert_validates_and_lowers(&Bfs::SPEC);
assert_validates_and_lowers(&ReachabilityOp::SPEC);
}
#[test]
fn graph_dfs_is_complete_category_a_program() {
assert_validates_and_lowers(&Dfs::SPEC);
}
#[test]
fn decode_ops_are_non_empty_category_a_programs_that_lower() {
for spec in [
&Base64Decode::SPEC,
&HexDecode::SPEC,
&UrlDecode::SPEC,
&UnicodeDecode::SPEC,
] {
assert_eq!(spec.inputs(), &[DataType::Bytes]);
assert_lowers(spec);
assert_laws(spec, &[AlgebraicLaw::Bounded { lo: 0, hi: 255 }]);
assert!(!spec.inlinable(), "{} must not be inlinable", spec.id());
}
}
#[test]
fn hash_ops_are_non_empty_category_a_programs_that_lower() {
for spec in [&fnv1a32::SPEC, &crc32::SPEC, &rolling::SPEC, &entropy::SPEC] {
assert_eq!(spec.inputs(), &[DataType::Bytes]);
assert_lowers(spec);
assert_has_bounded_law(spec);
assert!(!spec.inlinable(), "{} must not be inlinable", spec.id());
}
}
#[test]
fn scan_prefix_sum_inclusive_validates_lowers_and_declares_identity() {
assert_eq!(PrefixSumInclusiveU32::SPEC.inputs(), &[DataType::U32]);
assert_validates_and_lowers(&PrefixSumInclusiveU32::SPEC);
assert_laws(
&PrefixSumInclusiveU32::SPEC,
&[AlgebraicLaw::Identity { element: 0 }],
);
assert!(
!PrefixSumInclusiveU32::SPEC.inlinable(),
"{} must not be inlinable",
PrefixSumInclusiveU32::SPEC.id()
);
}
#[test]
fn compression_ops_validate_lower_and_declare_bounded_output() {
for spec in [&Lz4Decompress::SPEC, &ZstdDecompress::SPEC] {
assert_eq!(spec.inputs(), &[DataType::U32, DataType::U32]);
assert_validates_and_lowers(spec);
assert_laws(
spec,
&[AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}],
);
assert!(!spec.inlinable(), "{} must not be inlinable", spec.id());
}
}