use crate::spec::law::canonical_law_id;
use crate::spec::law::AlgebraicLaw;
use crate::spec::types::{BoundaryValue, DataType, EquivalenceClass, OpSpec, Strictness};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FingerprintError {
CpuFnPanic {
op_id: &'static str,
input: Vec<u8>,
},
}
impl core::fmt::Display for FingerprintError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::CpuFnPanic { op_id, input } => {
write!(f, "CPU reference for {op_id} panicked on input {input:?}")
}
}
}
}
impl core::error::Error for FingerprintError {}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct PublishedSpecHash {
pub id: &'static str,
pub version: u32,
pub fingerprint: [u8; 32],
}
pub const PUBLISHED_SPEC_HASHES: &[PublishedSpecHash] = &[];
#[inline]
pub fn fingerprint_spec(spec: &OpSpec) -> Result<[u8; 32], FingerprintError> {
let mut hash = Sha256::new();
hash_str(&mut hash, "id", spec.id);
hash.update(&spec.version.to_le_bytes());
hash_signature(&mut hash, &spec.signature.inputs, &spec.signature.output);
hash_strictness(&mut hash, spec.strictness);
hash_str(&mut hash, "category", &format!("{:?}", spec.category));
hash_str(&mut hash, "comparator", &format!("{:?}", spec.comparator));
hash_str(&mut hash, "convention", &format!("{:?}", spec.convention));
hash.update(&spec.since_version.major.to_le_bytes());
hash.update(&spec.since_version.minor.to_le_bytes());
hash.update(&spec.since_version.patch.to_le_bytes());
hash_str(&mut hash, "docs_path", spec.docs_path);
hash_workgroup_size(&mut hash, spec.workgroup_size);
hash_alt_wgsl(&mut hash, &spec.alt_wgsl_fns);
hash_equivalence_classes(&mut hash, &spec.equivalence_classes);
hash_boundary_values(&mut hash, &spec.boundary_values);
hash_str(
&mut hash,
"oracle_override",
&format!("{:?}", spec.oracle_override),
);
for law in &spec.laws {
hash.update(law_fingerprint(law).as_bytes());
hash.update(&[0]);
}
for seed in 0..256_u64 {
let input = sample_input(&spec.signature.inputs, seed);
let output = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
canonicalize_output((spec.cpu_fn)(&input), &spec.signature.output)
})) {
Ok(output) => output,
Err(_) => {
return Err(FingerprintError::CpuFnPanic {
op_id: spec.id,
input: input.clone(),
});
}
};
hash.update(&(input.len() as u64).to_le_bytes());
hash.update(&input);
hash.update(&(output.len() as u64).to_le_bytes());
hash.update(&output);
}
Ok(hash.finish())
}
fn hash_str(hash: &mut Sha256, tag: &str, value: &str) {
hash.update(tag.as_bytes());
hash.update(&[0]);
hash.update(&(value.len() as u64).to_le_bytes());
hash.update(value.as_bytes());
hash.update(&[0xff]);
}
fn hash_signature(hash: &mut Sha256, inputs: &[DataType], output: &DataType) {
hash.update(&(inputs.len() as u64).to_le_bytes());
for input in inputs {
hash_data_type(hash, input);
}
hash_data_type(hash, output);
}
fn hash_data_type(hash: &mut Sha256, ty: &DataType) {
match ty {
DataType::U32 => hash_str(hash, "type", "u32"),
DataType::I32 => hash_str(hash, "type", "i32"),
DataType::Bool => hash_str(hash, "type", "bool"),
DataType::U64 => hash_str(hash, "type", "u64"),
DataType::Vec2U32 => hash_str(hash, "type", "vec2u32"),
DataType::Vec4U32 => hash_str(hash, "type", "vec4u32"),
DataType::Bytes => hash_str(hash, "type", "bytes"),
DataType::Array { element_size } => {
hash_str(hash, "type", "array");
hash.update(&(*element_size as u64).to_le_bytes());
}
DataType::F16 => hash_str(hash, "type", "f16"),
DataType::BF16 => hash_str(hash, "type", "bf16"),
DataType::F32 => hash_str(hash, "type", "f32"),
DataType::F64 => hash_str(hash, "type", "f64"),
DataType::Tensor => hash_str(hash, "type", "tensor"),
}
}
fn hash_strictness(hash: &mut Sha256, strictness: Strictness) {
match strictness {
Strictness::Strict => hash_str(hash, "strictness", "strict"),
Strictness::Approximate { max_ulps } => {
hash_str(hash, "strictness", "approximate");
hash.update(&max_ulps.to_le_bytes());
}
}
}
fn hash_workgroup_size(hash: &mut Sha256, workgroup_size: Option<u32>) {
match workgroup_size {
Some(size) => {
hash.update(&[1]);
hash.update(&size.to_le_bytes());
}
None => hash.update(&[0]),
}
}
fn hash_alt_wgsl(hash: &mut Sha256, alt_wgsl_fns: &[crate::spec::types::AltWgslSource]) {
hash.update(&(alt_wgsl_fns.len() as u64).to_le_bytes());
for (label, source_fn) in alt_wgsl_fns {
hash_str(hash, "alt_label", label);
hash_str(hash, "alt_source", &source_fn());
}
}
fn hash_equivalence_classes(hash: &mut Sha256, classes: &[EquivalenceClass]) {
hash.update(&(classes.len() as u64).to_le_bytes());
for class in classes {
hash_str(hash, "equivalence", class.description);
hash.update(&[u8::from(class.universal)]);
hash.update(&(class.representative.len() as u64).to_le_bytes());
for value in &class.representative {
hash.update(&value.to_le_bytes());
}
}
}
fn hash_boundary_values(hash: &mut Sha256, values: &[BoundaryValue]) {
hash.update(&(values.len() as u64).to_le_bytes());
for value in values {
hash_str(hash, "boundary", value.label);
hash.update(&(value.inputs.len() as u64).to_le_bytes());
for input in &value.inputs {
hash.update(&input.to_le_bytes());
}
}
}
#[inline]
pub fn published_hash(id: &str, version: u32) -> Option<&'static PublishedSpecHash> {
PUBLISHED_SPEC_HASHES
.iter()
.find(|entry| entry.id == id && entry.version == version)
}
fn sample_input(inputs: &[DataType], seed: u64) -> Vec<u8> {
let mut rng = SplitMix64::new(seed ^ 0x9E37_79B9_7F4A_7C15);
let mut bytes = Vec::new();
for ty in inputs {
match ty {
DataType::U32 | DataType::I32 | DataType::Bool | DataType::F32 => {
bytes.extend_from_slice(&rng.next_u32().to_le_bytes());
}
DataType::U64 | DataType::F64 => bytes.extend_from_slice(&rng.next_u64().to_le_bytes()),
DataType::Vec2U32 => {
bytes.extend_from_slice(&rng.next_u32().to_le_bytes());
bytes.extend_from_slice(&rng.next_u32().to_le_bytes());
}
DataType::Vec4U32 => {
for _ in 0..4 {
bytes.extend_from_slice(&rng.next_u32().to_le_bytes());
}
}
DataType::F16 | DataType::BF16 => {
bytes.extend_from_slice(&(rng.next_u32() as u16).to_le_bytes());
}
DataType::Bytes | DataType::Array { .. } | DataType::Tensor => {
for _ in 0..32 {
bytes.push(rng.next_u32() as u8);
}
}
}
}
bytes
}
fn law_fingerprint(law: &AlgebraicLaw) -> String {
canonical_law_id(law)
}
fn canonicalize_output(mut output: Vec<u8>, ty: &DataType) -> Vec<u8> {
match ty {
DataType::F32 => {
for chunk in output.chunks_exact_mut(4) {
let bits = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
let canonical = if f32::from_bits(bits).is_nan() {
0x7fc0_0000_u32
} else {
bits
};
chunk.copy_from_slice(&canonical.to_le_bytes());
}
}
DataType::F64 => {
for chunk in output.chunks_exact_mut(8) {
let bits = u64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
]);
let canonical = if f64::from_bits(bits).is_nan() {
0x7ff8_0000_0000_0000_u64
} else {
bits
};
chunk.copy_from_slice(&canonical.to_le_bytes());
}
}
DataType::F16 => canonicalize_u16_nan(&mut output, 0x7e00, 0x7c00, 0x03ff),
DataType::BF16 => canonicalize_u16_nan(&mut output, 0x7fc0, 0x7f80, 0x007f),
DataType::U32 | DataType::I32 | DataType::Bool | DataType::Vec2U32 | DataType::Vec4U32 => {
for chunk in output.chunks_exact_mut(4) {
let value = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
chunk.copy_from_slice(&value.to_le_bytes());
}
}
DataType::U64 => {
for chunk in output.chunks_exact_mut(8) {
let value = u64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
]);
chunk.copy_from_slice(&value.to_le_bytes());
}
}
DataType::Bytes | DataType::Array { .. } | DataType::Tensor => {}
}
output
}
fn canonicalize_u16_nan(
output: &mut [u8],
canonical_nan: u16,
exponent_mask: u16,
mantissa_mask: u16,
) {
for chunk in output.chunks_exact_mut(2) {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
let canonical = if bits & exponent_mask == exponent_mask && bits & mantissa_mask != 0 {
canonical_nan
} else {
bits
};
chunk.copy_from_slice(&canonical.to_le_bytes());
}
}
struct SplitMix64 {
state: u64,
}
impl SplitMix64 {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut value = self.state;
value = (value ^ (value >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
value = (value ^ (value >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
value ^ (value >> 31)
}
fn next_u32(&mut self) -> u32 {
self.next_u64() as u32
}
}
struct Sha256 {
state: [u32; 8],
buffer: [u8; 64],
buffer_len: usize,
byte_len: u64,
}
impl Sha256 {
fn new() -> Self {
Self {
state: [
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab,
0x5be0cd19,
],
buffer: [0; 64],
buffer_len: 0,
byte_len: 0,
}
}
fn update(&mut self, mut input: &[u8]) {
self.byte_len = self.byte_len.wrapping_add(input.len() as u64);
if self.buffer_len > 0 {
let take = (64 - self.buffer_len).min(input.len());
self.buffer[self.buffer_len..self.buffer_len + take].copy_from_slice(&input[..take]);
self.buffer_len += take;
input = &input[take..];
if self.buffer_len == 64 {
compress(&mut self.state, &self.buffer);
self.buffer_len = 0;
}
}
while input.len() >= 64 {
let mut block = [0u8; 64];
block.copy_from_slice(&input[..64]);
compress(&mut self.state, &block);
input = &input[64..];
}
self.buffer[..input.len()].copy_from_slice(input);
self.buffer_len = input.len();
}
fn finish(mut self) -> [u8; 32] {
let bit_len = self.byte_len.wrapping_mul(8);
self.buffer[self.buffer_len] = 0x80;
self.buffer_len += 1;
if self.buffer_len > 56 {
self.buffer[self.buffer_len..].fill(0);
compress(&mut self.state, &self.buffer);
self.buffer_len = 0;
}
self.buffer[self.buffer_len..56].fill(0);
self.buffer[56..].copy_from_slice(&bit_len.to_be_bytes());
compress(&mut self.state, &self.buffer);
let mut out = [0u8; 32];
for (chunk, word) in out.chunks_exact_mut(4).zip(self.state) {
chunk.copy_from_slice(&word.to_be_bytes());
}
out
}
}
fn compress(state: &mut [u32; 8], block: &[u8; 64]) {
const K: [u32; 64] = [
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4,
0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe,
0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f,
0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc,
0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b,
0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116,
0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7,
0xc67178f2,
];
let mut w = [0u32; 64];
for (idx, chunk) in block.chunks_exact(4).take(16).enumerate() {
w[idx] = u32::from_be_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
for i in 16..64 {
let s0 = w[i - 15].rotate_right(7) ^ w[i - 15].rotate_right(18) ^ (w[i - 15] >> 3);
let s1 = w[i - 2].rotate_right(17) ^ w[i - 2].rotate_right(19) ^ (w[i - 2] >> 10);
w[i] = w[i - 16]
.wrapping_add(s0)
.wrapping_add(w[i - 7])
.wrapping_add(s1);
}
let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = *state;
for i in 0..64 {
let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25);
let ch = (e & f) ^ ((!e) & g);
let temp1 = h
.wrapping_add(s1)
.wrapping_add(ch)
.wrapping_add(K[i])
.wrapping_add(w[i]);
let s0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22);
let maj = (a & b) ^ (a & c) ^ (b & c);
let temp2 = s0.wrapping_add(maj);
h = g;
g = f;
f = e;
e = d.wrapping_add(temp1);
d = c;
c = b;
b = a;
a = temp1.wrapping_add(temp2);
}
for (slot, value) in state.iter_mut().zip([a, b, c, d, e, f, g, h]) {
*slot = slot.wrapping_add(value);
}
}