Skip to main content

diskann_quantization/algorithms/transforms/
double_hadamard.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use rand::{
11    Rng,
12    distr::{Distribution, StandardUniform},
13};
14use thiserror::Error;
15
16#[cfg(feature = "flatbuffers")]
17use super::utils::{bool_to_sign, sign_to_bool};
18use super::{
19    TargetDim,
20    utils::{TransformFailed, check_dims, is_sign, subsample_indices},
21};
22#[cfg(feature = "flatbuffers")]
23use crate::flatbuffers as fb;
24use crate::{
25    algorithms::hadamard_transform,
26    alloc::{Allocator, AllocatorError, Poly, ScopedAllocator, TryClone},
27    utils,
28};
29
30/// A Double Hadamard transform that applies the signed Hadamard Transform to a head of the
31/// vector and then the tail.
32///
33/// This struct performs the transformation
34/// ```math
35/// [I 0; 0 H/sqrt(t)] · D1 · [H/sqrt(t) 0; 0 I] · zeropad(D0 · x)
36/// ```
37///
38/// * `n` is the dimensionality of the input vector.
39/// * `m` is the desired output dimensionality.
40/// * `o = max(n, m)` is an intermediate dimension.
41/// * `t` is the largest power of 2 less than or equal to `o`.
42/// * `H` is a Hadamard Matrix of dimension `t`,
43/// * `I` is the identity matrix of dimension `n - t`
44/// * `D0` and `D1` are diagonal matrices with diagonal entries in `{-1, +1}` drawn
45///   uniformly at random with lengths `n` and `o` respectively.
46/// * `x` is the input vector of dimension `n`
47/// * `[A 0; 0 B]` represents a block diagonal matrix with blocks `A` and `B`.
48/// * `zeropad` indicates that the result `D0 · x` is zero-padded to the dimension `o` if
49///   needed.
50///
51/// As indicated above, if `o` is a power of two, then only a single transformation is
52/// applied. Further, if `o` exceeds `n`, then the input vector is zero padded at the end to
53/// `o` dimensions.
54#[derive(Debug)]
55#[cfg_attr(test, derive(PartialEq))]
56pub struct DoubleHadamard<A>
57where
58    A: Allocator,
59{
60    /// Vectors of `+/-1` used for to add randomness to the Hadamard transform
61    /// in each step.
62    ///
63    /// These are stored as a slice of `u32` where each value is either `0` or `0x8000_0000`,
64    /// corresponding to the sign-bit for an `f32` value, allowing sign flipping using
65    /// a cheap `xor` operation.
66    signs0: Poly<[u32], A>,
67    signs1: Poly<[u32], A>,
68
69    /// The target output dimension of the transformation.
70    target_dim: usize,
71
72    /// Optional array storing (in sorted order) the indices to sample if `target_dim < dim`
73    subsample: Option<Poly<[u32], A>>,
74}
75
76impl<A> DoubleHadamard<A>
77where
78    A: Allocator,
79{
80    /// Construct a new `DoubleHadamard` that transforms input vectors of dimension `dim`.
81    ///
82    /// The parameter `rng` is used to randomly initialize the diagonal matrices portion of
83    /// the transform.
84    ///
85    /// The following dimensionalities will be configured depending on the value of `target`:
86    ///
87    /// * [`TargetDim::Same`]
88    ///   - `self.input_dim() == dim.get()`
89    ///   - `self.output_dim() == dim.get()`
90    /// * [`TargetDim::Natural`]
91    ///   - `self.input_dim() == dim.get()`
92    ///   - `self.output_dim() == dim.get()`
93    /// * [`TargetDim::Override`]
94    ///   - `self.input_dim() == dim.get()`
95    ///   - `self.output_dim()`: The value provided by the override.
96    ///
97    /// Subsampling occurs if `self.output_dim()` is smaller than `self.input_dim()`.
98    pub fn new<R>(
99        dim: NonZeroUsize,
100        target_dim: TargetDim,
101        rng: &mut R,
102        allocator: A,
103    ) -> Result<Self, AllocatorError>
104    where
105        R: Rng + ?Sized,
106    {
107        let dim = dim.get();
108
109        let target_dim = match target_dim {
110            TargetDim::Override(target) => target.get(),
111            TargetDim::Same => dim,
112            TargetDim::Natural => dim,
113        };
114
115        // The intermediate dimension after applying the first transform.
116        //
117        // If `target_dim` exceeds `dim`, then we perform zero padding up to `target_dim`
118        // for this stage.
119        let intermediate_dim = dim.max(target_dim);
120
121        // Generate random signs for the diagonal matrices
122        let mut sample = |_: usize| {
123            let sign: bool = StandardUniform {}.sample(rng);
124            if sign { 0x8000_0000 } else { 0 }
125        };
126
127        // Since implicit zero padding is used for this stage, we only create space for
128        // `dim` values.
129        let signs0 = Poly::from_iter((0..dim).map(&mut sample), allocator.clone())?;
130        let signs1 = Poly::from_iter((0..intermediate_dim).map(&mut sample), allocator.clone())?;
131
132        let subsample = if dim > target_dim {
133            Some(subsample_indices(rng, dim, target_dim, allocator)?)
134        } else {
135            None
136        };
137
138        Ok(Self {
139            signs0,
140            signs1,
141            target_dim,
142            subsample,
143        })
144    }
145
146    pub fn try_from_parts(
147        signs0: Poly<[u32], A>,
148        signs1: Poly<[u32], A>,
149        subsample: Option<Poly<[u32], A>>,
150    ) -> Result<Self, DoubleHadamardError> {
151        type E = DoubleHadamardError;
152        if signs0.is_empty() {
153            return Err(E::Signs0Empty);
154        }
155        if signs1.len() < signs0.len() {
156            return Err(E::Signs1TooSmall);
157        }
158        if !signs0.iter().copied().all(is_sign) {
159            return Err(E::Signs0Invalid);
160        }
161        if !signs1.iter().copied().all(is_sign) {
162            return Err(E::Signs1Invalid);
163        }
164
165        // Some preliminary checks on `subsample` that must always hold if it is present.
166        let target_dim = if let Some(ref subsample) = subsample {
167            if !utils::is_strictly_monotonic(subsample.iter()) {
168                return Err(E::SubsampleNotMonotonic);
169            }
170
171            match subsample.last() {
172                Some(last) => {
173                    if *last as usize >= signs1.len() {
174                        // Since the entries in `subsample` are used to index an output
175                        // vector of lengths `signs1`, the last element must be strictly
176                        // less than this length.
177                        //
178                        // From the monotonicity check, we can therefore deduce that *all*
179                        // entries are in-bounds.
180                        return Err(E::LastSubsampleTooLarge);
181                    }
182                }
183                None => {
184                    // Subsample cannot be empty.
185                    return Err(E::InvalidSubsampleLength);
186                }
187            }
188
189            debug_assert!(
190                subsample.len() < signs1.len(),
191                "since we've verified monotonicity and the last element, this is implied"
192            );
193
194            subsample.len()
195        } else {
196            // With no subsampling, the target dim is the length of `signs1`.
197            signs1.len()
198        };
199
200        Ok(Self {
201            signs0,
202            signs1,
203            target_dim,
204            subsample,
205        })
206    }
207
208    /// Return the input dimension for the transformation.
209    pub fn input_dim(&self) -> usize {
210        self.signs0.len()
211    }
212
213    /// Return the output dimension for the transformation.
214    pub fn output_dim(&self) -> usize {
215        self.target_dim
216    }
217
218    /// Return whether or not the transform preserves norms.
219    ///
220    /// For this transform, norms are not preserved when the output dimensionality is less
221    /// than the input dimensionality.
222    pub fn preserves_norms(&self) -> bool {
223        self.subsample.is_none()
224    }
225
226    fn intermediate_dim(&self) -> usize {
227        self.input_dim().max(self.output_dim())
228    }
229
230    /// Perform the transformation of the `src` vector into the `dst` vector.
231    ///
232    /// # Errors
233    ///
234    /// Returns an error if
235    ///
236    /// * `src.len() != self.input_dim()`.
237    /// * `dst.len() != self.output_dim()`.
238    pub fn transform_into(
239        &self,
240        dst: &mut [f32],
241        src: &[f32],
242        allocator: ScopedAllocator<'_>,
243    ) -> Result<(), TransformFailed> {
244        check_dims(dst, src, self.input_dim(), self.output_dim())?;
245
246        // Copy and flip signs
247        let intermediate_dim = self.intermediate_dim();
248        let mut tmp = Poly::broadcast(0.0f32, intermediate_dim, allocator)?;
249
250        std::iter::zip(tmp.iter_mut(), src.iter())
251            .zip(self.signs0.iter())
252            .for_each(|((dst, src), sign)| *dst = f32::from_bits(src.to_bits() ^ sign));
253
254        let split = 1usize << (usize::BITS - intermediate_dim.leading_zeros() - 1);
255
256        // `split` is less than or equal to `tmp` and is a power to 2.
257        //
258        // If it is equal to the size of `tmp`, then we only run the first transform. Otherwise,
259        // we perform two transforms on the head and tail of `tmp`.
260        #[allow(clippy::unwrap_used)]
261        hadamard_transform(&mut tmp[..split]).unwrap();
262
263        // Apply the second transformation.
264        // Since random signs are applied to the intermediate value, the second transform
265        // does not undo the first.
266        tmp.iter_mut()
267            .zip(self.signs1.iter())
268            .for_each(|(dst, sign)| *dst = f32::from_bits(dst.to_bits() ^ sign));
269
270        #[allow(clippy::unwrap_used)]
271        hadamard_transform(&mut tmp[intermediate_dim - split..]).unwrap();
272
273        match self.subsample.as_ref() {
274            None => {
275                dst.copy_from_slice(&tmp);
276            }
277            Some(indices) => {
278                let rescale = ((tmp.len() as f32) / (indices.len() as f32)).sqrt();
279                debug_assert_eq!(dst.len(), indices.len());
280                dst.iter_mut()
281                    .zip(indices.iter())
282                    .for_each(|(d, s)| *d = tmp[*s as usize] * rescale);
283            }
284        }
285
286        Ok(())
287    }
288}
289
290impl<A> TryClone for DoubleHadamard<A>
291where
292    A: Allocator,
293{
294    fn try_clone(&self) -> Result<Self, AllocatorError> {
295        Ok(Self {
296            signs0: self.signs0.try_clone()?,
297            signs1: self.signs1.try_clone()?,
298            target_dim: self.target_dim,
299            subsample: self.subsample.try_clone()?,
300        })
301    }
302}
303
304#[derive(Debug, Clone, Copy, Error, PartialEq)]
305#[non_exhaustive]
306pub enum DoubleHadamardError {
307    #[error("first signs stage cannot be empty")]
308    Signs0Empty,
309    #[error("first signs stage has invalid coding")]
310    Signs0Invalid,
311
312    #[error("invalid sign representation for second stage")]
313    Signs1Invalid,
314    #[error("second sign stage must be at least as large as the first stage")]
315    Signs1TooSmall,
316
317    #[error("subsample length must equal `target_dim`")]
318    InvalidSubsampleLength,
319    #[error("subsample indices is not monotonic")]
320    SubsampleNotMonotonic,
321    #[error("last subsample index exceeded intermediate dim")]
322    LastSubsampleTooLarge,
323
324    #[error(transparent)]
325    AllocatorError(#[from] AllocatorError),
326}
327
328// Serialization
329#[cfg(feature = "flatbuffers")]
330impl<A> DoubleHadamard<A>
331where
332    A: Allocator,
333{
334    /// Pack into a [`crate::flatbuffers::transforms::DoubleHadamard`] serialized
335    /// represntation.
336    pub(crate) fn pack<'a, FA>(
337        &self,
338        buf: &mut FlatBufferBuilder<'a, FA>,
339    ) -> WIPOffset<fb::transforms::DoubleHadamard<'a>>
340    where
341        FA: flatbuffers::Allocator + 'a,
342    {
343        // Store the sign vectors.
344        let signs0 = buf.create_vector_from_iter(self.signs0.iter().copied().map(sign_to_bool));
345        let signs1 = buf.create_vector_from_iter(self.signs1.iter().copied().map(sign_to_bool));
346
347        // If subsample indices are present - save those as well.
348        let subsample = self
349            .subsample
350            .as_ref()
351            .map(|indices| buf.create_vector(indices));
352
353        fb::transforms::DoubleHadamard::create(
354            buf,
355            &fb::transforms::DoubleHadamardArgs {
356                signs0: Some(signs0),
357                signs1: Some(signs1),
358                subsample,
359            },
360        )
361    }
362
363    /// Attempt to unpack from a [`crate::flatbuffers::transforms::DoubleHadamard`]
364    /// serialized representation, returning any error if encountered.
365    pub(crate) fn try_unpack(
366        alloc: A,
367        proto: fb::transforms::DoubleHadamard<'_>,
368    ) -> Result<Self, DoubleHadamardError> {
369        let signs0 = Poly::from_iter(proto.signs0().iter().map(bool_to_sign), alloc.clone())?;
370        let signs1 = Poly::from_iter(proto.signs1().iter().map(bool_to_sign), alloc.clone())?;
371
372        let subsample = match proto.subsample() {
373            Some(subsample) => Some(Poly::from_iter(subsample.into_iter(), alloc)?),
374            None => None,
375        };
376
377        Self::try_from_parts(signs0, signs1, subsample)
378    }
379}
380
381///////////
382// Tests //
383///////////
384
385#[cfg(test)]
386mod tests {
387    use diskann_utils::lazy_format;
388    use rand::{SeedableRng, rngs::StdRng};
389
390    use super::*;
391    use crate::{
392        algorithms::transforms::{Transform, TransformKind, test_utils},
393        alloc::GlobalAllocator,
394    };
395
396    test_utils::delegate_transformer!(DoubleHadamard<GlobalAllocator>);
397
398    #[test]
399    fn test_double_hadamard() {
400        // Inner product computations are more susceptible to floating point error.
401        // Instead of using ULP here, we fall back to using absolute and relative error.
402        //
403        // These error bounds are for when we set the output dimenion to a power of 2 that
404        // is higher than input dimension.
405        let natural_errors = test_utils::ErrorSetup {
406            norm: test_utils::Check::ulp(5),
407            l2: test_utils::Check::ulp(5),
408            ip: test_utils::Check::absrel(2.5e-5, 2e-4),
409        };
410
411        // NOTE: Subsampling introduces high variance in the norm and L2, so our error
412        // bounds need to be looser.
413        //
414        // Subsampling results in poor preservation of inner products, so we skip it
415        // altogether.
416        let subsampled_errors = test_utils::ErrorSetup {
417            norm: test_utils::Check::absrel(0.0, 2e-2),
418            l2: test_utils::Check::absrel(0.0, 2e-2),
419            ip: test_utils::Check::skip(),
420        };
421
422        let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
423        let dim_combos = [
424            // Natural
425            (15, 15, true, TargetDim::Same, &natural_errors),
426            (15, 15, true, TargetDim::Natural, &natural_errors),
427            (16, 16, true, TargetDim::Same, &natural_errors),
428            (16, 16, true, TargetDim::Natural, &natural_errors),
429            (256, 256, true, TargetDim::Same, &natural_errors),
430            (1000, 1000, true, TargetDim::Same, &natural_errors),
431            // Larger
432            (15, 16, true, target_dim(16), &natural_errors),
433            (100, 128, true, target_dim(128), &natural_errors),
434            (15, 32, true, target_dim(32), &natural_errors),
435            (16, 64, true, target_dim(64), &natural_errors),
436            // Sub-Sampling.
437            (1024, 1023, false, target_dim(1023), &subsampled_errors),
438            (1000, 999, false, target_dim(999), &subsampled_errors),
439        ];
440
441        let trials_per_combo = 20;
442        let trials_per_dim = 100;
443
444        let mut rng = StdRng::seed_from_u64(0x6d1699abe066147);
445        for (input, output, preserves_norms, target, errors) in dim_combos {
446            let input_nz = NonZeroUsize::new(input).unwrap();
447            for trial in 0..trials_per_combo {
448                let ctx = &lazy_format!(
449                    "input dim = {}, output dim = {}, macro trial {} of {}",
450                    input,
451                    output,
452                    trial,
453                    trials_per_combo
454                );
455
456                let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
457                    let d = input.min(output);
458                    assert_ne!(&io.input0[..d], &io.output0[..d]);
459                    assert_ne!(&io.input1[..d], &io.output1[..d]);
460                    test_utils::check_errors(io, context, errors);
461                };
462
463                // Clone the Rng state so the abstract transform behaves the same.
464                let mut rng_clone = rng.clone();
465
466                // Test the underlying transformer.
467                {
468                    let transformer = DoubleHadamard::new(
469                        NonZeroUsize::new(input).unwrap(),
470                        target,
471                        &mut rng,
472                        GlobalAllocator,
473                    )
474                    .unwrap();
475
476                    assert_eq!(transformer.input_dim(), input);
477                    assert_eq!(transformer.output_dim(), output);
478                    assert_eq!(transformer.preserves_norms(), preserves_norms);
479
480                    test_utils::test_transform(
481                        &transformer,
482                        trials_per_dim,
483                        &mut checker,
484                        &mut rng,
485                        ctx,
486                    )
487                }
488
489                // Abstract Transformer
490                {
491                    let kind = TransformKind::DoubleHadamard { target_dim: target };
492                    let transformer =
493                        Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
494                            .unwrap();
495
496                    assert_eq!(transformer.input_dim(), input);
497                    assert_eq!(transformer.output_dim(), output);
498                    assert_eq!(transformer.preserves_norms(), preserves_norms);
499
500                    test_utils::test_transform(
501                        &transformer,
502                        trials_per_dim,
503                        &mut checker,
504                        &mut rng_clone,
505                        ctx,
506                    )
507                }
508            }
509        }
510    }
511
512    #[cfg(feature = "flatbuffers")]
513    mod serialization {
514        use super::*;
515        use crate::flatbuffers::to_flatbuffer;
516
517        #[test]
518        fn double_hadamard() {
519            let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
520            let alloc = GlobalAllocator;
521
522            // Test various dimension combinations
523            let test_cases = [
524                // No sub or super sampling
525                (5, TargetDim::Same),
526                (8, TargetDim::Same),
527                (10, TargetDim::Natural),
528                (16, TargetDim::Natural),
529                // Super sampling with both stages
530                (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
531                (10, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
532                // Super sample with one stage
533                (15, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
534                (16, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
535                (15, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
536                (16, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
537                // Sub sampling.
538                (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
539                (16, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
540            ];
541
542            for (dim, target_dim) in test_cases {
543                let transform = DoubleHadamard::new(
544                    NonZeroUsize::new(dim).unwrap(),
545                    target_dim,
546                    &mut rng,
547                    alloc,
548                )
549                .unwrap();
550                let data = to_flatbuffer(|buf| transform.pack(buf));
551
552                let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
553                let reloaded = DoubleHadamard::try_unpack(alloc, proto).unwrap();
554
555                assert_eq!(transform, reloaded);
556            }
557
558            let gen_err = |x: DoubleHadamard<_>| -> DoubleHadamardError {
559                let data = to_flatbuffer(|buf| x.pack(buf));
560                let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
561                DoubleHadamard::try_unpack(alloc, proto).unwrap_err()
562            };
563
564            type E = DoubleHadamardError;
565            let error_cases = [
566                // Signs1TooSmall: signs0.len() > signs1.len()
567                (
568                    vec![0, 0, 0, 0, 0], // 5 elements
569                    vec![0, 0, 0, 0],    // 4 elements
570                    4,
571                    None,
572                    E::Signs1TooSmall,
573                ),
574                // Signs0Empty
575                (
576                    vec![], // empty signs0
577                    vec![0, 0, 0, 0],
578                    4,
579                    None,
580                    E::Signs0Empty,
581                ),
582                // SubsampleNotMonotonic: subsample indices not in increasing order
583                (
584                    vec![0, 0, 0, 0],
585                    vec![0, 0, 0, 0],
586                    3,
587                    Some(vec![0, 2, 1]), // not monotonic
588                    E::SubsampleNotMonotonic,
589                ),
590                // SubsampleNotMonotonic: duplicate values
591                (
592                    vec![0, 0, 0, 0],
593                    vec![0, 0, 0, 0],
594                    3,
595                    Some(vec![0, 1, 1]), // duplicate values
596                    E::SubsampleNotMonotonic,
597                ),
598                // LastSubsampleTooLarge: exceeds intermediate dim with signs1
599                (
600                    vec![0, 0, 0], // 3 elements
601                    vec![0, 0, 0], // 3 elements
602                    2,
603                    Some(vec![0, 3]), // index 3 >= intermediate_dim(3,2) = 3
604                    E::LastSubsampleTooLarge,
605                ),
606                // LastSubsampleTooLarge: exceeds intermediate dim with signs1
607                (
608                    vec![0, 0, 0], // 3 elements
609                    vec![0, 0, 0], // 3 elements
610                    2,
611                    Some(vec![]), // empty
612                    E::InvalidSubsampleLength,
613                ),
614            ];
615
616            let poly = |v: &Vec<u32>| Poly::from_iter(v.iter().copied(), alloc).unwrap();
617
618            for (signs0, signs1, target_dim, subsample, expected) in error_cases.iter() {
619                println!(
620                    "on case ({:?}, {:?}, {}, {:?})",
621                    signs0, signs1, target_dim, subsample,
622                );
623                let err = gen_err(DoubleHadamard {
624                    signs0: poly(signs0),
625                    signs1: poly(signs1),
626                    target_dim: *target_dim,
627                    subsample: subsample.as_ref().map(poly),
628                });
629
630                assert_eq!(
631                    err, *expected,
632                    "failed for case ({:?}, {:?}, {}, {:?})",
633                    signs0, signs1, target_dim, subsample
634                );
635            }
636        }
637    }
638}