Skip to main content

diskann_quantization/algorithms/
hadamard.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann_wide::Architecture;
7#[cfg(any(test, target_arch = "x86_64"))]
8use diskann_wide::{SIMDMulAdd, SIMDVector};
9use thiserror::Error;
10
11/// Implicitly multiply the argument `x` by a Hadamard matrix and scale the results by `1 / x.len().sqrt()`.
12///
13/// This function does not allocate and operates in place.
14///
15/// # Error
16///
17/// Returns an error if `x.len()` is not a power of two.
18///
19/// # See Also
20///
21/// * <https://en.wikipedia.org/wiki/Hadamard_matrix>
22pub fn hadamard_transform(x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
23    // Defer application of target features until after retargeting `V4` to `V3`.
24    //
25    // This is because we do not have a better implementation for `V4`.
26    diskann_wide::arch::dispatch1_no_features(HadamardTransform, x)
27}
28
29#[derive(Debug, Error)]
30#[error("Hadamard input vector must have a length that is a power of two")]
31pub struct NotPowerOfTwo;
32
33/// Implicitly multiply the argument `x` by a Hadamard matrix and scale the results by `1 / x.len().sqrt()`.
34///
35/// This function does not allocate and operates in place.
36///
37/// # Error
38///
39/// Returns an error if `x.len()` is not a power of two.
40///
41/// # See Also
42///
43/// * <https://en.wikipedia.org/wiki/Hadamard_matrix>
44#[derive(Debug, Clone, Copy)]
45pub struct HadamardTransform;
46
47impl diskann_wide::arch::Target1<diskann_wide::arch::Scalar, Result<(), NotPowerOfTwo>, &mut [f32]>
48    for HadamardTransform
49{
50    #[inline(never)]
51    fn run(self, arch: diskann_wide::arch::Scalar, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
52        (HadamardTransformOuter).run(arch, x)
53    }
54}
55
56#[cfg(target_arch = "x86_64")]
57impl
58    diskann_wide::arch::Target1<
59        diskann_wide::arch::x86_64::V3,
60        Result<(), NotPowerOfTwo>,
61        &mut [f32],
62    > for HadamardTransform
63{
64    #[inline(never)]
65    fn run(self, arch: diskann_wide::arch::x86_64::V3, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
66        arch.run1(HadamardTransformOuter, x)
67    }
68}
69
70#[cfg(target_arch = "x86_64")]
71impl
72    diskann_wide::arch::Target1<
73        diskann_wide::arch::x86_64::V4,
74        Result<(), NotPowerOfTwo>,
75        &mut [f32],
76    > for HadamardTransform
77{
78    #[inline(never)]
79    fn run(self, arch: diskann_wide::arch::x86_64::V4, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
80        arch.retarget().run1(HadamardTransformOuter, x)
81    }
82}
83
84#[cfg(target_arch = "aarch64")]
85impl
86    diskann_wide::arch::Target1<
87        diskann_wide::arch::aarch64::Neon,
88        Result<(), NotPowerOfTwo>,
89        &mut [f32],
90    > for HadamardTransform
91{
92    #[inline(never)]
93    fn run(
94        self,
95        arch: diskann_wide::arch::aarch64::Neon,
96        x: &mut [f32],
97    ) -> Result<(), NotPowerOfTwo> {
98        arch.retarget().run1(HadamardTransformOuter, x)
99    }
100}
101
102////////////////////
103// Implementation //
104////////////////////
105
106#[derive(Debug, Clone, Copy)]
107pub struct HadamardTransformOuter;
108
109impl<A> diskann_wide::arch::Target1<A, Result<(), NotPowerOfTwo>, &mut [f32]>
110    for HadamardTransformOuter
111where
112    A: diskann_wide::Architecture,
113    HadamardTransformRecursive: for<'a> diskann_wide::arch::Target1<A, (), &'a mut [f32]>,
114{
115    #[inline(always)]
116    fn run(self, arch: A, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
117        let len = x.len();
118
119        if !len.is_power_of_two() {
120            return Err(NotPowerOfTwo);
121        }
122
123        // Nothing to do for length-1 transforms.
124        if len == 1 {
125            return Ok(());
126        }
127
128        // Perform the implicit matrix multiplication.
129        arch.run1(HadamardTransformRecursive, x);
130
131        // Scale the result.
132        let m = 1.0 / (x.len() as f32).sqrt();
133        x.iter_mut().for_each(|i| *i *= m);
134
135        Ok(())
136    }
137}
138
139#[derive(Debug, Clone, Copy)]
140struct HadamardTransformRecursive;
141
142impl diskann_wide::arch::Target1<diskann_wide::arch::Scalar, (), &mut [f32]>
143    for HadamardTransformRecursive
144{
145    /// A recursive helper for the divide-and-conquer step of Hadamard matrix multplication.
146    ///
147    /// # Preconditions
148    ///
149    /// This function is private with the following pre-conditions:
150    ///
151    /// * `x.len()` must be a power of 2.
152    /// * `x.len()` must be at least 2.
153    #[inline]
154    fn run(self, arch: diskann_wide::arch::Scalar, x: &mut [f32]) {
155        let len = x.len();
156        debug_assert!(len.is_power_of_two());
157        debug_assert!(len >= 2);
158
159        if len == 2 {
160            let l = x[0];
161            let r = x[1];
162            x[0] = l + r;
163            x[1] = l - r;
164        } else {
165            // Recursive case - divide and conquer.
166            let (left, right) = x.split_at_mut(len / 2);
167
168            arch.run1(self, left);
169            arch.run1(self, right);
170
171            std::iter::zip(left.iter_mut(), right.iter_mut()).for_each(|(l, r)| {
172                let a = *l + *r;
173                let b = *l - *r;
174                *l = a;
175                *r = b;
176            });
177        }
178    }
179}
180
181#[cfg(target_arch = "x86_64")]
182impl diskann_wide::arch::Target1<diskann_wide::arch::x86_64::V3, (), &mut [f32]>
183    for HadamardTransformRecursive
184{
185    /// A recursive helper for the divide-and-conquer step of Hadamard matrix multplication.
186    ///
187    /// # Preconditions
188    ///
189    /// This function is private with the following pre-conditions:
190    ///
191    /// * `x.len()` must be a power of 2.
192    /// * `x.len()` must be at least 2.
193    #[inline(always)]
194    fn run(self, arch: diskann_wide::arch::x86_64::V3, x: &mut [f32]) {
195        let len = x.len();
196        debug_assert!(len.is_power_of_two());
197        debug_assert!(len >= 2);
198
199        if let Ok(array) = <&mut [f32] as TryInto<&mut [f32; 64]>>::try_into(x) {
200            // We have a faster implementation for working with 64-elements at a time. Invoke
201            // that if possible.
202            //
203            // Lint: This conversion into an array will never fail because we've checked that the
204            // length is indeed 64.
205            micro_kernel_64(arch, array);
206        } else if len == 2 {
207            // This is only reachable if the original argument to `hadamard_transform` was
208            // shorter than 64.
209            let l = x[0];
210            let r = x[1];
211            x[0] = l + r;
212            x[1] = l - r;
213        } else {
214            // Recursive case - divide and conquer.
215            let (left, right) = x.split_at_mut(len / 2);
216
217            arch.run1(self, left);
218            arch.run1(self, right);
219
220            std::iter::zip(left.iter_mut(), right.iter_mut()).for_each(|(l, r)| {
221                let a = *l + *r;
222                let b = *l - *r;
223                *l = a;
224                *r = b;
225            });
226        }
227    }
228}
229
230/// The 8x8 Hadamard matrix.
231#[cfg(any(test, target_arch = "x86_64"))]
232const HADAMARD_8: [[f32; 8]; 8] = [
233    [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
234    [1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0],
235    [1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0],
236    [1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0],
237    [1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0],
238    [1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 1.0],
239    [1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0],
240    [1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0],
241];
242
243/// This micro-kernel computes a full 64-element Hadamard transform by first computing
244/// eight 8-element transforms via a matrix multiplication kernel and then using the
245/// recursive formulation to compute the full 64-element transform.
246#[cfg(any(test, target_arch = "x86_64"))]
247#[inline(always)]
248fn micro_kernel_64<A>(arch: A, x: &mut [f32; 64])
249where
250    A: Architecture,
251{
252    // Output registers
253    let mut d0 = A::f32x8::splat(arch, 0.0);
254    let mut d1 = A::f32x8::splat(arch, 0.0);
255    let mut d2 = A::f32x8::splat(arch, 0.0);
256    let mut d3 = A::f32x8::splat(arch, 0.0);
257    let mut d4 = A::f32x8::splat(arch, 0.0);
258    let mut d5 = A::f32x8::splat(arch, 0.0);
259    let mut d6 = A::f32x8::splat(arch, 0.0);
260    let mut d7 = A::f32x8::splat(arch, 0.0);
261
262    let p: *const f32 = HADAMARD_8.as_ptr().cast();
263    let src: *const f32 = x.as_ptr();
264    let mut process_patch = |offset: usize| {
265        // SAFETY: The unsafe actions in the enclosing block all consist of performing
266        // arithmetic on pointers and dereferencing said pointers.
267        //
268        // The pointers accessed are the pointer for the 8x8 Hadamard matrix (with 64 valid
269        // entries), and the input array (also with valid entries).
270        //
271        // The argument `offset` takes the values 0, 2, 4, 6.
272        //
273        // All the pointer arithmetic is performed so that accesses with `offset` as one
274        // of these four values is in-bounds.
275        unsafe {
276            let c0 = A::f32x8::load_simd(arch, p.add(8 * offset));
277            let c1 = A::f32x8::load_simd(arch, p.add(8 * (offset + 1)));
278
279            let r0 = A::f32x8::splat(arch, src.add(offset).read());
280            let r1 = A::f32x8::splat(arch, src.add(offset + 8).read());
281            d0 = r0.mul_add_simd(c0, d0);
282            d1 = r1.mul_add_simd(c0, d1);
283
284            let r0 = A::f32x8::splat(arch, src.add(offset + 1).read());
285            let r1 = A::f32x8::splat(arch, src.add(offset + 9).read());
286            d0 = r0.mul_add_simd(c1, d0);
287            d1 = r1.mul_add_simd(c1, d1);
288
289            let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 2).read());
290            let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 3).read());
291            d2 = r0.mul_add_simd(c0, d2);
292            d3 = r1.mul_add_simd(c0, d3);
293
294            let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 2 + 1).read());
295            let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 3 + 1).read());
296            d2 = r0.mul_add_simd(c1, d2);
297            d3 = r1.mul_add_simd(c1, d3);
298
299            let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 4).read());
300            let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 5).read());
301            d4 = r0.mul_add_simd(c0, d4);
302            d5 = r1.mul_add_simd(c0, d5);
303
304            let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 4 + 1).read());
305            let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 5 + 1).read());
306            d4 = r0.mul_add_simd(c1, d4);
307            d5 = r1.mul_add_simd(c1, d5);
308
309            let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 6).read());
310            let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 7).read());
311            d6 = r0.mul_add_simd(c0, d6);
312            d7 = r1.mul_add_simd(c0, d7);
313
314            let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 6 + 1).read());
315            let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 7 + 1).read());
316            d6 = r0.mul_add_simd(c1, d6);
317            d7 = r1.mul_add_simd(c1, d7);
318        }
319    };
320
321    // Do the 8x8 matrix multiplication to compute the eight individual 8-element transforms
322    // and store the results into `d0-7`.
323    for o in 0..4 {
324        process_patch(2 * o);
325    }
326
327    // Now that we have the individual 8-dimensional transformations, we can begin swizzling
328    // them together to construct the full 64-element transform.
329    //
330    // This computes four 16-element transforms.
331    let e0 = d0 + d1;
332    let e1 = d0 - d1;
333
334    let e2 = d2 + d3;
335    let e3 = d2 - d3;
336
337    let e4 = d4 + d5;
338    let e5 = d4 - d5;
339
340    let e6 = d6 + d7;
341    let e7 = d6 - d7;
342
343    // Compute two 32-element transforms.
344    let f0 = e0 + e2;
345    let f1 = e1 + e3;
346
347    let f2 = e0 - e2;
348    let f3 = e1 - e3;
349
350    let f4 = e4 + e6;
351    let f5 = e5 + e7;
352
353    let f6 = e4 - e6;
354    let f7 = e5 - e7;
355
356    // Compute the full 64-element transform and write-back the results.
357    let dst: *mut f32 = x.as_mut_ptr();
358
359    // SAFETY: The pointer `dst` is valid for writing up to 64-elements, which is what
360    // we do here.
361    unsafe {
362        (f0 + f4).store_simd(dst);
363        (f1 + f5).store_simd(dst.add(8));
364        (f2 + f6).store_simd(dst.add(16));
365        (f3 + f7).store_simd(dst.add(24));
366        (f0 - f4).store_simd(dst.add(32));
367        (f1 - f5).store_simd(dst.add(40));
368        (f2 - f6).store_simd(dst.add(48));
369        (f3 - f7).store_simd(dst.add(56));
370    }
371}
372
373///////////
374// Tests //
375///////////
376
377#[cfg(test)]
378mod tests {
379    use rand::{
380        SeedableRng,
381        distr::{Distribution, StandardUniform},
382        rngs::StdRng,
383    };
384
385    use super::*;
386    use diskann_utils::views::{self, Matrix, MatrixView};
387
388    /// Retrieve the 8x8 hadamard matrix as a `Matrix`.
389    fn get_hadamard_8() -> Matrix<f32> {
390        let v: Box<[f32]> = HADAMARD_8.iter().flatten().copied().collect();
391        Matrix::try_from(v, 8, 8).unwrap()
392    }
393
394    fn hadamard_by_sylvester(dim: usize) -> Matrix<f32> {
395        assert_ne!(dim, 0);
396        // Base case.
397        if dim == 1 {
398            Matrix::new(1.0, dim, dim)
399        } else {
400            let half = dim / 2;
401            let sub = hadamard_by_sylvester(half);
402            let mut m = Matrix::<f32>::new(0.0, dim, dim);
403
404            for c in 0..m.ncols() {
405                for r in 0..m.nrows() {
406                    let mut v = sub[(r % half, c % half)];
407                    if c >= half && r >= half {
408                        v = -v;
409                    }
410                    m[(c, r)] = v;
411                }
412            }
413            m
414        }
415    }
416
417    // Ensure that our 8x8 constant Hadamard matrix stays consistent.
418    #[test]
419    fn test_hadamard_8() {
420        let h8 = get_hadamard_8();
421        let reference = hadamard_by_sylvester(8);
422        assert_eq!(h8.as_slice(), reference.as_slice());
423    }
424
425    // A naive reference implementation.
426    fn matmul(a: MatrixView<f32>, b: MatrixView<f32>) -> Matrix<f32> {
427        assert_eq!(a.ncols(), b.nrows());
428        let mut c = Matrix::new(0.0, a.nrows(), b.ncols());
429
430        for i in 0..c.nrows() {
431            for j in 0..c.ncols() {
432                let mut v = 0.0;
433                for k in 0..a.ncols() {
434                    v = a[(i, k)].mul_add(b[(k, j)], v);
435                }
436                c[(i, j)] = v;
437            }
438        }
439        c
440    }
441
442    #[test]
443    fn test_micro_kernel_64() {
444        let mut src = {
445            let mut rng = StdRng::seed_from_u64(0xde1936d651285fc8);
446            let init = views::Init(|| StandardUniform {}.sample(&mut rng));
447            Matrix::new(init, 64, 1)
448        };
449
450        let h = hadamard_by_sylvester(64);
451        let reference = matmul(h.as_view(), src.as_view());
452
453        micro_kernel_64(diskann_wide::ARCH, src.as_mut_slice().try_into().unwrap());
454
455        assert_eq!(reference.nrows(), src.nrows());
456        assert_eq!(reference.ncols(), 1);
457        assert_eq!(src.ncols(), 1);
458
459        for j in 0..src.nrows() {
460            let src = src[(j, 0)];
461            let reference = reference[(j, 0)];
462
463            let relative_error = (src - reference).abs() / src.abs().max(reference.abs());
464            assert!(
465                relative_error < 5e-6,
466                "Got a relative error of {} for row {} - reference = {}, got = {}",
467                relative_error,
468                j,
469                reference,
470                src
471            );
472        }
473    }
474
475    // End-to-end tests.
476    fn test_hadamard_transform(dim: usize, seed: u64) {
477        let src = {
478            let mut rng = StdRng::seed_from_u64(seed);
479            let init = views::Init(|| StandardUniform {}.sample(&mut rng));
480            Matrix::new(init, dim, 1)
481        };
482
483        let h = hadamard_by_sylvester(dim);
484
485        let mut reference = matmul(h.as_view(), src.as_view());
486        reference
487            .as_mut_slice()
488            .iter_mut()
489            .for_each(|i| *i /= (dim as f32).sqrt());
490
491        // Queue up a list of implementations.
492        type Implementation = Box<dyn Fn(&mut [f32])>;
493
494        #[cfg_attr(
495            not(any(target_arch = "x86_64", target_arch = "aarch64")),
496            expect(unused_mut)
497        )]
498        let mut impls: Vec<(Implementation, &'static str)> = vec![
499            (
500                Box::new(|x| hadamard_transform(x).unwrap()),
501                "public entry point",
502            ),
503            (
504                Box::new(|x| {
505                    diskann_wide::arch::Scalar::new()
506                        .run1(HadamardTransform, x)
507                        .unwrap()
508                }),
509                "scalar recursive implementation",
510            ),
511        ];
512
513        #[cfg(target_arch = "x86_64")]
514        if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
515            impls.push((
516                Box::new(move |x| arch.run1(HadamardTransform, x).unwrap()),
517                "x86-64-v3",
518            ));
519        }
520
521        #[cfg(target_arch = "aarch64")]
522        if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
523            impls.push((
524                Box::new(move |x| arch.run1(HadamardTransform, x).unwrap()),
525                "neon",
526            ));
527        }
528
529        for (f, kernel) in impls.into_iter() {
530            let mut src_clone = src.clone();
531            f(src_clone.as_mut_slice());
532
533            assert_eq!(reference.nrows(), src_clone.nrows());
534            assert_eq!(reference.ncols(), 1);
535            assert_eq!(src_clone.ncols(), 1);
536
537            for j in 0..src_clone.nrows() {
538                let src_clone = src_clone[(j, 0)];
539                let reference = reference[(j, 0)];
540
541                let relative_error =
542                    (src_clone - reference).abs() / src_clone.abs().max(reference.abs());
543                assert!(
544                    relative_error < 5e-5,
545                    "Got a relative error of {} for row {} - reference = {}, got = {} -- dim = {}: kernel = {}",
546                    relative_error,
547                    j,
548                    reference,
549                    src_clone,
550                    dim,
551                    kernel,
552                );
553            }
554        }
555    }
556
557    #[test]
558    fn test_hadamard_transform_1() {
559        test_hadamard_transform(1, 0xcdb7283f806f237d);
560    }
561
562    #[test]
563    fn test_hadamard_transform_2() {
564        test_hadamard_transform(2, 0x1e8bba190423842c);
565    }
566
567    #[test]
568    fn test_hadamard_transform_4() {
569        test_hadamard_transform(4, 0x6cdcb7e1fe0fa296);
570    }
571
572    #[test]
573    fn test_hadamard_transform_8() {
574        test_hadamard_transform(8, 0xd120b32a83158c80);
575    }
576
577    #[test]
578    fn test_hadamard_transform_16() {
579        test_hadamard_transform(16, 0x56ef310cc7e42faa);
580    }
581
582    #[test]
583    fn test_hadamard_transform_32() {
584        test_hadamard_transform(32, 0xf2a1395699390b95);
585    }
586
587    #[test]
588    fn test_hadamard_transform_64() {
589        test_hadamard_transform(64, 0x31e6a1bfe4958c8a);
590    }
591
592    #[test]
593    fn test_hadamard_transform_128() {
594        test_hadamard_transform(128, 0xe13a35f4b9392747);
595    }
596
597    #[test]
598    fn test_hadamard_transform_256() {
599        test_hadamard_transform(256, 0xf71bb8e26e79681c);
600    }
601
602    // Test the error cases.
603    #[test]
604    fn test_error() {
605        // Supplying an empty-slice is an error.
606        assert!(matches!(hadamard_transform(&mut []), Err(NotPowerOfTwo)));
607
608        for dim in [3, 31, 33, 40, 63, 65, 100, 127, 129] {
609            let mut v = vec![0.0f32; dim];
610            assert!(matches!(hadamard_transform(&mut v), Err(NotPowerOfTwo)));
611        }
612    }
613}