Skip to main content

vyre_driver/
specialization.rs

1//! Backend-neutral specialization values and cache key inputs.
2
3use std::collections::BTreeMap;
4
5use vyre_foundation::ir::Program;
6use vyre_spec::data_type::DataType;
7
8/// One specializable scalar attribute value.
9///
10/// Not `Copy` because the `DType(DataType)` variant carries a
11/// `vyre_spec::DataType` whose payload-bearing variants
12/// (`Array { element_size }`, `Vec { .. }`, `Handle(_)`) are not
13/// trivially copyable. Cloning is cheap regardless  -  the enum is
14/// small and tag-discriminated.
15#[derive(Debug, Clone, PartialEq)]
16#[non_exhaustive]
17pub enum SpecValue {
18    /// Unsigned 32-bit integer.
19    U32(u32),
20    /// Signed 32-bit integer.
21    I32(i32),
22    /// 32-bit float, cache-hashed by its bit pattern.
23    F32(f32),
24    /// Boolean flag.
25    Bool(bool),
26    /// Element data type. ROADMAP F3  -  dtype-specialized kernel variants
27    /// flow through the same `SpecMap` cache as tile-size and unroll
28    /// choices, so the F1 specialization-cache key already separates
29    /// (matmul, F32) from (matmul, F16) without any backend-specific
30    /// extension.
31    DType(DataType),
32}
33
34impl SpecValue {
35    /// Convert to a lossless scalar form for backends whose override API
36    /// accepts numeric constants through a common floating-point carrier.
37    #[must_use]
38    pub fn as_pipeline_f64(&self) -> f64 {
39        match self {
40            SpecValue::U32(value) => f64::from(*value),
41            SpecValue::I32(value) => f64::from(*value),
42            SpecValue::F32(value) => f64::from(*value),
43            SpecValue::Bool(value) => f64::from(u8::from(*value)),
44            SpecValue::DType(dtype) => f64::from(dtype_tag(dtype)),
45        }
46    }
47
48    /// Hash this value into a 64-bit backend-neutral cache contribution.
49    #[must_use]
50    pub fn cache_hash(&self) -> u64 {
51        match self {
52            SpecValue::U32(value) => u64::from(*value) << 8,
53            SpecValue::I32(value) => (1u64) | ((*value as u32 as u64) << 8),
54            SpecValue::F32(value) => (2u64) | ((value.to_bits() as u64) << 8),
55            SpecValue::Bool(value) => (3u64) | (u64::from(u8::from(*value)) << 8),
56            SpecValue::DType(dtype) => (4u64) | (u64::from(dtype_tag(dtype)) << 8),
57        }
58    }
59}
60
61/// Stable u32 tag for each `DataType` variant. Used to seed
62/// `SpecValue::DType` into the F1 cache hash deterministically.
63/// Adding a new `DataType` variant must extend this table; the
64/// `dtype_tag_covers_every_data_type` test enforces it.
65///
66/// Tags mirror the wire-format `data_type_tag` table so the cache
67/// key, the on-disk artifact, and the conformance metadata all
68/// agree. Parameterised variants (`Vec`, `TensorShaped`, `Array`,
69/// `Handle`, `Opaque`, `Sparse*`, `DeviceMesh`) hash by their
70/// outer-discriminant tag; consumers that need parameter-aware
71/// keys must extend `SpecValue` rather than collapsing distinct
72/// shapes here.
73fn dtype_tag(dtype: &DataType) -> u32 {
74    match dtype {
75        DataType::U32 => 0x01,
76        DataType::I32 => 0x02,
77        DataType::U64 => 0x03,
78        DataType::Vec2U32 => 0x04,
79        DataType::Vec4U32 => 0x05,
80        DataType::Bool => 0x06,
81        DataType::Bytes => 0x07,
82        DataType::Array { .. } => 0x08,
83        DataType::F16 => 0x09,
84        DataType::BF16 => 0x0A,
85        DataType::F32 => 0x0B,
86        DataType::F64 => 0x0C,
87        DataType::Tensor => 0x0D,
88        DataType::U8 => 0x0E,
89        DataType::U16 => 0x0F,
90        DataType::I8 => 0x10,
91        DataType::I16 => 0x11,
92        DataType::I64 => 0x12,
93        DataType::Handle(_) => 0x13,
94        DataType::Vec { .. } => 0x14,
95        DataType::TensorShaped { .. } => 0x15,
96        DataType::SparseCsr { .. } => 0x16,
97        DataType::SparseCoo { .. } => 0x17,
98        DataType::SparseBsr { .. } => 0x18,
99        DataType::F8E4M3 => 0x19,
100        DataType::F8E5M2 => 0x1A,
101        DataType::I4 => 0x1B,
102        DataType::FP4 => 0x1C,
103        DataType::NF4 => 0x1D,
104        DataType::DeviceMesh { .. } => 0x1E,
105        DataType::Opaque(_) => 0x80,
106        // Truly unknown variant  -  sentinel collision is a soundness
107        // bug at the spec-cache layer (different DType values would
108        // collapse onto one cache key and serve the wrong shader),
109        // so any future variant MUST get an explicit tag here.
110        _ => 0xFFFF_FFFF,
111    }
112}
113
114/// Ordered specialization map.
115#[derive(Debug, Default, Clone)]
116pub struct SpecMap {
117    entries: BTreeMap<String, SpecValue>,
118}
119
120impl SpecMap {
121    /// Empty map.
122    #[must_use]
123    pub fn new() -> Self {
124        Self::default()
125    }
126
127    /// Insert or replace a `(name, value)` pair.
128    pub fn insert(&mut self, name: impl Into<String>, value: SpecValue) {
129        self.entries.insert(name.into(), value);
130    }
131
132    /// Number of entries.
133    #[must_use]
134    pub fn len(&self) -> usize {
135        self.entries.len()
136    }
137
138    /// Whether the map is empty.
139    #[must_use]
140    pub fn is_empty(&self) -> bool {
141        self.entries.is_empty()
142    }
143
144    /// Iterate `(name, value)` pairs in deterministic order.
145    pub fn iter(&self) -> impl Iterator<Item = (&str, &SpecValue)> {
146        self.entries
147            .iter()
148            .map(|(key, value)| (key.as_str(), value))
149    }
150
151    /// Convert to a deterministic numeric constant map.
152    #[must_use]
153    pub fn to_numeric_constants(&self) -> std::collections::HashMap<String, f64> {
154        let mut out = std::collections::HashMap::with_capacity(self.entries.len());
155        for (key, value) in &self.entries {
156            out.insert(key.clone(), value.as_pipeline_f64());
157        }
158        out
159    }
160
161    /// Compute this map's 64-bit cache contribution.
162    #[must_use]
163    pub fn cache_hash(&self) -> u64 {
164        let mut hash: u64 = 0xcbf29ce484222325;
165        for (name, value) in self.iter() {
166            for byte in name.as_bytes() {
167                hash ^= u64::from(*byte);
168                hash = hash.wrapping_mul(0x100000001b3);
169            }
170            for byte in value.cache_hash().to_le_bytes() {
171                hash ^= u64::from(byte);
172                hash = hash.wrapping_mul(0x100000001b3);
173            }
174        }
175        hash
176    }
177}
178
179/// Cache key extending a backend pipeline identity with specialization values.
180#[derive(Debug, Clone, PartialEq, Eq, Hash)]
181pub struct SpecCacheKey {
182    /// Hash of the shader or target module.
183    pub shader_hash: u64,
184    /// Stable signature of the binding layout.
185    pub binding_sig: u64,
186    /// Workgroup size in the dispatch.
187    pub workgroup_size: [u32; 3],
188    /// Hash of specialization values.
189    pub spec_hash: u64,
190}
191
192impl SpecCacheKey {
193    /// Fold a [`SpecMap`] into a cache key.
194    #[must_use]
195    pub fn new(
196        shader_hash: u64,
197        binding_sig: u64,
198        workgroup_size: [u32; 3],
199        specs: &SpecMap,
200    ) -> Self {
201        Self {
202            shader_hash,
203            binding_sig,
204            workgroup_size,
205            spec_hash: specs.cache_hash(),
206        }
207    }
208}
209
210/// Build the backend-neutral VSA specialization key used by shader caches.
211///
212/// The high half is the low 64 bits of the VSA fingerprint; the low half is
213/// the specialization hash. Keeping this in `vyre-driver` prevents concrete
214/// backends from each reimplementing the same identity folding.
215#[must_use]
216pub fn vsa_specialization_key(program: &Program, spec_hash: u64) -> u128 {
217    let fingerprint = crate::launch::program_vsa_fingerprint_words(program);
218    let fp_lo = fingerprint
219        .iter()
220        .take(2)
221        .enumerate()
222        .fold(0_u64, |acc, (i, &word)| {
223            acc | (u64::from(word) << (32 * (i as u32)))
224        });
225    ((fp_lo as u128) << 64) | u128::from(spec_hash)
226}
227
228/// Deterministic hex key for a backend specialization artifact.
229///
230/// Concrete backends use this for AOT artifacts whose identity is
231/// `(cache-version, specialization hash, backend fingerprint)`. Keeping the
232/// length-delimited hash format here prevents each backend from inventing a
233/// subtly different concatenation scheme.
234#[must_use]
235pub fn versioned_specialization_artifact_key(
236    cache_version: u32,
237    spec_hash: &str,
238    backend_fingerprint: &str,
239) -> String {
240    let mut hasher = blake3::Hasher::new();
241    hasher.update(b"vyre-specialization-artifact-key-v1\0version\0");
242    hasher.update(&cache_version.to_le_bytes());
243    hasher.update(b"\0spec\0");
244    hasher.update(&(spec_hash.len() as u64).to_le_bytes());
245    hasher.update(spec_hash.as_bytes());
246    hasher.update(b"\0backend\0");
247    hasher.update(&(backend_fingerprint.len() as u64).to_le_bytes());
248    hasher.update(backend_fingerprint.as_bytes());
249    let hash = hasher.finalize();
250    let mut key = String::with_capacity(64);
251    push_lower_hex(hash.as_bytes(), &mut key);
252    key
253}
254
255fn push_lower_hex(bytes: &[u8], out: &mut String) {
256    const HEX: &[u8; 16] = b"0123456789abcdef";
257    let additional = bytes.len().checked_mul(2).unwrap_or_else(|| {
258        panic!(
259            "hex encoding input length {} overflows output capacity. Fix: shard artifact-key material before encoding.",
260            bytes.len()
261        )
262    });
263    out.try_reserve(additional).unwrap_or_else(|error| {
264        panic!(
265            "hex encoding could not reserve {additional} output byte(s): {error}. Fix: shard artifact-key material before encoding."
266        )
267    });
268    for &byte in bytes {
269        out.push(HEX[(byte >> 4) as usize] as char);
270        out.push(HEX[(byte & 0x0f) as usize] as char);
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use vyre_foundation::ir::{BufferDecl, DataType, Expr, Node, Program};
278
279    #[test]
280    fn spec_map_ordering_is_commutative() {
281        let mut a = SpecMap::new();
282        a.insert("A", SpecValue::U32(1));
283        a.insert("B", SpecValue::U32(2));
284        let mut b = SpecMap::new();
285        b.insert("B", SpecValue::U32(2));
286        b.insert("A", SpecValue::U32(1));
287        assert_eq!(a.cache_hash(), b.cache_hash());
288    }
289
290    #[test]
291    fn cache_key_differs_by_spec_hash() {
292        let mut a = SpecMap::new();
293        a.insert("K", SpecValue::U32(1));
294        let mut b = SpecMap::new();
295        b.insert("K", SpecValue::U32(2));
296        assert_ne!(
297            SpecCacheKey::new(0xdead, 0xbeef, [64, 1, 1], &a),
298            SpecCacheKey::new(0xdead, 0xbeef, [64, 1, 1], &b)
299        );
300    }
301
302    #[test]
303    fn vsa_specialization_key_changes_only_low_half_for_spec_hash() {
304        let program = Program::wrapped(
305            vec![BufferDecl::output("out", 0, DataType::U32).with_count(1)],
306            [1, 1, 1],
307            vec![Node::store("out", Expr::u32(0), Expr::u32(7))],
308        );
309        let a = vsa_specialization_key(&program, 0x11);
310        let b = vsa_specialization_key(&program, 0x22);
311        assert_eq!(
312            a >> 64,
313            b >> 64,
314            "Fix: VSA specialization keys must keep program identity independent from specialization values."
315        );
316        assert_ne!(
317            a as u64, b as u64,
318            "Fix: VSA specialization keys must include the specialization hash."
319        );
320    }
321
322    #[test]
323    fn versioned_artifact_key_separates_variable_length_fields() {
324        let a = versioned_specialization_artifact_key(1, "ab", "cd");
325        let b = versioned_specialization_artifact_key(1, "abc", "d");
326        assert_ne!(
327            a, b,
328            "Fix: specialization artifact keys must length-prefix variable fields."
329        );
330    }
331
332    // ---------------- F3 dtype-spec ----------------
333
334    #[test]
335    fn dtype_spec_value_round_trips() {
336        let v = SpecValue::DType(DataType::F32);
337        match v {
338            SpecValue::DType(DataType::F32) => {}
339            other => panic!("expected DType(F32); got {other:?}"),
340        }
341    }
342
343    #[test]
344    fn dtype_spec_distinct_dtypes_hash_distinct() {
345        let f32_hash = SpecValue::DType(DataType::F32).cache_hash();
346        let u32_hash = SpecValue::DType(DataType::U32).cache_hash();
347        let i32_hash = SpecValue::DType(DataType::I32).cache_hash();
348        assert_ne!(f32_hash, u32_hash);
349        assert_ne!(u32_hash, i32_hash);
350        assert_ne!(f32_hash, i32_hash);
351    }
352
353    #[test]
354    fn dtype_spec_equal_dtypes_hash_equal() {
355        assert_eq!(
356            SpecValue::DType(DataType::F32).cache_hash(),
357            SpecValue::DType(DataType::F32).cache_hash()
358        );
359    }
360
361    #[test]
362    fn dtype_spec_does_not_collide_with_other_variants() {
363        // The variant tag (low byte) of DType is 4. U32(0).cache_hash() is
364        // 0 << 8 = 0; the DType hash carries tag 4 in the low byte plus
365        // the dtype tag in the next 32 bits, so they cannot collide.
366        let dtype_hash = SpecValue::DType(DataType::U32).cache_hash();
367        let u32_hash = SpecValue::U32(0).cache_hash();
368        let i32_hash = SpecValue::I32(0).cache_hash();
369        let f32_hash = SpecValue::F32(0.0).cache_hash();
370        let bool_hash = SpecValue::Bool(false).cache_hash();
371        assert_ne!(dtype_hash, u32_hash);
372        assert_ne!(dtype_hash, i32_hash);
373        assert_ne!(dtype_hash, f32_hash);
374        assert_ne!(dtype_hash, bool_hash);
375    }
376
377    #[test]
378    fn dtype_spec_separates_cache_key_in_specmap() {
379        let mut a = SpecMap::new();
380        a.insert("dtype", SpecValue::DType(DataType::F32));
381        let mut b = SpecMap::new();
382        b.insert("dtype", SpecValue::DType(DataType::U32));
383        assert_ne!(
384            a.cache_hash(),
385            b.cache_hash(),
386            "Fix: dtype-keyed SpecMaps must produce distinct cache hashes."
387        );
388        assert_ne!(
389            SpecCacheKey::new(0, 0, [1, 1, 1], &a),
390            SpecCacheKey::new(0, 0, [1, 1, 1], &b)
391        );
392    }
393
394    #[test]
395    fn dtype_tag_covers_every_data_type() {
396        // Soundness gate: any new DataType variant must extend dtype_tag
397        // explicitly. Every shipped variant returns a unique non-fallback
398        // (≠ 0xFFFF_FFFF) tag.
399        let known = [
400            DataType::U32,
401            DataType::I32,
402            DataType::U64,
403            DataType::Vec2U32,
404            DataType::Vec4U32,
405            DataType::Bool,
406            DataType::Bytes,
407            DataType::Array { element_size: 1 },
408            DataType::F16,
409            DataType::BF16,
410            DataType::F32,
411            DataType::F64,
412            DataType::Tensor,
413            DataType::U8,
414            DataType::U16,
415            DataType::I8,
416            DataType::I16,
417            DataType::I64,
418            DataType::Handle(vyre_spec::data_type::TypeId(0)),
419            DataType::Vec {
420                element: Box::new(DataType::U32),
421                count: 1,
422            },
423            DataType::TensorShaped {
424                element: Box::new(DataType::U32),
425                shape: smallvec::smallvec![1],
426            },
427            DataType::SparseCsr {
428                element: Box::new(DataType::U32),
429            },
430            DataType::SparseCoo {
431                element: Box::new(DataType::U32),
432            },
433            DataType::SparseBsr {
434                element: Box::new(DataType::U32),
435                block_rows: 1,
436                block_cols: 1,
437            },
438            DataType::F8E4M3,
439            DataType::F8E5M2,
440            DataType::I4,
441            DataType::FP4,
442            DataType::NF4,
443            DataType::DeviceMesh {
444                axes: smallvec::smallvec![1],
445            },
446        ];
447        let mut tags = std::collections::BTreeSet::new();
448        for dtype in known {
449            let tag = dtype_tag(&dtype);
450            assert_ne!(
451                tag, 0xFFFF_FFFF,
452                "Fix: dtype_tag missing arm for {dtype:?}  -  extend specialization.rs::dtype_tag."
453            );
454            assert!(
455                tags.insert(tag),
456                "Fix: dtype_tag returned duplicate tag {tag} for {dtype:?}."
457            );
458        }
459    }
460}