1pub struct Hasher(blake3::Hasher);
4
5impl Default for Hasher {
6 fn default() -> Self {
7 Self::new()
8 }
9}
10
11impl Hasher {
12 pub fn new() -> Self {
13 Self(blake3::Hasher::new())
14 }
15 pub fn new_keyed(key: &[u8; 32]) -> Self {
16 Self(blake3::Hasher::new_keyed(key))
17 }
18 pub fn update(&mut self, buf: &[u8]) -> &mut blake3::Hasher {
19 self.0.update(buf)
20 }
21 #[cfg(feature = "rayon")]
22 pub fn update_rayon(&mut self, buf: &[u8]) {
23 self.0.update_rayon(buf);
24 }
25 #[cfg(not(feature = "rayon"))]
26 pub fn update_rayon(&mut self, _buf: &[u8]) {
27 panic!("Blake3.update_rayon() called without rayon feature enabled");
28 }
29 pub fn reset(&mut self) {
30 self.0.reset();
31 }
32 pub fn finalize(&self) -> [u8; 32] {
33 self.0.finalize().as_bytes().to_owned()
34 }
35 pub fn finalize_xof(&self, output_size: usize) -> Vec<u8> {
36 let mut out = vec![0u8; output_size];
37 let mut x = self.0.finalize_xof();
38 x.fill(&mut out);
39 out
40 }
41 pub fn finalize_xof_reader(&self) -> blake3::OutputReader {
42 self.0.finalize_xof()
43 }
44}
45
46pub fn hash(buf: &[u8]) -> [u8; 32] {
47 blake3::hash(buf).as_bytes().to_owned()
48}
49
50pub fn derive_key(context: &str, input_key: &[u8]) -> [u8; 32] {
51 blake3::derive_key(context, input_key)
52}
53
54pub fn keyed_hash(key: &[u8; 32], buf: &[u8]) -> [u8; 32] {
55 blake3::keyed_hash(key, buf).as_bytes().to_owned()
56}
57
58#[cfg(test)]
59mod tests {
60 use super::*;
61
62 #[test]
63 fn hash_and_hasher_consistency() {
64 let data = b"hello world";
65 let one_shot = hash(data);
66 let mut h = Hasher::new();
67 h.update(b"hello");
68 h.update(b" world");
69 let inc = h.finalize();
70 assert_eq!(one_shot, inc);
71
72 let xof = h.finalize_xof(64);
74 assert_eq!(xof.len(), 64);
75 assert_eq!(inc.as_slice(), &xof[..32]);
76
77 assert_eq!(one_shot, blake3::hash(data).as_bytes().to_owned());
79 }
80
81 #[test]
82 fn derive_and_keyed_hash_match_reference() {
83 let key = [7u8; 32];
84 let msg = b"abc";
85 let kd = derive_key("context7:test", b"input_key");
86 assert_eq!(kd, blake3::derive_key("context7:test", b"input_key"));
87
88 let ours = keyed_hash(&key, msg);
89 let theirs = blake3::keyed_hash(&key, msg).as_bytes().to_owned();
90 assert_eq!(ours, theirs);
91 }
92
93 #[test]
94 fn freivalds_is_deterministic() {
95 let tensor = vec![0u8; 240 + 1024];
97 let a = freivalds(&tensor);
98 let b = freivalds(&tensor);
99 assert_eq!(a, b);
100 }
101}
102
103#[cfg(target_arch = "x86_64")]
104use std::arch::x86_64::*;
105use std::{cell::RefCell, mem, mem::MaybeUninit, ptr, slice};
106
107#[allow(non_snake_case)]
109#[repr(C, align(4096))]
110struct AMAMatMul {
111 pub A: [[u8; 50240]; 16],
112 pub B: [[i8; 16]; 50240],
113 pub B2: [[i8; 64]; 16],
114 pub Rs: [[i8; 16]; 3],
115 pub C: [[i32; 16]; 16],
116}
117
118thread_local! {
119 static SCRATCH: RefCell<Option<Box<AMAMatMul>>> = const { RefCell::new(None) };
120}
121
122struct ScratchGuard {
123 buf: Option<Box<AMAMatMul>>,
124}
125
126impl std::ops::Deref for ScratchGuard {
127 type Target = AMAMatMul;
128 fn deref(&self) -> &Self::Target {
129 self.buf.as_ref().expect("buffer disappeared")
130 }
131}
132impl std::ops::DerefMut for ScratchGuard {
133 fn deref_mut(&mut self) -> &mut Self::Target {
134 self.buf.as_mut().expect("buffer disappeared")
135 }
136}
137impl Drop for ScratchGuard {
138 fn drop(&mut self) {
139 if let Some(buf) = self.buf.take() {
140 SCRATCH.with(|tls| *tls.borrow_mut() = Some(buf));
141 }
142 }
143}
144
145fn borrow_scratch() -> ScratchGuard {
147 SCRATCH.with(|tls| {
148 let mut slot = tls.borrow_mut();
149 let buf = slot.take().unwrap_or_else(|| {
150 let boxed_uninit: Box<MaybeUninit<AMAMatMul>> = Box::new_uninit(); unsafe { boxed_uninit.assume_init() }
154 });
155 ScratchGuard { buf: Some(buf) }
156 })
157}
158
159pub fn freivalds(tensor: &[u8]) -> bool {
160 let mut scratch = borrow_scratch();
161
162 let mut hasher = blake3::Hasher::new();
163 hasher.update(&tensor[..240]);
164 let mut xof = hasher.finalize_xof();
165
166 let head_bytes = 16 * 50_240 + 50_240 * 16 + 16 * 64 + 3 * 16; unsafe {
172 let dest = ptr::slice_from_raw_parts_mut((&mut scratch.A) as *mut _ as *mut u8, head_bytes) as *mut [u8];
173 xof.fill(&mut *dest);
174 }
175
176 let tail = &tensor[tensor.len() - 1024..];
177 unsafe {
178 let dst = &mut scratch.C as *mut _ as *mut u8;
179 ptr::copy_nonoverlapping(tail.as_ptr(), dst, 1024);
180 }
181
182 freivalds_inner(&scratch.Rs, &scratch.A, &scratch.B, &scratch.C)
183}
184
185pub fn freivalds_e260(tensor: &[u8], vr_b3: &[u8]) -> bool {
186 let mut scratch = borrow_scratch();
187
188 let head = &tensor[..240];
189 let tail = &tensor[tensor.len() - 1024..];
190
191 let mut hasher = blake3::Hasher::new();
192 hasher.update(head);
193 let mut xof = hasher.finalize_xof();
194
195 let ab_bytes = 16 * 50_240 + 50_240 * 16 + 16 * 64; unsafe {
200 let dest = ptr::slice_from_raw_parts_mut((&mut scratch.A) as *mut _ as *mut u8, ab_bytes) as *mut [u8];
201 xof.fill(&mut *dest);
202 }
203
204 unsafe {
205 let dst = &mut scratch.C as *mut _ as *mut u8;
206 ptr::copy_nonoverlapping(tail.as_ptr(), dst, 1024);
207 }
208
209 let mut hasher_rs = blake3::Hasher::new();
211 hasher_rs.update(tensor);
212 hasher_rs.update(vr_b3);
213 let mut xof_rs = hasher_rs.finalize_xof();
214
215 unsafe {
216 let p = (&mut scratch.Rs) as *mut _ as *mut u8;
217 let n = mem::size_of_val(&scratch.Rs);
218 let dst = slice::from_raw_parts_mut(p, n);
219 xof_rs.fill(dst);
220 }
221
222 freivalds_inner(&scratch.Rs, &scratch.A, &scratch.B, &scratch.C)
223}
224
225#[allow(non_snake_case)]
227pub fn freivalds_inner(
228 Rs: &[[i8; 16]; 3],
229 A: &[[u8; 50_240]; 16],
230 B: &[[i8; 16]; 50_240],
231 C: &[[i32; 16]; 16],
232) -> bool {
233 #[cfg(target_arch = "x86_64")]
234 {
235 if std::is_x86_feature_detected!("avx2") {
236 unsafe {
237 return freivalds_inner_avx2(Rs, A, B, C);
238 }
239 }
240 }
241 freivalds_inner_scalar(Rs, A, B, C)
242}
243
244#[cfg(target_arch = "x86_64")]
245#[inline(always)]
246unsafe fn hsum256_epi32(v: __m256i) -> i32 {
247 unsafe {
249 let hi = _mm256_extracti128_si256(v, 1);
250 let lo = _mm256_castsi256_si128(v);
251 let sum128 = _mm_add_epi32(lo, hi); let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
253 let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
254 _mm_cvtsi128_si32(sum32)
255 }
256}
257
258#[cfg(target_arch = "x86_64")]
259#[repr(C)]
260struct I32x16 {
261 lo: __m256i,
262 hi: __m256i,
263}
264
265#[cfg(target_arch = "x86_64")]
267#[inline(always)]
268unsafe fn load_i8x16_as_i32(ptr: *const i8) -> I32x16 {
269 unsafe {
271 let v = _mm_loadu_si128(ptr as *const __m128i);
272 let lo = _mm256_cvtepi8_epi32(v); let hi = _mm256_cvtepi8_epi32(_mm_srli_si128(v, 8));
274 I32x16 { lo, hi }
275 }
276}
277
278#[cfg(target_arch = "x86_64")]
279#[target_feature(enable = "avx2")]
280#[allow(non_snake_case)]
281pub unsafe fn freivalds_inner_avx2(
282 Rs: &[[i8; 16]; 3],
283 A: &[[u8; 50_240]; 16],
284 B: &[[i8; 16]; 50_240],
285 C: &[[i32; 16]; 16],
286) -> bool {
287 unsafe {
291 const N: usize = 50_240;
292 let mut U = [[0i32; 16]; 3];
293
294 let r0_i32 = load_i8x16_as_i32(Rs[0].as_ptr());
296 let r1_i32 = load_i8x16_as_i32(Rs[1].as_ptr());
297 let r2_i32 = load_i8x16_as_i32(Rs[2].as_ptr());
298
299 for i in 0..16 {
300 let c_lo = _mm256_loadu_si256(C[i].as_ptr() as *const __m256i);
301 let c_hi = _mm256_loadu_si256(C[i].as_ptr().add(8) as *const __m256i);
302
303 let u0 = _mm256_add_epi32(_mm256_mullo_epi32(c_lo, r0_i32.lo), _mm256_mullo_epi32(c_hi, r0_i32.hi));
304 let u1 = _mm256_add_epi32(_mm256_mullo_epi32(c_lo, r1_i32.lo), _mm256_mullo_epi32(c_hi, r1_i32.hi));
305 let u2 = _mm256_add_epi32(_mm256_mullo_epi32(c_lo, r2_i32.lo), _mm256_mullo_epi32(c_hi, r2_i32.hi));
306
307 U[0][i] = hsum256_epi32(u0);
308 U[1][i] = hsum256_epi32(u1);
309 U[2][i] = hsum256_epi32(u2);
310 }
311
312 let mut P0 = vec![0i32; N];
314 let mut P1 = vec![0i32; N];
315 let mut P2 = vec![0i32; N];
316
317 let r0_i16 = _mm256_cvtepi8_epi16(_mm_loadu_si128(Rs[0].as_ptr() as *const _));
318 let r1_i16 = _mm256_cvtepi8_epi16(_mm_loadu_si128(Rs[1].as_ptr() as *const _));
319 let r2_i16 = _mm256_cvtepi8_epi16(_mm_loadu_si128(Rs[2].as_ptr() as *const _));
320
321 for k in 0..N {
322 let row_i16 = _mm256_cvtepi8_epi16(_mm_loadu_si128(B[k].as_ptr() as *const _));
323
324 P0[k] = hsum256_epi32(_mm256_madd_epi16(row_i16, r0_i16));
325 P1[k] = hsum256_epi32(_mm256_madd_epi16(row_i16, r1_i16));
326 P2[k] = hsum256_epi32(_mm256_madd_epi16(row_i16, r2_i16));
327 }
328
329 for i in 0..16 {
331 let mut acc0 = _mm256_setzero_si256();
332 let mut acc1 = _mm256_setzero_si256();
333 let mut acc2 = _mm256_setzero_si256();
334
335 for k in (0..N).step_by(8) {
336 let a_i32 = _mm256_cvtepu8_epi32(_mm_loadl_epi64(A[i].as_ptr().add(k) as *const _));
337 let p0 = _mm256_loadu_si256(P0.as_ptr().add(k) as *const _);
338 let p1 = _mm256_loadu_si256(P1.as_ptr().add(k) as *const _);
339 let p2 = _mm256_loadu_si256(P2.as_ptr().add(k) as *const _);
340
341 acc0 = _mm256_add_epi32(acc0, _mm256_mullo_epi32(a_i32, p0));
342 acc1 = _mm256_add_epi32(acc1, _mm256_mullo_epi32(a_i32, p1));
343 acc2 = _mm256_add_epi32(acc2, _mm256_mullo_epi32(a_i32, p2));
344 }
345
346 if hsum256_epi32(acc0) != U[0][i] || hsum256_epi32(acc1) != U[1][i] || hsum256_epi32(acc2) != U[2][i] {
347 return false;
348 }
349 }
350 true
351 }
352}
353
354#[allow(non_snake_case)]
355fn freivalds_inner_scalar(
356 Rs: &[[i8; 16]; 3],
357 A: &[[u8; 50_240]; 16],
358 B: &[[i8; 16]; 50_240],
359 C: &[[i32; 16]; 16],
360) -> bool {
361 let mut U = [[0i32; 16]; 3];
362 for r in 0..3 {
363 for i in 0..16 {
364 let mut sum = 0i32;
365 for j in 0..16 {
366 sum = sum.wrapping_add(C[i][j].wrapping_mul(Rs[r][j] as i32));
367 }
368 U[r][i] = sum;
369 }
370 }
371
372 let mut P = [[0i32; 3]; 50_240];
373 for k in 0..50_240 {
374 let row = &B[k];
375 let mut s0 = 0i32;
376 let mut s1 = 0i32;
377 let mut s2 = 0i32;
378 for j in 0..16 {
379 let b = row[j] as i32;
380 s0 = s0.wrapping_add(b.wrapping_mul(Rs[0][j] as i32));
381 s1 = s1.wrapping_add(b.wrapping_mul(Rs[1][j] as i32));
382 s2 = s2.wrapping_add(b.wrapping_mul(Rs[2][j] as i32));
383 }
384 P[k][0] = s0;
385 P[k][1] = s1;
386 P[k][2] = s2;
387 }
388
389 for i in 0..16 {
390 let rowA = &A[i];
391 let mut v0 = 0i32;
392 let mut v1 = 0i32;
393 let mut v2 = 0i32;
394 for k in 0..50_240 {
395 let a = rowA[k] as i32;
396 let p = P[k];
397 v0 = v0.wrapping_add(a.wrapping_mul(p[0]));
398 v1 = v1.wrapping_add(a.wrapping_mul(p[1]));
399 v2 = v2.wrapping_add(a.wrapping_mul(p[2]));
400 }
401 if v0 != U[0][i] || v1 != U[1][i] || v2 != U[2][i] {
402 return false;
403 }
404 }
405
406 true
407}