1use diskann_wide::Architecture;
7#[cfg(any(test, target_arch = "x86_64"))]
8use diskann_wide::{SIMDMulAdd, SIMDVector};
9use thiserror::Error;
10
11pub fn hadamard_transform(x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
23 diskann_wide::arch::dispatch1_no_features(HadamardTransform, x)
27}
28
29#[derive(Debug, Error)]
30#[error("Hadamard input vector must have a length that is a power of two")]
31pub struct NotPowerOfTwo;
32
33#[derive(Debug, Clone, Copy)]
45pub struct HadamardTransform;
46
47impl diskann_wide::arch::Target1<diskann_wide::arch::Scalar, Result<(), NotPowerOfTwo>, &mut [f32]>
48 for HadamardTransform
49{
50 #[inline(never)]
51 fn run(self, arch: diskann_wide::arch::Scalar, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
52 (HadamardTransformOuter).run(arch, x)
53 }
54}
55
56#[cfg(target_arch = "x86_64")]
57impl
58 diskann_wide::arch::Target1<
59 diskann_wide::arch::x86_64::V3,
60 Result<(), NotPowerOfTwo>,
61 &mut [f32],
62 > for HadamardTransform
63{
64 #[inline(never)]
65 fn run(self, arch: diskann_wide::arch::x86_64::V3, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
66 arch.run1(HadamardTransformOuter, x)
67 }
68}
69
70#[cfg(target_arch = "x86_64")]
71impl
72 diskann_wide::arch::Target1<
73 diskann_wide::arch::x86_64::V4,
74 Result<(), NotPowerOfTwo>,
75 &mut [f32],
76 > for HadamardTransform
77{
78 #[inline(never)]
79 fn run(self, arch: diskann_wide::arch::x86_64::V4, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
80 arch.retarget().run1(HadamardTransformOuter, x)
81 }
82}
83
84#[cfg(target_arch = "aarch64")]
85impl
86 diskann_wide::arch::Target1<
87 diskann_wide::arch::aarch64::Neon,
88 Result<(), NotPowerOfTwo>,
89 &mut [f32],
90 > for HadamardTransform
91{
92 #[inline(never)]
93 fn run(
94 self,
95 arch: diskann_wide::arch::aarch64::Neon,
96 x: &mut [f32],
97 ) -> Result<(), NotPowerOfTwo> {
98 arch.retarget().run1(HadamardTransformOuter, x)
99 }
100}
101
102#[derive(Debug, Clone, Copy)]
107pub struct HadamardTransformOuter;
108
109impl<A> diskann_wide::arch::Target1<A, Result<(), NotPowerOfTwo>, &mut [f32]>
110 for HadamardTransformOuter
111where
112 A: diskann_wide::Architecture,
113 HadamardTransformRecursive: for<'a> diskann_wide::arch::Target1<A, (), &'a mut [f32]>,
114{
115 #[inline(always)]
116 fn run(self, arch: A, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
117 let len = x.len();
118
119 if !len.is_power_of_two() {
120 return Err(NotPowerOfTwo);
121 }
122
123 if len == 1 {
125 return Ok(());
126 }
127
128 arch.run1(HadamardTransformRecursive, x);
130
131 let m = 1.0 / (x.len() as f32).sqrt();
133 x.iter_mut().for_each(|i| *i *= m);
134
135 Ok(())
136 }
137}
138
139#[derive(Debug, Clone, Copy)]
140struct HadamardTransformRecursive;
141
142impl diskann_wide::arch::Target1<diskann_wide::arch::Scalar, (), &mut [f32]>
143 for HadamardTransformRecursive
144{
145 #[inline]
154 fn run(self, arch: diskann_wide::arch::Scalar, x: &mut [f32]) {
155 let len = x.len();
156 debug_assert!(len.is_power_of_two());
157 debug_assert!(len >= 2);
158
159 if len == 2 {
160 let l = x[0];
161 let r = x[1];
162 x[0] = l + r;
163 x[1] = l - r;
164 } else {
165 let (left, right) = x.split_at_mut(len / 2);
167
168 arch.run1(self, left);
169 arch.run1(self, right);
170
171 std::iter::zip(left.iter_mut(), right.iter_mut()).for_each(|(l, r)| {
172 let a = *l + *r;
173 let b = *l - *r;
174 *l = a;
175 *r = b;
176 });
177 }
178 }
179}
180
181#[cfg(target_arch = "x86_64")]
182impl diskann_wide::arch::Target1<diskann_wide::arch::x86_64::V3, (), &mut [f32]>
183 for HadamardTransformRecursive
184{
185 #[inline(always)]
194 fn run(self, arch: diskann_wide::arch::x86_64::V3, x: &mut [f32]) {
195 let len = x.len();
196 debug_assert!(len.is_power_of_two());
197 debug_assert!(len >= 2);
198
199 if let Ok(array) = <&mut [f32] as TryInto<&mut [f32; 64]>>::try_into(x) {
200 micro_kernel_64(arch, array);
206 } else if len == 2 {
207 let l = x[0];
210 let r = x[1];
211 x[0] = l + r;
212 x[1] = l - r;
213 } else {
214 let (left, right) = x.split_at_mut(len / 2);
216
217 arch.run1(self, left);
218 arch.run1(self, right);
219
220 std::iter::zip(left.iter_mut(), right.iter_mut()).for_each(|(l, r)| {
221 let a = *l + *r;
222 let b = *l - *r;
223 *l = a;
224 *r = b;
225 });
226 }
227 }
228}
229
230#[cfg(any(test, target_arch = "x86_64"))]
232const HADAMARD_8: [[f32; 8]; 8] = [
233 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
234 [1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0],
235 [1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0],
236 [1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0],
237 [1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0],
238 [1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 1.0],
239 [1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0],
240 [1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0],
241];
242
243#[cfg(any(test, target_arch = "x86_64"))]
247#[inline(always)]
248fn micro_kernel_64<A>(arch: A, x: &mut [f32; 64])
249where
250 A: Architecture,
251{
252 let mut d0 = A::f32x8::splat(arch, 0.0);
254 let mut d1 = A::f32x8::splat(arch, 0.0);
255 let mut d2 = A::f32x8::splat(arch, 0.0);
256 let mut d3 = A::f32x8::splat(arch, 0.0);
257 let mut d4 = A::f32x8::splat(arch, 0.0);
258 let mut d5 = A::f32x8::splat(arch, 0.0);
259 let mut d6 = A::f32x8::splat(arch, 0.0);
260 let mut d7 = A::f32x8::splat(arch, 0.0);
261
262 let p: *const f32 = HADAMARD_8.as_ptr().cast();
263 let src: *const f32 = x.as_ptr();
264 let mut process_patch = |offset: usize| {
265 unsafe {
276 let c0 = A::f32x8::load_simd(arch, p.add(8 * offset));
277 let c1 = A::f32x8::load_simd(arch, p.add(8 * (offset + 1)));
278
279 let r0 = A::f32x8::splat(arch, src.add(offset).read());
280 let r1 = A::f32x8::splat(arch, src.add(offset + 8).read());
281 d0 = r0.mul_add_simd(c0, d0);
282 d1 = r1.mul_add_simd(c0, d1);
283
284 let r0 = A::f32x8::splat(arch, src.add(offset + 1).read());
285 let r1 = A::f32x8::splat(arch, src.add(offset + 9).read());
286 d0 = r0.mul_add_simd(c1, d0);
287 d1 = r1.mul_add_simd(c1, d1);
288
289 let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 2).read());
290 let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 3).read());
291 d2 = r0.mul_add_simd(c0, d2);
292 d3 = r1.mul_add_simd(c0, d3);
293
294 let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 2 + 1).read());
295 let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 3 + 1).read());
296 d2 = r0.mul_add_simd(c1, d2);
297 d3 = r1.mul_add_simd(c1, d3);
298
299 let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 4).read());
300 let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 5).read());
301 d4 = r0.mul_add_simd(c0, d4);
302 d5 = r1.mul_add_simd(c0, d5);
303
304 let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 4 + 1).read());
305 let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 5 + 1).read());
306 d4 = r0.mul_add_simd(c1, d4);
307 d5 = r1.mul_add_simd(c1, d5);
308
309 let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 6).read());
310 let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 7).read());
311 d6 = r0.mul_add_simd(c0, d6);
312 d7 = r1.mul_add_simd(c0, d7);
313
314 let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 6 + 1).read());
315 let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 7 + 1).read());
316 d6 = r0.mul_add_simd(c1, d6);
317 d7 = r1.mul_add_simd(c1, d7);
318 }
319 };
320
321 for o in 0..4 {
324 process_patch(2 * o);
325 }
326
327 let e0 = d0 + d1;
332 let e1 = d0 - d1;
333
334 let e2 = d2 + d3;
335 let e3 = d2 - d3;
336
337 let e4 = d4 + d5;
338 let e5 = d4 - d5;
339
340 let e6 = d6 + d7;
341 let e7 = d6 - d7;
342
343 let f0 = e0 + e2;
345 let f1 = e1 + e3;
346
347 let f2 = e0 - e2;
348 let f3 = e1 - e3;
349
350 let f4 = e4 + e6;
351 let f5 = e5 + e7;
352
353 let f6 = e4 - e6;
354 let f7 = e5 - e7;
355
356 let dst: *mut f32 = x.as_mut_ptr();
358
359 unsafe {
362 (f0 + f4).store_simd(dst);
363 (f1 + f5).store_simd(dst.add(8));
364 (f2 + f6).store_simd(dst.add(16));
365 (f3 + f7).store_simd(dst.add(24));
366 (f0 - f4).store_simd(dst.add(32));
367 (f1 - f5).store_simd(dst.add(40));
368 (f2 - f6).store_simd(dst.add(48));
369 (f3 - f7).store_simd(dst.add(56));
370 }
371}
372
373#[cfg(test)]
378mod tests {
379 use rand::{
380 SeedableRng,
381 distr::{Distribution, StandardUniform},
382 rngs::StdRng,
383 };
384
385 use super::*;
386 use diskann_utils::views::{self, Matrix, MatrixView};
387
388 fn get_hadamard_8() -> Matrix<f32> {
390 let v: Box<[f32]> = HADAMARD_8.iter().flatten().copied().collect();
391 Matrix::try_from(v, 8, 8).unwrap()
392 }
393
394 fn hadamard_by_sylvester(dim: usize) -> Matrix<f32> {
395 assert_ne!(dim, 0);
396 if dim == 1 {
398 Matrix::new(1.0, dim, dim)
399 } else {
400 let half = dim / 2;
401 let sub = hadamard_by_sylvester(half);
402 let mut m = Matrix::<f32>::new(0.0, dim, dim);
403
404 for c in 0..m.ncols() {
405 for r in 0..m.nrows() {
406 let mut v = sub[(r % half, c % half)];
407 if c >= half && r >= half {
408 v = -v;
409 }
410 m[(c, r)] = v;
411 }
412 }
413 m
414 }
415 }
416
417 #[test]
419 fn test_hadamard_8() {
420 let h8 = get_hadamard_8();
421 let reference = hadamard_by_sylvester(8);
422 assert_eq!(h8.as_slice(), reference.as_slice());
423 }
424
425 fn matmul(a: MatrixView<f32>, b: MatrixView<f32>) -> Matrix<f32> {
427 assert_eq!(a.ncols(), b.nrows());
428 let mut c = Matrix::new(0.0, a.nrows(), b.ncols());
429
430 for i in 0..c.nrows() {
431 for j in 0..c.ncols() {
432 let mut v = 0.0;
433 for k in 0..a.ncols() {
434 v = a[(i, k)].mul_add(b[(k, j)], v);
435 }
436 c[(i, j)] = v;
437 }
438 }
439 c
440 }
441
442 #[test]
443 fn test_micro_kernel_64() {
444 let mut src = {
445 let mut rng = StdRng::seed_from_u64(0xde1936d651285fc8);
446 let init = views::Init(|| StandardUniform {}.sample(&mut rng));
447 Matrix::new(init, 64, 1)
448 };
449
450 let h = hadamard_by_sylvester(64);
451 let reference = matmul(h.as_view(), src.as_view());
452
453 micro_kernel_64(diskann_wide::ARCH, src.as_mut_slice().try_into().unwrap());
454
455 assert_eq!(reference.nrows(), src.nrows());
456 assert_eq!(reference.ncols(), 1);
457 assert_eq!(src.ncols(), 1);
458
459 for j in 0..src.nrows() {
460 let src = src[(j, 0)];
461 let reference = reference[(j, 0)];
462
463 let relative_error = (src - reference).abs() / src.abs().max(reference.abs());
464 assert!(
465 relative_error < 5e-6,
466 "Got a relative error of {} for row {} - reference = {}, got = {}",
467 relative_error,
468 j,
469 reference,
470 src
471 );
472 }
473 }
474
475 fn test_hadamard_transform(dim: usize, seed: u64) {
477 let src = {
478 let mut rng = StdRng::seed_from_u64(seed);
479 let init = views::Init(|| StandardUniform {}.sample(&mut rng));
480 Matrix::new(init, dim, 1)
481 };
482
483 let h = hadamard_by_sylvester(dim);
484
485 let mut reference = matmul(h.as_view(), src.as_view());
486 reference
487 .as_mut_slice()
488 .iter_mut()
489 .for_each(|i| *i /= (dim as f32).sqrt());
490
491 type Implementation = Box<dyn Fn(&mut [f32])>;
493
494 #[cfg_attr(
495 not(any(target_arch = "x86_64", target_arch = "aarch64")),
496 expect(unused_mut)
497 )]
498 let mut impls: Vec<(Implementation, &'static str)> = vec![
499 (
500 Box::new(|x| hadamard_transform(x).unwrap()),
501 "public entry point",
502 ),
503 (
504 Box::new(|x| {
505 diskann_wide::arch::Scalar::new()
506 .run1(HadamardTransform, x)
507 .unwrap()
508 }),
509 "scalar recursive implementation",
510 ),
511 ];
512
513 #[cfg(target_arch = "x86_64")]
514 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
515 impls.push((
516 Box::new(move |x| arch.run1(HadamardTransform, x).unwrap()),
517 "x86-64-v3",
518 ));
519 }
520
521 #[cfg(target_arch = "aarch64")]
522 if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
523 impls.push((
524 Box::new(move |x| arch.run1(HadamardTransform, x).unwrap()),
525 "neon",
526 ));
527 }
528
529 for (f, kernel) in impls.into_iter() {
530 let mut src_clone = src.clone();
531 f(src_clone.as_mut_slice());
532
533 assert_eq!(reference.nrows(), src_clone.nrows());
534 assert_eq!(reference.ncols(), 1);
535 assert_eq!(src_clone.ncols(), 1);
536
537 for j in 0..src_clone.nrows() {
538 let src_clone = src_clone[(j, 0)];
539 let reference = reference[(j, 0)];
540
541 let relative_error =
542 (src_clone - reference).abs() / src_clone.abs().max(reference.abs());
543 assert!(
544 relative_error < 5e-5,
545 "Got a relative error of {} for row {} - reference = {}, got = {} -- dim = {}: kernel = {}",
546 relative_error,
547 j,
548 reference,
549 src_clone,
550 dim,
551 kernel,
552 );
553 }
554 }
555 }
556
557 #[test]
558 fn test_hadamard_transform_1() {
559 test_hadamard_transform(1, 0xcdb7283f806f237d);
560 }
561
562 #[test]
563 fn test_hadamard_transform_2() {
564 test_hadamard_transform(2, 0x1e8bba190423842c);
565 }
566
567 #[test]
568 fn test_hadamard_transform_4() {
569 test_hadamard_transform(4, 0x6cdcb7e1fe0fa296);
570 }
571
572 #[test]
573 fn test_hadamard_transform_8() {
574 test_hadamard_transform(8, 0xd120b32a83158c80);
575 }
576
577 #[test]
578 fn test_hadamard_transform_16() {
579 test_hadamard_transform(16, 0x56ef310cc7e42faa);
580 }
581
582 #[test]
583 fn test_hadamard_transform_32() {
584 test_hadamard_transform(32, 0xf2a1395699390b95);
585 }
586
587 #[test]
588 fn test_hadamard_transform_64() {
589 test_hadamard_transform(64, 0x31e6a1bfe4958c8a);
590 }
591
592 #[test]
593 fn test_hadamard_transform_128() {
594 test_hadamard_transform(128, 0xe13a35f4b9392747);
595 }
596
597 #[test]
598 fn test_hadamard_transform_256() {
599 test_hadamard_transform(256, 0xf71bb8e26e79681c);
600 }
601
602 #[test]
604 fn test_error() {
605 assert!(matches!(hadamard_transform(&mut []), Err(NotPowerOfTwo)));
607
608 for dim in [3, 31, 33, 40, 63, 65, 100, 127, 129] {
609 let mut v = vec![0.0f32; dim];
610 assert!(matches!(hadamard_transform(&mut v), Err(NotPowerOfTwo)));
611 }
612 }
613}