use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};
use crate::spec::types::OpSpec;
#[inline]
fn dispatch_legacy_wgsl_probe(
backend: &dyn vyre::VyreBackend,
_wgsl: &str,
_input: &[u8],
_output_size: usize,
_config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
Err(format!(
"Backend `{}` cannot run legacy WGSL-string float probes through the public VyreBackend API. Fix: express this float probe as vyre IR and dispatch it with VyreBackend::dispatch.",
backend.id()
))
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum FloatGateStatus {
Passed,
Pending,
Failed,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct FloatFinding {
pub rule: &'static str,
pub spec_line: u32,
pub subject: String,
pub status: FloatGateStatus,
pub message: String,
}
impl FloatFinding {
#[inline]
pub fn failed(
rule: &'static str,
spec_line: u32,
subject: impl Into<String>,
message: impl Into<String>,
) -> Self {
Self {
rule,
spec_line,
subject: subject.into(),
status: FloatGateStatus::Failed,
message: message.into(),
}
}
#[inline]
pub fn pending(
rule: &'static str,
spec_line: u32,
subject: impl Into<String>,
message: impl Into<String>,
) -> Self {
Self {
rule,
spec_line,
subject: subject.into(),
status: FloatGateStatus::Pending,
message: message.into(),
}
}
}
#[inline]
pub(crate) fn enforce_float_semantics(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
) -> Vec<FloatFinding> {
let mut findings = Vec::new();
if let Some(program) = spec.program() {
if let Err(finding) = b1_fma_fusion::check_fma_fusion(&program) {
findings.push(finding.into());
}
}
findings.extend(precision::enforce_declared_float_width(backend, spec));
findings.extend(rounding::enforce_round_to_nearest_even(backend, spec));
findings.extend(b2_reduction_ordering::enforce_reduction_ordering(
backend, spec,
));
findings.extend(b3_subnormal::enforce_subnormal_preservation(backend));
findings.extend(b4_transcendentals::enforce_transcendentals(backend, spec));
findings.extend(b5_div_sqrt::enforce_div_sqrt(backend, spec));
findings.extend(b6_tensor::enforce_tensor_precision(backend, spec));
let wgsl = (spec.wgsl_fn)();
findings.extend(
b1_fma_fusion::inspect_backend_shader_for_fma_fusion(&wgsl)
.into_iter()
.map(Into::into),
);
findings.extend(b1_fma_fusion::probe_backend_for_fma_fusion(backend));
findings
}
#[inline]
pub(crate) fn lower_id(id: &str) -> String {
id.to_ascii_lowercase().replace(['-', '_'], ".")
}
#[inline]
pub(crate) fn is_float_spec(spec: &OpSpec) -> bool {
let id = lower_id(spec.id);
spec.signature.output.is_float_family()
|| id.contains("float")
|| id.contains("matmul")
|| matches!(spec.signature.output, crate::spec::types::DataType::Tensor)
}
pub mod b1_fma_fusion {
use std::collections::HashSet;
use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};
use crate::spec::types::conform::BufferInitPolicy;
use crate::spec::types::Convention;
use crate::enforce::enforcers::float_semantics::FloatFinding;
use vyre::ir::{BinOp, Expr, Node, Program};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct FmaFusionFinding {
pub location: String,
pub message: String,
}
impl From<FmaFusionFinding> for FloatFinding {
fn from(value: FmaFusionFinding) -> Self {
Self::failed("B1", 310, value.location, value.message)
}
}
#[inline]
pub fn check_fma_fusion(program: &Program) -> Result<(), FmaFusionFinding> {
let mut state = FmaState::default();
for (index, node) in program.entry.iter().enumerate() {
check_node(node, &mut state, &format!("IR entry node {}", index + 1))?;
}
Ok(())
}
#[inline]
pub fn inspect_backend_shader_for_fma_fusion(wgsl: &str) -> Vec<FmaFusionFinding> {
let mut findings = Vec::new();
for (line_index, line) in wgsl.lines().enumerate() {
if contains_call_token(line, "fma") && !line.contains("primitive.float.fma") {
let line_no = line_index + 1;
findings.push(FmaFusionFinding {
location: format!("WGSL:{line_no}"),
message: format!(
"FMA fusion detected at WGSL:{line_no}. Use primitive.float.fma explicitly or reshape the expression. Fix: emit separate multiply and add for SPEC.md:310."
),
});
}
}
findings
}
#[inline]
pub(crate) fn probe_backend_for_fma_fusion(
backend: &dyn vyre::VyreBackend,
) -> Vec<FloatFinding> {
match super::dispatch_legacy_wgsl_probe(backend,
fma_fusion_probe_wgsl(),
&[],
8,
ConformDispatchConfig {
workgroup_size: 1,
workgroup_count: 1,
convention: Convention::V1,
lookup_data: None,
buffer_init: BufferInitPolicy::Zero,
},
) {
Ok(output) if output.len() >= 8 => {
let separate = u32::from_le_bytes([output[0], output[1], output[2], output[3]]);
let fused = u32::from_le_bytes([output[4], output[5], output[6], output[7]]);
if separate == 0x3400_0000 && fused == 0x3400_0001 {
Vec::new()
} else if separate == fused {
vec![FloatFinding::failed(
"B1",
310,
backend.id(),
format!(
"Backend `{}` fused `(a * b) + c` into single-rounding FMA: both outputs were 0x{separate:08X}. Fix: disable implicit FMA contraction unless the op is primitive.float.fma.",
backend.id()
),
)]
} else {
vec![FloatFinding::failed(
"B1",
310,
backend.id(),
format!(
"Backend `{}` returned unexpected FMA probe bits separate=0x{separate:08X}, fused=0x{fused:08X}. Fix: execute the B1 probe with strict f32 semantics before certification.",
backend.id()
),
)]
}
}
Ok(output) => vec![FloatFinding::failed(
"B1",
310,
backend.id(),
format!(
"Backend `{}` returned {} bytes for FMA fusion probe. Fix: return separate and fused f32 bit patterns.",
backend.id(),
output.len()
),
)],
Err(error) => vec![FloatFinding::failed(
"B1",
310,
backend.id(),
format!(
"Backend `{}` failed FMA fusion probe: {error}. Fix: strict float backends must execute the B1 behavioral probe.",
backend.id()
),
)],
}
}
#[derive(Clone, Default)]
struct FmaState {
multiply_vars: HashSet<String>,
}
fn check_node(
node: &Node,
state: &mut FmaState,
location: &str,
) -> Result<(), FmaFusionFinding> {
match node {
Node::Let { name, value } => {
check_expr(value, state, location)?;
if is_mul(value) {
state.multiply_vars.insert(name.clone());
} else {
state.multiply_vars.remove(name);
}
}
Node::Assign { name, value } => {
check_expr(value, state, location)?;
if is_mul(value) {
state.multiply_vars.insert(name.clone());
} else {
state.multiply_vars.remove(name);
}
}
Node::Store { index, value, .. } => {
check_expr(index, state, location)?;
check_expr(value, state, location)?;
}
Node::If {
cond,
then,
otherwise,
} => {
check_expr(cond, state, location)?;
let before_branch = state.clone();
let mut then_state = before_branch.clone();
for (index, child) in then.iter().enumerate() {
check_node(
child,
&mut then_state,
&format!("{location}.then.{}", index + 1),
)?;
}
let mut otherwise_state = before_branch;
for (index, child) in otherwise.iter().enumerate() {
check_node(
child,
&mut otherwise_state,
&format!("{location}.else.{}", index + 1),
)?;
}
then_state
.multiply_vars
.retain(|name| otherwise_state.multiply_vars.contains(name));
*state = then_state;
}
Node::Loop { from, to, body, .. } => {
check_expr(from, state, location)?;
check_expr(to, state, location)?;
for (index, child) in body.iter().enumerate() {
check_node(child, state, &format!("{location}.loop.{}", index + 1))?;
}
}
Node::Block(nodes) => {
for (index, child) in nodes.iter().enumerate() {
check_node(child, state, &format!("{location}.block.{}", index + 1))?;
}
}
Node::Return | Node::Barrier => {}
}
Ok(())
}
fn check_expr(expr: &Expr, state: &FmaState, location: &str) -> Result<(), FmaFusionFinding> {
match expr {
Expr::BinOp {
op: BinOp::Add | BinOp::Sub,
left,
right,
} => {
if add_operand_is_mul(left, state) || add_operand_is_mul(right, state) {
return Err(FmaFusionFinding {
location: location.to_string(),
message: format!(
"FMA fusion detected at {location}. Use primitive.float.fma explicitly or reshape the expression. Fix: keep SPEC.md:310 two-rounding FMul+FAdd explicit."
),
});
}
check_expr(left, state, location)?;
check_expr(right, state, location)?;
}
Expr::BinOp { left, right, .. } => {
check_expr(left, state, location)?;
check_expr(right, state, location)?;
}
Expr::UnOp { operand, .. } | Expr::Cast { value: operand, .. } => {
check_expr(operand, state, location)?;
}
Expr::Load { index, .. } => check_expr(index, state, location)?,
Expr::Select {
cond,
true_val,
false_val,
} => {
check_expr(cond, state, location)?;
check_expr(true_val, state, location)?;
check_expr(false_val, state, location)?;
}
Expr::Call { args, .. } => {
for arg in args {
check_expr(arg, state, location)?;
}
}
Expr::Atomic {
index,
expected,
value,
..
} => {
check_expr(index, state, location)?;
if let Some(expected) = expected {
check_expr(expected, state, location)?;
}
check_expr(value, state, location)?;
}
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. } => {}
_ => {
return Err(FmaFusionFinding {
location: location.to_string(),
message: format!(
"Unhandled IR expression in FMA fusion check at {location}. Fix: teach B1 to inspect this expression variant before certifying SPEC.md:310."
),
});
}
}
Ok(())
}
fn add_operand_is_mul(expr: &Expr, state: &FmaState) -> bool {
match expr {
_ if is_mul(expr) => true,
Expr::Var(name) => state.multiply_vars.contains(name.as_str()),
Expr::UnOp { operand, .. } | Expr::Cast { value: operand, .. } => {
add_operand_is_mul(operand, state)
}
_ => false,
}
}
fn is_mul(expr: &Expr) -> bool {
match expr {
Expr::BinOp { op: BinOp::Mul, .. } => true,
Expr::UnOp { operand, .. } | Expr::Cast { value: operand, .. } => is_mul(operand),
_ => false,
}
}
fn contains_call_token(line: &str, token: &str) -> bool {
let mut rest = line;
while let Some(index) = rest.find(token) {
let before = rest[..index].chars().last();
let after = rest[index + token.len()..].chars().next();
let valid_before = before.is_none_or(|ch| !(ch == '_' || ch.is_ascii_alphanumeric()));
let valid_after = after.is_some_and(|ch| ch == '(' || ch.is_ascii_whitespace());
if valid_before && valid_after {
return true;
}
rest = &rest[index + token.len()..];
}
false
}
fn fma_fusion_probe_wgsl() -> &'static str {
r"
struct Bytes { data: array<u32>, };
@group(0) @binding(1) var<storage, read_write> output: Bytes;
@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x == 0u) {
let a = bitcast<f32>(0x3F800001u);
let b = bitcast<f32>(0x3F800001u);
let c = bitcast<f32>(0xBF800001u);
output.data[0u] = bitcast<u32>((a * b) + c);
output.data[1u] = bitcast<u32>(fma(a, b, c));
}
}
"
}
#[cfg(test)]
mod tests {
use super::*;
use vyre::ir::{BufferDecl, DataType, Expr, Node, Program};
fn program(entry: Vec<Node>) -> Program {
Program::new(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[1, 1, 1],
entry,
)
}
#[test]
fn positive_separate_mul_and_add_passes() {
let p = program(vec![
Node::let_bind("m", Expr::mul(Expr::u32(2), Expr::u32(3))),
Node::store("out", Expr::u32(0), Expr::add(Expr::u32(4), Expr::u32(5))),
]);
assert!(check_fma_fusion(&p).is_ok());
}
#[test]
fn negative_direct_mul_add_is_rejected() {
let p = program(vec![Node::store(
"out",
Expr::u32(0),
Expr::add(Expr::mul(Expr::u32(2), Expr::u32(3)), Expr::u32(4)),
)]);
let err = check_fma_fusion(&p).unwrap_err();
assert!(err.message.contains("FMA fusion detected"));
assert!(err.message.contains("SPEC.md:310"));
}
#[test]
fn negative_mul_sub_is_rejected() {
let p = program(vec![Node::store(
"out",
Expr::u32(0),
Expr::sub(Expr::mul(Expr::u32(2), Expr::u32(3)), Expr::u32(4)),
)]);
let err = check_fma_fusion(&p).unwrap_err();
assert!(err.message.contains("FMA fusion detected"));
}
#[test]
fn boundary_explicit_fma_call_passes() {
let p = program(vec![Node::store(
"out",
Expr::u32(0),
Expr::call(
"primitive.float.fma",
vec![Expr::u32(1), Expr::u32(2), Expr::u32(3)],
),
)]);
assert!(check_fma_fusion(&p).is_ok());
}
#[test]
fn wgsl_fma_builtin_is_rejected() {
let findings = inspect_backend_shader_for_fma_fusion("let x = fma(a, b, c);");
assert_eq!(findings.len(), 1);
}
}
}
pub mod b2_reduction_ordering {
use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};
use crate::spec::types::conform::BufferInitPolicy;
use crate::spec::types::OpSpec;
use std::hint::black_box;
use crate::enforce::enforcers::float_semantics::{lower_id, FloatFinding};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ReductionKind {
Strict,
TreeBinary,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ReductionOp {
Add,
Mul,
Min,
Max,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum ReductionShape {
LeftFold,
RightFold,
Balanced,
}
#[inline]
pub fn check_reduction_registry_id(op_id: &str) -> Result<(), FloatFinding> {
let id = lower_id(op_id);
let forbidden = [
"freduce.unordered",
"freduceunordered",
"f.reduce.unordered",
"freduce.approx",
"freduceapprox",
"f.reduce.approx",
];
if forbidden.iter().any(|needle| id.contains(needle)) {
return Err(FloatFinding::failed(
"B2",
315,
op_id,
format!(
"FReduceUnordered or approximate float reduction registered as `{op_id}`. Fix: use FReduceStrict or FReduceTreeBinary per SPEC.md:315."
),
));
}
if id.contains("freduce") && reduction_kind(op_id).is_none() {
return Err(FloatFinding::failed(
"B2",
315,
op_id,
format!(
"Unknown float reduction kind registered as `{op_id}`. Fix: declare only FReduceStrict or FReduceTreeBinary per SPEC.md:315."
),
));
}
Ok(())
}
#[inline]
pub fn reduction_kind(op_id: &str) -> Option<ReductionKind> {
let id = lower_id(op_id);
if id.contains("freduce.strict") || id.contains("freducestrict") {
Some(ReductionKind::Strict)
} else if id.contains("freduce.tree.binary") || id.contains("freducetreebinary") {
Some(ReductionKind::TreeBinary)
} else {
None
}
}
#[inline]
pub fn enforce_reduction_ordering(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
) -> Vec<FloatFinding> {
let mut findings = Vec::new();
if let Err(finding) = check_reduction_registry_id(spec.id) {
findings.push(finding);
return findings;
}
findings.extend(inspect_shader_for_unordered_reductions(
&(spec.wgsl_fn)(),
spec.id,
));
if !findings.is_empty() {
return findings;
}
let Some(kind) = reduction_kind(spec.id) else {
return findings;
};
let op = reduction_op(spec.id);
for witness in reduction_witnesses() {
let expected = canonical_reduce(kind, op, &witness)
.to_bits()
.to_le_bytes()
.to_vec();
let input = witness
.iter()
.flat_map(|value| value.to_bits().to_le_bytes())
.collect::<Vec<_>>();
for workgroup_size in [1, 64, 256] {
match super::dispatch_legacy_wgsl_probe(backend,
&(spec.wgsl_fn)(),
&input,
4,
ConformDispatchConfig {
workgroup_size,
workgroup_count: 1,
convention: spec.convention,
lookup_data: None,
buffer_init: BufferInitPolicy::Zero,
},
) {
Ok(output) if reduction_bytes_match(&output, &expected) => {}
Ok(output) => findings.push(FloatFinding::failed(
"B2",
315,
spec.id,
format!(
"Reduction ordering mismatch for `{}` at workgroup_size={workgroup_size}. Fix: match canonical {:?} {:?} byte-exactly per SPEC.md:315; left=0x{:08X}, right=0x{:08X}, balanced=0x{:08X}.",
spec.id,
kind,
op,
reduce_shape(ReductionShape::LeftFold, op, &witness).to_bits(),
reduce_shape(ReductionShape::RightFold, op, &witness).to_bits(),
reduce_shape(ReductionShape::Balanced, op, &witness).to_bits(),
),
)
.with_bytes(output, expected.clone())),
Err(error) => findings.push(FloatFinding::failed(
"B2",
315,
spec.id,
format!(
"Reduction backend dispatch failed for `{}` at workgroup_size={workgroup_size}: {error}. Fix: compile and run the canonical reduction conformance shader.",
spec.id
),
)),
}
if !findings.is_empty() {
return findings;
}
}
}
findings
}
#[inline]
pub fn inspect_shader_for_unordered_reductions(wgsl: &str, subject: &str) -> Vec<FloatFinding> {
let lower = wgsl.to_ascii_lowercase();
[
"workgroupreduceadd",
"workgroupreducemin",
"workgroupreducemax",
"subgroupadd",
"subgroupinclusiveadd",
"subgroupexclusiveadd",
"warp_shuffle",
"shuffle_xor",
"unordered_reduction",
"freduce_unordered",
]
.into_iter()
.filter(|token| lower.contains(token))
.map(|token| {
FloatFinding::failed(
"B2",
315,
subject,
format!(
"Unordered reduction token `{token}` found in `{subject}` shader. Fix: use the declared FReduceStrict or FReduceTreeBinary order per SPEC.md:315."
),
)
})
.collect()
}
trait WithBytes {
fn with_bytes(self, gpu: Vec<u8>, expected: Vec<u8>) -> Self;
}
impl WithBytes for FloatFinding {
fn with_bytes(mut self, gpu: Vec<u8>, expected: Vec<u8>) -> Self {
self.message.push_str(&format!(
" Observed gpu={gpu:02X?}, expected={expected:02X?}."
));
self
}
}
fn reduction_bytes_match(output: &[u8], expected: &[u8]) -> bool {
if output == expected {
return true;
}
if output.len() != 4 || expected.len() != 4 {
return false;
}
let left = u32::from_le_bytes([output[0], output[1], output[2], output[3]]);
let right = u32::from_le_bytes([expected[0], expected[1], expected[2], expected[3]]);
let left_value = f32::from_bits(left);
let right_value = f32::from_bits(right);
(left_value.is_nan() && right_value.is_nan()) || (left_value == 0.0 && right_value == 0.0)
}
#[inline]
pub fn canonical_reduce(kind: ReductionKind, op: ReductionOp, values: &[f32]) -> f32 {
match kind {
ReductionKind::Strict => reduce_shape(ReductionShape::LeftFold, op, values),
ReductionKind::TreeBinary => reduce_shape(ReductionShape::Balanced, op, values),
}
}
fn reduction_op(op_id: &str) -> ReductionOp {
let id = lower_id(op_id);
let tail = id.rsplit('.').next().unwrap_or(id.as_str());
if tail.ends_with("mul") {
ReductionOp::Mul
} else if tail.ends_with("min") {
ReductionOp::Min
} else if tail.ends_with("max") {
ReductionOp::Max
} else {
ReductionOp::Add
}
}
fn reduce_shape(shape: ReductionShape, op: ReductionOp, values: &[f32]) -> f32 {
match shape {
ReductionShape::LeftFold => values
.iter()
.copied()
.fold(identity(op), |acc, value| apply_op(op, acc, value)),
ReductionShape::RightFold => values
.iter()
.rev()
.copied()
.fold(identity(op), |acc, value| apply_op(op, value, acc)),
ReductionShape::Balanced => tree_reduce(op, values),
}
}
fn tree_reduce(op: ReductionOp, values: &[f32]) -> f32 {
match values {
[] => identity(op),
[value] => *value,
_ => {
let mid = values.len() / 2;
apply_op(
op,
tree_reduce(op, &values[..mid]),
tree_reduce(op, &values[mid..]),
)
}
}
}
fn reduction_witnesses() -> Vec<Vec<f32>> {
let mut witnesses = vec![
vec![],
vec![1.0],
vec![0.0, -0.0],
vec![f32::INFINITY, f32::NEG_INFINITY, 1.0],
vec![f32::from_bits(0x7FC0_0000), 1.0, 2.0],
vec![
f32::from_bits(0x7F80_0001),
f32::from_bits(0x7FC0_0001),
1.0,
],
vec![1.0, 2.0, 3.0, 4.0],
vec![16_777_216.0, 1.0, -16_777_216.0, 1.0],
vec![f32::from_bits(1), 1.0, -1.0, f32::from_bits(2)],
];
for len in [64, 256, 1024] {
witnesses.push((0..len).map(reduction_stress_value).collect());
}
witnesses
}
fn reduction_stress_value(index: usize) -> f32 {
match index % 8 {
0 => 16_777_216.0,
1 => 1.0,
2 => -16_777_216.0,
3 => -0.0,
4 => f32::from_bits(1),
5 => -1.0,
6 => f32::from_bits(2),
_ => 2.0,
}
}
fn identity(op: ReductionOp) -> f32 {
match op {
ReductionOp::Add => 0.0,
ReductionOp::Mul => 1.0,
ReductionOp::Min => f32::INFINITY,
ReductionOp::Max => f32::NEG_INFINITY,
}
}
fn apply_op(op: ReductionOp, left: f32, right: f32) -> f32 {
let left = black_box(left);
let right = black_box(right);
match op {
ReductionOp::Add => black_box(left + right),
ReductionOp::Mul => black_box(left * right),
ReductionOp::Min => black_box(left.min(right)),
ReductionOp::Max => black_box(left.max(right)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};
use crate::spec::types::conform::Strictness;
use crate::spec::types::{DataType, OpSignature};
use vyre_spec::Category;
struct ReductionBackend {
kind: ReductionKind,
}
impl WgslBackend for ReductionBackend {
fn name(&self) -> &str {
"reduction-mock"
}
fn dispatch(
&self,
_wgsl: &str,
input: &[u8],
_output_size: usize,
_config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
let values = input
.chunks_exact(4)
.map(|chunk| {
f32::from_bits(u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
})
.collect::<Vec<_>>();
Ok(canonical_reduce(self.kind, ReductionOp::Add, &values)
.to_bits()
.to_le_bytes()
.to_vec())
}
}
fn cpu(_: &[u8]) -> Vec<u8> {
0_u32.to_le_bytes().to_vec()
}
fn test_spec(id: &'static str) -> OpSpec {
OpSpec::builder(id)
.signature(OpSignature {
inputs: vec![DataType::F32],
output: DataType::F32,
})
.cpu_fn(cpu)
.wgsl_fn(|| {
"fn vyre_op(index: u32, input_len: u32) -> u32 { return 0u; }".to_string()
})
.category(Category::A {
composition_of: vec![],
})
.laws(vec![])
.strictness(Strictness::Strict)
.version(1)
.build()
.expect("registry invariant violated")
}
#[test]
fn positive_strict_reference_matches_left_to_right() {
let values = [16_777_216.0, 1.0, -16_777_216.0, 1.0];
assert_eq!(
canonical_reduce(ReductionKind::Strict, ReductionOp::Add, &values).to_bits(),
1.0_f32.to_bits()
);
}
#[test]
fn negative_unordered_id_rejected() {
let err =
check_reduction_registry_id("primitive.float.freduce_unordered_add").unwrap_err();
assert!(err.message.contains("SPEC.md:315"));
}
#[test]
fn boundary_tree_differs_from_strict_but_is_canonical() {
let values = [1.0_f32, 1.0e20_f32, -1.0e20_f32, 1.0_f32];
let strict = canonical_reduce(ReductionKind::Strict, ReductionOp::Add, &values);
let tree = canonical_reduce(ReductionKind::TreeBinary, ReductionOp::Add, &values);
assert_ne!(
strict.to_bits(),
tree.to_bits(),
"strict={strict} tree={tree} — inputs must expose reduction shape divergence"
);
}
#[test]
fn backend_tree_kind_passes_tree_witnesses() {
let spec = test_spec("primitive.float.freduce_tree_binary_add");
let findings = enforce_reduction_ordering(
&ReductionBackend {
kind: ReductionKind::TreeBinary,
},
&spec,
);
assert!(findings.is_empty(), "{findings:?}");
}
#[test]
fn negative_unordered_shader_token_fails() {
let findings = inspect_shader_for_unordered_reductions(
"fn x(v: f32) -> f32 { return workgroupReduceAdd(v); }",
"primitive.float.freduce_strict_add",
);
assert_eq!(findings[0].rule, "B2");
}
#[test]
fn boundary_nan_and_signed_zero_match_ieee_reduction_semantics() {
assert!(reduction_bytes_match(
&0x7FC0_0001_u32.to_le_bytes(),
&0x7FC0_0000_u32.to_le_bytes()
));
assert!(reduction_bytes_match(
&0x8000_0000_u32.to_le_bytes(),
&0x0000_0000_u32.to_le_bytes()
));
}
}
}
pub mod b3_subnormal {
use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};
use crate::spec::types::conform::BufferInitPolicy;
use crate::spec::types::Convention;
use crate::enforce::enforcers::float_semantics::FloatFinding;
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum SubnormalProbe {
Preserved,
FlushesToZero,
UnsupportedByBackend(String),
}
#[derive(Clone, Copy, Debug)]
struct SubnormalCase {
label: &'static str,
op: u32,
x_bits: u32,
y_bits: u32,
expected_bits: u32,
ftz_bits: u32,
}
#[inline]
pub(crate) fn probe_subnormal_preservation(backend: &dyn vyre::VyreBackend) -> SubnormalProbe {
for case in subnormal_cases() {
let mut input = Vec::with_capacity(12);
input.extend_from_slice(&case.x_bits.to_le_bytes());
input.extend_from_slice(&case.y_bits.to_le_bytes());
input.extend_from_slice(&case.op.to_le_bytes());
match super::dispatch_legacy_wgsl_probe(
backend,
subnormal_probe_wgsl(),
&input,
4,
ConformDispatchConfig {
workgroup_size: 1,
workgroup_count: 1,
convention: Convention::V1,
lookup_data: None,
buffer_init: BufferInitPolicy::Zero,
},
) {
Ok(output) if output == case.expected_bits.to_le_bytes() => {}
Ok(output) if output == case.ftz_bits.to_le_bytes() => {
return SubnormalProbe::FlushesToZero
}
Ok(output) => {
return SubnormalProbe::UnsupportedByBackend(format!(
"subnormal probe `{}` returned {output:02X?}; Fix: preserve f32 denormal bits, including signed zero, for SPEC.md:321.",
case.label
));
}
Err(error) => return SubnormalProbe::UnsupportedByBackend(error),
}
}
SubnormalProbe::Preserved
}
#[inline]
pub(crate) fn enforce_subnormal_preservation(
backend: &dyn vyre::VyreBackend,
) -> Vec<FloatFinding> {
match probe_subnormal_preservation(backend) {
SubnormalProbe::Preserved => Vec::new(),
SubnormalProbe::FlushesToZero => vec![FloatFinding::failed(
"B3",
321,
backend.id(),
"Subnormal flushing detected by the f32 denormal corpus. Fix: disable flush-to-zero mode, preserve signed denormal results, or report UnsupportedByBackend per SPEC.md:321.",
)],
SubnormalProbe::UnsupportedByBackend(error) => vec![FloatFinding::failed(
"B3",
321,
backend.id(),
format!(
"Subnormal preservation probe unsupported on backend `{}`: {error}. Fix: expose a backend with FTZ disabled or return UnsupportedByBackend for strict float ops per SPEC.md:321.",
backend.id()
),
)],
}
}
fn subnormal_cases() -> Vec<SubnormalCase> {
vec![
SubnormalCase {
label: "smallest-positive-mul",
op: 0,
x_bits: 0x0000_0001,
y_bits: 0x3F80_0000,
expected_bits: 0x0000_0001,
ftz_bits: 0x0000_0000,
},
SubnormalCase {
label: "smallest-negative-mul",
op: 0,
x_bits: 0x8000_0001,
y_bits: 0x3F80_0000,
expected_bits: 0x8000_0001,
ftz_bits: 0x8000_0000,
},
SubnormalCase {
label: "largest-positive-add-zero",
op: 1,
x_bits: 0x007F_FFFF,
y_bits: 0x0000_0000,
expected_bits: 0x007F_FFFF,
ftz_bits: 0x0000_0000,
},
SubnormalCase {
label: "smallest-negative-sub-zero",
op: 2,
x_bits: 0x8000_0001,
y_bits: 0x0000_0000,
expected_bits: 0x8000_0001,
ftz_bits: 0x8000_0000,
},
SubnormalCase {
label: "normal-generates-subnormal",
op: 3,
x_bits: 0x0080_0000,
y_bits: 0x3F00_0000,
expected_bits: 0x0040_0000,
ftz_bits: 0x0000_0000,
},
]
}
fn subnormal_probe_wgsl() -> &'static str {
r"
struct Bytes { data: array<u32>, };
@group(0) @binding(0) var<storage, read> input: Bytes;
@group(0) @binding(1) var<storage, read_write> output: Bytes;
@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x == 0u) {
let x = bitcast<f32>(input.data[0u]);
let y = bitcast<f32>(input.data[1u]);
let op = input.data[2u];
if (op == 0u) {
output.data[0u] = bitcast<u32>(x * y);
} else if (op == 1u) {
output.data[0u] = bitcast<u32>(x + y);
} else if (op == 2u) {
output.data[0u] = bitcast<u32>(x - y);
} else if (op == 3u) {
output.data[0u] = bitcast<u32>(x * y);
} else {
output.data[0u] = 0xFFFF_FFFFu;
}
}
}
"
}
#[cfg(test)]
mod tests {
use super::*;
enum Mode {
Preserve,
Flush,
Error,
}
struct Backend {
mode: Mode,
}
impl WgslBackend for Backend {
fn name(&self) -> &str {
"subnormal-mock"
}
fn dispatch(
&self,
_wgsl: &str,
input: &[u8],
_output_size: usize,
_config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
let x_bits = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
let y_bits = u32::from_le_bytes([input[4], input[5], input[6], input[7]]);
let op = u32::from_le_bytes([input[8], input[9], input[10], input[11]]);
let x = f32::from_bits(x_bits);
let y = f32::from_bits(y_bits);
match self.mode {
Mode::Preserve => {
let value = match op {
0 | 3 => x * y,
1 => x + y,
2 => x - y,
_ => f32::from_bits(0xFFFF_FFFF),
};
Ok(value.to_bits().to_le_bytes().to_vec())
}
Mode::Flush => {
let sign = x_bits & 0x8000_0000;
Ok(sign.to_le_bytes().to_vec())
}
Mode::Error => Err("UnsupportedByBackend: no float controls".to_string()),
}
}
}
#[test]
fn positive_subnormal_preserved() {
assert_eq!(
probe_subnormal_preservation(&Backend {
mode: Mode::Preserve
}),
SubnormalProbe::Preserved
);
}
#[test]
fn negative_ftz_is_rejected() {
let findings = enforce_subnormal_preservation(&Backend { mode: Mode::Flush });
assert_eq!(findings.len(), 1);
assert!(findings[0].message.contains("SPEC.md:321"));
}
#[test]
fn boundary_unsupported_backend_is_not_silent() {
let findings = enforce_subnormal_preservation(&Backend { mode: Mode::Error });
assert_eq!(findings[0].rule, "B3");
assert!(findings[0].message.contains("UnsupportedByBackend"));
}
}
}
pub mod b4_transcendentals {
use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};
use crate::spec::types::conform::BufferInitPolicy;
use crate::spec::types::OpSpec;
use crate::enforce::enforcers::float_semantics::{lower_id, FloatFinding};
const DEFAULT_WITNESSES: u32 = 10_000_000;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum TranscendentalOp {
Sin,
Cos,
Tan,
Exp,
Log,
Pow,
Exp2,
Log2,
}
#[inline]
pub fn transcendental_op(op_id: &str) -> Option<TranscendentalOp> {
let id = lower_id(op_id);
let tail = id.rsplit('.').next().unwrap_or(id.as_str());
match tail {
"fsin" | "sin" => Some(TranscendentalOp::Sin),
"fcos" | "cos" => Some(TranscendentalOp::Cos),
"ftan" | "tan" => Some(TranscendentalOp::Tan),
"fexp" | "exp" => Some(TranscendentalOp::Exp),
"flog" | "log" => Some(TranscendentalOp::Log),
"fpow" | "pow" => Some(TranscendentalOp::Pow),
"fexp2" | "exp2" => Some(TranscendentalOp::Exp2),
"flog2" | "log2" => Some(TranscendentalOp::Log2),
_ => None,
}
}
#[inline]
pub(crate) fn enforce_transcendentals(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
) -> Vec<FloatFinding> {
enforce_transcendentals_with_witnesses(backend, spec, DEFAULT_WITNESSES)
}
#[inline]
pub(crate) fn enforce_transcendentals_with_witnesses(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
witnesses: u32,
) -> Vec<FloatFinding> {
let Some(op) = transcendental_op(spec.id) else {
return Vec::new();
};
if op_declares_pending_correct_rounding(spec) {
return vec![FloatFinding::failed(
"B4",
325,
spec.id,
format!(
"`{}` declares pending_wgsl_correct_rounding=true. Fix: replace vendor transcendental lowering with CORE-MATH parity before claiming L2f per SPEC.md:325.",
spec.id
),
)];
}
if let Err(message) = validate_core_math_oracle(op, core_math_reference) {
return vec![FloatFinding::failed("B4", 325, spec.id, message)];
}
let findings = inspect_shader_for_transcendental_approximations(&(spec.wgsl_fn)(), spec.id);
if !findings.is_empty() {
return findings;
}
run_transcendental_parity(backend, spec, op, witnesses, core_math_reference)
}
#[inline]
pub fn run_transcendental_parity(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
op: TranscendentalOp,
witnesses: u32,
oracle: fn(TranscendentalOp, f32, f32) -> f32,
) -> Vec<FloatFinding> {
let mut findings = Vec::new();
for witness in WitnessStream::new().take(witnesses as usize) {
let input = input_for(op, witness);
let expected = reference_bytes(op, witness, oracle);
for workgroup_size in [1, 64] {
match super::dispatch_legacy_wgsl_probe(
backend,
&(spec.wgsl_fn)(),
&input,
4,
ConformDispatchConfig {
workgroup_size,
workgroup_count: 1,
convention: spec.convention,
lookup_data: None,
buffer_init: BufferInitPolicy::Zero,
},
) {
Ok(output) if float_bytes_match(&output, &expected) => {}
Ok(output) => {
findings.push(FloatFinding::failed(
"B4",
325,
spec.id,
format!(
"Transcendental `{}` is not correctly rounded for witness 0x{:08X} at workgroup_size={workgroup_size}: gpu={output:02X?}, core_math={expected:02X?}. Fix: use CORE-MATH and require 0 ULP per SPEC.md:325.",
spec.id, witness
),
));
break;
}
Err(error) => {
findings.push(FloatFinding::failed(
"B4",
325,
spec.id,
format!(
"Transcendental `{}` dispatch failed at workgroup_size={workgroup_size}: {error}. Fix: strict float backends must run the CORE-MATH parity shader per SPEC.md:325.",
spec.id
),
));
break;
}
}
}
if !findings.is_empty() {
break;
}
}
findings
}
#[inline]
pub fn core_math_reference(op: TranscendentalOp, x: f32, y: f32) -> f32 {
match op {
TranscendentalOp::Sin => x.sin(),
TranscendentalOp::Cos => x.cos(),
TranscendentalOp::Tan => x.tan(),
TranscendentalOp::Exp => x.exp(),
TranscendentalOp::Log => x.ln(),
TranscendentalOp::Pow => x.powf(y),
TranscendentalOp::Exp2 => x.exp2(),
TranscendentalOp::Log2 => x.log2(),
}
}
#[inline]
pub fn op_declares_pending_correct_rounding(spec: &OpSpec) -> bool {
let wgsl = (spec.wgsl_fn)();
wgsl.contains("pending_wgsl_correct_rounding: true")
|| wgsl.contains("pending_wgsl_correct_rounding = true")
|| spec
.docs_path
.contains("pending_wgsl_correct_rounding=true")
}
#[inline]
pub fn inspect_shader_for_transcendental_approximations(
wgsl: &str,
subject: &str,
) -> Vec<FloatFinding> {
let lower = wgsl.to_ascii_lowercase();
[
"fast_sin",
"fast_cos",
"fast_tan",
"fast_exp",
"fast_log",
"approx_sin",
"approx_cos",
"approx_tan",
"approx_exp",
"approx_log",
"native_sin",
"native_cos",
"native_exp",
"native_log",
"sincos",
"fast_math",
"fast-math",
]
.into_iter()
.filter(|token| lower.contains(token))
.map(|token| {
FloatFinding::failed(
"B4",
325,
subject,
format!(
"Forbidden transcendental approximation token `{token}` in strict float shader. Fix: use correctly-rounded CORE-MATH-equivalent lowering per SPEC.md:325."
),
)
})
.collect()
}
fn input_for(op: TranscendentalOp, x_bits: u32) -> Vec<u8> {
let mut input = x_bits.to_le_bytes().to_vec();
if op == TranscendentalOp::Pow {
input.extend(x_bits.rotate_left(13).to_le_bytes());
}
input
}
fn reference_bytes(
op: TranscendentalOp,
x_bits: u32,
oracle: fn(TranscendentalOp, f32, f32) -> f32,
) -> Vec<u8> {
let x = f32::from_bits(x_bits);
let y = f32::from_bits(x_bits.rotate_left(13));
oracle(op, x, y).to_bits().to_le_bytes().to_vec()
}
fn float_bytes_match(output: &[u8], expected: &[u8]) -> bool {
if output == expected {
return true;
}
if output.len() != 4 || expected.len() != 4 {
return false;
}
let left = u32::from_le_bytes([output[0], output[1], output[2], output[3]]);
let right = u32::from_le_bytes([expected[0], expected[1], expected[2], expected[3]]);
f32::from_bits(left).is_nan() && f32::from_bits(right).is_nan()
}
fn validate_core_math_oracle(
op: TranscendentalOp,
oracle: fn(TranscendentalOp, f32, f32) -> f32,
) -> Result<(), String> {
let checks: &[(u32, u32)] = match op {
TranscendentalOp::Sin | TranscendentalOp::Tan => &[(0x0000_0000, 0x0000_0000)],
TranscendentalOp::Cos => &[(0x0000_0000, 0x3F80_0000)],
TranscendentalOp::Exp | TranscendentalOp::Exp2 => {
&[(0x0000_0000, 0x3F80_0000), (0xFF80_0000, 0x0000_0000)]
}
TranscendentalOp::Log | TranscendentalOp::Log2 => {
&[(0x3F80_0000, 0x0000_0000), (0x0000_0000, 0xFF80_0000)]
}
TranscendentalOp::Pow => &[(0x3F80_0000, 0x3F80_0000)],
};
for (x_bits, expected_bits) in checks {
let observed = oracle(op, f32::from_bits(*x_bits), 1.0).to_bits();
if observed != *expected_bits {
return Err(format!(
"CORE-MATH oracle self-check failed for {op:?} input 0x{x_bits:08X}: got 0x{observed:08X}, expected 0x{expected_bits:08X}. Fix: link a correctly-rounded oracle before certifying SPEC.md:325."
));
}
}
Ok(())
}
struct WitnessStream {
state: u32,
boundary_index: usize,
}
impl WitnessStream {
fn new() -> Self {
Self {
state: 0x9E37_79B9,
boundary_index: 0,
}
}
}
impl Iterator for WitnessStream {
type Item = u32;
fn next(&mut self) -> Option<Self::Item> {
const BOUNDARIES: &[u32] = &[
0x0000_0000,
0x8000_0000,
0x3F80_0000,
0xBF80_0000,
0x7F80_0000,
0xFF80_0000,
0x7FC0_0000,
0x7FC0_0001,
0x7F80_0001,
0x0000_0001,
0x007F_FFFF,
];
if let Some(value) = BOUNDARIES.get(self.boundary_index) {
self.boundary_index += 1;
return Some(*value);
}
self.state = self
.state
.wrapping_mul(1_664_525)
.wrapping_add(1_013_904_223);
Some(self.state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::spec::types::conform::Strictness;
use crate::spec::types::{DataType, OpSignature};
use vyre_spec::Category;
struct OracleBackend {
op: TranscendentalOp,
oracle: fn(TranscendentalOp, f32, f32) -> f32,
}
impl WgslBackend for OracleBackend {
fn name(&self) -> &str {
"transcendental-mock"
}
fn dispatch(
&self,
_wgsl: &str,
input: &[u8],
_output_size: usize,
_config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
let x =
f32::from_bits(u32::from_le_bytes([input[0], input[1], input[2], input[3]]));
let y = if input.len() >= 8 {
f32::from_bits(u32::from_le_bytes([input[4], input[5], input[6], input[7]]))
} else {
0.0
};
Ok((self.oracle)(self.op, x, y)
.to_bits()
.to_le_bytes()
.to_vec())
}
}
fn wrong_reference(op: TranscendentalOp, x: f32, y: f32) -> f32 {
f32::from_bits(core_math_reference(op, x, y).to_bits().wrapping_add(1))
}
fn cpu(_: &[u8]) -> Vec<u8> {
0_u32.to_le_bytes().to_vec()
}
fn test_spec(id: &'static str, wgsl_fn: fn() -> String) -> OpSpec {
OpSpec::builder(id)
.signature(OpSignature {
inputs: vec![DataType::F32],
output: DataType::F32,
})
.cpu_fn(cpu)
.wgsl_fn(wgsl_fn)
.category(Category::A {
composition_of: vec![],
})
.laws(vec![])
.strictness(Strictness::Strict)
.version(1)
.build()
.expect("registry invariant violated")
}
fn regular_wgsl() -> String {
"fn vyre_op(index: u32, input_len: u32) -> u32 { return 0u; }".to_string()
}
fn pending_wgsl() -> String {
"// pending_wgsl_correct_rounding: true".to_string()
}
#[test]
fn positive_core_math_backend_passes() {
let spec = test_spec("primitive.float.fsin", regular_wgsl);
let findings = enforce_transcendentals_with_witnesses(
&OracleBackend {
op: TranscendentalOp::Sin,
oracle: core_math_reference,
},
&spec,
64,
);
assert!(findings.is_empty(), "{findings:?}");
}
#[test]
fn negative_synthetic_wrong_reference_fails() {
let spec = test_spec("primitive.float.fsin", regular_wgsl);
let findings = enforce_transcendentals_with_witnesses(
&OracleBackend {
op: TranscendentalOp::Sin,
oracle: wrong_reference,
},
&spec,
1,
);
assert_eq!(findings[0].rule, "B4");
assert!(findings[0].message.contains("0 ULP"));
}
#[test]
fn boundary_pending_status_is_not_failure_or_pass() {
let spec = test_spec("primitive.float.fcos", pending_wgsl);
let findings = enforce_transcendentals_with_witnesses(
&OracleBackend {
op: TranscendentalOp::Cos,
oracle: wrong_reference,
},
&spec,
1,
);
assert_eq!(findings[0].status, super::super::FloatGateStatus::Failed);
}
#[test]
fn boundary_nan_payloads_compare_as_nan() {
assert!(float_bytes_match(
&0x7FC0_0001_u32.to_le_bytes(),
&0x7FC0_0000_u32.to_le_bytes()
));
}
#[test]
fn negative_fast_transcendental_token_fails() {
let findings = inspect_shader_for_transcendental_approximations(
"fn x(v: f32) -> f32 { return native_sin(v); }",
"primitive.float.fsin",
);
assert_eq!(findings[0].rule, "B4");
}
}
}
pub mod b5_div_sqrt {
use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};
use crate::spec::types::conform::BufferInitPolicy;
use crate::spec::types::OpSpec;
use std::hint::black_box;
use crate::enforce::enforcers::float_semantics::{lower_id, FloatFinding};
const DEFAULT_WITNESSES: u32 = 10_000_000;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum DivSqrtOp {
Div,
Sqrt,
}
#[inline]
pub fn div_sqrt_op(op_id: &str) -> Option<DivSqrtOp> {
let id = lower_id(op_id);
let tail = id.rsplit('.').next().unwrap_or(id.as_str());
if matches!(tail, "fdiv" | "div") || tail.ends_with("fdiv") || tail.ends_with("div") {
Some(DivSqrtOp::Div)
} else if matches!(tail, "fsqrt" | "sqrt")
|| tail.ends_with("fsqrt")
|| tail.ends_with("sqrt")
{
Some(DivSqrtOp::Sqrt)
} else {
None
}
}
#[inline]
pub(crate) fn enforce_div_sqrt(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
) -> Vec<FloatFinding> {
enforce_div_sqrt_with_witnesses(backend, spec, DEFAULT_WITNESSES)
}
#[inline]
pub(crate) fn enforce_div_sqrt_with_witnesses(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
witnesses: u32,
) -> Vec<FloatFinding> {
let Some(op) = div_sqrt_op(spec.id) else {
return Vec::new();
};
let mut findings = inspect_shader_for_div_sqrt_approximations(&(spec.wgsl_fn)(), op);
if !findings.is_empty() {
return findings;
}
findings.extend(run_div_sqrt_parity(backend, spec, op, witnesses));
findings
}
#[inline]
pub fn inspect_shader_for_div_sqrt_approximations(
wgsl: &str,
op: DivSqrtOp,
) -> Vec<FloatFinding> {
let mut findings = Vec::new();
let lower = wgsl.to_ascii_lowercase();
let forbidden = [
"fast-math",
"fast_math",
"approx",
"rcp(",
"rsqrt(",
"native_divide",
"native_sqrt",
"reciprocal",
"fast_div",
"fast_sqrt",
];
for token in forbidden {
if lower.contains(token) {
findings.push(FloatFinding::failed(
"B5",
330,
"WGSL",
format!(
"Forbidden fast float token `{token}` in FDiv/FSqrt shader. Fix: remove approximation paths per SPEC.md:330."
),
));
}
}
if op == DivSqrtOp::Sqrt && lower.contains("inversesqrt(") {
findings.push(FloatFinding::failed(
"B5",
330,
"WGSL",
"FSqrt shader uses inverseSqrt. Fix: use correctly-rounded sqrt semantics; reciprocal-square-root approximations are forbidden by SPEC.md:330.",
));
}
if op == DivSqrtOp::Sqrt && contains_reciprocal_sqrt_shortcut(&lower) {
findings.push(FloatFinding::failed(
"B5",
330,
"WGSL",
"FSqrt shader contains `1.0 / sqrt(...)`. Fix: do not replace declared FSqrt with reciprocal-sqrt or split reciprocal shortcuts per SPEC.md:330.",
));
}
if op == DivSqrtOp::Sqrt && contains_pow_sqrt_shortcut(&lower) {
findings.push(FloatFinding::failed(
"B5",
330,
"WGSL",
"FSqrt shader uses pow(..., 0.5) or pow(..., -0.5). Fix: use correctly-rounded sqrt semantics per SPEC.md:330.",
));
}
findings
}
#[inline]
pub fn run_div_sqrt_parity(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
op: DivSqrtOp,
witnesses: u32,
) -> Vec<FloatFinding> {
let mut findings = Vec::new();
for (x_bits, y_bits) in WitnessStream::new().take(witnesses as usize) {
let input = input_for(op, x_bits, y_bits);
let expected = reference_bytes(op, x_bits, y_bits);
for workgroup_size in [1, 64] {
match super::dispatch_legacy_wgsl_probe(
backend,
&(spec.wgsl_fn)(),
&input,
4,
ConformDispatchConfig {
workgroup_size,
workgroup_count: 1,
convention: spec.convention,
lookup_data: None,
buffer_init: BufferInitPolicy::Zero,
},
) {
Ok(output) if output == expected => {}
Ok(output) => {
findings.push(FloatFinding::failed(
"B5",
330,
spec.id,
format!(
"`{}` is not correctly rounded for witness x=0x{x_bits:08X}, y=0x{y_bits:08X} at workgroup_size={workgroup_size}: gpu={output:02X?}, expected={expected:02X?}. Fix: use IEEE 754 correctly-rounded f32 operations per SPEC.md:330.",
spec.id
),
));
break;
}
Err(error) => {
findings.push(FloatFinding::failed(
"B5",
330,
spec.id,
format!(
"`{}` dispatch failed during FDiv/FSqrt parity at workgroup_size={workgroup_size}: {error}. Fix: strict float backends must execute the conformance shader per SPEC.md:330.",
spec.id
),
));
break;
}
}
}
if !findings.is_empty() {
break;
}
}
findings
}
fn contains_reciprocal_sqrt_shortcut(wgsl: &str) -> bool {
let compact: String = wgsl.chars().filter(|ch| !ch.is_whitespace()).collect();
[
"1.0/sqrt(",
"1./sqrt(",
"1f/sqrt(",
"1/sqrt(",
"(1.0)/sqrt(",
"(1.)/sqrt(",
"(1f)/sqrt(",
"(1)/sqrt(",
]
.into_iter()
.any(|needle| compact.contains(needle))
}
fn contains_pow_sqrt_shortcut(wgsl: &str) -> bool {
let compact: String = wgsl.chars().filter(|ch| !ch.is_whitespace()).collect();
compact.contains("pow(")
&& (compact.contains(",0.5)")
|| compact.contains(",.5)")
|| compact.contains(",-0.5)")
|| compact.contains(",-.5)"))
}
fn input_for(op: DivSqrtOp, x_bits: u32, y_bits: u32) -> Vec<u8> {
let mut input = x_bits.to_le_bytes().to_vec();
if op == DivSqrtOp::Div {
input.extend(y_bits.to_le_bytes());
}
input
}
fn reference_bytes(op: DivSqrtOp, x_bits: u32, y_bits: u32) -> Vec<u8> {
let x = black_box(f32::from_bits(x_bits));
let y = black_box(f32::from_bits(y_bits));
let value = match op {
DivSqrtOp::Div => black_box(x / y),
DivSqrtOp::Sqrt => black_box(x.sqrt()),
};
value.to_bits().to_le_bytes().to_vec()
}
struct WitnessStream {
state: u32,
boundary_index: usize,
}
impl WitnessStream {
fn new() -> Self {
Self {
state: 0xA341_316C,
boundary_index: 0,
}
}
}
impl Iterator for WitnessStream {
type Item = (u32, u32);
fn next(&mut self) -> Option<Self::Item> {
const BOUNDARIES: &[(u32, u32)] = &[
(0x0000_0000, 0x3F80_0000),
(0x8000_0000, 0x3F80_0000),
(0x3F80_0000, 0x0000_0000),
(0xBF80_0000, 0x8000_0000),
(0x7F80_0000, 0x3F80_0000),
(0xFF80_0000, 0x3F80_0000),
(0x7FC0_0000, 0x3F80_0000),
(0x7F80_0001, 0x3F80_0000),
(0x0000_0001, 0x3F80_0000),
(0x007F_FFFF, 0x3F80_0000),
];
if let Some(pair) = BOUNDARIES.get(self.boundary_index) {
self.boundary_index += 1;
return Some(*pair);
}
self.state ^= self.state << 13;
self.state ^= self.state >> 17;
self.state ^= self.state << 5;
Some((self.state, self.state.rotate_left(7)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::spec::types::conform::Strictness;
use crate::spec::types::{DataType, OpSignature};
use vyre_spec::Category;
struct DivSqrtBackend {
op: DivSqrtOp,
approximate: bool,
}
impl WgslBackend for DivSqrtBackend {
fn name(&self) -> &str {
"div-sqrt-mock"
}
fn dispatch(
&self,
_wgsl: &str,
input: &[u8],
_output_size: usize,
_config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
let x_bits = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
let y_bits = if input.len() >= 8 {
u32::from_le_bytes([input[4], input[5], input[6], input[7]])
} else {
1
};
let mut out = reference_bytes(self.op, x_bits, y_bits);
if self.approximate {
out[0] ^= 1;
}
Ok(out)
}
}
fn cpu(_: &[u8]) -> Vec<u8> {
0_u32.to_le_bytes().to_vec()
}
fn regular_wgsl() -> String {
"fn vyre_op(index: u32, input_len: u32) -> u32 { return 0u; }".to_string()
}
fn test_spec(id: &'static str) -> OpSpec {
OpSpec::builder(id)
.signature(OpSignature {
inputs: vec![DataType::F32],
output: DataType::F32,
})
.cpu_fn(cpu)
.wgsl_fn(regular_wgsl)
.category(Category::A {
composition_of: vec![],
})
.laws(vec![])
.strictness(Strictness::Strict)
.version(1)
.build()
.expect("registry invariant violated")
}
#[test]
fn positive_correct_division_passes() {
let spec = test_spec("primitive.float.fdiv");
let findings = enforce_div_sqrt_with_witnesses(
&DivSqrtBackend {
op: DivSqrtOp::Div,
approximate: false,
},
&spec,
64,
);
assert!(findings.is_empty(), "{findings:?}");
}
#[test]
fn negative_inverse_sqrt_shader_fails() {
let findings = inspect_shader_for_div_sqrt_approximations(
"fn x(v: f32) -> f32 { return inverseSqrt(v); }",
DivSqrtOp::Sqrt,
);
assert_eq!(findings.len(), 1);
assert!(findings[0].message.contains("SPEC.md:330"));
}
#[test]
fn negative_pow_sqrt_shader_fails() {
let findings = inspect_shader_for_div_sqrt_approximations(
"fn x(v: f32) -> f32 { return pow(v, 0.5); }",
DivSqrtOp::Sqrt,
);
assert_eq!(findings.len(), 1);
assert!(findings[0].message.contains("pow"));
}
#[test]
fn boundary_correctly_rounded_inverse_sqrt_token_is_still_rejected() {
let findings = inspect_shader_for_div_sqrt_approximations(
"fn x(v: f32) -> f32 { return inverseSqrt(v); }",
DivSqrtOp::Sqrt,
);
assert_eq!(findings.len(), 1);
assert!(findings[0].message.contains("inverseSqrt"));
}
#[test]
fn boundary_zero_divisor_is_injected() {
assert!(WitnessStream::new()
.take(8)
.any(|(x, y)| x == 0x3F80_0000 && y == 0x0000_0000));
}
#[test]
fn approximation_output_fails_parity() {
let spec = test_spec("primitive.float.fsqrt");
let findings = enforce_div_sqrt_with_witnesses(
&DivSqrtBackend {
op: DivSqrtOp::Sqrt,
approximate: true,
},
&spec,
1,
);
assert_eq!(findings[0].rule, "B5");
}
}
}
pub mod b6_tensor {
use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};
use crate::spec::types::conform::BufferInitPolicy;
use crate::spec::types::OpSpec;
use crate::enforce::enforcers::float_semantics::{lower_id, FloatFinding};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum AccumulatorPrecision {
F32,
Tf32,
Underspecified,
}
#[inline]
pub(crate) fn enforce_tensor_precision(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
) -> Vec<FloatFinding> {
if !is_matmul(spec) {
return Vec::new();
}
if !declares_f32_accumulator(spec) {
return vec![FloatFinding::failed(
"B6",
334,
spec.id,
format!(
"MatMul `{}` omits accumulator_type: F32. Fix: declare accumulator_type explicitly; underspecified tensor precision is rejected by SPEC.md:334.",
spec.id
),
)];
}
let mut findings = inspect_shader_for_tf32(&(spec.wgsl_fn)());
if !findings.is_empty() {
return findings;
}
findings.extend(run_matmul_precision_probe(backend, spec));
findings
}
#[inline]
pub fn inspect_shader_for_tf32(wgsl: &str) -> Vec<FloatFinding> {
let lower = wgsl.to_ascii_lowercase();
let mut findings = Vec::new();
for token in [
"tensor_op_math",
"fast_matmul",
"use_tf32=true",
"tf32=true",
] {
if lower.contains(token) {
findings.push(FloatFinding::failed(
"B6",
334,
"WGSL",
format!(
"Tensor precision downgrade token `{token}` detected. Fix: use F32 accumulators for accumulator_type: F32 per SPEC.md:334."
),
));
}
}
if lower.contains("allow_tf32")
&& !(lower.contains("allow_tf32=false")
|| lower.contains("allow_tf32 = false")
|| lower.contains("allow_tf32: false"))
{
findings.push(FloatFinding::failed(
"B6",
334,
"WGSL",
"Tensor precision downgrade token `allow_tf32` detected without an explicit false value. Fix: disable TF32 for accumulator_type: F32 per SPEC.md:334.",
));
}
findings
}
#[inline]
pub fn classify_accumulator_precision(output_bits: u32) -> AccumulatorPrecision {
let f32_bits = f32_matmul_probe_result().to_bits();
let tf32_bits = tf32_matmul_probe_result().to_bits();
if output_bits == f32_bits {
AccumulatorPrecision::F32
} else if output_bits == tf32_bits {
AccumulatorPrecision::Tf32
} else {
AccumulatorPrecision::Underspecified
}
}
fn run_matmul_precision_probe(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
) -> Vec<FloatFinding> {
let input = matmul_probe_input();
match super::dispatch_legacy_wgsl_probe(backend,
&(spec.wgsl_fn)(),
&input,
4,
ConformDispatchConfig {
workgroup_size: 1,
workgroup_count: 1,
convention: spec.convention,
lookup_data: None,
buffer_init: BufferInitPolicy::Zero,
},
) {
Ok(output) if output.len() >= 4 => {
let bits = u32::from_le_bytes([output[0], output[1], output[2], output[3]]);
match classify_accumulator_precision(bits) {
AccumulatorPrecision::F32 => Vec::new(),
AccumulatorPrecision::Tf32 => vec![FloatFinding::failed(
"B6",
334,
spec.id,
format!(
"MatMul `{}` produced TF32-level accumulator result. Fix: disable silent TF32 substitution and use F32 accumulation per SPEC.md:334.",
spec.id
),
)],
AccumulatorPrecision::Underspecified => vec![FloatFinding::failed(
"B6",
334,
spec.id,
format!(
"MatMul `{}` produced unknown accumulator precision bits 0x{bits:08X}. Fix: match the F32 reference exactly per SPEC.md:334.",
spec.id
),
)],
}
}
Ok(output) => vec![FloatFinding::failed(
"B6",
334,
spec.id,
format!(
"MatMul `{}` returned {} bytes for precision probe. Fix: return one f32 output for the 16x16x16 mixed-sign conformance matmul.",
spec.id,
output.len()
),
)],
Err(error) => vec![FloatFinding::failed(
"B6",
334,
spec.id,
format!(
"MatMul `{}` precision probe failed: {error}. Fix: strict tensor backends must expose F32 accumulation or report UnsupportedByBackend per SPEC.md:334.",
spec.id
),
)],
}
}
fn is_matmul(spec: &OpSpec) -> bool {
let id = lower_id(spec.id);
id.contains("matmul")
}
fn declares_f32_accumulator(spec: &OpSpec) -> bool {
let wgsl = (spec.wgsl_fn)().to_ascii_lowercase();
wgsl.contains("accumulator_type: f32")
|| wgsl.contains("accumulator_type = \"f32\"")
|| wgsl.contains("accumulator_type = 'f32'")
|| wgsl.contains("accumulator_type = f32")
|| wgsl.contains("accumulator = \"f32\"")
|| spec.docs_path.contains("accumulator_type=F32")
|| spec.docs_path.contains("accumulator_type=f32")
}
fn matmul_probe_input() -> Vec<u8> {
let mut input = Vec::with_capacity((3 * 4) + (16 * 16 * 2 * 4));
for dim in [16_u32, 16, 16] {
input.extend_from_slice(&dim.to_le_bytes());
}
for value in matmul_probe_a()
.into_iter()
.chain(matmul_probe_b().into_iter())
{
input.extend_from_slice(&value.to_bits().to_le_bytes());
}
input
}
fn f32_matmul_probe_result() -> f32 {
matmul_probe_a()
.into_iter()
.zip(matmul_probe_b())
.fold(0.0_f32, |acc, (a, b)| acc + (a * b))
}
fn tf32_matmul_probe_result() -> f32 {
matmul_probe_a()
.into_iter()
.zip(matmul_probe_b())
.fold(0.0_f32, |acc, (a, b)| acc + (tf32_round(a) * tf32_round(b)))
}
fn matmul_probe_a() -> Vec<f32> {
(0..16)
.map(|index| {
let sign = if index % 2 == 0 { 1.0 } else { -1.0 };
sign * (1.0_f32 + ((index + 1) as f32) * 2.0_f32.powi(-12))
})
.collect()
}
fn matmul_probe_b() -> Vec<f32> {
(0..16)
.map(|index| {
let sign = if index % 3 == 0 { -1.0 } else { 1.0 };
sign * (1.0_f32 - ((index + 1) as f32) * 2.0_f32.powi(-13))
})
.collect()
}
fn tf32_round(value: f32) -> f32 {
let mut bits = value.to_bits();
let mantissa = bits & 0x1FFF;
if mantissa > 0x1000 || (mantissa == 0x1000 && (bits & 0x2000 != 0)) {
bits = bits.wrapping_add(0x1000);
}
f32::from_bits(bits & 0xFFFF_E000)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::spec::types::conform::Strictness;
use crate::spec::types::{DataType, OpSignature};
use vyre_spec::Category;
struct MatmulBackend {
precision: AccumulatorPrecision,
}
impl WgslBackend for MatmulBackend {
fn name(&self) -> &str {
"matmul-mock"
}
fn dispatch(
&self,
_wgsl: &str,
_input: &[u8],
_output_size: usize,
_config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
let value = match self.precision {
AccumulatorPrecision::F32 => f32_matmul_probe_result(),
AccumulatorPrecision::Tf32 => tf32_matmul_probe_result(),
AccumulatorPrecision::Underspecified => 0.0,
};
Ok(value.to_bits().to_le_bytes().to_vec())
}
}
fn cpu(_: &[u8]) -> Vec<u8> {
0_u32.to_le_bytes().to_vec()
}
fn accumulator_wgsl() -> String {
"// accumulator_type: F32".to_string()
}
fn missing_accumulator_wgsl() -> String {
"fn main() {}".to_string()
}
fn matmul_spec() -> OpSpec {
OpSpec::builder("primitive.tensor.matmul")
.signature(OpSignature {
inputs: vec![DataType::Tensor, DataType::Tensor],
output: DataType::Tensor,
})
.cpu_fn(cpu)
.wgsl_fn(accumulator_wgsl)
.category(Category::A {
composition_of: vec![],
})
.laws(vec![])
.strictness(Strictness::Strict)
.version(1)
.build()
.expect("registry invariant violated")
}
#[test]
fn positive_f32_accumulation_passes() {
let findings = enforce_tensor_precision(
&MatmulBackend {
precision: AccumulatorPrecision::F32,
},
&matmul_spec(),
);
assert!(findings.is_empty(), "{findings:?}");
}
#[test]
fn negative_tf32_output_is_rejected() {
let findings = enforce_tensor_precision(
&MatmulBackend {
precision: AccumulatorPrecision::Tf32,
},
&matmul_spec(),
);
assert_eq!(findings[0].rule, "B6");
assert!(findings[0].message.contains("SPEC.md:334"));
}
#[test]
fn boundary_missing_accumulator_type_rejected() {
let mut spec = matmul_spec();
spec.wgsl_fn = missing_accumulator_wgsl;
let findings = enforce_tensor_precision(
&MatmulBackend {
precision: AccumulatorPrecision::F32,
},
&spec,
);
assert!(findings[0].message.contains("accumulator_type"));
}
#[test]
fn boundary_allow_tf32_false_is_not_rejected() {
let findings = inspect_shader_for_tf32("// allow_tf32=false");
assert!(findings.is_empty(), "{findings:?}");
}
#[test]
fn negative_tf32_tie_rounds_to_even() {
let value = f32::from_bits(0x3F80_1000);
assert_eq!(tf32_round(value).to_bits(), 0x3F80_0000);
}
#[test]
fn boundary_tensor_non_matmul_is_ignored() {
let mut spec = matmul_spec();
spec.id = "primitive.tensor.relu";
let findings = enforce_tensor_precision(
&MatmulBackend {
precision: AccumulatorPrecision::Underspecified,
},
&spec,
);
assert!(findings.is_empty(), "{findings:?}");
}
}
}
pub mod precision {
use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};
use crate::spec::types::conform::BufferInitPolicy;
use crate::spec::types::DataType;
use crate::spec::types::OpSpec;
use crate::enforce::enforcers::float_semantics::FloatFinding;
#[inline]
pub fn enforce_declared_float_width(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
) -> Vec<FloatFinding> {
let Some(width) = declared_float_width(spec) else {
return Vec::new();
};
if width == FloatWidth::F32 {
return Vec::new();
}
for input in width_witnesses(spec, width) {
let expected = (spec.cpu_fn)(&input);
let output_size = spec
.expected_output_bytes
.unwrap_or_else(|| spec.signature.output.min_bytes().max(width.bytes()));
match super::dispatch_legacy_wgsl_probe(backend, &(spec.wgsl_fn)(), &input, output_size, ConformDispatchConfig {
workgroup_size: spec.workgroup_size.unwrap_or(1),
workgroup_count: 1,
convention: spec.convention,
lookup_data: None,
buffer_init: BufferInitPolicy::Zero,
}) {
Ok(output) if output == expected => {}
Ok(output) => return vec![FloatFinding::failed(
"B-Precision",
338,
spec.id,
format!("`{}` declared {width:?} but failed declared-width GPU parity: gpu={output:02X?}, cpu={expected:02X?}. Fix: run the kernel at the declared precision; do not substitute f32.", spec.id),
)],
Err(error) => return vec![FloatFinding::failed(
"B-Precision",
338,
spec.id,
format!("`{}` declared {width:?} but dispatch failed: {error}. Fix: implement native declared-precision execution or reject registration.", spec.id),
)],
}
}
Vec::new()
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum FloatWidth {
F16,
F32,
F64,
}
impl FloatWidth {
fn bytes(self) -> usize {
match self {
Self::F16 => 2,
Self::F32 => 4,
Self::F64 => 8,
}
}
}
fn declared_float_width(spec: &OpSpec) -> Option<FloatWidth> {
let mut saw_f32 = false;
for ty in spec
.signature
.inputs
.iter()
.chain(std::iter::once(&spec.signature.output))
{
match ty {
DataType::F16 => return Some(FloatWidth::F16),
DataType::F64 => return Some(FloatWidth::F64),
DataType::F32 => saw_f32 = true,
_ => {}
}
}
saw_f32.then_some(FloatWidth::F32)
}
fn width_witnesses(spec: &OpSpec, width: FloatWidth) -> Vec<Vec<u8>> {
let patterns: Vec<Vec<u8>> = match width {
FloatWidth::F16 => vec![
0x3C00_u16.to_le_bytes().to_vec(),
0x0001_u16.to_le_bytes().to_vec(),
],
FloatWidth::F32 => vec![0x3F80_0000_u32.to_le_bytes().to_vec()],
FloatWidth::F64 => vec![
0x3FF0_0000_0000_0000_u64.to_le_bytes().to_vec(),
0x0000_0000_0000_0001_u64.to_le_bytes().to_vec(),
],
};
patterns
.iter()
.map(|pattern| {
let mut input = Vec::with_capacity(spec.signature.min_input_bytes());
for ty in &spec.signature.inputs {
let len = ty.min_bytes();
if len > 0 {
input.extend(pattern.iter().copied().cycle().take(len));
}
}
input
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::spec::types::conform::Strictness;
use crate::spec::types::OpSignature;
use vyre_spec::Category;
fn cpu(_: &[u8]) -> Vec<u8> {
Vec::new()
}
fn wgsl() -> String {
String::new()
}
fn spec(inputs: Vec<DataType>, output: DataType) -> OpSpec {
OpSpec::builder("primitive.float.mixed")
.signature(OpSignature { inputs, output })
.cpu_fn(cpu)
.wgsl_fn(wgsl)
.category(Category::A {
composition_of: vec![],
})
.laws(vec![])
.strictness(Strictness::Strict)
.version(1)
.build()
.expect("registry invariant violated")
}
#[test]
fn boundary_f64_is_not_hidden_by_f32_input() {
assert_eq!(
declared_float_width(&spec(vec![DataType::F32], DataType::F64)),
Some(FloatWidth::F64)
);
}
#[test]
fn boundary_f16_is_not_hidden_by_f32_input() {
assert_eq!(
declared_float_width(&spec(vec![DataType::F32], DataType::F16)),
Some(FloatWidth::F16)
);
}
}
}
pub mod rounding {
use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};
use crate::spec::types::conform::BufferInitPolicy;
use crate::spec::types::DataType;
use crate::spec::types::OpSpec;
use crate::enforce::enforcers::float_semantics::{is_float_spec, FloatFinding};
#[inline]
pub fn enforce_round_to_nearest_even(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
) -> Vec<FloatFinding> {
if !is_float_spec(spec) {
return Vec::new();
}
let mut findings = inspect_shader_for_rounding_tokens(&(spec.wgsl_fn)(), spec.id);
if !findings.is_empty() {
return findings;
}
for (label, input) in rounding_witnesses(spec) {
let expected = (spec.cpu_fn)(&input);
match super::dispatch_legacy_wgsl_probe(
backend,
&(spec.wgsl_fn)(),
&input,
expected.len(),
ConformDispatchConfig {
workgroup_size: spec.workgroup_size.unwrap_or(1),
workgroup_count: 1,
convention: spec.convention,
lookup_data: None,
buffer_init: BufferInitPolicy::Zero,
},
) {
Ok(output) if output == expected => {}
Ok(output) => {
findings.push(FloatFinding::failed(
"B-Rounding",
338,
spec.id,
format!(
"`{}` failed RNE-sensitive witness `{label}`: gpu={output:02X?}, cpu={expected:02X?}. Fix: dispatch strict IEEE round-to-nearest-even semantics; do not use RTZ/RDN/RUP substitution.",
spec.id
),
));
break;
}
Err(error) => {
findings.push(FloatFinding::failed(
"B-Rounding",
338,
spec.id,
format!(
"`{}` dispatch failed for RNE-sensitive witness `{label}`: {error}. Fix: strict float backends must execute rounding probes before certification.",
spec.id
),
));
break;
}
}
}
findings
}
fn inspect_shader_for_rounding_tokens(wgsl: &str, subject: &str) -> Vec<FloatFinding> {
let lower = wgsl.to_ascii_lowercase();
["rtz", "rdn", "rup", "round_toward_zero", "round_down", "round_up"]
.into_iter()
.filter(|token| lower.contains(token))
.map(|token| {
FloatFinding::failed(
"B-Rounding",
338,
subject,
format!(
"Forbidden non-RNE rounding token `{token}` in strict float shader. Fix: remove explicit RTZ/RDN/RUP paths or declare a separate operation with matching semantics."
),
)
})
.collect()
}
fn rounding_witnesses(spec: &OpSpec) -> Vec<(String, Vec<u8>)> {
let patterns = [
("half-ulp-even", 0x3F80_0000_u32),
("half-ulp-odd", 0x3F80_0001_u32),
("subnormal-tie", 0x0000_0001_u32),
("negative-tie", 0xBF80_0001_u32),
];
patterns
.into_iter()
.map(|(label, bits)| {
let mut input = Vec::with_capacity(spec.signature.min_input_bytes());
for ty in &spec.signature.inputs {
append_pattern(&mut input, ty, bits);
}
(label.to_string(), input)
})
.collect()
}
fn append_pattern(out: &mut Vec<u8>, ty: &DataType, bits: u32) {
match ty {
DataType::F16 | DataType::BF16 => out.extend_from_slice(&(bits as u16).to_le_bytes()),
DataType::F32 | DataType::U32 | DataType::I32 | DataType::Bool => {
out.extend_from_slice(&bits.to_le_bytes());
}
DataType::F64 | DataType::U64 => {
let wide = (u64::from(bits) << 32) | u64::from(bits.rotate_left(13));
out.extend_from_slice(&wide.to_le_bytes());
}
DataType::Vec2U32 => {
out.extend_from_slice(&bits.to_le_bytes());
out.extend_from_slice(&bits.rotate_left(13).to_le_bytes());
}
DataType::Vec4U32 => {
for lane in [
bits,
bits.rotate_left(13),
bits.rotate_right(7),
bits ^ 0x8000_0000,
] {
out.extend_from_slice(&lane.to_le_bytes());
}
}
DataType::Bytes | DataType::Array { .. } | DataType::Tensor => {}
}
}
}
pub struct FloatSemanticsEnforcer;
impl crate::enforce::EnforceGate for FloatSemanticsEnforcer {
fn id(&self) -> &'static str {
"float_semantics"
}
fn name(&self) -> &'static str {
"float_semantics"
}
fn run(&self, ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
let Some(backend) = ctx.backend else {
return vec![crate::enforce::aggregate_finding(self.id(), vec!["float_semantics: backend is required. Fix: provide a VyreBackend in EnforceCtx.".to_string()])];
};
let messages = ctx
.specs
.iter()
.filter(|spec| is_float_spec(spec))
.flat_map(|spec| enforce_float_semantics(backend, spec))
.map(|finding| finding.message)
.collect::<Vec<_>>();
crate::enforce::finding_result(self.id(), messages)
}
}
pub const REGISTERED: FloatSemanticsEnforcer = FloatSemanticsEnforcer;