use crate::spec::types::{
AltWgslSource, ArchetypeRef, ComparatorKind, DataType, DeclaredLaw, EquivalenceClass,
MutationClass, OpSignature, OpSpec, OracleKind, Strictness, Version,
};
use crate::spec::{AlgebraicLaw, Verification};
use vyre_spec::Category;
pub const ID: &str = "primitive.math.add";
pub const DOCS_PATH: &str = "docs/ops/primitive/add.md";
pub const DECLARED_ARCHETYPES: &[ArchetypeRef] = &["A1", "A2", "A3", "A5"];
pub const DECLARED_LAWS: &[DeclaredLaw] = &[
DeclaredLaw {
law: AlgebraicLaw::Commutative,
verified_by: Verification::ExhaustiveU8,
},
DeclaredLaw {
law: AlgebraicLaw::Associative,
verified_by: Verification::ExhaustiveU8,
},
DeclaredLaw {
law: AlgebraicLaw::Identity { element: 0 },
verified_by: Verification::ExhaustiveU8,
},
];
pub const MUTATION_SENSITIVITY: &[MutationClass] = &[
MutationClass::ArithmeticMutations,
MutationClass::ConstantMutations,
];
#[derive(Clone, Copy)]
pub struct AddSpecSource {
cpu_fn: fn(&[u8]) -> Vec<u8>,
category_a_sources: fn(fn() -> String) -> Vec<AltWgslSource>,
}
impl AddSpecSource {
#[must_use]
pub const fn new(
cpu_fn: fn(&[u8]) -> Vec<u8>,
category_a_sources: fn(fn() -> String) -> Vec<AltWgslSource>,
) -> Self {
Self {
cpu_fn,
category_a_sources,
}
}
#[must_use]
pub const fn cpu_fn(self) -> fn(&[u8]) -> Vec<u8> {
self.cpu_fn
}
#[must_use]
#[inline]
pub fn category_a_sources(self, wgsl_fn: fn() -> String) -> Vec<AltWgslSource> {
(self.category_a_sources)(wgsl_fn)
}
}
#[inline]
pub fn build(source: AddSpecSource) -> OpSpec {
OpSpec::builder(ID)
.signature(OpSignature {
inputs: vec![DataType::U32, DataType::U32],
output: DataType::U32,
})
.cpu_fn(source.cpu_fn())
.wgsl_fn(wgsl)
.category(Category::A {
composition_of: vec![ID],
})
.laws(vec![
AlgebraicLaw::Commutative,
AlgebraicLaw::Associative,
AlgebraicLaw::Identity { element: 0 },
])
.strictness(Strictness::Strict)
.version(1)
.alt_wgsl_fns(source.category_a_sources(wgsl))
.declared_laws(DECLARED_LAWS)
.spec_table(crate::spec::tables::add::ROWS)
.archetypes(DECLARED_ARCHETYPES)
.mutation_sensitivity(MUTATION_SENSITIVITY)
.oracle_override(None::<OracleKind>)
.since_version(Version::V1_0)
.docs_path(DOCS_PATH)
.equivalence_classes(vec![EquivalenceClass::universal("all u32 pairs")])
.comparator(ComparatorKind::ExactMatch)
.build()
.expect("registry invariant violated. Fix: repair primitive.math.add spec metadata.")
}
#[inline]
pub fn spec(source: AddSpecSource) -> OpSpec {
build(source)
}
fn wgsl() -> String {
"fn vyre_op(index: u32, input_len: u32) -> u32 { return input.data[0u] + input.data[1u]; }"
.to_string()
}