1use diskann_wide::Architecture;
7#[cfg(any(test, target_arch = "x86_64"))]
8use diskann_wide::{SIMDMulAdd, SIMDVector};
9use thiserror::Error;
10
11pub fn hadamard_transform(x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
23 diskann_wide::arch::dispatch1_no_features(HadamardTransform, x)
27}
28
29#[derive(Debug, Error)]
30#[error("Hadamard input vector must have a length that is a power of two")]
31pub struct NotPowerOfTwo;
32
33#[derive(Debug, Clone, Copy)]
45pub struct HadamardTransform;
46
47impl diskann_wide::arch::Target1<diskann_wide::arch::Scalar, Result<(), NotPowerOfTwo>, &mut [f32]>
48 for HadamardTransform
49{
50 #[inline(never)]
51 fn run(self, arch: diskann_wide::arch::Scalar, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
52 (HadamardTransformOuter).run(arch, x)
53 }
54}
55
56#[cfg(target_arch = "x86_64")]
57impl
58 diskann_wide::arch::Target1<
59 diskann_wide::arch::x86_64::V3,
60 Result<(), NotPowerOfTwo>,
61 &mut [f32],
62 > for HadamardTransform
63{
64 #[inline(never)]
65 fn run(self, arch: diskann_wide::arch::x86_64::V3, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
66 arch.run1(HadamardTransformOuter, x)
67 }
68}
69
70#[cfg(target_arch = "x86_64")]
71impl
72 diskann_wide::arch::Target1<
73 diskann_wide::arch::x86_64::V4,
74 Result<(), NotPowerOfTwo>,
75 &mut [f32],
76 > for HadamardTransform
77{
78 #[inline(never)]
79 fn run(self, arch: diskann_wide::arch::x86_64::V4, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
80 arch.retarget().run1(HadamardTransformOuter, x)
81 }
82}
83
84#[derive(Debug, Clone, Copy)]
89pub struct HadamardTransformOuter;
90
91impl<A> diskann_wide::arch::Target1<A, Result<(), NotPowerOfTwo>, &mut [f32]>
92 for HadamardTransformOuter
93where
94 A: diskann_wide::Architecture,
95 HadamardTransformRecursive: for<'a> diskann_wide::arch::Target1<A, (), &'a mut [f32]>,
96{
97 #[inline(always)]
98 fn run(self, arch: A, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
99 let len = x.len();
100
101 if !len.is_power_of_two() {
102 return Err(NotPowerOfTwo);
103 }
104
105 if len == 1 {
107 return Ok(());
108 }
109
110 arch.run1(HadamardTransformRecursive, x);
112
113 let m = 1.0 / (x.len() as f32).sqrt();
115 x.iter_mut().for_each(|i| *i *= m);
116
117 Ok(())
118 }
119}
120
121#[derive(Debug, Clone, Copy)]
122struct HadamardTransformRecursive;
123
124impl diskann_wide::arch::Target1<diskann_wide::arch::Scalar, (), &mut [f32]>
125 for HadamardTransformRecursive
126{
127 #[inline]
136 fn run(self, arch: diskann_wide::arch::Scalar, x: &mut [f32]) {
137 let len = x.len();
138 debug_assert!(len.is_power_of_two());
139 debug_assert!(len >= 2);
140
141 if len == 2 {
142 let l = x[0];
143 let r = x[1];
144 x[0] = l + r;
145 x[1] = l - r;
146 } else {
147 let (left, right) = x.split_at_mut(len / 2);
149
150 arch.run1(self, left);
151 arch.run1(self, right);
152
153 std::iter::zip(left.iter_mut(), right.iter_mut()).for_each(|(l, r)| {
154 let a = *l + *r;
155 let b = *l - *r;
156 *l = a;
157 *r = b;
158 });
159 }
160 }
161}
162
163#[cfg(target_arch = "x86_64")]
164impl diskann_wide::arch::Target1<diskann_wide::arch::x86_64::V3, (), &mut [f32]>
165 for HadamardTransformRecursive
166{
167 #[inline(always)]
176 fn run(self, arch: diskann_wide::arch::x86_64::V3, x: &mut [f32]) {
177 let len = x.len();
178 debug_assert!(len.is_power_of_two());
179 debug_assert!(len >= 2);
180
181 if let Ok(array) = <&mut [f32] as TryInto<&mut [f32; 64]>>::try_into(x) {
182 micro_kernel_64(arch, array);
188 } else if len == 2 {
189 let l = x[0];
192 let r = x[1];
193 x[0] = l + r;
194 x[1] = l - r;
195 } else {
196 let (left, right) = x.split_at_mut(len / 2);
198
199 arch.run1(self, left);
200 arch.run1(self, right);
201
202 std::iter::zip(left.iter_mut(), right.iter_mut()).for_each(|(l, r)| {
203 let a = *l + *r;
204 let b = *l - *r;
205 *l = a;
206 *r = b;
207 });
208 }
209 }
210}
211
212#[cfg(any(test, target_arch = "x86_64"))]
214const HADAMARD_8: [[f32; 8]; 8] = [
215 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
216 [1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0],
217 [1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0],
218 [1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0],
219 [1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0],
220 [1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 1.0],
221 [1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0],
222 [1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0],
223];
224
225#[cfg(any(test, target_arch = "x86_64"))]
229#[inline(always)]
230fn micro_kernel_64<A>(arch: A, x: &mut [f32; 64])
231where
232 A: Architecture,
233{
234 let mut d0 = A::f32x8::splat(arch, 0.0);
236 let mut d1 = A::f32x8::splat(arch, 0.0);
237 let mut d2 = A::f32x8::splat(arch, 0.0);
238 let mut d3 = A::f32x8::splat(arch, 0.0);
239 let mut d4 = A::f32x8::splat(arch, 0.0);
240 let mut d5 = A::f32x8::splat(arch, 0.0);
241 let mut d6 = A::f32x8::splat(arch, 0.0);
242 let mut d7 = A::f32x8::splat(arch, 0.0);
243
244 let p: *const f32 = HADAMARD_8.as_ptr().cast();
245 let src: *const f32 = x.as_ptr();
246 let mut process_patch = |offset: usize| {
247 unsafe {
258 let c0 = A::f32x8::load_simd(arch, p.add(8 * offset));
259 let c1 = A::f32x8::load_simd(arch, p.add(8 * (offset + 1)));
260
261 let r0 = A::f32x8::splat(arch, src.add(offset).read());
262 let r1 = A::f32x8::splat(arch, src.add(offset + 8).read());
263 d0 = r0.mul_add_simd(c0, d0);
264 d1 = r1.mul_add_simd(c0, d1);
265
266 let r0 = A::f32x8::splat(arch, src.add(offset + 1).read());
267 let r1 = A::f32x8::splat(arch, src.add(offset + 9).read());
268 d0 = r0.mul_add_simd(c1, d0);
269 d1 = r1.mul_add_simd(c1, d1);
270
271 let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 2).read());
272 let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 3).read());
273 d2 = r0.mul_add_simd(c0, d2);
274 d3 = r1.mul_add_simd(c0, d3);
275
276 let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 2 + 1).read());
277 let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 3 + 1).read());
278 d2 = r0.mul_add_simd(c1, d2);
279 d3 = r1.mul_add_simd(c1, d3);
280
281 let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 4).read());
282 let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 5).read());
283 d4 = r0.mul_add_simd(c0, d4);
284 d5 = r1.mul_add_simd(c0, d5);
285
286 let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 4 + 1).read());
287 let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 5 + 1).read());
288 d4 = r0.mul_add_simd(c1, d4);
289 d5 = r1.mul_add_simd(c1, d5);
290
291 let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 6).read());
292 let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 7).read());
293 d6 = r0.mul_add_simd(c0, d6);
294 d7 = r1.mul_add_simd(c0, d7);
295
296 let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 6 + 1).read());
297 let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 7 + 1).read());
298 d6 = r0.mul_add_simd(c1, d6);
299 d7 = r1.mul_add_simd(c1, d7);
300 }
301 };
302
303 for o in 0..4 {
306 process_patch(2 * o);
307 }
308
309 let e0 = d0 + d1;
314 let e1 = d0 - d1;
315
316 let e2 = d2 + d3;
317 let e3 = d2 - d3;
318
319 let e4 = d4 + d5;
320 let e5 = d4 - d5;
321
322 let e6 = d6 + d7;
323 let e7 = d6 - d7;
324
325 let f0 = e0 + e2;
327 let f1 = e1 + e3;
328
329 let f2 = e0 - e2;
330 let f3 = e1 - e3;
331
332 let f4 = e4 + e6;
333 let f5 = e5 + e7;
334
335 let f6 = e4 - e6;
336 let f7 = e5 - e7;
337
338 let dst: *mut f32 = x.as_mut_ptr();
340
341 unsafe {
344 (f0 + f4).store_simd(dst);
345 (f1 + f5).store_simd(dst.add(8));
346 (f2 + f6).store_simd(dst.add(16));
347 (f3 + f7).store_simd(dst.add(24));
348 (f0 - f4).store_simd(dst.add(32));
349 (f1 - f5).store_simd(dst.add(40));
350 (f2 - f6).store_simd(dst.add(48));
351 (f3 - f7).store_simd(dst.add(56));
352 }
353}
354
355#[cfg(test)]
360mod tests {
361 use rand::{
362 distr::{Distribution, StandardUniform},
363 rngs::StdRng,
364 SeedableRng,
365 };
366
367 use super::*;
368 use diskann_utils::views::{self, Matrix, MatrixView};
369
370 fn get_hadamard_8() -> Matrix<f32> {
372 let v: Box<[f32]> = HADAMARD_8.iter().flatten().copied().collect();
373 Matrix::try_from(v, 8, 8).unwrap()
374 }
375
376 fn hadamard_by_sylvester(dim: usize) -> Matrix<f32> {
377 assert_ne!(dim, 0);
378 if dim == 1 {
380 Matrix::new(1.0, dim, dim)
381 } else {
382 let half = dim / 2;
383 let sub = hadamard_by_sylvester(half);
384 let mut m = Matrix::<f32>::new(0.0, dim, dim);
385
386 for c in 0..m.ncols() {
387 for r in 0..m.nrows() {
388 let mut v = sub[(r % half, c % half)];
389 if c >= half && r >= half {
390 v = -v;
391 }
392 m[(c, r)] = v;
393 }
394 }
395 m
396 }
397 }
398
399 #[test]
401 fn test_hadamard_8() {
402 let h8 = get_hadamard_8();
403 let reference = hadamard_by_sylvester(8);
404 assert_eq!(h8.as_slice(), reference.as_slice());
405 }
406
407 fn matmul(a: MatrixView<f32>, b: MatrixView<f32>) -> Matrix<f32> {
409 assert_eq!(a.ncols(), b.nrows());
410 let mut c = Matrix::new(0.0, a.nrows(), b.ncols());
411
412 for i in 0..c.nrows() {
413 for j in 0..c.ncols() {
414 let mut v = 0.0;
415 for k in 0..a.ncols() {
416 v = a[(i, k)].mul_add(b[(k, j)], v);
417 }
418 c[(i, j)] = v;
419 }
420 }
421 c
422 }
423
424 #[test]
425 fn test_micro_kernel_64() {
426 let mut src = {
427 let mut rng = StdRng::seed_from_u64(0xde1936d651285fc8);
428 let init = views::Init(|| StandardUniform {}.sample(&mut rng));
429 Matrix::new(init, 64, 1)
430 };
431
432 let h = hadamard_by_sylvester(64);
433 let reference = matmul(h.as_view(), src.as_view());
434
435 micro_kernel_64(diskann_wide::ARCH, src.as_mut_slice().try_into().unwrap());
436
437 assert_eq!(reference.nrows(), src.nrows());
438 assert_eq!(reference.ncols(), 1);
439 assert_eq!(src.ncols(), 1);
440
441 for j in 0..src.nrows() {
442 let src = src[(j, 0)];
443 let reference = reference[(j, 0)];
444
445 let relative_error = (src - reference).abs() / src.abs().max(reference.abs());
446 assert!(
447 relative_error < 5e-6,
448 "Got a relative error of {} for row {} - reference = {}, got = {}",
449 relative_error,
450 j,
451 reference,
452 src
453 );
454 }
455 }
456
457 fn test_hadamard_transform(dim: usize, seed: u64) {
459 let src = {
460 let mut rng = StdRng::seed_from_u64(seed);
461 let init = views::Init(|| StandardUniform {}.sample(&mut rng));
462 Matrix::new(init, dim, 1)
463 };
464
465 let h = hadamard_by_sylvester(dim);
466
467 let mut reference = matmul(h.as_view(), src.as_view());
468 reference
469 .as_mut_slice()
470 .iter_mut()
471 .for_each(|i| *i /= (dim as f32).sqrt());
472
473 type Implementation = Box<dyn Fn(&mut [f32])>;
475
476 #[cfg_attr(not(target_arch = "x86_64"), expect(unused_mut))]
477 let mut impls: Vec<(Implementation, &'static str)> = vec![
478 (
479 Box::new(|x| hadamard_transform(x).unwrap()),
480 "public entry point",
481 ),
482 (
483 Box::new(|x| {
484 diskann_wide::arch::Scalar::new()
485 .run1(HadamardTransform, x)
486 .unwrap()
487 }),
488 "scalar recursive implementation",
489 ),
490 ];
491
492 #[cfg(target_arch = "x86_64")]
493 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
494 impls.push((
495 Box::new(move |x| arch.run1(HadamardTransform, x).unwrap()),
496 "x86-64-v3",
497 ));
498 }
499
500 for (f, kernel) in impls.into_iter() {
501 let mut src_clone = src.clone();
502 f(src_clone.as_mut_slice());
503
504 assert_eq!(reference.nrows(), src_clone.nrows());
505 assert_eq!(reference.ncols(), 1);
506 assert_eq!(src_clone.ncols(), 1);
507
508 for j in 0..src_clone.nrows() {
509 let src_clone = src_clone[(j, 0)];
510 let reference = reference[(j, 0)];
511
512 let relative_error =
513 (src_clone - reference).abs() / src_clone.abs().max(reference.abs());
514 assert!(
515 relative_error < 5e-5,
516 "Got a relative error of {} for row {} - reference = {}, got = {} -- dim = {}: kernel = {}",
517 relative_error,
518 j,
519 reference,
520 src_clone,
521 dim,
522 kernel,
523 );
524 }
525 }
526 }
527
528 #[test]
529 fn test_hadamard_transform_1() {
530 test_hadamard_transform(1, 0xcdb7283f806f237d);
531 }
532
533 #[test]
534 fn test_hadamard_transform_2() {
535 test_hadamard_transform(2, 0x1e8bba190423842c);
536 }
537
538 #[test]
539 fn test_hadamard_transform_4() {
540 test_hadamard_transform(4, 0x6cdcb7e1fe0fa296);
541 }
542
543 #[test]
544 fn test_hadamard_transform_8() {
545 test_hadamard_transform(8, 0xd120b32a83158c80);
546 }
547
548 #[test]
549 fn test_hadamard_transform_16() {
550 test_hadamard_transform(16, 0x56ef310cc7e42faa);
551 }
552
553 #[test]
554 fn test_hadamard_transform_32() {
555 test_hadamard_transform(32, 0xf2a1395699390b95);
556 }
557
558 #[test]
559 fn test_hadamard_transform_64() {
560 test_hadamard_transform(64, 0x31e6a1bfe4958c8a);
561 }
562
563 #[test]
564 fn test_hadamard_transform_128() {
565 test_hadamard_transform(128, 0xe13a35f4b9392747);
566 }
567
568 #[test]
569 fn test_hadamard_transform_256() {
570 test_hadamard_transform(256, 0xf71bb8e26e79681c);
571 }
572
573 #[test]
575 fn test_error() {
576 assert!(matches!(hadamard_transform(&mut []), Err(NotPowerOfTwo)));
578
579 for dim in [3, 31, 33, 40, 63, 65, 100, 127, 129] {
580 let mut v = vec![0.0f32; dim];
581 assert!(matches!(hadamard_transform(&mut v), Err(NotPowerOfTwo)));
582 }
583 }
584}