use crate::proof::algebra::checker::support::{
call_binary, call_unary, engine_failure_violation, missing_companion_violation,
};
use crate::proof::algebra::{LawResult, VerificationLevel};
use crate::spec::law::{AlgebraicLaw, LawViolation};
use crate::spec::program::program_for_spec_input;
use crate::spec::types::{DataType, OpSpec};
#[inline]
pub fn verify_gpu_laws_witnessed(
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
witness_count: u64,
) -> Vec<LawResult> {
let arity = arity(op);
op.laws
.iter()
.map(|law| verify_one_gpu_law(backend, op, law, arity, witness_count))
.collect()
}
#[derive(Clone, Copy)]
enum Arity {
Unary,
Binary,
Unsupported,
}
fn arity(op: &OpSpec) -> Arity {
match op.signature.inputs.as_slice() {
[DataType::U32] if op.signature.output == DataType::U32 => Arity::Unary,
[DataType::U32, DataType::U32] if op.signature.output == DataType::U32 => Arity::Binary,
_ => Arity::Unsupported,
}
}
fn verify_one_gpu_law(
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
law: &AlgebraicLaw,
arity: Arity,
witness_count: u64,
) -> LawResult {
let (cases_tested, violation) = match arity {
Arity::Binary => check_binary_gpu_law(backend, op, law, witness_count),
Arity::Unary => check_unary_gpu_law(backend, op, law, witness_count),
Arity::Unsupported => (0, None),
};
LawResult {
op_id: op.id.to_string(),
law_name: law.name().to_string(),
level: VerificationLevel::GpuWitnessedU32 {
count: witness_count,
},
cases_tested,
violation,
}
}
fn check_binary_gpu_law(
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
law: &AlgebraicLaw,
count: u64,
) -> (u64, Option<LawViolation>) {
let mut rng = simple_rng(op.id, law.name());
match law {
AlgebraicLaw::Commutative => {
for i in 0..count {
let a = rng.next_u32();
let b = rng.next_u32();
let ab = match gpu_binary(backend, op, a, b, "commutative f(a,b)") {
Ok(value) => value,
Err(err) => return (i + 1, Some(err)),
};
let ba = match gpu_binary(backend, op, b, a, "commutative f(b,a)") {
Ok(value) => value,
Err(err) => return (i + 1, Some(err)),
};
if ab != ba {
return (
i + 1,
Some(violation(op.id, "gpu-commutative", a, b, 0, ab, ba)),
);
}
}
(count, None)
}
AlgebraicLaw::Associative => {
for i in 0..count {
let a = rng.next_u32();
let b = rng.next_u32();
let c = rng.next_u32();
let ab = match gpu_binary(backend, op, a, b, "associative f(a,b)") {
Ok(value) => value,
Err(err) => return (i + 1, Some(err)),
};
let bc = match gpu_binary(backend, op, b, c, "associative f(b,c)") {
Ok(value) => value,
Err(err) => return (i + 1, Some(err)),
};
let ab_c = match gpu_binary(backend, op, ab, c, "associative f(f(a,b),c)") {
Ok(value) => value,
Err(err) => return (i + 1, Some(err)),
};
let a_bc = match gpu_binary(backend, op, a, bc, "associative f(a,f(b,c))") {
Ok(value) => value,
Err(err) => return (i + 1, Some(err)),
};
if ab_c != a_bc {
return (
i + 1,
Some(violation(op.id, "gpu-associative", a, b, c, ab_c, a_bc)),
);
}
}
(count, None)
}
AlgebraicLaw::Identity { element } => {
for i in 0..count {
let a = rng.next_u32();
if let Some(err) =
check_binary_equals(backend, op, a, *element, a, "identity(right)")
{
return (i + 1, Some(err));
}
if let Some(err) =
check_binary_equals(backend, op, *element, a, a, "identity(left)")
{
return (i + 1, Some(err));
}
}
(count, None)
}
AlgebraicLaw::SelfInverse { result } => {
for i in 0..count {
let a = rng.next_u32();
if let Some(err) = check_binary_equals(backend, op, a, a, *result, "self-inverse") {
return (i + 1, Some(err));
}
}
(count, None)
}
AlgebraicLaw::Idempotent => {
for i in 0..count {
let a = rng.next_u32();
if let Some(err) = check_binary_equals(backend, op, a, a, a, "idempotent") {
return (i + 1, Some(err));
}
}
(count, None)
}
AlgebraicLaw::Absorbing { element } => {
for i in 0..count {
let a = rng.next_u32();
if let Some(err) =
check_binary_equals(backend, op, a, *element, *element, "absorbing(right)")
{
return (i + 1, Some(err));
}
if let Some(err) =
check_binary_equals(backend, op, *element, a, *element, "absorbing(left)")
{
return (i + 1, Some(err));
}
}
(count, None)
}
AlgebraicLaw::Bounded { lo, hi } => {
for i in 0..count {
let a = rng.next_u32();
let b = rng.next_u32();
let ab = match gpu_binary(backend, op, a, b, "bounded") {
Ok(value) => value,
Err(err) => return (i + 1, Some(err)),
};
if ab < *lo || ab > *hi {
return (
i + 1,
Some(violation(op.id, "gpu-bounded", a, b, 0, ab, *lo)),
);
}
}
(count, None)
}
AlgebraicLaw::ZeroProduct { holds: true } => {
for i in 0..count {
let a = rng.next_u32();
let b = rng.next_u32();
let ab = match gpu_binary(backend, op, a, b, "zero-product(true)") {
Ok(value) => value,
Err(err) => return (i + 1, Some(err)),
};
if ab == 0 && a != 0 && b != 0 {
return (
i + 1,
Some(violation(op.id, "gpu-zero-product(true)", a, b, 0, ab, 1)),
);
}
}
(count, None)
}
AlgebraicLaw::ZeroProduct { holds: false } => {
let pairs = [
(2, 0x8000_0000),
(0x8000_0000, 2),
(0x0001_0000, 0x0001_0000),
(0xFFFF_0000, 0x0001_0000),
];
for (i, (a, b)) in pairs.into_iter().enumerate() {
let ab = match gpu_binary(backend, op, a, b, "zero-product(false)") {
Ok(value) => value,
Err(err) => return (i as u64 + 1, Some(err)),
};
if ab == 0 {
return (i as u64 + 1, None);
}
}
for i in 0..count {
let a = rng.next_u32();
let b = rng.next_u32();
let ab = match gpu_binary(backend, op, a, b, "zero-product(false)") {
Ok(value) => value,
Err(err) => return (pairs.len() as u64 + i + 1, Some(err)),
};
if a != 0 && b != 0 && ab == 0 {
return (pairs.len() as u64 + i + 1, None);
}
}
let cpu_zp = match call_binary(op.cpu_fn, 2, 0x8000_0000) {
Ok(v) => v,
Err(e) => {
return (
pairs.len() as u64 + count,
Some(engine_failure_violation(op.id, e)),
)
}
};
(
pairs.len() as u64 + count,
Some(violation(
op.id,
"gpu-zero-product(false)",
2,
0x8000_0000,
0,
cpu_zp,
0,
)),
)
}
AlgebraicLaw::LeftIdentity { element } => {
for i in 0..count {
let a = rng.next_u32();
if let Some(err) = check_binary_equals(backend, op, *element, a, a, "left-identity")
{
return (i + 1, Some(err));
}
}
(count, None)
}
AlgebraicLaw::RightIdentity { element } => {
for i in 0..count {
let a = rng.next_u32();
if let Some(err) =
check_binary_equals(backend, op, a, *element, a, "right-identity")
{
return (i + 1, Some(err));
}
}
(count, None)
}
AlgebraicLaw::LeftAbsorbing { element } => {
for i in 0..count {
let a = rng.next_u32();
if let Some(err) =
check_binary_equals(backend, op, *element, a, *element, "left-absorbing")
{
return (i + 1, Some(err));
}
}
(count, None)
}
AlgebraicLaw::RightAbsorbing { element } => {
for i in 0..count {
let a = rng.next_u32();
if let Some(err) =
check_binary_equals(backend, op, a, *element, *element, "right-absorbing")
{
return (i + 1, Some(err));
}
}
(count, None)
}
AlgebraicLaw::DeMorgan { inner_op, dual_op } => {
let specs = crate::spec::op_registry::compiled_specs();
let inner_fn = specs.iter().find(|s| s.id == *inner_op).map(|s| s.cpu_fn);
let dual_fn = specs.iter().find(|s| s.id == *dual_op).map(|s| s.cpu_fn);
let (Some(_inner), Some(dual)) = (inner_fn, dual_fn) else {
return (0, Some(unimplemented_gpu_law(op.id, law))); };
for i in 0..count {
let a = rng.next_u32();
let b = rng.next_u32();
let a_inner_b = match gpu_binary(backend, op, a, b, "demorgan inner") {
Ok(v) => v,
Err(e) => return (i + 1, Some(e)),
};
let not_ab = match call_unary(op.cpu_fn, a_inner_b) {
Ok(v) => v,
Err(e) => return (i + 1, Some(engine_failure_violation(op.id, e))),
}; let not_a = match call_unary(op.cpu_fn, a) {
Ok(v) => v,
Err(e) => return (i + 1, Some(engine_failure_violation(op.id, e))),
};
let not_b = match call_unary(op.cpu_fn, b) {
Ok(v) => v,
Err(e) => return (i + 1, Some(engine_failure_violation(op.id, e))),
};
let na_dual_nb = match call_binary(dual, not_a, not_b) {
Ok(v) => v,
Err(e) => return (i + 1, Some(engine_failure_violation(op.id, e))),
};
if not_ab != na_dual_nb {
return (
i + 1,
Some(violation(
op.id,
"gpu-demorgan",
a,
b,
0,
not_ab,
na_dual_nb,
)),
);
}
}
(count, None)
}
AlgebraicLaw::DistributiveOver { over_op } => {
let specs = crate::spec::op_registry::compiled_specs();
let over_fn = specs.iter().find(|s| s.id == *over_op).map(|s| s.cpu_fn);
let Some(over) = over_fn else {
return (0, Some(unimplemented_gpu_law(op.id, law)));
};
for i in 0..count {
let a = rng.next_u32();
let b = rng.next_u32();
let c = rng.next_u32();
let b_over_c = match call_binary(over, b, c) {
Ok(v) => v,
Err(e) => return (i + 1, Some(engine_failure_violation(op.id, e))),
};
let lhs = match gpu_binary(backend, op, a, b_over_c, "distributive lhs") {
Ok(v) => v,
Err(e) => return (i + 1, Some(e)),
};
let a_op_b = match gpu_binary(backend, op, a, b, "distributive a*b") {
Ok(v) => v,
Err(e) => return (i + 1, Some(e)),
};
let a_op_c = match gpu_binary(backend, op, a, c, "distributive a*c") {
Ok(v) => v,
Err(e) => return (i + 1, Some(e)),
};
let rhs = match call_binary(over, a_op_b, a_op_c) {
Ok(v) => v,
Err(e) => return (i + 1, Some(engine_failure_violation(op.id, e))),
};
if lhs != rhs {
return (
i + 1,
Some(violation(op.id, "gpu-distributive", a, b, c, lhs, rhs)),
);
}
}
(count, None)
}
AlgebraicLaw::Complement { .. }
| AlgebraicLaw::LatticeAbsorption { .. }
| AlgebraicLaw::InverseOf { .. }
| AlgebraicLaw::Trichotomy { .. } => (0, Some(unimplemented_gpu_law(op.id, law))),
_ => (0, Some(unimplemented_gpu_law(op.id, law))),
}
}
fn check_unary_gpu_law(
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
law: &AlgebraicLaw,
count: u64,
) -> (u64, Option<LawViolation>) {
let mut rng = simple_rng(op.id, law.name());
match law {
AlgebraicLaw::Involution => {
for i in 0..count {
let a = rng.next_u32();
let fa = match gpu_unary(backend, op, a, "involution f(a)") {
Ok(value) => value,
Err(err) => return (i + 1, Some(err)),
};
let ffa = match gpu_unary(backend, op, fa, "involution f(f(a))") {
Ok(value) => value,
Err(err) => return (i + 1, Some(err)),
};
if ffa != a {
return (
i + 1,
Some(violation(op.id, "gpu-involution", a, 0, 0, ffa, a)),
);
}
}
(count, None)
}
AlgebraicLaw::Bounded { lo, hi } => {
for i in 0..count {
let a = rng.next_u32();
let fa = match gpu_unary(backend, op, a, "bounded") {
Ok(value) => value,
Err(err) => return (i + 1, Some(err)),
};
if fa < *lo || fa > *hi {
return (
i + 1,
Some(violation(op.id, "gpu-bounded", a, 0, 0, fa, *lo)),
);
}
}
(count, None)
}
AlgebraicLaw::Monotonic { direction } => {
use crate::spec::law::MonotonicDirection;
for i in 0..count {
let a = rng.next_u32();
let b = rng.next_u32();
let (a, b) = if a <= b { (a, b) } else { (b, a) };
let fa = match gpu_unary(backend, op, a, "monotonic f(a)") {
Ok(value) => value,
Err(err) => return (i + 1, Some(err)),
};
let fb = match gpu_unary(backend, op, b, "monotonic f(b)") {
Ok(value) => value,
Err(err) => return (i + 1, Some(err)),
};
let ok = match direction {
MonotonicDirection::NonDecreasing => fa <= fb,
MonotonicDirection::NonIncreasing => fa >= fb,
_ => false,
};
if !ok {
return (
i + 1,
Some(violation(op.id, "gpu-monotonic", a, b, 0, fa, fb)),
);
}
}
(count, None)
}
AlgebraicLaw::DeMorgan { .. }
| AlgebraicLaw::Complement { .. }
| AlgebraicLaw::Monotone
| AlgebraicLaw::Custom { .. } => (0, Some(unimplemented_gpu_law(op.id, law))),
_ => (0, Some(unimplemented_gpu_law(op.id, law))),
}
}
fn unimplemented_gpu_law(op_id: &str, law: &AlgebraicLaw) -> LawViolation {
missing_companion_violation(
op_id,
law.name(),
law.name(),
"GPU checker",
&format!("unimplemented GPU checker for {}", law.name()),
)
}
fn check_binary_equals(
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
a: u32,
b: u32,
expected: u32,
law_name: &str,
) -> Option<LawViolation> {
match gpu_binary(backend, op, a, b, law_name) {
Ok(actual) if actual == expected => None,
Ok(actual) => Some(violation(
op.id,
&format!("gpu-{law_name}"),
a,
b,
0,
actual,
expected,
)),
Err(err) => Some(err),
}
}
fn gpu_binary(
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
a: u32,
b: u32,
context: &str,
) -> Result<u32, LawViolation> {
let mut input = Vec::with_capacity(8);
input.extend_from_slice(&a.to_le_bytes());
input.extend_from_slice(&b.to_le_bytes());
let cpu = call_binary(op.cpu_fn, a, b).map_err(|e| engine_failure_violation(op.id, e))?;
dispatch_u32(backend, op, &input, cpu, a, b, 0, context)
}
fn gpu_unary(
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
a: u32,
context: &str,
) -> Result<u32, LawViolation> {
let input = a.to_le_bytes();
let cpu = call_unary(op.cpu_fn, a).map_err(|e| engine_failure_violation(op.id, e))?;
dispatch_u32(backend, op, &input, cpu, a, 0, 0, context)
}
fn dispatch_u32(
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
input: &[u8],
cpu: u32,
a: u32,
b: u32,
c: u32,
context: &str,
) -> Result<u32, LawViolation> {
let program = program_for_spec_input(op, input)
.map_err(|err| backend_violation(op.id, a, b, c, cpu, context, &err))?;
let mut outputs = backend
.dispatch(
&program,
&[input.to_vec()],
&vyre::DispatchConfig::default(),
)
.map_err(|err| backend_violation(op.id, a, b, c, cpu, context, &err.message))?;
if outputs.is_empty() {
return Err(backend_violation(
op.id,
a,
b,
c,
cpu,
context,
"backend returned zero output buffers, expected one. Fix: return the operation result as outputs[0].",
));
}
let output = outputs.remove(0);
if output.len() < 4 {
return Err(backend_violation(
op.id,
a,
b,
c,
cpu,
context,
&format!(
"backend returned {} bytes, expected 4. Fix: return exactly the requested output size.",
output.len()
),
));
}
let gpu = u32::from_le_bytes([output[0], output[1], output[2], output[3]]);
if gpu != cpu {
return Err(LawViolation {
law: format!("gpu-cpu-parity({context})"),
op_id: op.id.to_string(),
a,
b,
c,
lhs: gpu,
rhs: cpu,
message: format!(
"GPU result diverged from CPU reference during {context}: gpu={gpu}, cpu={cpu}. Fix: make the backend WGSL implementation byte-for-byte equivalent to the CPU reference."
),
});
}
Ok(gpu)
}
fn backend_violation(
op_id: &str,
a: u32,
b: u32,
c: u32,
cpu: u32,
context: &str,
err: &str,
) -> LawViolation {
LawViolation {
law: format!("gpu-dispatch({context})"),
op_id: op_id.to_string(),
a,
b,
c,
lhs: 0,
rhs: cpu,
message: format!("GPU dispatch failed during {context}: {err}. Fix: execute the conformance vyre IR for this operation."),
}
}
fn violation(op_id: &str, law: &str, a: u32, b: u32, c: u32, lhs: u32, rhs: u32) -> LawViolation {
LawViolation {
law: law.to_string(),
op_id: op_id.to_string(),
a,
b,
c,
lhs,
rhs,
message: format!(
"{law} violated on GPU: f({a}, {b}{}) = {lhs}, expected {rhs}. Fix: make the GPU backend satisfy the declared algebraic law and match the CPU reference.",
if c != 0 {
format!(", {c}")
} else {
String::new()
}
),
}
}
struct SimpleRng {
state: u64,
}
impl SimpleRng {
fn next_u32(&mut self) -> u32 {
self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z = z ^ (z >> 31);
z as u32
}
}
fn simple_rng(op_id: &str, law_name: &str) -> SimpleRng {
let mut hash = 0xcbf2_9ce4_8422_2325_u64;
if let Ok(seed) = std::env::var("VYRE_CONFORM_SEED") {
for byte in seed.as_bytes() {
hash ^= u64::from(*byte);
hash = hash.wrapping_mul(0x0000_0100_0000_01B3);
}
}
for byte in op_id.as_bytes().iter().chain(law_name.as_bytes()) {
hash ^= u64::from(*byte);
hash = hash.wrapping_mul(0x0000_0100_0000_01B3);
}
SimpleRng { state: hash }
}