use std::path::{Path, PathBuf};
use crate::proof::comparator::ComparatorKind;
use crate::spec::types::{AltWgslSource, DataType, OpSignature, OpSpec, Strictness};
use crate::spec::{OverflowContract, SpecRow, SpecSource, Version};
use serde::Deserialize;
use vyre::ops::{Category as CoreCategory, Compose};
pub use crate::spec::law::AlgebraicLaw;
pub use crate::spec::types::{BoundaryValue, EquivalenceClass};
#[derive(Clone, Copy)]
struct Adapter {
cpu: fn(&[u8]) -> Vec<u8>,
wgsl: fn() -> String,
program: fn() -> vyre::ir::Program,
}
#[derive(Debug, Deserialize)]
struct KatDocument {
op_id: String,
golden: Vec<KatRow>,
}
#[derive(Debug, Deserialize)]
struct KatRow {
input: String,
expected: String,
reason: String,
}
pub const PRIMITIVE_ARCHETYPES: &[crate::spec::types::ArchetypeRef] =
&["A1", "A2", "A3", "A4", "A5", "A6", "A7"];
#[inline]
pub fn binary_u32_sig() -> OpSignature {
OpSignature {
inputs: vec![DataType::U32, DataType::U32],
output: DataType::U32,
}
}
#[inline]
pub fn unary_u32_sig() -> OpSignature {
OpSignature {
inputs: vec![DataType::U32],
output: DataType::U32,
}
}
#[must_use]
#[inline]
pub fn specs() -> Vec<OpSpec> {
vyre::ops::registry::known_op_ids()
.filter(|id| id.starts_with("primitive."))
.map(|id| {
spec_by_id(id).unwrap_or_else(|| {
panic!("missing primitive spec for {id}. Fix: add an adapter entry for this core registry op.")
})
})
.collect()
}
#[must_use]
#[inline]
pub fn spec_by_id(id: &str) -> Option<OpSpec> {
let core = vyre::ops::registry::lookup(id)?;
let adapter = adapter_for(id)?;
let id_static = core.id();
let signature = signature_from_core(core);
let laws = core.laws().to_vec();
let spec_table = load_spec_table(id_static, &signature);
let mut builder = OpSpec::builder(id_static)
.signature(signature.clone())
.cpu_fn(adapter.cpu)
.wgsl_fn(adapter.wgsl)
.category(category_from_core(core))
.laws(laws)
.strictness(Strictness::Strict)
.version(1)
.alt_wgsl_fns(Vec::new())
.declared_laws(Vec::<crate::spec::types::DeclaredLaw>::new())
.spec_table(spec_table)
.archetypes(PRIMITIVE_ARCHETYPES)
.mutation_sensitivity(&[])
.oracle_override(None)
.since_version(Version::V1_0)
.docs_path("")
.equivalence_classes(vec![EquivalenceClass::universal(
"core registry primitive domain",
)])
.boundary_values(boundaries_from_rows(spec_table))
.comparator(ComparatorKind::ExactMatch)
.ir_program(Some(adapter.program))
.expected_output_bytes(Some(signature.output.min_bytes().max(4)));
if core.laws().is_empty() {
builder = builder.no_algebraic_laws_reason(Some(
"core registry declares no algebraic laws for this primitive",
));
}
if signature_uses_integer(&signature) {
builder = builder.overflow_contract(overflow_contract(id_static));
}
Some(builder.expect("Fix: registry-backed primitive spec must satisfy the typestate builder"))
}
#[must_use]
#[inline]
pub fn kat_vectors() -> Vec<vyre_spec::KatVector> {
specs()
.into_iter()
.flat_map(|spec| {
load_kat_document(spec.id)
.golden
.into_iter()
.map(|row| vyre_spec::KatVector {
input: leak_bytes(decode_hex(spec.id, "input", &row.input)),
expected: leak_bytes(decode_hex(spec.id, "expected", &row.expected)),
source: leak_str(row.reason),
})
})
.collect()
}
fn signature_from_core(core: &vyre::ops::OpSpec) -> OpSignature {
let [output] = core.outputs() else {
panic!(
"{} declares {} outputs. Fix: primitive conform specs require exactly one output.",
core.id(),
core.outputs().len()
);
};
OpSignature {
inputs: core
.inputs()
.iter()
.cloned()
.map(convert_data_type)
.collect(),
output: convert_data_type(output.clone()),
}
}
fn convert_data_type(data_type: vyre::ir::DataType) -> DataType {
match data_type {
vyre::ir::DataType::U32 => DataType::U32,
vyre::ir::DataType::I32 => DataType::I32,
vyre::ir::DataType::U64 => DataType::U64,
vyre::ir::DataType::Vec2U32 => DataType::Vec2U32,
vyre::ir::DataType::Vec4U32 => DataType::Vec4U32,
vyre::ir::DataType::Bool => DataType::Bool,
vyre::ir::DataType::Bytes => DataType::Bytes,
vyre::ir::DataType::F32 => DataType::F32,
_ => {
panic!("unsupported primitive data type {data_type:?}. Fix: add conform type mapping.")
}
}
}
fn category_from_core(core: &vyre::ops::OpSpec) -> crate::enforce::category::Category {
match core.category() {
CoreCategory::A => crate::enforce::category::Category::A {
composition_of: vec![core.id()],
},
CoreCategory::C { .. } => match core.compose() {
Compose::Intrinsic(intrinsic) => crate::enforce::category::Category::C {
hardware: intrinsic.hardware(),
backend_availability: vyre_spec::BackendAvailabilityPredicate::new(|_| false),
},
Compose::Composition(_) => crate::enforce::category::Category::A {
composition_of: vec![core.id()],
},
_ => panic!(
"unsupported primitive compose mode for {}. Fix: add conform category mapping.",
core.id()
),
},
_ => panic!(
"unsupported primitive category for {}. Fix: add conform category mapping.",
core.id()
),
}
}
fn load_spec_table(id: &'static str, signature: &OpSignature) -> &'static [SpecRow] {
let rows = load_kat_document(id)
.golden
.into_iter()
.map(|row| {
let input = decode_hex(id, "input", &row.input);
let expected = decode_hex(id, "expected", &row.expected);
let inputs = split_inputs(id, signature, &input);
SpecRow::new(
leak_input_slices(inputs),
leak_bytes(expected),
leak_str(row.reason),
SpecSource::FromCorpus("rules/kat/primitive"),
)
})
.collect::<Vec<_>>();
if rows.is_empty() {
panic!("{id} has no KAT rows. Fix: add at least one [[golden]] vector.");
}
Box::leak(rows.into_boxed_slice())
}
fn load_kat_document(id: &str) -> KatDocument {
let path = kat_path(id);
let content = std::fs::read_to_string(&path).unwrap_or_else(|err| {
panic!(
"missing KAT TOML for {id} at {}: {err}. Fix: add rules/kat/primitive/<family>/<op>.toml.",
path.display()
)
});
let doc: KatDocument = toml::from_str(&content).unwrap_or_else(|err| {
panic!(
"invalid KAT TOML for {id} at {}: {err}. Fix: keep the primitive KAT schema valid.",
path.display()
)
});
if doc.op_id != id {
panic!(
"KAT TOML path {} declares op_id {} but walker expected {id}. Fix: correct op_id or move the file.",
path.display(),
doc.op_id
);
}
doc
}
fn kat_path(id: &str) -> PathBuf {
let rest = id.strip_prefix("primitive.").unwrap_or_else(|| {
panic!("{id} is not a primitive op id. Fix: call kat_path only for primitive registry ids.")
});
let mut parts = rest.split('.');
let family = parts.next().unwrap_or_else(|| {
panic!("{id} has no primitive family. Fix: use primitive.<family>.<op> ids.")
});
let op = parts.collect::<Vec<_>>().join("_");
Path::new(env!("CARGO_MANIFEST_DIR"))
.join("../rules/kat/primitive")
.join(family)
.join(format!("{op}.toml"))
}
fn decode_hex(id: &str, field: &str, hex_text: &str) -> Vec<u8> {
hex::decode(hex_text).unwrap_or_else(|err| {
panic!("{id} KAT {field} is not valid hex: {err}. Fix: use even-length lowercase hex.")
})
}
fn split_inputs(id: &str, signature: &OpSignature, input: &[u8]) -> Vec<&'static [u8]> {
let expected = signature.min_input_bytes();
if input.len() != expected {
panic!(
"{id} KAT input has {} bytes, expected {expected}. Fix: encode exactly one flat primitive invocation.",
input.len()
);
}
let mut offset = 0usize;
signature
.inputs
.iter()
.map(|ty| {
let width = ty.min_bytes();
let end = offset + width;
let bytes = leak_bytes(input[offset..end].to_vec());
offset = end;
bytes
})
.collect()
}
fn boundaries_from_rows(rows: &'static [SpecRow]) -> Vec<BoundaryValue> {
rows.iter()
.enumerate()
.map(|(index, row)| BoundaryValue {
label: leak_str(format!("kat_{index}")),
inputs: row
.inputs
.iter()
.map(|bytes| read_u32_prefix(bytes))
.collect(),
})
.collect()
}
fn cpu_by_id(id: &'static str, input: &[u8]) -> Vec<u8> {
let core = vyre::ops::registry::lookup(id).unwrap_or_else(|| {
panic!("missing core op {id}. Fix: register the op in vyre::ops::registry.")
});
let min_input = core.inputs().iter().map(|ty| ty.min_bytes()).sum::<usize>();
if input.len() < min_input {
let output_bytes: usize = core.outputs().iter().map(|ty| ty.min_bytes()).sum();
return vec![0u8; output_bytes];
}
let input = &input[..min_input];
let mut output = Vec::new();
match core.compose() {
Compose::Composition(build) => {
let program = build().with_entry_op_id(id);
vyre_reference::flat_cpu::run_flat(&program, input, &mut output).unwrap_or_else(
|err| panic!("{id} CPU reference failed: {err}. Fix: repair the core IR program."),
);
}
Compose::Intrinsic(intrinsic) => intrinsic.cpu_fn()(input, &mut output),
_ => panic!("unsupported primitive compose mode for {id}. Fix: add CPU adapter mapping."),
}
output
}
fn lower_core_op(id: &'static str) -> String {
let program = vyre::ops::registry::lookup_program(id).unwrap_or_else(|| {
panic!("missing core IR program for {id}. Fix: register the op in vyre::ops::registry.")
});
vyre::lower::wgsl::lower(&program).unwrap_or_else(|err| {
panic!("failed to lower core IR program for {id}: {err}. Fix: repair core WGSL lowering.")
})
}
fn program_by_id(id: &'static str) -> vyre::ir::Program {
vyre::ops::registry::lookup_program(id).unwrap_or_else(|| {
panic!("missing core IR program for {id}. Fix: register the op in vyre::ops::registry.")
})
}
fn signature_uses_integer(signature: &OpSignature) -> bool {
signature
.inputs
.iter()
.chain(std::iter::once(&signature.output))
.any(|ty| matches!(ty, DataType::U32 | DataType::I32 | DataType::U64))
}
fn overflow_contract(id: &str) -> OverflowContract {
if id.ends_with("_sat") {
OverflowContract::Saturating
} else {
OverflowContract::Wrapping
}
}
fn read_u32_prefix(bytes: &[u8]) -> u32 {
let mut padded = [0u8; 4];
let len = bytes.len().min(4);
padded[..len].copy_from_slice(&bytes[..len]);
u32::from_le_bytes(padded)
}
fn leak_bytes(bytes: Vec<u8>) -> &'static [u8] {
Box::leak(bytes.into_boxed_slice())
}
fn leak_input_slices(inputs: Vec<&'static [u8]>) -> &'static [&'static [u8]] {
Box::leak(inputs.into_boxed_slice())
}
fn leak_str(text: impl Into<String>) -> &'static str {
Box::leak(text.into().into_boxed_str())
}
fn category_a_sources(wgsl: fn() -> String) -> Vec<AltWgslSource> {
vec![("core-registry-wgsl", wgsl)]
}
macro_rules! primitive_adapter {
($module:ident, $id:literal) => {
pub mod $module {
use crate::OpSpec;
#[inline]
pub fn vyre_op() -> OpSpec {
super::spec_by_id($id)
.expect("Fix: primitive adapter id must exist in core registry.")
}
#[inline]
pub fn spec() -> OpSpec {
vyre_op()
}
#[inline]
pub fn spec_layer_source() -> crate::spec::ops::add::AddSpecSource {
crate::spec::ops::add::AddSpecSource::new(cpu, super::category_a_sources)
}
#[inline]
pub(crate) fn cpu_fn() -> fn(&[u8]) -> Vec<u8> {
cpu
}
fn cpu(input: &[u8]) -> Vec<u8> {
super::cpu_by_id($id, input)
}
pub(super) fn wgsl() -> String {
super::lower_core_op($id)
}
pub(super) fn program() -> vyre::ir::Program {
super::program_by_id($id)
}
}
};
}
primitive_adapter!(abs, "primitive.math.abs");
primitive_adapter!(abs_diff, "primitive.math.abs_diff");
primitive_adapter!(add, "primitive.math.add");
primitive_adapter!(add_sat, "primitive.math.add_sat");
primitive_adapter!(and, "primitive.bitwise.and");
primitive_adapter!(clamp, "primitive.math.clamp");
primitive_adapter!(clz, "primitive.bitwise.clz");
primitive_adapter!(ctz, "primitive.bitwise.ctz");
primitive_adapter!(div, "primitive.math.div");
primitive_adapter!(eq, "primitive.compare.eq");
primitive_adapter!(extract_bits, "primitive.bitwise.extract_bits");
primitive_adapter!(f32_abs, "primitive.float.f32_abs");
primitive_adapter!(f32_add, "primitive.float.f32_add");
primitive_adapter!(f32_cos, "primitive.float.f32_cos");
primitive_adapter!(f32_div, "primitive.float.f32_div");
primitive_adapter!(f32_mul, "primitive.float.f32_mul");
primitive_adapter!(f32_neg, "primitive.float.f32_neg");
primitive_adapter!(f32_sin, "primitive.float.f32_sin");
primitive_adapter!(f32_sqrt, "primitive.float.f32_sqrt");
primitive_adapter!(f32_sub, "primitive.float.f32_sub");
primitive_adapter!(gcd, "primitive.math.gcd");
primitive_adapter!(ge, "primitive.compare.ge");
primitive_adapter!(gt, "primitive.compare.gt");
primitive_adapter!(insert_bits, "primitive.bitwise.insert_bits");
primitive_adapter!(lcm, "primitive.math.lcm");
primitive_adapter!(le, "primitive.compare.le");
primitive_adapter!(logical_not, "primitive.compare.logical_not");
primitive_adapter!(lt, "primitive.compare.lt");
primitive_adapter!(max, "primitive.math.max");
primitive_adapter!(min, "primitive.math.min");
primitive_adapter!(mod_op, "primitive.math.mod");
primitive_adapter!(mul, "primitive.math.mul");
primitive_adapter!(ne, "primitive.compare.ne");
primitive_adapter!(neg, "primitive.math.neg");
primitive_adapter!(negate, "primitive.math.negate");
primitive_adapter!(not, "primitive.bitwise.not");
primitive_adapter!(or, "primitive.bitwise.or");
primitive_adapter!(popcount, "primitive.bitwise.popcount");
primitive_adapter!(popcount_sw, "primitive.bitwise.popcount_sw");
primitive_adapter!(reverse_bits, "primitive.bitwise.reverse_bits");
primitive_adapter!(rotl, "primitive.bitwise.rotl");
primitive_adapter!(rotr, "primitive.bitwise.rotr");
primitive_adapter!(select, "primitive.compare.select");
primitive_adapter!(shl, "primitive.bitwise.shl");
primitive_adapter!(shr, "primitive.bitwise.shr");
primitive_adapter!(sign, "primitive.math.sign");
primitive_adapter!(sub, "primitive.math.sub");
primitive_adapter!(sub_sat, "primitive.math.sub_sat");
primitive_adapter!(xor, "primitive.bitwise.xor");
fn adapter_for(id: &str) -> Option<Adapter> {
let adapter = match id {
"primitive.math.abs" => Adapter {
cpu: abs::cpu_fn(),
wgsl: abs::wgsl,
program: abs::program,
},
"primitive.math.abs_diff" => Adapter {
cpu: abs_diff::cpu_fn(),
wgsl: abs_diff::wgsl,
program: abs_diff::program,
},
"primitive.math.add" => Adapter {
cpu: add::cpu_fn(),
wgsl: add::wgsl,
program: add::program,
},
"primitive.math.add_sat" => Adapter {
cpu: add_sat::cpu_fn(),
wgsl: add_sat::wgsl,
program: add_sat::program,
},
"primitive.bitwise.and" => Adapter {
cpu: and::cpu_fn(),
wgsl: and::wgsl,
program: and::program,
},
"primitive.math.clamp" => Adapter {
cpu: clamp::cpu_fn(),
wgsl: clamp::wgsl,
program: clamp::program,
},
"primitive.bitwise.clz" => Adapter {
cpu: clz::cpu_fn(),
wgsl: clz::wgsl,
program: clz::program,
},
"primitive.bitwise.ctz" => Adapter {
cpu: ctz::cpu_fn(),
wgsl: ctz::wgsl,
program: ctz::program,
},
"primitive.math.div" => Adapter {
cpu: div::cpu_fn(),
wgsl: div::wgsl,
program: div::program,
},
"primitive.compare.eq" => Adapter {
cpu: eq::cpu_fn(),
wgsl: eq::wgsl,
program: eq::program,
},
"primitive.bitwise.extract_bits" => Adapter {
cpu: extract_bits::cpu_fn(),
wgsl: extract_bits::wgsl,
program: extract_bits::program,
},
"primitive.float.f32_abs" => Adapter {
cpu: f32_abs::cpu_fn(),
wgsl: f32_abs::wgsl,
program: f32_abs::program,
},
"primitive.float.f32_add" => Adapter {
cpu: f32_add::cpu_fn(),
wgsl: f32_add::wgsl,
program: f32_add::program,
},
"primitive.float.f32_cos" => Adapter {
cpu: f32_cos::cpu_fn(),
wgsl: f32_cos::wgsl,
program: f32_cos::program,
},
"primitive.float.f32_div" => Adapter {
cpu: f32_div::cpu_fn(),
wgsl: f32_div::wgsl,
program: f32_div::program,
},
"primitive.float.f32_mul" => Adapter {
cpu: f32_mul::cpu_fn(),
wgsl: f32_mul::wgsl,
program: f32_mul::program,
},
"primitive.float.f32_neg" => Adapter {
cpu: f32_neg::cpu_fn(),
wgsl: f32_neg::wgsl,
program: f32_neg::program,
},
"primitive.float.f32_sin" => Adapter {
cpu: f32_sin::cpu_fn(),
wgsl: f32_sin::wgsl,
program: f32_sin::program,
},
"primitive.float.f32_sqrt" => Adapter {
cpu: f32_sqrt::cpu_fn(),
wgsl: f32_sqrt::wgsl,
program: f32_sqrt::program,
},
"primitive.float.f32_sub" => Adapter {
cpu: f32_sub::cpu_fn(),
wgsl: f32_sub::wgsl,
program: f32_sub::program,
},
"primitive.math.gcd" => Adapter {
cpu: gcd::cpu_fn(),
wgsl: gcd::wgsl,
program: gcd::program,
},
"primitive.compare.ge" => Adapter {
cpu: ge::cpu_fn(),
wgsl: ge::wgsl,
program: ge::program,
},
"primitive.compare.gt" => Adapter {
cpu: gt::cpu_fn(),
wgsl: gt::wgsl,
program: gt::program,
},
"primitive.bitwise.insert_bits" => Adapter {
cpu: insert_bits::cpu_fn(),
wgsl: insert_bits::wgsl,
program: insert_bits::program,
},
"primitive.math.lcm" => Adapter {
cpu: lcm::cpu_fn(),
wgsl: lcm::wgsl,
program: lcm::program,
},
"primitive.compare.le" => Adapter {
cpu: le::cpu_fn(),
wgsl: le::wgsl,
program: le::program,
},
"primitive.compare.logical_not" => Adapter {
cpu: logical_not::cpu_fn(),
wgsl: logical_not::wgsl,
program: logical_not::program,
},
"primitive.compare.lt" => Adapter {
cpu: lt::cpu_fn(),
wgsl: lt::wgsl,
program: lt::program,
},
"primitive.math.max" => Adapter {
cpu: max::cpu_fn(),
wgsl: max::wgsl,
program: max::program,
},
"primitive.math.min" => Adapter {
cpu: min::cpu_fn(),
wgsl: min::wgsl,
program: min::program,
},
"primitive.math.mod" => Adapter {
cpu: mod_op::cpu_fn(),
wgsl: mod_op::wgsl,
program: mod_op::program,
},
"primitive.math.mul" => Adapter {
cpu: mul::cpu_fn(),
wgsl: mul::wgsl,
program: mul::program,
},
"primitive.compare.ne" => Adapter {
cpu: ne::cpu_fn(),
wgsl: ne::wgsl,
program: ne::program,
},
"primitive.math.neg" => Adapter {
cpu: neg::cpu_fn(),
wgsl: neg::wgsl,
program: neg::program,
},
"primitive.math.negate" => Adapter {
cpu: negate::cpu_fn(),
wgsl: negate::wgsl,
program: negate::program,
},
"primitive.bitwise.not" => Adapter {
cpu: not::cpu_fn(),
wgsl: not::wgsl,
program: not::program,
},
"primitive.bitwise.or" => Adapter {
cpu: or::cpu_fn(),
wgsl: or::wgsl,
program: or::program,
},
"primitive.bitwise.popcount" => Adapter {
cpu: popcount::cpu_fn(),
wgsl: popcount::wgsl,
program: popcount::program,
},
"primitive.bitwise.popcount_sw" => Adapter {
cpu: popcount_sw::cpu_fn(),
wgsl: popcount_sw::wgsl,
program: popcount_sw::program,
},
"primitive.bitwise.reverse_bits" => Adapter {
cpu: reverse_bits::cpu_fn(),
wgsl: reverse_bits::wgsl,
program: reverse_bits::program,
},
"primitive.bitwise.rotl" => Adapter {
cpu: rotl::cpu_fn(),
wgsl: rotl::wgsl,
program: rotl::program,
},
"primitive.bitwise.rotr" => Adapter {
cpu: rotr::cpu_fn(),
wgsl: rotr::wgsl,
program: rotr::program,
},
"primitive.compare.select" => Adapter {
cpu: select::cpu_fn(),
wgsl: select::wgsl,
program: select::program,
},
"primitive.bitwise.shl" => Adapter {
cpu: shl::cpu_fn(),
wgsl: shl::wgsl,
program: shl::program,
},
"primitive.bitwise.shr" => Adapter {
cpu: shr::cpu_fn(),
wgsl: shr::wgsl,
program: shr::program,
},
"primitive.math.sign" => Adapter {
cpu: sign::cpu_fn(),
wgsl: sign::wgsl,
program: sign::program,
},
"primitive.math.sub" => Adapter {
cpu: sub::cpu_fn(),
wgsl: sub::wgsl,
program: sub::program,
},
"primitive.math.sub_sat" => Adapter {
cpu: sub_sat::cpu_fn(),
wgsl: sub_sat::wgsl,
program: sub_sat::program,
},
"primitive.bitwise.xor" => Adapter {
cpu: xor::cpu_fn(),
wgsl: xor::wgsl,
program: xor::program,
},
_ => return None,
};
Some(adapter)
}
#[cfg(test)]
mod tests {
#[test]
fn every_core_registry_primitive_has_a_kat_backed_spec() {
let specs = super::specs();
let expected = vyre::ops::registry::known_op_ids()
.filter(|id| id.starts_with("primitive."))
.count();
assert_eq!(specs.len(), expected);
for spec in specs {
assert!(!spec.spec_table.is_empty(), "{} has no KAT rows", spec.id);
for row in spec.spec_table {
assert_eq!((spec.cpu_fn)(row.inputs.concat().as_slice()), row.expected);
}
}
}
}