Skip to main content

diskann_quantization/algorithms/
hadamard.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann_wide::Architecture;
7#[cfg(any(test, target_arch = "x86_64"))]
8use diskann_wide::{SIMDMulAdd, SIMDVector};
9use thiserror::Error;
10
11/// Implicitly multiply the argument `x` by a Hadamard matrix and scale the results by `1 / x.len().sqrt()`.
12///
13/// This function does not allocate and operates in place.
14///
15/// # Error
16///
17/// Returns an error if `x.len()` is not a power of two.
18///
19/// # See Also
20///
21/// * <https://en.wikipedia.org/wiki/Hadamard_matrix>
22pub fn hadamard_transform(x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
23    // Defer application of target features until after retargeting `V4` to `V3`.
24    //
25    // This is because we do not have a better implementation for `V4`.
26    diskann_wide::arch::dispatch1_no_features(HadamardTransform, x)
27}
28
29#[derive(Debug, Error)]
30#[error("Hadamard input vector must have a length that is a power of two")]
31pub struct NotPowerOfTwo;
32
33/// Implicitly multiply the argument `x` by a Hadamard matrix and scale the results by `1 / x.len().sqrt()`.
34///
35/// This function does not allocate and operates in place.
36///
37/// # Error
38///
39/// Returns an error if `x.len()` is not a power of two.
40///
41/// # See Also
42///
43/// * <https://en.wikipedia.org/wiki/Hadamard_matrix>
44#[derive(Debug, Clone, Copy)]
45pub struct HadamardTransform;
46
47impl diskann_wide::arch::Target1<diskann_wide::arch::Scalar, Result<(), NotPowerOfTwo>, &mut [f32]>
48    for HadamardTransform
49{
50    #[inline(never)]
51    fn run(self, arch: diskann_wide::arch::Scalar, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
52        (HadamardTransformOuter).run(arch, x)
53    }
54}
55
56#[cfg(target_arch = "x86_64")]
57impl
58    diskann_wide::arch::Target1<
59        diskann_wide::arch::x86_64::V3,
60        Result<(), NotPowerOfTwo>,
61        &mut [f32],
62    > for HadamardTransform
63{
64    #[inline(never)]
65    fn run(self, arch: diskann_wide::arch::x86_64::V3, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
66        arch.run1(HadamardTransformOuter, x)
67    }
68}
69
70#[cfg(target_arch = "x86_64")]
71impl
72    diskann_wide::arch::Target1<
73        diskann_wide::arch::x86_64::V4,
74        Result<(), NotPowerOfTwo>,
75        &mut [f32],
76    > for HadamardTransform
77{
78    #[inline(never)]
79    fn run(self, arch: diskann_wide::arch::x86_64::V4, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
80        arch.retarget().run1(HadamardTransformOuter, x)
81    }
82}
83
84////////////////////
85// Implementation //
86////////////////////
87
88#[derive(Debug, Clone, Copy)]
89pub struct HadamardTransformOuter;
90
91impl<A> diskann_wide::arch::Target1<A, Result<(), NotPowerOfTwo>, &mut [f32]>
92    for HadamardTransformOuter
93where
94    A: diskann_wide::Architecture,
95    HadamardTransformRecursive: for<'a> diskann_wide::arch::Target1<A, (), &'a mut [f32]>,
96{
97    #[inline(always)]
98    fn run(self, arch: A, x: &mut [f32]) -> Result<(), NotPowerOfTwo> {
99        let len = x.len();
100
101        if !len.is_power_of_two() {
102            return Err(NotPowerOfTwo);
103        }
104
105        // Nothing to do for length-1 transforms.
106        if len == 1 {
107            return Ok(());
108        }
109
110        // Perform the implicit matrix multiplication.
111        arch.run1(HadamardTransformRecursive, x);
112
113        // Scale the result.
114        let m = 1.0 / (x.len() as f32).sqrt();
115        x.iter_mut().for_each(|i| *i *= m);
116
117        Ok(())
118    }
119}
120
121#[derive(Debug, Clone, Copy)]
122struct HadamardTransformRecursive;
123
124impl diskann_wide::arch::Target1<diskann_wide::arch::Scalar, (), &mut [f32]>
125    for HadamardTransformRecursive
126{
127    /// A recursive helper for the divide-and-conquer step of Hadamard matrix multplication.
128    ///
129    /// # Preconditions
130    ///
131    /// This function is private with the following pre-conditions:
132    ///
133    /// * `x.len()` must be a power of 2.
134    /// * `x.len()` must be at least 2.
135    #[inline]
136    fn run(self, arch: diskann_wide::arch::Scalar, x: &mut [f32]) {
137        let len = x.len();
138        debug_assert!(len.is_power_of_two());
139        debug_assert!(len >= 2);
140
141        if len == 2 {
142            let l = x[0];
143            let r = x[1];
144            x[0] = l + r;
145            x[1] = l - r;
146        } else {
147            // Recursive case - divide and conquer.
148            let (left, right) = x.split_at_mut(len / 2);
149
150            arch.run1(self, left);
151            arch.run1(self, right);
152
153            std::iter::zip(left.iter_mut(), right.iter_mut()).for_each(|(l, r)| {
154                let a = *l + *r;
155                let b = *l - *r;
156                *l = a;
157                *r = b;
158            });
159        }
160    }
161}
162
163#[cfg(target_arch = "x86_64")]
164impl diskann_wide::arch::Target1<diskann_wide::arch::x86_64::V3, (), &mut [f32]>
165    for HadamardTransformRecursive
166{
167    /// A recursive helper for the divide-and-conquer step of Hadamard matrix multplication.
168    ///
169    /// # Preconditions
170    ///
171    /// This function is private with the following pre-conditions:
172    ///
173    /// * `x.len()` must be a power of 2.
174    /// * `x.len()` must be at least 2.
175    #[inline(always)]
176    fn run(self, arch: diskann_wide::arch::x86_64::V3, x: &mut [f32]) {
177        let len = x.len();
178        debug_assert!(len.is_power_of_two());
179        debug_assert!(len >= 2);
180
181        if let Ok(array) = <&mut [f32] as TryInto<&mut [f32; 64]>>::try_into(x) {
182            // We have a faster implementation for working with 64-elements at a time. Invoke
183            // that if possible.
184            //
185            // Lint: This conversion into an array will never fail because we've checked that the
186            // length is indeed 64.
187            micro_kernel_64(arch, array);
188        } else if len == 2 {
189            // This is only reachable if the original argument to `hadamard_transform` was
190            // shorter than 64.
191            let l = x[0];
192            let r = x[1];
193            x[0] = l + r;
194            x[1] = l - r;
195        } else {
196            // Recursive case - divide and conquer.
197            let (left, right) = x.split_at_mut(len / 2);
198
199            arch.run1(self, left);
200            arch.run1(self, right);
201
202            std::iter::zip(left.iter_mut(), right.iter_mut()).for_each(|(l, r)| {
203                let a = *l + *r;
204                let b = *l - *r;
205                *l = a;
206                *r = b;
207            });
208        }
209    }
210}
211
212/// The 8x8 Hadamard matrix.
213#[cfg(any(test, target_arch = "x86_64"))]
214const HADAMARD_8: [[f32; 8]; 8] = [
215    [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
216    [1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0],
217    [1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0],
218    [1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0],
219    [1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0],
220    [1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 1.0],
221    [1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0],
222    [1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0],
223];
224
225/// This micro-kernel computes a full 64-element Hadamard transform by first computing
226/// eight 8-element transforms via a matrix multiplication kernel and then using the
227/// recursive formulation to compute the full 64-element transform.
228#[cfg(any(test, target_arch = "x86_64"))]
229#[inline(always)]
230fn micro_kernel_64<A>(arch: A, x: &mut [f32; 64])
231where
232    A: Architecture,
233{
234    // Output registers
235    let mut d0 = A::f32x8::splat(arch, 0.0);
236    let mut d1 = A::f32x8::splat(arch, 0.0);
237    let mut d2 = A::f32x8::splat(arch, 0.0);
238    let mut d3 = A::f32x8::splat(arch, 0.0);
239    let mut d4 = A::f32x8::splat(arch, 0.0);
240    let mut d5 = A::f32x8::splat(arch, 0.0);
241    let mut d6 = A::f32x8::splat(arch, 0.0);
242    let mut d7 = A::f32x8::splat(arch, 0.0);
243
244    let p: *const f32 = HADAMARD_8.as_ptr().cast();
245    let src: *const f32 = x.as_ptr();
246    let mut process_patch = |offset: usize| {
247        // SAFETY: The unsafe actions in the enclosing block all consist of performing
248        // arithmetic on pointers and dereferencing said pointers.
249        //
250        // The pointers accessed are the pointer for the 8x8 Hadamard matrix (with 64 valid
251        // entries), and the input array (also with valid entries).
252        //
253        // The argument `offset` takes the values 0, 2, 4, 6.
254        //
255        // All the pointer arithmetic is performed so that accesses with `offset` as one
256        // of these four values is in-bounds.
257        unsafe {
258            let c0 = A::f32x8::load_simd(arch, p.add(8 * offset));
259            let c1 = A::f32x8::load_simd(arch, p.add(8 * (offset + 1)));
260
261            let r0 = A::f32x8::splat(arch, src.add(offset).read());
262            let r1 = A::f32x8::splat(arch, src.add(offset + 8).read());
263            d0 = r0.mul_add_simd(c0, d0);
264            d1 = r1.mul_add_simd(c0, d1);
265
266            let r0 = A::f32x8::splat(arch, src.add(offset + 1).read());
267            let r1 = A::f32x8::splat(arch, src.add(offset + 9).read());
268            d0 = r0.mul_add_simd(c1, d0);
269            d1 = r1.mul_add_simd(c1, d1);
270
271            let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 2).read());
272            let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 3).read());
273            d2 = r0.mul_add_simd(c0, d2);
274            d3 = r1.mul_add_simd(c0, d3);
275
276            let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 2 + 1).read());
277            let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 3 + 1).read());
278            d2 = r0.mul_add_simd(c1, d2);
279            d3 = r1.mul_add_simd(c1, d3);
280
281            let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 4).read());
282            let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 5).read());
283            d4 = r0.mul_add_simd(c0, d4);
284            d5 = r1.mul_add_simd(c0, d5);
285
286            let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 4 + 1).read());
287            let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 5 + 1).read());
288            d4 = r0.mul_add_simd(c1, d4);
289            d5 = r1.mul_add_simd(c1, d5);
290
291            let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 6).read());
292            let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 7).read());
293            d6 = r0.mul_add_simd(c0, d6);
294            d7 = r1.mul_add_simd(c0, d7);
295
296            let r0 = A::f32x8::splat(arch, src.add(offset + 8 * 6 + 1).read());
297            let r1 = A::f32x8::splat(arch, src.add(offset + 8 * 7 + 1).read());
298            d6 = r0.mul_add_simd(c1, d6);
299            d7 = r1.mul_add_simd(c1, d7);
300        }
301    };
302
303    // Do the 8x8 matrix multiplication to compute the eight individual 8-element transforms
304    // and store the results into `d0-7`.
305    for o in 0..4 {
306        process_patch(2 * o);
307    }
308
309    // Now that we have the individual 8-dimensional transformations, we can begin swizzling
310    // them together to construct the full 64-element transform.
311    //
312    // This computes four 16-element transforms.
313    let e0 = d0 + d1;
314    let e1 = d0 - d1;
315
316    let e2 = d2 + d3;
317    let e3 = d2 - d3;
318
319    let e4 = d4 + d5;
320    let e5 = d4 - d5;
321
322    let e6 = d6 + d7;
323    let e7 = d6 - d7;
324
325    // Compute two 32-element transforms.
326    let f0 = e0 + e2;
327    let f1 = e1 + e3;
328
329    let f2 = e0 - e2;
330    let f3 = e1 - e3;
331
332    let f4 = e4 + e6;
333    let f5 = e5 + e7;
334
335    let f6 = e4 - e6;
336    let f7 = e5 - e7;
337
338    // Compute the full 64-element transform and write-back the results.
339    let dst: *mut f32 = x.as_mut_ptr();
340
341    // SAFETY: The pointer `dst` is valid for writing up to 64-elements, which is what
342    // we do here.
343    unsafe {
344        (f0 + f4).store_simd(dst);
345        (f1 + f5).store_simd(dst.add(8));
346        (f2 + f6).store_simd(dst.add(16));
347        (f3 + f7).store_simd(dst.add(24));
348        (f0 - f4).store_simd(dst.add(32));
349        (f1 - f5).store_simd(dst.add(40));
350        (f2 - f6).store_simd(dst.add(48));
351        (f3 - f7).store_simd(dst.add(56));
352    }
353}
354
355///////////
356// Tests //
357///////////
358
359#[cfg(test)]
360mod tests {
361    use rand::{
362        distr::{Distribution, StandardUniform},
363        rngs::StdRng,
364        SeedableRng,
365    };
366
367    use super::*;
368    use diskann_utils::views::{self, Matrix, MatrixView};
369
370    /// Retrieve the 8x8 hadamard matrix as a `Matrix`.
371    fn get_hadamard_8() -> Matrix<f32> {
372        let v: Box<[f32]> = HADAMARD_8.iter().flatten().copied().collect();
373        Matrix::try_from(v, 8, 8).unwrap()
374    }
375
376    fn hadamard_by_sylvester(dim: usize) -> Matrix<f32> {
377        assert_ne!(dim, 0);
378        // Base case.
379        if dim == 1 {
380            Matrix::new(1.0, dim, dim)
381        } else {
382            let half = dim / 2;
383            let sub = hadamard_by_sylvester(half);
384            let mut m = Matrix::<f32>::new(0.0, dim, dim);
385
386            for c in 0..m.ncols() {
387                for r in 0..m.nrows() {
388                    let mut v = sub[(r % half, c % half)];
389                    if c >= half && r >= half {
390                        v = -v;
391                    }
392                    m[(c, r)] = v;
393                }
394            }
395            m
396        }
397    }
398
399    // Ensure that our 8x8 constant Hadamard matrix stays consistent.
400    #[test]
401    fn test_hadamard_8() {
402        let h8 = get_hadamard_8();
403        let reference = hadamard_by_sylvester(8);
404        assert_eq!(h8.as_slice(), reference.as_slice());
405    }
406
407    // A naive reference implementation.
408    fn matmul(a: MatrixView<f32>, b: MatrixView<f32>) -> Matrix<f32> {
409        assert_eq!(a.ncols(), b.nrows());
410        let mut c = Matrix::new(0.0, a.nrows(), b.ncols());
411
412        for i in 0..c.nrows() {
413            for j in 0..c.ncols() {
414                let mut v = 0.0;
415                for k in 0..a.ncols() {
416                    v = a[(i, k)].mul_add(b[(k, j)], v);
417                }
418                c[(i, j)] = v;
419            }
420        }
421        c
422    }
423
424    #[test]
425    fn test_micro_kernel_64() {
426        let mut src = {
427            let mut rng = StdRng::seed_from_u64(0xde1936d651285fc8);
428            let init = views::Init(|| StandardUniform {}.sample(&mut rng));
429            Matrix::new(init, 64, 1)
430        };
431
432        let h = hadamard_by_sylvester(64);
433        let reference = matmul(h.as_view(), src.as_view());
434
435        micro_kernel_64(diskann_wide::ARCH, src.as_mut_slice().try_into().unwrap());
436
437        assert_eq!(reference.nrows(), src.nrows());
438        assert_eq!(reference.ncols(), 1);
439        assert_eq!(src.ncols(), 1);
440
441        for j in 0..src.nrows() {
442            let src = src[(j, 0)];
443            let reference = reference[(j, 0)];
444
445            let relative_error = (src - reference).abs() / src.abs().max(reference.abs());
446            assert!(
447                relative_error < 5e-6,
448                "Got a relative error of {} for row {} - reference = {}, got = {}",
449                relative_error,
450                j,
451                reference,
452                src
453            );
454        }
455    }
456
457    // End-to-end tests.
458    fn test_hadamard_transform(dim: usize, seed: u64) {
459        let src = {
460            let mut rng = StdRng::seed_from_u64(seed);
461            let init = views::Init(|| StandardUniform {}.sample(&mut rng));
462            Matrix::new(init, dim, 1)
463        };
464
465        let h = hadamard_by_sylvester(dim);
466
467        let mut reference = matmul(h.as_view(), src.as_view());
468        reference
469            .as_mut_slice()
470            .iter_mut()
471            .for_each(|i| *i /= (dim as f32).sqrt());
472
473        // Queue up a list of implementations.
474        type Implementation = Box<dyn Fn(&mut [f32])>;
475
476        #[cfg_attr(not(target_arch = "x86_64"), expect(unused_mut))]
477        let mut impls: Vec<(Implementation, &'static str)> = vec![
478            (
479                Box::new(|x| hadamard_transform(x).unwrap()),
480                "public entry point",
481            ),
482            (
483                Box::new(|x| {
484                    diskann_wide::arch::Scalar::new()
485                        .run1(HadamardTransform, x)
486                        .unwrap()
487                }),
488                "scalar recursive implementation",
489            ),
490        ];
491
492        #[cfg(target_arch = "x86_64")]
493        if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
494            impls.push((
495                Box::new(move |x| arch.run1(HadamardTransform, x).unwrap()),
496                "x86-64-v3",
497            ));
498        }
499
500        for (f, kernel) in impls.into_iter() {
501            let mut src_clone = src.clone();
502            f(src_clone.as_mut_slice());
503
504            assert_eq!(reference.nrows(), src_clone.nrows());
505            assert_eq!(reference.ncols(), 1);
506            assert_eq!(src_clone.ncols(), 1);
507
508            for j in 0..src_clone.nrows() {
509                let src_clone = src_clone[(j, 0)];
510                let reference = reference[(j, 0)];
511
512                let relative_error =
513                    (src_clone - reference).abs() / src_clone.abs().max(reference.abs());
514                assert!(
515                    relative_error < 5e-5,
516                    "Got a relative error of {} for row {} - reference = {}, got = {} -- dim = {}: kernel = {}",
517                    relative_error,
518                    j,
519                    reference,
520                    src_clone,
521                    dim,
522                    kernel,
523                );
524            }
525        }
526    }
527
528    #[test]
529    fn test_hadamard_transform_1() {
530        test_hadamard_transform(1, 0xcdb7283f806f237d);
531    }
532
533    #[test]
534    fn test_hadamard_transform_2() {
535        test_hadamard_transform(2, 0x1e8bba190423842c);
536    }
537
538    #[test]
539    fn test_hadamard_transform_4() {
540        test_hadamard_transform(4, 0x6cdcb7e1fe0fa296);
541    }
542
543    #[test]
544    fn test_hadamard_transform_8() {
545        test_hadamard_transform(8, 0xd120b32a83158c80);
546    }
547
548    #[test]
549    fn test_hadamard_transform_16() {
550        test_hadamard_transform(16, 0x56ef310cc7e42faa);
551    }
552
553    #[test]
554    fn test_hadamard_transform_32() {
555        test_hadamard_transform(32, 0xf2a1395699390b95);
556    }
557
558    #[test]
559    fn test_hadamard_transform_64() {
560        test_hadamard_transform(64, 0x31e6a1bfe4958c8a);
561    }
562
563    #[test]
564    fn test_hadamard_transform_128() {
565        test_hadamard_transform(128, 0xe13a35f4b9392747);
566    }
567
568    #[test]
569    fn test_hadamard_transform_256() {
570        test_hadamard_transform(256, 0xf71bb8e26e79681c);
571    }
572
573    // Test the error cases.
574    #[test]
575    fn test_error() {
576        // Supplying an empty-slice is an error.
577        assert!(matches!(hadamard_transform(&mut []), Err(NotPowerOfTwo)));
578
579        for dim in [3, 31, 33, 40, 63, 65, 100, 127, 129] {
580            let mut v = vec![0.0f32; dim];
581            assert!(matches!(hadamard_transform(&mut v), Err(NotPowerOfTwo)));
582        }
583    }
584}