Skip to main content

gam_runtime/warm_start/
key.rs

1//! Fingerprint keying for the warm-start store.
2//!
3//! A [`Fingerprint`] is a SHA-256 hash; two fits whose (data, spec) byte
4//! representations agree under [`Fingerprinter`] absorption produce the same
5//! key. Adversarial collisions don't matter — per-variant warm-start
6//! validators are the correctness fail-safe; the fingerprint is just a fast
7//! filter.
8
9use serde::de::{self, Visitor};
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11use sha2::{Digest, Sha256};
12use std::fmt;
13
14/// 256-bit warm-start key.
15#[derive(Clone, Copy, PartialEq, Eq, Hash)]
16pub struct Fingerprint([u8; 32]);
17
18impl Serialize for Fingerprint {
19    /// Serialize as the canonical 64-char lowercase hex string so on-disk
20    /// payloads carrying a `Fingerprint` (e.g. the cross-fit `FitArtifact`
21    /// term identities) are stable and human-readable.
22    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
23        serializer.serialize_str(&self.to_hex())
24    }
25}
26
27impl<'de> Deserialize<'de> for Fingerprint {
28    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
29        struct HexVisitor;
30        impl Visitor<'_> for HexVisitor {
31            type Value = Fingerprint;
32            fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33                f.write_str("a 64-character hex-encoded SHA-256 fingerprint")
34            }
35            fn visit_str<E: de::Error>(self, v: &str) -> Result<Fingerprint, E> {
36                Fingerprint::from_hex(v).ok_or_else(|| de::Error::custom("invalid hex fingerprint"))
37            }
38        }
39        deserializer.deserialize_str(HexVisitor)
40    }
41}
42
43impl Fingerprint {
44    pub const fn as_bytes(&self) -> &[u8; 32] {
45        &self.0
46    }
47
48    pub fn to_hex(&self) -> String {
49        let mut s = String::with_capacity(64);
50        for b in &self.0 {
51            use std::fmt::Write;
52            write!(&mut s, "{:02x}", b).expect("writing to String is infallible");
53        }
54        s
55    }
56
57    pub fn from_hex(s: &str) -> Option<Self> {
58        if s.len() != 64 {
59            return None;
60        }
61        let bytes = s.as_bytes();
62        let mut out = [0u8; 32];
63        for i in 0..32 {
64            let hi = from_hex_nibble(bytes[2 * i])?;
65            let lo = from_hex_nibble(bytes[2 * i + 1])?;
66            out[i] = (hi << 4) | lo;
67        }
68        Some(Fingerprint(out))
69    }
70}
71
72const fn from_hex_nibble(c: u8) -> Option<u8> {
73    match c {
74        b'0'..=b'9' => Some(c - b'0'),
75        b'a'..=b'f' => Some(c - b'a' + 10),
76        b'A'..=b'F' => Some(c - b'A' + 10),
77        _ => None,
78    }
79}
80
81impl fmt::Debug for Fingerprint {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        write!(f, "Fingerprint({})", self.to_hex())
84    }
85}
86
87impl fmt::Display for Fingerprint {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        f.write_str(&self.to_hex())
90    }
91}
92
93/// Streaming hasher for building a [`Fingerprint`].
94///
95/// Each `absorb_*` writes a per-type discriminator byte, the caller's content
96/// tag, and a length before the data, so `absorb_f64(b"x", 0.5)` cannot
97/// collide with `absorb_bytes(b"x", <the 8 little-endian bytes of 0.5>)`, nor
98/// with `absorb_u64(b"x", 0.5f64.to_bits())` — heterogeneous fields sharing a
99/// tag can never alias.
100pub struct Fingerprinter {
101    h: Sha256,
102}
103
104/// Per-type frame discriminators for the `absorb_*` family. Written before
105/// the content tag so values of different primitive types absorbed under the
106/// same tag with coinciding payload bytes still produce distinct digests.
107mod type_code {
108    pub const TAG: u8 = 0;
109    pub const BYTES: u8 = 1;
110    pub const STR: u8 = 2;
111    pub const U64: u8 = 3;
112    pub const F64: u8 = 4;
113    pub const F64_SLICE: u8 = 5;
114    pub const F64_2D: u8 = 6;
115}
116
117impl Fingerprinter {
118    pub fn new() -> Self {
119        Self { h: Sha256::new() }
120    }
121
122    /// Write one frame header: type discriminator, then length-prefixed tag.
123    fn frame(&mut self, code: u8, tag: &[u8]) {
124        self.h.update([code]);
125        self.h.update((tag.len() as u32).to_le_bytes());
126        self.h.update(tag);
127    }
128
129    /// Absorb a tag with no payload. Useful for structural separators.
130    pub fn absorb_tag(&mut self, tag: &[u8]) {
131        self.frame(type_code::TAG, tag);
132    }
133
134    pub fn absorb_bytes(&mut self, tag: &[u8], data: &[u8]) {
135        self.frame(type_code::BYTES, tag);
136        self.h.update((data.len() as u64).to_le_bytes());
137        self.h.update(data);
138    }
139
140    pub fn absorb_str(&mut self, tag: &[u8], s: &str) {
141        self.frame(type_code::STR, tag);
142        self.h.update((s.len() as u64).to_le_bytes());
143        self.h.update(s.as_bytes());
144    }
145
146    pub fn absorb_u64(&mut self, tag: &[u8], v: u64) {
147        self.frame(type_code::U64, tag);
148        self.h.update(v.to_le_bytes());
149    }
150
151    pub fn absorb_f64(&mut self, tag: &[u8], v: f64) {
152        self.frame(type_code::F64, tag);
153        self.h.update(v.to_bits().to_le_bytes());
154    }
155
156    pub fn absorb_f64_slice(&mut self, tag: &[u8], xs: &[f64]) {
157        self.frame(type_code::F64_SLICE, tag);
158        self.h.update((xs.len() as u64).to_le_bytes());
159        absorb_f64_bytes(&mut self.h, xs);
160    }
161
162    pub fn absorb_f64_2d(&mut self, tag: &[u8], rows: usize, cols: usize, xs: &[f64]) {
163        self.frame(type_code::F64_2D, tag);
164        self.h.update((rows as u64).to_le_bytes());
165        self.h.update((cols as u64).to_le_bytes());
166        absorb_f64_bytes(&mut self.h, xs);
167    }
168
169    pub fn finalize(self) -> Fingerprint {
170        let out = self.h.finalize();
171        let mut bytes = [0u8; 32];
172        bytes.copy_from_slice(&out);
173        Fingerprint(bytes)
174    }
175
176    // ------------------------------------------------------------------
177    // Untagged write_* API — drop-in replacement for the formerly-separate
178    // `StableHasher` (warm-start) and `CacheDigestBuilder` (latent_cache)
179    // hashers. Callers that use this API are responsible for their own
180    // type-disambiguation (typically by writing a leading namespace string
181    // via `write_str`); the `absorb_*` family above prepends per-call tags
182    // and is the safer choice for new code. Callers that use the untagged API
183    // must preserve their own hand-framed input protocol.
184    // ------------------------------------------------------------------
185
186    pub fn write_bytes(&mut self, data: &[u8]) {
187        self.h.update(data);
188    }
189
190    pub fn write_u8(&mut self, value: u8) {
191        self.h.update([value]);
192    }
193
194    pub fn write_bool(&mut self, value: bool) {
195        self.h.update([u8::from(value)]);
196    }
197
198    pub fn write_u64(&mut self, value: u64) {
199        self.h.update(value.to_le_bytes());
200    }
201
202    pub fn write_usize(&mut self, value: usize) {
203        self.h.update((value as u64).to_le_bytes());
204    }
205
206    pub fn write_f64(&mut self, value: f64) {
207        // Normalize -0.0 to +0.0 so signed-zero comparison ambiguity does
208        // not split warm-start key buckets — matches the prior `StableHasher`
209        // contract that warm-start keys depended on.
210        let normalized = if value == 0.0 { 0.0 } else { value };
211        self.h.update(normalized.to_bits().to_le_bytes());
212    }
213
214    pub fn write_str(&mut self, value: &str) {
215        self.write_usize(value.len());
216        self.h.update(value.as_bytes());
217    }
218
219    /// Absorb a length-prefixed `f64` slice using the per-element
220    /// [`Fingerprinter::write_f64`] contract (so `-0.0` is normalized to
221    /// `+0.0`). Canonical home for the byte-identical `len`-then-each-`f64`
222    /// hashing that previously lived as module-local `hash_f64_slice` /
223    /// `hash_vector` copies in `solver/latent_cache`. Uses a bulk byte path
224    /// only when it can emit exactly the same bytes as the element-wise
225    /// normalizing protocol.
226    pub fn write_f64_slice(&mut self, values: &[f64]) {
227        self.write_usize(values.len());
228        self.write_f64_slice_payload(values);
229    }
230
231    fn write_f64_slice_payload(&mut self, values: &[f64]) {
232        #[cfg(target_endian = "little")]
233        {
234            let needs_normalization = values
235                .iter()
236                .any(|&value| value.is_nan() || (value == 0.0 && value.is_sign_negative()));
237            if !needs_normalization {
238                // SAFETY: values.as_ptr() is valid for values.len() contiguous
239                // f64s, f64 has no padding, and reborrowing as bytes is confined
240                // to this update call. Little-endian storage matches write_f64's
241                // to_bits().to_le_bytes() byte stream for non-normalized values.
242                let bytes = unsafe {
243                    std::slice::from_raw_parts(
244                        values.as_ptr() as *const u8,
245                        std::mem::size_of_val(values),
246                    )
247                };
248                self.h.update(bytes);
249                return;
250            }
251        }
252        self.write_f64_slice_payload_slow(values);
253    }
254
255    fn write_f64_slice_payload_slow(&mut self, values: &[f64]) {
256        for &value in values {
257            self.write_f64(value);
258        }
259    }
260
261    /// Absorb a 1D `f64` array as `len` followed by every element via
262    /// [`Fingerprinter::write_f64`]. Canonical home for the byte-identical
263    /// `hash_vector` copy that previously lived in `solver/latent_cache`.
264    pub fn write_f64_array1(&mut self, values: &ndarray::Array1<f64>) {
265        self.write_usize(values.len());
266        if let Some(slice) = values.as_slice() {
267            self.write_f64_slice_payload(slice);
268        } else {
269            self.write_f64_slice_payload_slow_iter(values.iter().copied());
270        }
271    }
272
273    /// Absorb a 2D `f64` array as `(nrows, ncols)` followed by every element in
274    /// iteration order, each via [`Fingerprinter::write_f64`]. Canonical home
275    /// for the byte-identical heuristic that previously lived as module-local
276    /// `write_array2_fingerprint` (`solver/arrow_schur`) and `hash_matrix`
277    /// (`solver/latent_cache`) copies.
278    pub fn write_f64_array2(&mut self, values: &ndarray::Array2<f64>) {
279        self.write_usize(values.nrows());
280        self.write_usize(values.ncols());
281        if let Some(slice) = values.as_slice() {
282            self.write_f64_slice_payload(slice);
283        } else {
284            self.write_f64_slice_payload_slow_iter(values.iter().copied());
285        }
286    }
287
288    fn write_f64_slice_payload_slow_iter<I>(&mut self, values: I)
289    where
290        I: IntoIterator<Item = f64>,
291    {
292        for value in values {
293            self.write_f64(value);
294        }
295    }
296
297    /// Finalize and return the first 8 bytes of the SHA-256 digest as a
298    /// little-endian `u64`. Used by callers that need a compact in-process
299    /// identifier (manifold mode fingerprints, registry fingerprints, …)
300    /// rather than the full 32-byte [`Fingerprint`].
301    pub fn finish_u64(self) -> u64 {
302        let out = self.h.finalize();
303        let mut bytes = [0u8; 8];
304        bytes.copy_from_slice(&out[..8]);
305        u64::from_le_bytes(bytes)
306    }
307
308    /// Finalize and return a zero-padded 16-character hex representation
309    /// of [`Fingerprinter::finish_u64`], suitable for embedding directly
310    /// in cache-key strings.
311    pub fn finish_hex(self) -> String {
312        format!("{:016x}", self.finish_u64())
313    }
314}
315
316/// Feed `xs` to the hasher in one bulk `update` instead of one 8-byte
317/// `to_le_bytes` call per element. On little-endian hosts we reinterpret the
318/// `&[f64]` storage directly as `&[u8]`; on big-endian hosts we fall back to
319/// a per-element loop so the fingerprint stays endian-stable across machines.
320#[inline]
321fn absorb_f64_bytes(h: &mut Sha256, xs: &[f64]) {
322    #[cfg(target_endian = "little")]
323    {
324        // SAFETY: xs.as_ptr() is non-null/aligned (slice invariant); f64
325        // has no padding and any bit pattern is a valid u8; size_of_val
326        // covers exactly xs's bytes and the borrow lives within this call.
327        let bytes = unsafe {
328            std::slice::from_raw_parts(xs.as_ptr() as *const u8, std::mem::size_of_val(xs))
329        };
330        h.update(bytes);
331    }
332    #[cfg(not(target_endian = "little"))]
333    {
334        for &x in xs {
335            h.update(x.to_bits().to_le_bytes());
336        }
337    }
338}
339
340impl Default for Fingerprinter {
341    fn default() -> Self {
342        Self::new()
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn hex_roundtrips() {
352        let mut fp = Fingerprinter::new();
353        fp.absorb_str(b"family", "standard");
354        fp.absorb_f64_slice(b"y", &[1.0, 2.0, 3.0]);
355        let key = fp.finalize();
356        let hex = key.to_hex();
357        assert_eq!(hex.len(), 64);
358        let parsed = Fingerprint::from_hex(&hex).unwrap();
359        assert_eq!(key, parsed);
360    }
361
362    #[test]
363    fn tagged_absorptions_dont_collide() {
364        let mut a = Fingerprinter::new();
365        a.absorb_f64(b"x", 0.5);
366        let ka = a.finalize();
367
368        let mut b = Fingerprinter::new();
369        // Same bytes, different tag: must produce a different key.
370        b.absorb_bytes(b"y", &0.5f64.to_bits().to_le_bytes());
371        let kb = b.finalize();
372
373        assert_ne!(ka, kb);
374    }
375
376    #[test]
377    fn same_tag_cross_type_absorptions_dont_collide() {
378        // The documented contract: heterogeneous fields sharing a tag whose
379        // payload bytes coincide must still produce distinct fingerprints.
380        let v = 0.5f64;
381
382        let mut f = Fingerprinter::new();
383        f.absorb_f64(b"t", v);
384        let kf = f.finalize();
385
386        let mut u = Fingerprinter::new();
387        u.absorb_u64(b"t", v.to_bits());
388        let ku = u.finalize();
389
390        let mut raw = Fingerprinter::new();
391        raw.absorb_bytes(b"t", &v.to_bits().to_le_bytes());
392        let kraw = raw.finalize();
393
394        assert_ne!(kf, ku, "f64/u64 type confusion under a shared tag");
395        assert_ne!(kf, kraw, "f64/bytes type confusion under a shared tag");
396        assert_ne!(ku, kraw, "u64/bytes type confusion under a shared tag");
397
398        let mut s = Fingerprinter::new();
399        s.absorb_str(b"k", "AB");
400        let ks = s.finalize();
401        let mut sb = Fingerprinter::new();
402        sb.absorb_bytes(b"k", b"AB");
403        let ksb = sb.finalize();
404        assert_ne!(ks, ksb, "str/bytes type confusion under a shared tag");
405
406        // A bare structural tag must not alias an empty-payload absorption.
407        let mut t = Fingerprinter::new();
408        t.absorb_tag(b"sep");
409        let kt = t.finalize();
410        let mut e = Fingerprinter::new();
411        e.absorb_bytes(b"sep", b"");
412        let ke = e.finalize();
413        assert_ne!(kt, ke, "tag/empty-bytes confusion under a shared tag");
414    }
415
416    #[test]
417    fn different_data_yields_different_keys() {
418        let mut a = Fingerprinter::new();
419        a.absorb_f64_slice(b"y", &[1.0, 2.0]);
420        let mut b = Fingerprinter::new();
421        b.absorb_f64_slice(b"y", &[1.0, 3.0]);
422        assert_ne!(a.finalize(), b.finalize());
423    }
424
425    #[test]
426    fn same_input_yields_same_key() {
427        let mut a = Fingerprinter::new();
428        a.absorb_str(b"f", "binomial");
429        a.absorb_f64_2d(b"x", 2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
430        let mut b = Fingerprinter::new();
431        b.absorb_str(b"f", "binomial");
432        b.absorb_f64_2d(b"x", 2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
433        assert_eq!(a.finalize(), b.finalize());
434    }
435
436    #[test]
437    fn write_f64_slice_bulk_matches_element_protocol() {
438        fn pseudo_random_values(n: usize) -> Vec<f64> {
439            let mut state = 0x4d59_5df4_d0f3_3173_u64;
440            let mut values = Vec::with_capacity(n);
441            for idx in 0..n {
442                state = state
443                    .wrapping_mul(6364136223846793005)
444                    .wrapping_add(1442695040888963407);
445                let mantissa = state >> 12;
446                let unit = f64::from_bits(0x3ff0_0000_0000_0000 | mantissa) - 1.0;
447                values.push((unit - 0.5) * ((idx % 17) as f64 + 1.0));
448            }
449            values
450        }
451
452        fn fast_key(values: &[f64]) -> Fingerprint {
453            let mut fp = Fingerprinter::new();
454            fp.write_str("write_f64_slice_bulk_matches_element_protocol");
455            fp.write_f64_slice(values);
456            fp.finalize()
457        }
458
459        fn slow_key(values: &[f64]) -> Fingerprint {
460            let mut fp = Fingerprinter::new();
461            fp.write_str("write_f64_slice_bulk_matches_element_protocol");
462            fp.write_usize(values.len());
463            fp.write_f64_slice_payload_slow(values);
464            fp.finalize()
465        }
466
467        let clean = pseudo_random_values(257);
468        assert_eq!(fast_key(&clean), slow_key(&clean));
469
470        let mut normalized = clean.clone();
471        normalized[7] = -0.0;
472        normalized[113] = f64::from_bits(0x7ff8_0000_0000_0042);
473        assert_eq!(fast_key(&normalized), slow_key(&normalized));
474    }
475
476    #[test]
477    fn write_f64_arrays_match_element_protocol() {
478        let values = ndarray::Array2::from_shape_vec(
479            (3, 4),
480            vec![
481                1.25,
482                -2.5,
483                3.75,
484                4.0,
485                5.5,
486                -0.0,
487                7.25,
488                8.5,
489                9.75,
490                10.0,
491                f64::from_bits(0x7ff8_0000_0000_0100),
492                12.25,
493            ],
494        )
495        .expect("test array shape is valid");
496
497        let mut fast = Fingerprinter::new();
498        fast.write_str("write_f64_arrays_match_element_protocol");
499        fast.write_f64_array2(&values);
500
501        let mut slow = Fingerprinter::new();
502        slow.write_str("write_f64_arrays_match_element_protocol");
503        slow.write_usize(values.nrows());
504        slow.write_usize(values.ncols());
505        slow.write_f64_slice_payload_slow_iter(values.iter().copied());
506
507        assert_eq!(fast.finalize(), slow.finalize());
508    }
509
510    #[test]
511    fn fingerprint_serde_roundtrips_as_hex() {
512        let mut fp = Fingerprinter::new();
513        fp.absorb_str(b"k", "fingerprint-serde");
514        let key = fp.finalize();
515        let json = serde_json::to_string(&key).expect("serialize");
516        // Serialized form is the canonical quoted hex string.
517        assert_eq!(json, format!("\"{}\"", key.to_hex()));
518        let back: Fingerprint = serde_json::from_str(&json).expect("deserialize");
519        assert_eq!(key, back);
520        // A malformed hex payload is rejected, not silently aliased.
521        assert!(serde_json::from_str::<Fingerprint>("\"not-hex\"").is_err());
522    }
523
524    #[test]
525    fn invalid_hex_rejected() {
526        assert!(Fingerprint::from_hex("not hex").is_none());
527        assert!(Fingerprint::from_hex(&"a".repeat(63)).is_none());
528        assert!(Fingerprint::from_hex(&"z".repeat(64)).is_none());
529    }
530}