use std::collections::BTreeMap;
use vyre_foundation::ir::Program;
use vyre_spec::data_type::DataType;
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum SpecValue {
U32(u32),
I32(i32),
F32(f32),
Bool(bool),
DType(DataType),
}
impl SpecValue {
#[must_use]
pub fn as_pipeline_f64(&self) -> f64 {
match self {
SpecValue::U32(value) => f64::from(*value),
SpecValue::I32(value) => f64::from(*value),
SpecValue::F32(value) => f64::from(*value),
SpecValue::Bool(value) => f64::from(u8::from(*value)),
SpecValue::DType(dtype) => f64::from(dtype_tag(dtype)),
}
}
#[must_use]
pub fn cache_hash(&self) -> u64 {
match self {
SpecValue::U32(value) => u64::from(*value) << 8,
SpecValue::I32(value) => (1u64) | ((*value as u32 as u64) << 8),
SpecValue::F32(value) => (2u64) | ((value.to_bits() as u64) << 8),
SpecValue::Bool(value) => (3u64) | (u64::from(u8::from(*value)) << 8),
SpecValue::DType(dtype) => (4u64) | (u64::from(dtype_tag(dtype)) << 8),
}
}
}
fn dtype_tag(dtype: &DataType) -> u32 {
match dtype {
DataType::Bool => 1,
DataType::U8 => 2,
DataType::U16 => 3,
DataType::U32 => 4,
DataType::I8 => 5,
DataType::I16 => 6,
DataType::I32 => 7,
DataType::F32 => 8,
DataType::Bytes => 9,
_ => 0xFFFF_FFFF,
}
}
#[derive(Debug, Default, Clone)]
pub struct SpecMap {
entries: BTreeMap<String, SpecValue>,
}
impl SpecMap {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, name: impl Into<String>, value: SpecValue) {
self.entries.insert(name.into(), value);
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &SpecValue)> {
self.entries
.iter()
.map(|(key, value)| (key.as_str(), value))
}
#[must_use]
pub fn to_numeric_constants(&self) -> std::collections::HashMap<String, f64> {
let mut out = std::collections::HashMap::with_capacity(self.entries.len());
for (key, value) in &self.entries {
out.insert(key.clone(), value.as_pipeline_f64());
}
out
}
#[must_use]
pub fn cache_hash(&self) -> u64 {
let mut hash: u64 = 0xcbf29ce484222325;
for (name, value) in self.iter() {
for byte in name.as_bytes() {
hash ^= u64::from(*byte);
hash = hash.wrapping_mul(0x100000001b3);
}
for byte in value.cache_hash().to_le_bytes() {
hash ^= u64::from(byte);
hash = hash.wrapping_mul(0x100000001b3);
}
}
hash
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SpecCacheKey {
pub shader_hash: u64,
pub binding_sig: u64,
pub workgroup_size: [u32; 3],
pub spec_hash: u64,
}
impl SpecCacheKey {
#[must_use]
pub fn new(
shader_hash: u64,
binding_sig: u64,
workgroup_size: [u32; 3],
specs: &SpecMap,
) -> Self {
Self {
shader_hash,
binding_sig,
workgroup_size,
spec_hash: specs.cache_hash(),
}
}
}
#[must_use]
pub fn vsa_specialization_key(program: &Program, spec_hash: u64) -> u128 {
let fingerprint = crate::launch::program_vsa_fingerprint_words(program);
let fp_lo = fingerprint
.iter()
.take(2)
.enumerate()
.fold(0_u64, |acc, (i, &word)| {
acc | (u64::from(word) << (32 * (i as u32)))
});
((fp_lo as u128) << 64) | u128::from(spec_hash)
}
#[must_use]
pub fn versioned_specialization_artifact_key(
cache_version: u32,
spec_hash: &str,
backend_fingerprint: &str,
) -> String {
let mut hasher = blake3::Hasher::new();
hasher.update(b"vyre-specialization-artifact-key-v1\0version\0");
hasher.update(&cache_version.to_le_bytes());
hasher.update(b"\0spec\0");
hasher.update(&(spec_hash.len() as u64).to_le_bytes());
hasher.update(spec_hash.as_bytes());
hasher.update(b"\0backend\0");
hasher.update(&(backend_fingerprint.len() as u64).to_le_bytes());
hasher.update(backend_fingerprint.as_bytes());
let hash = hasher.finalize();
let mut key = String::with_capacity(64);
push_lower_hex(hash.as_bytes(), &mut key);
key
}
fn push_lower_hex(bytes: &[u8], out: &mut String) {
const HEX: &[u8; 16] = b"0123456789abcdef";
out.reserve(bytes.len().saturating_mul(2));
for &byte in bytes {
out.push(HEX[(byte >> 4) as usize] as char);
out.push(HEX[(byte & 0x0f) as usize] as char);
}
}
#[cfg(test)]
mod tests {
use super::*;
use vyre_foundation::ir::{BufferDecl, DataType, Expr, Node, Program};
#[test]
fn spec_map_ordering_is_commutative() {
let mut a = SpecMap::new();
a.insert("A", SpecValue::U32(1));
a.insert("B", SpecValue::U32(2));
let mut b = SpecMap::new();
b.insert("B", SpecValue::U32(2));
b.insert("A", SpecValue::U32(1));
assert_eq!(a.cache_hash(), b.cache_hash());
}
#[test]
fn cache_key_differs_by_spec_hash() {
let mut a = SpecMap::new();
a.insert("K", SpecValue::U32(1));
let mut b = SpecMap::new();
b.insert("K", SpecValue::U32(2));
assert_ne!(
SpecCacheKey::new(0xdead, 0xbeef, [64, 1, 1], &a),
SpecCacheKey::new(0xdead, 0xbeef, [64, 1, 1], &b)
);
}
#[test]
fn vsa_specialization_key_changes_only_low_half_for_spec_hash() {
let program = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32).with_count(1)],
[1, 1, 1],
vec![Node::store("out", Expr::u32(0), Expr::u32(7))],
);
let a = vsa_specialization_key(&program, 0x11);
let b = vsa_specialization_key(&program, 0x22);
assert_eq!(
a >> 64,
b >> 64,
"Fix: VSA specialization keys must keep program identity independent from specialization values."
);
assert_ne!(
a as u64, b as u64,
"Fix: VSA specialization keys must include the specialization hash."
);
}
#[test]
fn versioned_artifact_key_separates_variable_length_fields() {
let a = versioned_specialization_artifact_key(1, "ab", "cd");
let b = versioned_specialization_artifact_key(1, "abc", "d");
assert_ne!(
a, b,
"Fix: specialization artifact keys must length-prefix variable fields."
);
}
#[test]
fn dtype_spec_value_round_trips() {
let v = SpecValue::DType(DataType::F32);
match v {
SpecValue::DType(DataType::F32) => {}
other => panic!("expected DType(F32); got {other:?}"),
}
}
#[test]
fn dtype_spec_distinct_dtypes_hash_distinct() {
let f32_hash = SpecValue::DType(DataType::F32).cache_hash();
let u32_hash = SpecValue::DType(DataType::U32).cache_hash();
let i32_hash = SpecValue::DType(DataType::I32).cache_hash();
assert_ne!(f32_hash, u32_hash);
assert_ne!(u32_hash, i32_hash);
assert_ne!(f32_hash, i32_hash);
}
#[test]
fn dtype_spec_equal_dtypes_hash_equal() {
assert_eq!(
SpecValue::DType(DataType::F32).cache_hash(),
SpecValue::DType(DataType::F32).cache_hash()
);
}
#[test]
fn dtype_spec_does_not_collide_with_other_variants() {
let dtype_hash = SpecValue::DType(DataType::U32).cache_hash();
let u32_hash = SpecValue::U32(0).cache_hash();
let i32_hash = SpecValue::I32(0).cache_hash();
let f32_hash = SpecValue::F32(0.0).cache_hash();
let bool_hash = SpecValue::Bool(false).cache_hash();
assert_ne!(dtype_hash, u32_hash);
assert_ne!(dtype_hash, i32_hash);
assert_ne!(dtype_hash, f32_hash);
assert_ne!(dtype_hash, bool_hash);
}
#[test]
fn dtype_spec_separates_cache_key_in_specmap() {
let mut a = SpecMap::new();
a.insert("dtype", SpecValue::DType(DataType::F32));
let mut b = SpecMap::new();
b.insert("dtype", SpecValue::DType(DataType::U32));
assert_ne!(
a.cache_hash(),
b.cache_hash(),
"Fix: dtype-keyed SpecMaps must produce distinct cache hashes."
);
assert_ne!(
SpecCacheKey::new(0, 0, [1, 1, 1], &a),
SpecCacheKey::new(0, 0, [1, 1, 1], &b)
);
}
#[test]
fn dtype_tag_covers_every_data_type() {
let known = [
DataType::Bool,
DataType::U8,
DataType::U16,
DataType::U32,
DataType::I8,
DataType::I16,
DataType::I32,
DataType::F32,
DataType::Bytes,
];
let mut tags = std::collections::BTreeSet::new();
for dtype in known {
let tag = dtype_tag(&dtype);
assert_ne!(
tag, 0xFFFF_FFFF,
"Fix: dtype_tag missing arm for {dtype:?} — extend specialization.rs::dtype_tag."
);
assert!(
tags.insert(tag),
"Fix: dtype_tag returned duplicate tag {tag} for {dtype:?}."
);
}
}
}