1use std::collections::BTreeMap;
4
5use vyre_foundation::ir::Program;
6use vyre_spec::data_type::DataType;
7
8#[derive(Debug, Clone, PartialEq)]
16#[non_exhaustive]
17pub enum SpecValue {
18 U32(u32),
20 I32(i32),
22 F32(f32),
24 Bool(bool),
26 DType(DataType),
32}
33
34impl SpecValue {
35 #[must_use]
38 pub fn as_pipeline_f64(&self) -> f64 {
39 match self {
40 SpecValue::U32(value) => f64::from(*value),
41 SpecValue::I32(value) => f64::from(*value),
42 SpecValue::F32(value) => f64::from(*value),
43 SpecValue::Bool(value) => f64::from(u8::from(*value)),
44 SpecValue::DType(dtype) => f64::from(dtype_tag(dtype)),
45 }
46 }
47
48 #[must_use]
50 pub fn cache_hash(&self) -> u64 {
51 match self {
52 SpecValue::U32(value) => u64::from(*value) << 8,
53 SpecValue::I32(value) => (1u64) | ((*value as u32 as u64) << 8),
54 SpecValue::F32(value) => (2u64) | ((value.to_bits() as u64) << 8),
55 SpecValue::Bool(value) => (3u64) | (u64::from(u8::from(*value)) << 8),
56 SpecValue::DType(dtype) => (4u64) | (u64::from(dtype_tag(dtype)) << 8),
57 }
58 }
59}
60
61fn dtype_tag(dtype: &DataType) -> u32 {
74 match dtype {
75 DataType::U32 => 0x01,
76 DataType::I32 => 0x02,
77 DataType::U64 => 0x03,
78 DataType::Vec2U32 => 0x04,
79 DataType::Vec4U32 => 0x05,
80 DataType::Bool => 0x06,
81 DataType::Bytes => 0x07,
82 DataType::Array { .. } => 0x08,
83 DataType::F16 => 0x09,
84 DataType::BF16 => 0x0A,
85 DataType::F32 => 0x0B,
86 DataType::F64 => 0x0C,
87 DataType::Tensor => 0x0D,
88 DataType::U8 => 0x0E,
89 DataType::U16 => 0x0F,
90 DataType::I8 => 0x10,
91 DataType::I16 => 0x11,
92 DataType::I64 => 0x12,
93 DataType::Handle(_) => 0x13,
94 DataType::Vec { .. } => 0x14,
95 DataType::TensorShaped { .. } => 0x15,
96 DataType::SparseCsr { .. } => 0x16,
97 DataType::SparseCoo { .. } => 0x17,
98 DataType::SparseBsr { .. } => 0x18,
99 DataType::F8E4M3 => 0x19,
100 DataType::F8E5M2 => 0x1A,
101 DataType::I4 => 0x1B,
102 DataType::FP4 => 0x1C,
103 DataType::NF4 => 0x1D,
104 DataType::DeviceMesh { .. } => 0x1E,
105 DataType::Opaque(_) => 0x80,
106 _ => 0xFFFF_FFFF,
111 }
112}
113
114#[derive(Debug, Default, Clone)]
116pub struct SpecMap {
117 entries: BTreeMap<String, SpecValue>,
118}
119
120impl SpecMap {
121 #[must_use]
123 pub fn new() -> Self {
124 Self::default()
125 }
126
127 pub fn insert(&mut self, name: impl Into<String>, value: SpecValue) {
129 self.entries.insert(name.into(), value);
130 }
131
132 #[must_use]
134 pub fn len(&self) -> usize {
135 self.entries.len()
136 }
137
138 #[must_use]
140 pub fn is_empty(&self) -> bool {
141 self.entries.is_empty()
142 }
143
144 pub fn iter(&self) -> impl Iterator<Item = (&str, &SpecValue)> {
146 self.entries
147 .iter()
148 .map(|(key, value)| (key.as_str(), value))
149 }
150
151 #[must_use]
153 pub fn to_numeric_constants(&self) -> std::collections::HashMap<String, f64> {
154 let mut out = std::collections::HashMap::with_capacity(self.entries.len());
155 for (key, value) in &self.entries {
156 out.insert(key.clone(), value.as_pipeline_f64());
157 }
158 out
159 }
160
161 #[must_use]
163 pub fn cache_hash(&self) -> u64 {
164 let mut hash: u64 = 0xcbf29ce484222325;
165 for (name, value) in self.iter() {
166 for byte in name.as_bytes() {
167 hash ^= u64::from(*byte);
168 hash = hash.wrapping_mul(0x100000001b3);
169 }
170 for byte in value.cache_hash().to_le_bytes() {
171 hash ^= u64::from(byte);
172 hash = hash.wrapping_mul(0x100000001b3);
173 }
174 }
175 hash
176 }
177}
178
179#[derive(Debug, Clone, PartialEq, Eq, Hash)]
181pub struct SpecCacheKey {
182 pub shader_hash: u64,
184 pub binding_sig: u64,
186 pub workgroup_size: [u32; 3],
188 pub spec_hash: u64,
190}
191
192impl SpecCacheKey {
193 #[must_use]
195 pub fn new(
196 shader_hash: u64,
197 binding_sig: u64,
198 workgroup_size: [u32; 3],
199 specs: &SpecMap,
200 ) -> Self {
201 Self {
202 shader_hash,
203 binding_sig,
204 workgroup_size,
205 spec_hash: specs.cache_hash(),
206 }
207 }
208}
209
210#[must_use]
216pub fn vsa_specialization_key(program: &Program, spec_hash: u64) -> u128 {
217 let fingerprint = crate::launch::program_vsa_fingerprint_words(program);
218 let fp_lo = fingerprint
219 .iter()
220 .take(2)
221 .enumerate()
222 .fold(0_u64, |acc, (i, &word)| {
223 acc | (u64::from(word) << (32 * (i as u32)))
224 });
225 ((fp_lo as u128) << 64) | u128::from(spec_hash)
226}
227
228#[must_use]
235pub fn versioned_specialization_artifact_key(
236 cache_version: u32,
237 spec_hash: &str,
238 backend_fingerprint: &str,
239) -> String {
240 let mut hasher = blake3::Hasher::new();
241 hasher.update(b"vyre-specialization-artifact-key-v1\0version\0");
242 hasher.update(&cache_version.to_le_bytes());
243 hasher.update(b"\0spec\0");
244 hasher.update(&(spec_hash.len() as u64).to_le_bytes());
245 hasher.update(spec_hash.as_bytes());
246 hasher.update(b"\0backend\0");
247 hasher.update(&(backend_fingerprint.len() as u64).to_le_bytes());
248 hasher.update(backend_fingerprint.as_bytes());
249 let hash = hasher.finalize();
250 let mut key = String::with_capacity(64);
251 push_lower_hex(hash.as_bytes(), &mut key);
252 key
253}
254
255fn push_lower_hex(bytes: &[u8], out: &mut String) {
256 const HEX: &[u8; 16] = b"0123456789abcdef";
257 let additional = bytes.len().checked_mul(2).unwrap_or_else(|| {
258 panic!(
259 "hex encoding input length {} overflows output capacity. Fix: shard artifact-key material before encoding.",
260 bytes.len()
261 )
262 });
263 out.try_reserve(additional).unwrap_or_else(|error| {
264 panic!(
265 "hex encoding could not reserve {additional} output byte(s): {error}. Fix: shard artifact-key material before encoding."
266 )
267 });
268 for &byte in bytes {
269 out.push(HEX[(byte >> 4) as usize] as char);
270 out.push(HEX[(byte & 0x0f) as usize] as char);
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use vyre_foundation::ir::{BufferDecl, DataType, Expr, Node, Program};
278
279 #[test]
280 fn spec_map_ordering_is_commutative() {
281 let mut a = SpecMap::new();
282 a.insert("A", SpecValue::U32(1));
283 a.insert("B", SpecValue::U32(2));
284 let mut b = SpecMap::new();
285 b.insert("B", SpecValue::U32(2));
286 b.insert("A", SpecValue::U32(1));
287 assert_eq!(a.cache_hash(), b.cache_hash());
288 }
289
290 #[test]
291 fn cache_key_differs_by_spec_hash() {
292 let mut a = SpecMap::new();
293 a.insert("K", SpecValue::U32(1));
294 let mut b = SpecMap::new();
295 b.insert("K", SpecValue::U32(2));
296 assert_ne!(
297 SpecCacheKey::new(0xdead, 0xbeef, [64, 1, 1], &a),
298 SpecCacheKey::new(0xdead, 0xbeef, [64, 1, 1], &b)
299 );
300 }
301
302 #[test]
303 fn vsa_specialization_key_changes_only_low_half_for_spec_hash() {
304 let program = Program::wrapped(
305 vec![BufferDecl::output("out", 0, DataType::U32).with_count(1)],
306 [1, 1, 1],
307 vec![Node::store("out", Expr::u32(0), Expr::u32(7))],
308 );
309 let a = vsa_specialization_key(&program, 0x11);
310 let b = vsa_specialization_key(&program, 0x22);
311 assert_eq!(
312 a >> 64,
313 b >> 64,
314 "Fix: VSA specialization keys must keep program identity independent from specialization values."
315 );
316 assert_ne!(
317 a as u64, b as u64,
318 "Fix: VSA specialization keys must include the specialization hash."
319 );
320 }
321
322 #[test]
323 fn versioned_artifact_key_separates_variable_length_fields() {
324 let a = versioned_specialization_artifact_key(1, "ab", "cd");
325 let b = versioned_specialization_artifact_key(1, "abc", "d");
326 assert_ne!(
327 a, b,
328 "Fix: specialization artifact keys must length-prefix variable fields."
329 );
330 }
331
332 #[test]
335 fn dtype_spec_value_round_trips() {
336 let v = SpecValue::DType(DataType::F32);
337 match v {
338 SpecValue::DType(DataType::F32) => {}
339 other => panic!("expected DType(F32); got {other:?}"),
340 }
341 }
342
343 #[test]
344 fn dtype_spec_distinct_dtypes_hash_distinct() {
345 let f32_hash = SpecValue::DType(DataType::F32).cache_hash();
346 let u32_hash = SpecValue::DType(DataType::U32).cache_hash();
347 let i32_hash = SpecValue::DType(DataType::I32).cache_hash();
348 assert_ne!(f32_hash, u32_hash);
349 assert_ne!(u32_hash, i32_hash);
350 assert_ne!(f32_hash, i32_hash);
351 }
352
353 #[test]
354 fn dtype_spec_equal_dtypes_hash_equal() {
355 assert_eq!(
356 SpecValue::DType(DataType::F32).cache_hash(),
357 SpecValue::DType(DataType::F32).cache_hash()
358 );
359 }
360
361 #[test]
362 fn dtype_spec_does_not_collide_with_other_variants() {
363 let dtype_hash = SpecValue::DType(DataType::U32).cache_hash();
367 let u32_hash = SpecValue::U32(0).cache_hash();
368 let i32_hash = SpecValue::I32(0).cache_hash();
369 let f32_hash = SpecValue::F32(0.0).cache_hash();
370 let bool_hash = SpecValue::Bool(false).cache_hash();
371 assert_ne!(dtype_hash, u32_hash);
372 assert_ne!(dtype_hash, i32_hash);
373 assert_ne!(dtype_hash, f32_hash);
374 assert_ne!(dtype_hash, bool_hash);
375 }
376
377 #[test]
378 fn dtype_spec_separates_cache_key_in_specmap() {
379 let mut a = SpecMap::new();
380 a.insert("dtype", SpecValue::DType(DataType::F32));
381 let mut b = SpecMap::new();
382 b.insert("dtype", SpecValue::DType(DataType::U32));
383 assert_ne!(
384 a.cache_hash(),
385 b.cache_hash(),
386 "Fix: dtype-keyed SpecMaps must produce distinct cache hashes."
387 );
388 assert_ne!(
389 SpecCacheKey::new(0, 0, [1, 1, 1], &a),
390 SpecCacheKey::new(0, 0, [1, 1, 1], &b)
391 );
392 }
393
394 #[test]
395 fn dtype_tag_covers_every_data_type() {
396 let known = [
400 DataType::U32,
401 DataType::I32,
402 DataType::U64,
403 DataType::Vec2U32,
404 DataType::Vec4U32,
405 DataType::Bool,
406 DataType::Bytes,
407 DataType::Array { element_size: 1 },
408 DataType::F16,
409 DataType::BF16,
410 DataType::F32,
411 DataType::F64,
412 DataType::Tensor,
413 DataType::U8,
414 DataType::U16,
415 DataType::I8,
416 DataType::I16,
417 DataType::I64,
418 DataType::Handle(vyre_spec::data_type::TypeId(0)),
419 DataType::Vec {
420 element: Box::new(DataType::U32),
421 count: 1,
422 },
423 DataType::TensorShaped {
424 element: Box::new(DataType::U32),
425 shape: smallvec::smallvec![1],
426 },
427 DataType::SparseCsr {
428 element: Box::new(DataType::U32),
429 },
430 DataType::SparseCoo {
431 element: Box::new(DataType::U32),
432 },
433 DataType::SparseBsr {
434 element: Box::new(DataType::U32),
435 block_rows: 1,
436 block_cols: 1,
437 },
438 DataType::F8E4M3,
439 DataType::F8E5M2,
440 DataType::I4,
441 DataType::FP4,
442 DataType::NF4,
443 DataType::DeviceMesh {
444 axes: smallvec::smallvec![1],
445 },
446 ];
447 let mut tags = std::collections::BTreeSet::new();
448 for dtype in known {
449 let tag = dtype_tag(&dtype);
450 assert_ne!(
451 tag, 0xFFFF_FFFF,
452 "Fix: dtype_tag missing arm for {dtype:?} - extend specialization.rs::dtype_tag."
453 );
454 assert!(
455 tags.insert(tag),
456 "Fix: dtype_tag returned duplicate tag {tag} for {dtype:?}."
457 );
458 }
459 }
460}