amadeus_utils/
blake3.rs

1/// Translated from https://github.com/vans163/blake3
2/// Infallible implementation of Blake3 hashing algorithm
3pub struct Hasher(blake3::Hasher);
4
5impl Default for Hasher {
6    fn default() -> Self {
7        Self::new()
8    }
9}
10
11impl Hasher {
12    pub fn new() -> Self {
13        Self(blake3::Hasher::new())
14    }
15    pub fn new_keyed(key: &[u8; 32]) -> Self {
16        Self(blake3::Hasher::new_keyed(key))
17    }
18    pub fn update(&mut self, buf: &[u8]) -> &mut blake3::Hasher {
19        self.0.update(buf)
20    }
21    #[cfg(feature = "rayon")]
22    pub fn update_rayon(&mut self, buf: &[u8]) {
23        self.0.update_rayon(buf);
24    }
25    #[cfg(not(feature = "rayon"))]
26    pub fn update_rayon(&mut self, _buf: &[u8]) {
27        panic!("Blake3.update_rayon() called without rayon feature enabled");
28    }
29    pub fn reset(&mut self) {
30        self.0.reset();
31    }
32    pub fn finalize(&self) -> [u8; 32] {
33        self.0.finalize().as_bytes().to_owned()
34    }
35    pub fn finalize_xof(&self, output_size: usize) -> Vec<u8> {
36        let mut out = vec![0u8; output_size];
37        let mut x = self.0.finalize_xof();
38        x.fill(&mut out);
39        out
40    }
41    pub fn finalize_xof_reader(&self) -> blake3::OutputReader {
42        self.0.finalize_xof()
43    }
44}
45
46pub fn hash(buf: &[u8]) -> [u8; 32] {
47    blake3::hash(buf).as_bytes().to_owned()
48}
49
50pub fn derive_key(context: &str, input_key: &[u8]) -> [u8; 32] {
51    blake3::derive_key(context, input_key)
52}
53
54pub fn keyed_hash(key: &[u8; 32], buf: &[u8]) -> [u8; 32] {
55    blake3::keyed_hash(key, buf).as_bytes().to_owned()
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61
62    #[test]
63    fn hash_and_hasher_consistency() {
64        let data = b"hello world";
65        let one_shot = hash(data);
66        let mut h = Hasher::new();
67        h.update(b"hello");
68        h.update(b" world");
69        let inc = h.finalize();
70        assert_eq!(one_shot, inc);
71
72        // XOF length and prefix check
73        let xof = h.finalize_xof(64);
74        assert_eq!(xof.len(), 64);
75        assert_eq!(inc.as_slice(), &xof[..32]);
76
77        // compare with crate reference
78        assert_eq!(one_shot, blake3::hash(data).as_bytes().to_owned());
79    }
80
81    #[test]
82    fn derive_and_keyed_hash_match_reference() {
83        let key = [7u8; 32];
84        let msg = b"abc";
85        let kd = derive_key("context7:test", b"input_key");
86        assert_eq!(kd, blake3::derive_key("context7:test", b"input_key"));
87
88        let ours = keyed_hash(&key, msg);
89        let theirs = blake3::keyed_hash(&key, msg).as_bytes().to_owned();
90        assert_eq!(ours, theirs);
91    }
92
93    #[test]
94    fn freivalds_is_deterministic() {
95        // minimal tensor length to satisfy internal slicing: 240 head + 1024 tail
96        let tensor = vec![0u8; 240 + 1024];
97        let a = freivalds(&tensor);
98        let b = freivalds(&tensor);
99        assert_eq!(a, b);
100    }
101}
102
103#[cfg(target_arch = "x86_64")]
104use std::arch::x86_64::*;
105use std::{cell::RefCell, mem, mem::MaybeUninit, ptr, slice};
106
107// Capitalized field names match mathematical matrix notation
108#[allow(non_snake_case)]
109#[repr(C, align(4096))]
110struct AMAMatMul {
111    pub A: [[u8; 50240]; 16],
112    pub B: [[i8; 16]; 50240],
113    pub B2: [[i8; 64]; 16],
114    pub Rs: [[i8; 16]; 3],
115    pub C: [[i32; 16]; 16],
116}
117
118thread_local! {
119    static SCRATCH: RefCell<Option<Box<AMAMatMul>>> = const { RefCell::new(None) };
120}
121
122struct ScratchGuard {
123    buf: Option<Box<AMAMatMul>>,
124}
125
126impl std::ops::Deref for ScratchGuard {
127    type Target = AMAMatMul;
128    fn deref(&self) -> &Self::Target {
129        self.buf.as_ref().expect("buffer disappeared")
130    }
131}
132impl std::ops::DerefMut for ScratchGuard {
133    fn deref_mut(&mut self) -> &mut Self::Target {
134        self.buf.as_mut().expect("buffer disappeared")
135    }
136}
137impl Drop for ScratchGuard {
138    fn drop(&mut self) {
139        if let Some(buf) = self.buf.take() {
140            SCRATCH.with(|tls| *tls.borrow_mut() = Some(buf));
141        }
142    }
143}
144
145/// Obtain the per‑thread scratch buffer, allocating it the first time.
146fn borrow_scratch() -> ScratchGuard {
147    SCRATCH.with(|tls| {
148        let mut slot = tls.borrow_mut();
149        let buf = slot.take().unwrap_or_else(|| {
150            // first time on this thread: allocate **uninitialised** memory
151            let boxed_uninit: Box<MaybeUninit<AMAMatMul>> = Box::new_uninit(); // ≈ zero cost for the OS here
152            // SAFETY: we promise to fully overwrite every byte before reading
153            unsafe { boxed_uninit.assume_init() }
154        });
155        ScratchGuard { buf: Some(buf) }
156    })
157}
158
159pub fn freivalds(tensor: &[u8]) -> bool {
160    let mut scratch = borrow_scratch();
161
162    let mut hasher = blake3::Hasher::new();
163    hasher.update(&tensor[..240]);
164    let mut xof = hasher.finalize_xof();
165
166    let head_bytes = 16 * 50_240         // A
167        + 50_240 * 16         // B
168        + 16 * 64             // B2
169        + 3 * 16; // Rs
170
171    unsafe {
172        let dest = ptr::slice_from_raw_parts_mut((&mut scratch.A) as *mut _ as *mut u8, head_bytes) as *mut [u8];
173        xof.fill(&mut *dest);
174    }
175
176    let tail = &tensor[tensor.len() - 1024..];
177    unsafe {
178        let dst = &mut scratch.C as *mut _ as *mut u8;
179        ptr::copy_nonoverlapping(tail.as_ptr(), dst, 1024);
180    }
181
182    freivalds_inner(&scratch.Rs, &scratch.A, &scratch.B, &scratch.C)
183}
184
185pub fn freivalds_e260(tensor: &[u8], vr_b3: &[u8]) -> bool {
186    let mut scratch = borrow_scratch();
187
188    let head = &tensor[..240];
189    let tail = &tensor[tensor.len() - 1024..];
190
191    let mut hasher = blake3::Hasher::new();
192    hasher.update(head);
193    let mut xof = hasher.finalize_xof();
194
195    let ab_bytes = 16 * 50_240           // A
196        + 50_240 * 16         // B
197        + 16 * 64; // B2
198
199    unsafe {
200        let dest = ptr::slice_from_raw_parts_mut((&mut scratch.A) as *mut _ as *mut u8, ab_bytes) as *mut [u8];
201        xof.fill(&mut *dest);
202    }
203
204    unsafe {
205        let dst = &mut scratch.C as *mut _ as *mut u8;
206        ptr::copy_nonoverlapping(tail.as_ptr(), dst, 1024);
207    }
208
209    //Take R from entire sol + VRF
210    let mut hasher_rs = blake3::Hasher::new();
211    hasher_rs.update(tensor);
212    hasher_rs.update(vr_b3);
213    let mut xof_rs = hasher_rs.finalize_xof();
214
215    unsafe {
216        let p = (&mut scratch.Rs) as *mut _ as *mut u8;
217        let n = mem::size_of_val(&scratch.Rs);
218        let dst = slice::from_raw_parts_mut(p, n);
219        xof_rs.fill(dst);
220    }
221
222    freivalds_inner(&scratch.Rs, &scratch.A, &scratch.B, &scratch.C)
223}
224
225// Capitalized parameters match mathematical matrix notation
226#[allow(non_snake_case)]
227pub fn freivalds_inner(
228    Rs: &[[i8; 16]; 3],
229    A: &[[u8; 50_240]; 16],
230    B: &[[i8; 16]; 50_240],
231    C: &[[i32; 16]; 16],
232) -> bool {
233    #[cfg(target_arch = "x86_64")]
234    {
235        if std::is_x86_feature_detected!("avx2") {
236            unsafe {
237                return freivalds_inner_avx2(Rs, A, B, C);
238            }
239        }
240    }
241    freivalds_inner_scalar(Rs, A, B, C)
242}
243
244#[cfg(target_arch = "x86_64")]
245#[inline(always)]
246unsafe fn hsum256_epi32(v: __m256i) -> i32 {
247    // reduce 8 × i32 → scalar
248    unsafe {
249        let hi = _mm256_extracti128_si256(v, 1);
250        let lo = _mm256_castsi256_si128(v);
251        let sum128 = _mm_add_epi32(lo, hi); // 4 lanes
252        let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
253        let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
254        _mm_cvtsi128_si32(sum32)
255    }
256}
257
258#[cfg(target_arch = "x86_64")]
259#[repr(C)]
260struct I32x16 {
261    lo: __m256i,
262    hi: __m256i,
263}
264
265/// Load 16×i8 and sign‑extend to 16×i32 (as two 256‑bit halves)
266#[cfg(target_arch = "x86_64")]
267#[inline(always)]
268unsafe fn load_i8x16_as_i32(ptr: *const i8) -> I32x16 {
269    // load 16 bytes
270    unsafe {
271        let v = _mm_loadu_si128(ptr as *const __m128i);
272        let lo = _mm256_cvtepi8_epi32(v); // first 8
273        let hi = _mm256_cvtepi8_epi32(_mm_srli_si128(v, 8));
274        I32x16 { lo, hi }
275    }
276}
277
278#[cfg(target_arch = "x86_64")]
279#[target_feature(enable = "avx2")]
280#[allow(non_snake_case)]
281pub unsafe fn freivalds_inner_avx2(
282    Rs: &[[i8; 16]; 3],
283    A: &[[u8; 50_240]; 16],
284    B: &[[i8; 16]; 50_240],
285    C: &[[i32; 16]; 16],
286) -> bool {
287    // the *body* is exactly what we previously had in `freivalds_inner_avx2`
288    // (helpers like `hsum256_epi32` go below, unchanged)
289    // ------------------------------------------------------------------ //
290    unsafe {
291        const N: usize = 50_240;
292        let mut U = [[0i32; 16]; 3];
293
294        // --- Stage 1: U = C × R --------------------------------------------------
295        let r0_i32 = load_i8x16_as_i32(Rs[0].as_ptr());
296        let r1_i32 = load_i8x16_as_i32(Rs[1].as_ptr());
297        let r2_i32 = load_i8x16_as_i32(Rs[2].as_ptr());
298
299        for i in 0..16 {
300            let c_lo = _mm256_loadu_si256(C[i].as_ptr() as *const __m256i);
301            let c_hi = _mm256_loadu_si256(C[i].as_ptr().add(8) as *const __m256i);
302
303            let u0 = _mm256_add_epi32(_mm256_mullo_epi32(c_lo, r0_i32.lo), _mm256_mullo_epi32(c_hi, r0_i32.hi));
304            let u1 = _mm256_add_epi32(_mm256_mullo_epi32(c_lo, r1_i32.lo), _mm256_mullo_epi32(c_hi, r1_i32.hi));
305            let u2 = _mm256_add_epi32(_mm256_mullo_epi32(c_lo, r2_i32.lo), _mm256_mullo_epi32(c_hi, r2_i32.hi));
306
307            U[0][i] = hsum256_epi32(u0);
308            U[1][i] = hsum256_epi32(u1);
309            U[2][i] = hsum256_epi32(u2);
310        }
311
312        // --- Stage 2: P(k) = B[k]·R -------------------------------------------
313        let mut P0 = vec![0i32; N];
314        let mut P1 = vec![0i32; N];
315        let mut P2 = vec![0i32; N];
316
317        let r0_i16 = _mm256_cvtepi8_epi16(_mm_loadu_si128(Rs[0].as_ptr() as *const _));
318        let r1_i16 = _mm256_cvtepi8_epi16(_mm_loadu_si128(Rs[1].as_ptr() as *const _));
319        let r2_i16 = _mm256_cvtepi8_epi16(_mm_loadu_si128(Rs[2].as_ptr() as *const _));
320
321        for k in 0..N {
322            let row_i16 = _mm256_cvtepi8_epi16(_mm_loadu_si128(B[k].as_ptr() as *const _));
323
324            P0[k] = hsum256_epi32(_mm256_madd_epi16(row_i16, r0_i16));
325            P1[k] = hsum256_epi32(_mm256_madd_epi16(row_i16, r1_i16));
326            P2[k] = hsum256_epi32(_mm256_madd_epi16(row_i16, r2_i16));
327        }
328
329        // --- Stage 3: dot( A[i], P ) --------------------------------------------
330        for i in 0..16 {
331            let mut acc0 = _mm256_setzero_si256();
332            let mut acc1 = _mm256_setzero_si256();
333            let mut acc2 = _mm256_setzero_si256();
334
335            for k in (0..N).step_by(8) {
336                let a_i32 = _mm256_cvtepu8_epi32(_mm_loadl_epi64(A[i].as_ptr().add(k) as *const _));
337                let p0 = _mm256_loadu_si256(P0.as_ptr().add(k) as *const _);
338                let p1 = _mm256_loadu_si256(P1.as_ptr().add(k) as *const _);
339                let p2 = _mm256_loadu_si256(P2.as_ptr().add(k) as *const _);
340
341                acc0 = _mm256_add_epi32(acc0, _mm256_mullo_epi32(a_i32, p0));
342                acc1 = _mm256_add_epi32(acc1, _mm256_mullo_epi32(a_i32, p1));
343                acc2 = _mm256_add_epi32(acc2, _mm256_mullo_epi32(a_i32, p2));
344            }
345
346            if hsum256_epi32(acc0) != U[0][i] || hsum256_epi32(acc1) != U[1][i] || hsum256_epi32(acc2) != U[2][i] {
347                return false;
348            }
349        }
350        true
351    }
352}
353
354#[allow(non_snake_case)]
355fn freivalds_inner_scalar(
356    Rs: &[[i8; 16]; 3],
357    A: &[[u8; 50_240]; 16],
358    B: &[[i8; 16]; 50_240],
359    C: &[[i32; 16]; 16],
360) -> bool {
361    let mut U = [[0i32; 16]; 3];
362    for r in 0..3 {
363        for i in 0..16 {
364            let mut sum = 0i32;
365            for j in 0..16 {
366                sum = sum.wrapping_add(C[i][j].wrapping_mul(Rs[r][j] as i32));
367            }
368            U[r][i] = sum;
369        }
370    }
371
372    let mut P = [[0i32; 3]; 50_240];
373    for k in 0..50_240 {
374        let row = &B[k];
375        let mut s0 = 0i32;
376        let mut s1 = 0i32;
377        let mut s2 = 0i32;
378        for j in 0..16 {
379            let b = row[j] as i32;
380            s0 = s0.wrapping_add(b.wrapping_mul(Rs[0][j] as i32));
381            s1 = s1.wrapping_add(b.wrapping_mul(Rs[1][j] as i32));
382            s2 = s2.wrapping_add(b.wrapping_mul(Rs[2][j] as i32));
383        }
384        P[k][0] = s0;
385        P[k][1] = s1;
386        P[k][2] = s2;
387    }
388
389    for i in 0..16 {
390        let rowA = &A[i];
391        let mut v0 = 0i32;
392        let mut v1 = 0i32;
393        let mut v2 = 0i32;
394        for k in 0..50_240 {
395            let a = rowA[k] as i32;
396            let p = P[k];
397            v0 = v0.wrapping_add(a.wrapping_mul(p[0]));
398            v1 = v1.wrapping_add(a.wrapping_mul(p[1]));
399            v2 = v2.wrapping_add(a.wrapping_mul(p[2]));
400        }
401        if v0 != U[0][i] || v1 != U[1][i] || v2 != U[2][i] {
402            return false;
403        }
404    }
405
406    true
407}