Skip to main content

diskann_quantization/algorithms/transforms/
double_hadamard.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use rand::{
11    Rng,
12    distr::{Distribution, StandardUniform},
13};
14use thiserror::Error;
15
16#[cfg(feature = "flatbuffers")]
17use super::utils::{bool_to_sign, sign_to_bool};
18use super::{
19    TargetDim,
20    utils::{TransformFailed, check_dims, is_sign, subsample_indices},
21};
22#[cfg(feature = "flatbuffers")]
23use crate::flatbuffers as fb;
24use crate::{
25    algorithms::hadamard_transform,
26    alloc::{Allocator, AllocatorError, Poly, ScopedAllocator, TryClone},
27    utils,
28};
29
30/// A Double Hadamard transform that applies the signed Hadamard Transform to a head of the
31/// vector and then the tail.
32///
33/// This struct performs the transformation
34/// ```math
35/// [I 0; 0 H/sqrt(t)] · D1 · [H/sqrt(t) 0; 0 I] · zeropad(D0 · x)
36/// ```
37///
38/// * `n` is the dimensionality of the input vector.
39/// * `m` is the desired output dimensionality.
40/// * `o = max(n, m)` is an intermediate dimension.
41/// * `t` is the largest power of 2 less than or equal to `o`.
42/// * `H` is a Hadamard Matrix of dimension `t`,
43/// * `I` is the identity matrix of dimension `n - t`
44/// * `D0` and `D1` are diagonal matrices with diagonal entries in `{-1, +1}` drawn
45///   uniformly at random with lengths `n` and `o` respectively.
46/// * `x` is the input vector of dimension `n`
47/// * `[A 0; 0 B]` represents a block diagonal matrix with blocks `A` and `B`.
48/// * `zeropad` indicates that the result `D0 · x` is zero-padded to the dimension `o` if
49///   needed.
50///
51/// As indicated above, if `o` is a power of two, then only a single transformation is
52/// applied. Further, if `o` exceeds `n`, then the input vector is zero padded at the end to
53/// `o` dimensions.
54#[derive(Debug)]
55#[cfg_attr(test, derive(PartialEq))]
56pub struct DoubleHadamard<A>
57where
58    A: Allocator,
59{
60    /// Vectors of `+/-1` used for to add randomness to the Hadamard transform
61    /// in each step.
62    ///
63    /// These are stored as a slice of `u32` where each value is either `0` or `0x8000_0000`,
64    /// corresponding to the sign-bit for an `f32` value, allowing sign flipping using
65    /// a cheap `xor` operation.
66    signs0: Poly<[u32], A>,
67    signs1: Poly<[u32], A>,
68
69    /// The target output dimension of the transformation.
70    target_dim: usize,
71
72    /// Optional array storing (in sorted order) the indices to sample if `target_dim < dim`
73    subsample: Option<Poly<[u32], A>>,
74}
75
76impl<A> DoubleHadamard<A>
77where
78    A: Allocator,
79{
80    /// Construct a new `DoubleHadamard` that transforms input vectors of dimension `dim`.
81    ///
82    /// The parameter `rng` is used to randomly initialize the diagonal matrices portion of
83    /// the transform.
84    ///
85    /// The following dimensionalities will be configured depending on the value of `target`:
86    ///
87    /// * [`TargetDim::Same`]
88    ///   - `self.input_dim() == dim.get()`
89    ///   - `self.output_dim() == dim.get()`
90    /// * [`TargetDim::Natural`]
91    ///   - `self.input_dim() == dim.get()`
92    ///   - `self.output_dim() == dim.get()`
93    /// * [`TargetDim::Override`]
94    ///   - `self.input_dim() == dim.get()`
95    ///   - `self.output_dim()`: The value provided by the override.
96    ///
97    /// Subsampling occurs if `self.output_dim()` is smaller than `self.input_dim()`.
98    pub fn new<R>(
99        dim: NonZeroUsize,
100        target_dim: TargetDim,
101        rng: &mut R,
102        allocator: A,
103    ) -> Result<Self, AllocatorError>
104    where
105        R: Rng + ?Sized,
106    {
107        let dim = dim.get();
108
109        let target_dim = match target_dim {
110            TargetDim::Override(target) => target.get(),
111            TargetDim::Same => dim,
112            TargetDim::Natural => dim,
113        };
114
115        // The intermediate dimension after applying the first transform.
116        //
117        // If `target_dim` exceeds `dim`, then we perform zero padding up to `target_dim`
118        // for this stage.
119        let intermediate_dim = dim.max(target_dim);
120
121        // Generate random signs for the diagonal matrices
122        let mut sample = |_: usize| {
123            let sign: bool = StandardUniform {}.sample(rng);
124            if sign { 0x8000_0000 } else { 0 }
125        };
126
127        // Since implicit zero padding is used for this stage, we only create space for
128        // `dim` values.
129        let signs0 = Poly::from_iter((0..dim).map(&mut sample), allocator.clone())?;
130        let signs1 = Poly::from_iter((0..intermediate_dim).map(&mut sample), allocator.clone())?;
131
132        let subsample = if dim > target_dim {
133            Some(subsample_indices(rng, dim, target_dim, allocator)?)
134        } else {
135            None
136        };
137
138        Ok(Self {
139            signs0,
140            signs1,
141            target_dim,
142            subsample,
143        })
144    }
145
146    pub fn try_from_parts(
147        signs0: Poly<[u32], A>,
148        signs1: Poly<[u32], A>,
149        subsample: Option<Poly<[u32], A>>,
150    ) -> Result<Self, DoubleHadamardError> {
151        type E = DoubleHadamardError;
152        if signs0.is_empty() {
153            return Err(E::Signs0Empty);
154        }
155        if signs1.len() < signs0.len() {
156            return Err(E::Signs1TooSmall);
157        }
158        if !signs0.iter().copied().all(is_sign) {
159            return Err(E::Signs0Invalid);
160        }
161        if !signs1.iter().copied().all(is_sign) {
162            return Err(E::Signs1Invalid);
163        }
164
165        // Some preliminary checks on `subsample` that must always hold if it is present.
166        let target_dim = if let Some(ref subsample) = subsample {
167            if !utils::is_strictly_monotonic(subsample.iter()) {
168                return Err(E::SubsampleNotMonotonic);
169            }
170
171            match subsample.last() {
172                Some(last) => {
173                    if *last as usize >= signs1.len() {
174                        // Since the entries in `subsample` are used to index an output
175                        // vector of lengths `signs1`, the last element must be strictly
176                        // less than this length.
177                        //
178                        // From the monotonicity check, we can therefore deduce that *all*
179                        // entries are in-bounds.
180                        return Err(E::LastSubsampleTooLarge);
181                    }
182                }
183                None => {
184                    // Subsample cannot be empty.
185                    return Err(E::InvalidSubsampleLength);
186                }
187            }
188
189            debug_assert!(
190                subsample.len() < signs1.len(),
191                "since we've verified monotonicity and the last element, this is implied"
192            );
193
194            subsample.len()
195        } else {
196            // With no subsampling, the target dim is the length of `signs1`.
197            signs1.len()
198        };
199
200        Ok(Self {
201            signs0,
202            signs1,
203            target_dim,
204            subsample,
205        })
206    }
207
208    /// Return the input dimension for the transformation.
209    pub fn input_dim(&self) -> usize {
210        self.signs0.len()
211    }
212
213    /// Return the output dimension for the transformation.
214    pub fn output_dim(&self) -> usize {
215        self.target_dim
216    }
217
218    /// Return whether or not the transform preserves norms.
219    ///
220    /// For this transform, norms are not preserved when the output dimensionality is less
221    /// than the input dimensionality.
222    pub fn preserves_norms(&self) -> bool {
223        self.subsample.is_none()
224    }
225
226    fn intermediate_dim(&self) -> usize {
227        self.input_dim().max(self.output_dim())
228    }
229
230    /// Perform the transformation of the `src` vector into the `dst` vector.
231    ///
232    /// # Errors
233    ///
234    /// Returns an error if
235    ///
236    /// * `src.len() != self.input_dim()`.
237    /// * `dst.len() != self.output_dim()`.
238    pub fn transform_into(
239        &self,
240        dst: &mut [f32],
241        src: &[f32],
242        allocator: ScopedAllocator<'_>,
243    ) -> Result<(), TransformFailed> {
244        check_dims(dst, src, self.input_dim(), self.output_dim())?;
245
246        // Copy and flip signs
247        let intermediate_dim = self.intermediate_dim();
248        let mut tmp = Poly::broadcast(0.0f32, intermediate_dim, allocator)?;
249
250        std::iter::zip(tmp.iter_mut(), src.iter())
251            .zip(self.signs0.iter())
252            .for_each(|((dst, src), sign)| *dst = f32::from_bits(src.to_bits() ^ sign));
253
254        let split = 1usize << (usize::BITS - intermediate_dim.leading_zeros() - 1);
255
256        // `split` is less than or equal to `tmp` and is a power to 2.
257        //
258        // If it is equal to the size of `tmp`, then we only run the first transform. Otherwise,
259        // we perform two transforms on the head and tail of `tmp`.
260        #[allow(clippy::unwrap_used)]
261        hadamard_transform(&mut tmp[..split]).unwrap();
262
263        // Apply the second transformation.
264        // Since random signs are applied to the intermediate value, the second transform
265        // does not undo the first.
266        tmp.iter_mut()
267            .zip(self.signs1.iter())
268            .for_each(|(dst, sign)| *dst = f32::from_bits(dst.to_bits() ^ sign));
269
270        #[allow(clippy::unwrap_used)]
271        hadamard_transform(&mut tmp[intermediate_dim - split..]).unwrap();
272
273        match self.subsample.as_ref() {
274            None => {
275                dst.copy_from_slice(&tmp);
276            }
277            Some(indices) => {
278                let rescale = ((tmp.len() as f32) / (indices.len() as f32)).sqrt();
279                debug_assert_eq!(dst.len(), indices.len());
280                dst.iter_mut()
281                    .zip(indices.iter())
282                    .for_each(|(d, s)| *d = tmp[*s as usize] * rescale);
283            }
284        }
285
286        Ok(())
287    }
288}
289
290impl<A> TryClone for DoubleHadamard<A>
291where
292    A: Allocator,
293{
294    fn try_clone(&self) -> Result<Self, AllocatorError> {
295        Ok(Self {
296            signs0: self.signs0.try_clone()?,
297            signs1: self.signs1.try_clone()?,
298            target_dim: self.target_dim,
299            subsample: self.subsample.try_clone()?,
300        })
301    }
302}
303
304#[derive(Debug, Clone, Copy, Error, PartialEq)]
305#[non_exhaustive]
306pub enum DoubleHadamardError {
307    #[error("first signs stage cannot be empty")]
308    Signs0Empty,
309    #[error("first signs stage has invalid coding")]
310    Signs0Invalid,
311
312    #[error("invalid sign representation for second stage")]
313    Signs1Invalid,
314    #[error("second sign stage must be at least as large as the first stage")]
315    Signs1TooSmall,
316
317    #[error("subsample length must equal `target_dim`")]
318    InvalidSubsampleLength,
319    #[error("subsample indices is not monotonic")]
320    SubsampleNotMonotonic,
321    #[error("last subsample index exceeded intermediate dim")]
322    LastSubsampleTooLarge,
323
324    #[error(transparent)]
325    AllocatorError(#[from] AllocatorError),
326}
327
328// Serialization
329#[cfg(feature = "flatbuffers")]
330impl<A> DoubleHadamard<A>
331where
332    A: Allocator,
333{
334    /// Pack into a [`crate::flatbuffers::transforms::DoubleHadamard`] serialized
335    /// represntation.
336    pub(crate) fn pack<'a, FA>(
337        &self,
338        buf: &mut FlatBufferBuilder<'a, FA>,
339    ) -> WIPOffset<fb::transforms::DoubleHadamard<'a>>
340    where
341        FA: flatbuffers::Allocator + 'a,
342    {
343        // Store the sign vectors.
344        let signs0 = buf.create_vector_from_iter(self.signs0.iter().copied().map(sign_to_bool));
345        let signs1 = buf.create_vector_from_iter(self.signs1.iter().copied().map(sign_to_bool));
346
347        // If subsample indices are present - save those as well.
348        let subsample = self
349            .subsample
350            .as_ref()
351            .map(|indices| buf.create_vector(indices));
352
353        fb::transforms::DoubleHadamard::create(
354            buf,
355            &fb::transforms::DoubleHadamardArgs {
356                signs0: Some(signs0),
357                signs1: Some(signs1),
358                subsample,
359            },
360        )
361    }
362
363    /// Attempt to unpack from a [`crate::flatbuffers::transforms::DoubleHadamard`]
364    /// serialized representation, returning any error if encountered.
365    pub(crate) fn try_unpack(
366        alloc: A,
367        proto: fb::transforms::DoubleHadamard<'_>,
368    ) -> Result<Self, DoubleHadamardError> {
369        let signs0 = Poly::from_iter(proto.signs0().iter().map(bool_to_sign), alloc.clone())?;
370        let signs1 = Poly::from_iter(proto.signs1().iter().map(bool_to_sign), alloc.clone())?;
371
372        let subsample = match proto.subsample() {
373            Some(subsample) => Some(Poly::from_iter(subsample.into_iter(), alloc)?),
374            None => None,
375        };
376
377        Self::try_from_parts(signs0, signs1, subsample)
378    }
379}
380
381///////////
382// Tests //
383///////////
384
385#[cfg(test)]
386#[cfg(not(miri))]
387mod tests {
388    use diskann_utils::lazy_format;
389    use rand::{SeedableRng, rngs::StdRng};
390
391    use super::*;
392    use crate::{
393        algorithms::transforms::{Transform, TransformKind, test_utils},
394        alloc::GlobalAllocator,
395    };
396
397    test_utils::delegate_transformer!(DoubleHadamard<GlobalAllocator>);
398
399    #[test]
400    fn test_double_hadamard() {
401        // Inner product computations are more susceptible to floating point error.
402        // Instead of using ULP here, we fall back to using absolute and relative error.
403        //
404        // These error bounds are for when we set the output dimenion to a power of 2 that
405        // is higher than input dimension.
406        let natural_errors = test_utils::ErrorSetup {
407            norm: test_utils::Check::ulp(5),
408            l2: test_utils::Check::ulp(5),
409            ip: test_utils::Check::absrel(2.5e-5, 2e-4),
410        };
411
412        // NOTE: Subsampling introduces high variance in the norm and L2, so our error
413        // bounds need to be looser.
414        //
415        // Subsampling results in poor preservation of inner products, so we skip it
416        // altogether.
417        let subsampled_errors = test_utils::ErrorSetup {
418            norm: test_utils::Check::absrel(0.0, 2e-2),
419            l2: test_utils::Check::absrel(0.0, 2e-2),
420            ip: test_utils::Check::skip(),
421        };
422
423        let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
424        let dim_combos = [
425            // Natural
426            (15, 15, true, TargetDim::Same, &natural_errors),
427            (15, 15, true, TargetDim::Natural, &natural_errors),
428            (16, 16, true, TargetDim::Same, &natural_errors),
429            (16, 16, true, TargetDim::Natural, &natural_errors),
430            (256, 256, true, TargetDim::Same, &natural_errors),
431            (1000, 1000, true, TargetDim::Same, &natural_errors),
432            // Larger
433            (15, 16, true, target_dim(16), &natural_errors),
434            (100, 128, true, target_dim(128), &natural_errors),
435            (15, 32, true, target_dim(32), &natural_errors),
436            (16, 64, true, target_dim(64), &natural_errors),
437            // Sub-Sampling.
438            (1024, 1023, false, target_dim(1023), &subsampled_errors),
439            (1000, 999, false, target_dim(999), &subsampled_errors),
440        ];
441
442        let trials_per_combo = 20;
443        let trials_per_dim = 100;
444
445        let mut rng = StdRng::seed_from_u64(0x6d1699abe066147);
446        for (input, output, preserves_norms, target, errors) in dim_combos {
447            let input_nz = NonZeroUsize::new(input).unwrap();
448            for trial in 0..trials_per_combo {
449                let ctx = &lazy_format!(
450                    "input dim = {}, output dim = {}, macro trial {} of {}",
451                    input,
452                    output,
453                    trial,
454                    trials_per_combo
455                );
456
457                let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
458                    let d = input.min(output);
459                    assert_ne!(&io.input0[..d], &io.output0[..d]);
460                    assert_ne!(&io.input1[..d], &io.output1[..d]);
461                    test_utils::check_errors(io, context, errors);
462                };
463
464                // Clone the Rng state so the abstract transform behaves the same.
465                let mut rng_clone = rng.clone();
466
467                // Test the underlying transformer.
468                {
469                    let transformer = DoubleHadamard::new(
470                        NonZeroUsize::new(input).unwrap(),
471                        target,
472                        &mut rng,
473                        GlobalAllocator,
474                    )
475                    .unwrap();
476
477                    assert_eq!(transformer.input_dim(), input);
478                    assert_eq!(transformer.output_dim(), output);
479                    assert_eq!(transformer.preserves_norms(), preserves_norms);
480
481                    test_utils::test_transform(
482                        &transformer,
483                        trials_per_dim,
484                        &mut checker,
485                        &mut rng,
486                        ctx,
487                    )
488                }
489
490                // Abstract Transformer
491                {
492                    let kind = TransformKind::DoubleHadamard { target_dim: target };
493                    let transformer =
494                        Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
495                            .unwrap();
496
497                    assert_eq!(transformer.input_dim(), input);
498                    assert_eq!(transformer.output_dim(), output);
499                    assert_eq!(transformer.preserves_norms(), preserves_norms);
500
501                    test_utils::test_transform(
502                        &transformer,
503                        trials_per_dim,
504                        &mut checker,
505                        &mut rng_clone,
506                        ctx,
507                    )
508                }
509            }
510        }
511    }
512
513    #[cfg(feature = "flatbuffers")]
514    mod serialization {
515        use super::*;
516        use crate::flatbuffers::to_flatbuffer;
517
518        #[test]
519        fn double_hadamard() {
520            let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
521            let alloc = GlobalAllocator;
522
523            // Test various dimension combinations
524            let test_cases = [
525                // No sub or super sampling
526                (5, TargetDim::Same),
527                (8, TargetDim::Same),
528                (10, TargetDim::Natural),
529                (16, TargetDim::Natural),
530                // Super sampling with both stages
531                (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
532                (10, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
533                // Super sample with one stage
534                (15, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
535                (16, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
536                (15, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
537                (16, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
538                // Sub sampling.
539                (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
540                (16, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
541            ];
542
543            for (dim, target_dim) in test_cases {
544                let transform = DoubleHadamard::new(
545                    NonZeroUsize::new(dim).unwrap(),
546                    target_dim,
547                    &mut rng,
548                    alloc,
549                )
550                .unwrap();
551                let data = to_flatbuffer(|buf| transform.pack(buf));
552
553                let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
554                let reloaded = DoubleHadamard::try_unpack(alloc, proto).unwrap();
555
556                assert_eq!(transform, reloaded);
557            }
558
559            let gen_err = |x: DoubleHadamard<_>| -> DoubleHadamardError {
560                let data = to_flatbuffer(|buf| x.pack(buf));
561                let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
562                DoubleHadamard::try_unpack(alloc, proto).unwrap_err()
563            };
564
565            type E = DoubleHadamardError;
566            let error_cases = [
567                // Signs1TooSmall: signs0.len() > signs1.len()
568                (
569                    vec![0, 0, 0, 0, 0], // 5 elements
570                    vec![0, 0, 0, 0],    // 4 elements
571                    4,
572                    None,
573                    E::Signs1TooSmall,
574                ),
575                // Signs0Empty
576                (
577                    vec![], // empty signs0
578                    vec![0, 0, 0, 0],
579                    4,
580                    None,
581                    E::Signs0Empty,
582                ),
583                // SubsampleNotMonotonic: subsample indices not in increasing order
584                (
585                    vec![0, 0, 0, 0],
586                    vec![0, 0, 0, 0],
587                    3,
588                    Some(vec![0, 2, 1]), // not monotonic
589                    E::SubsampleNotMonotonic,
590                ),
591                // SubsampleNotMonotonic: duplicate values
592                (
593                    vec![0, 0, 0, 0],
594                    vec![0, 0, 0, 0],
595                    3,
596                    Some(vec![0, 1, 1]), // duplicate values
597                    E::SubsampleNotMonotonic,
598                ),
599                // LastSubsampleTooLarge: exceeds intermediate dim with signs1
600                (
601                    vec![0, 0, 0], // 3 elements
602                    vec![0, 0, 0], // 3 elements
603                    2,
604                    Some(vec![0, 3]), // index 3 >= intermediate_dim(3,2) = 3
605                    E::LastSubsampleTooLarge,
606                ),
607                // LastSubsampleTooLarge: exceeds intermediate dim with signs1
608                (
609                    vec![0, 0, 0], // 3 elements
610                    vec![0, 0, 0], // 3 elements
611                    2,
612                    Some(vec![]), // empty
613                    E::InvalidSubsampleLength,
614                ),
615            ];
616
617            let poly = |v: &Vec<u32>| Poly::from_iter(v.iter().copied(), alloc).unwrap();
618
619            for (signs0, signs1, target_dim, subsample, expected) in error_cases.iter() {
620                println!(
621                    "on case ({:?}, {:?}, {}, {:?})",
622                    signs0, signs1, target_dim, subsample,
623                );
624                let err = gen_err(DoubleHadamard {
625                    signs0: poly(signs0),
626                    signs1: poly(signs1),
627                    target_dim: *target_dim,
628                    subsample: subsample.as_ref().map(poly),
629                });
630
631                assert_eq!(
632                    err, *expected,
633                    "failed for case ({:?}, {:?}, {}, {:?})",
634                    signs0, signs1, target_dim, subsample
635                );
636            }
637        }
638    }
639}