Skip to main content

diskann_quantization/algorithms/transforms/
double_hadamard.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use rand::{
11    distr::{Distribution, StandardUniform},
12    Rng,
13};
14use thiserror::Error;
15
16#[cfg(feature = "flatbuffers")]
17use super::utils::{bool_to_sign, sign_to_bool};
18use super::{
19    utils::{check_dims, is_sign, subsample_indices, TransformFailed},
20    TargetDim,
21};
22#[cfg(feature = "flatbuffers")]
23use crate::flatbuffers as fb;
24use crate::{
25    algorithms::hadamard_transform,
26    alloc::{Allocator, AllocatorError, Poly, ScopedAllocator, TryClone},
27    utils,
28};
29
30/// A Double Hadamard transform that applies the signed Hadamard Transform to a head of the
31/// vector and then the tail.
32///
33/// This struct performs the transformation
34/// ```math
35/// [I 0; 0 H/sqrt(t)] · D1 · [H/sqrt(t) 0; 0 I] · zeropad(D0 · x)
36/// ```
37///
38/// * `n` is the dimensionality of the input vector.
39/// * `m` is the desired output dimensionality.
40/// * `o = max(n, m)` is an intermediate dimension.
41/// * `t` is the largest power of 2 less than or equal to `o`.
42/// * `H` is a Hadamard Matrix of dimension `t`,
43/// * `I` is the identity matrix of dimension `n - t`
44/// * `D0` and `D1` are diagonal matrices with diagonal entries in `{-1, +1}` drawn
45///   uniformly at random with lengths `n` and `o` respectively.
46/// * `x` is the input vector of dimension `n`
47/// * `[A 0; 0 B]` represents a block diagonal matrix with blocks `A` and `B`.
48/// * `zeropad` indicates that the result `D0 · x` is zero-padded to the dimension `o` if
49///   needed.
50///
51/// As indicated above, if `o` is a power of two, then only a single transformation is
52/// applied. Further, if `o` exceeds `n`, then the input vector is zero padded at the end to
53/// `o` dimensions.
54#[derive(Debug)]
55#[cfg_attr(test, derive(PartialEq))]
56pub struct DoubleHadamard<A>
57where
58    A: Allocator,
59{
60    /// Vectors of `+/-1` used for to add randomness to the Hadamard transform
61    /// in each step.
62    ///
63    /// These are stored as a slice of `u32` where each value is either `0` or `0x8000_0000`,
64    /// corresponding to the sign-bit for an `f32` value, allowing sign flipping using
65    /// a cheap `xor` operation.
66    signs0: Poly<[u32], A>,
67    signs1: Poly<[u32], A>,
68
69    /// The target output dimension of the transformation.
70    target_dim: usize,
71
72    /// Optional array storing (in sorted order) the indices to sample if `target_dim < dim`
73    subsample: Option<Poly<[u32], A>>,
74}
75
76impl<A> DoubleHadamard<A>
77where
78    A: Allocator,
79{
80    /// Construct a new `DoubleHadamard` that transforms input vectors of dimension `dim`.
81    ///
82    /// The parameter `rng` is used to randomly initialize the diagonal matrices portion of
83    /// the transform.
84    ///
85    /// The following dimensionalities will be configured depending on the value of `target`:
86    ///
87    /// * [`TargetDim::Same`]
88    ///   - `self.input_dim() == dim.get()`
89    ///   - `self.output_dim() == dim.get()`
90    /// * [`TargetDim::Natural`]
91    ///   - `self.input_dim() == dim.get()`
92    ///   - `self.output_dim() == dim.get()`
93    /// * [`TargetDim::Override`]
94    ///   - `self.input_dim() == dim.get()`
95    ///   - `self.output_dim()`: The value provided by the override.
96    ///
97    /// Subsampling occurs if `self.output_dim()` is smaller than `self.input_dim()`.
98    pub fn new<R>(
99        dim: NonZeroUsize,
100        target_dim: TargetDim,
101        rng: &mut R,
102        allocator: A,
103    ) -> Result<Self, AllocatorError>
104    where
105        R: Rng + ?Sized,
106    {
107        let dim = dim.get();
108
109        let target_dim = match target_dim {
110            TargetDim::Override(target) => target.get(),
111            TargetDim::Same => dim,
112            TargetDim::Natural => dim,
113        };
114
115        // The intermediate dimension after applying the first transform.
116        //
117        // If `target_dim` exceeds `dim`, then we perform zero padding up to `target_dim`
118        // for this stage.
119        let intermediate_dim = dim.max(target_dim);
120
121        // Generate random signs for the diagonal matrices
122        let mut sample = |_: usize| {
123            let sign: bool = StandardUniform {}.sample(rng);
124            if sign {
125                0x8000_0000
126            } else {
127                0
128            }
129        };
130
131        // Since implicit zero padding is used for this stage, we only create space for
132        // `dim` values.
133        let signs0 = Poly::from_iter((0..dim).map(&mut sample), allocator.clone())?;
134        let signs1 = Poly::from_iter((0..intermediate_dim).map(&mut sample), allocator.clone())?;
135
136        let subsample = if dim > target_dim {
137            Some(subsample_indices(rng, dim, target_dim, allocator)?)
138        } else {
139            None
140        };
141
142        Ok(Self {
143            signs0,
144            signs1,
145            target_dim,
146            subsample,
147        })
148    }
149
150    pub fn try_from_parts(
151        signs0: Poly<[u32], A>,
152        signs1: Poly<[u32], A>,
153        subsample: Option<Poly<[u32], A>>,
154    ) -> Result<Self, DoubleHadamardError> {
155        type E = DoubleHadamardError;
156        if signs0.is_empty() {
157            return Err(E::Signs0Empty);
158        }
159        if signs1.len() < signs0.len() {
160            return Err(E::Signs1TooSmall);
161        }
162        if !signs0.iter().copied().all(is_sign) {
163            return Err(E::Signs0Invalid);
164        }
165        if !signs1.iter().copied().all(is_sign) {
166            return Err(E::Signs1Invalid);
167        }
168
169        // Some preliminary checks on `subsample` that must always hold if it is present.
170        let target_dim = if let Some(ref subsample) = subsample {
171            if !utils::is_strictly_monotonic(subsample.iter()) {
172                return Err(E::SubsampleNotMonotonic);
173            }
174
175            match subsample.last() {
176                Some(last) => {
177                    if *last as usize >= signs1.len() {
178                        // Since the entries in `subsample` are used to index an output
179                        // vector of lengths `signs1`, the last element must be strictly
180                        // less than this length.
181                        //
182                        // From the monotonicity check, we can therefore deduce that *all*
183                        // entries are in-bounds.
184                        return Err(E::LastSubsampleTooLarge);
185                    }
186                }
187                None => {
188                    // Subsample cannot be empty.
189                    return Err(E::InvalidSubsampleLength);
190                }
191            }
192
193            debug_assert!(
194                subsample.len() < signs1.len(),
195                "since we've verified monotonicity and the last element, this is implied"
196            );
197
198            subsample.len()
199        } else {
200            // With no subsampling, the target dim is the length of `signs1`.
201            signs1.len()
202        };
203
204        Ok(Self {
205            signs0,
206            signs1,
207            target_dim,
208            subsample,
209        })
210    }
211
212    /// Return the input dimension for the transformation.
213    pub fn input_dim(&self) -> usize {
214        self.signs0.len()
215    }
216
217    /// Return the output dimension for the transformation.
218    pub fn output_dim(&self) -> usize {
219        self.target_dim
220    }
221
222    /// Return whether or not the transform preserves norms.
223    ///
224    /// For this transform, norms are not preserved when the output dimensionality is less
225    /// than the input dimensionality.
226    pub fn preserves_norms(&self) -> bool {
227        self.subsample.is_none()
228    }
229
230    fn intermediate_dim(&self) -> usize {
231        self.input_dim().max(self.output_dim())
232    }
233
234    /// Perform the transformation of the `src` vector into the `dst` vector.
235    ///
236    /// # Errors
237    ///
238    /// Returns an error if
239    ///
240    /// * `src.len() != self.input_dim()`.
241    /// * `dst.len() != self.output_dim()`.
242    pub fn transform_into(
243        &self,
244        dst: &mut [f32],
245        src: &[f32],
246        allocator: ScopedAllocator<'_>,
247    ) -> Result<(), TransformFailed> {
248        check_dims(dst, src, self.input_dim(), self.output_dim())?;
249
250        // Copy and flip signs
251        let intermediate_dim = self.intermediate_dim();
252        let mut tmp = Poly::broadcast(0.0f32, intermediate_dim, allocator)?;
253
254        std::iter::zip(tmp.iter_mut(), src.iter())
255            .zip(self.signs0.iter())
256            .for_each(|((dst, src), sign)| *dst = f32::from_bits(src.to_bits() ^ sign));
257
258        let split = 1usize << (usize::BITS - intermediate_dim.leading_zeros() - 1);
259
260        // `split` is less than or equal to `tmp` and is a power to 2.
261        //
262        // If it is equal to the size of `tmp`, then we only run the first transform. Otherwise,
263        // we perform two transforms on the head and tail of `tmp`.
264        #[allow(clippy::unwrap_used)]
265        hadamard_transform(&mut tmp[..split]).unwrap();
266
267        // Apply the second transformation.
268        // Since random signs are applied to the intermediate value, the second transform
269        // does not undo the first.
270        tmp.iter_mut()
271            .zip(self.signs1.iter())
272            .for_each(|(dst, sign)| *dst = f32::from_bits(dst.to_bits() ^ sign));
273
274        #[allow(clippy::unwrap_used)]
275        hadamard_transform(&mut tmp[intermediate_dim - split..]).unwrap();
276
277        match self.subsample.as_ref() {
278            None => {
279                dst.copy_from_slice(&tmp);
280            }
281            Some(indices) => {
282                let rescale = ((tmp.len() as f32) / (indices.len() as f32)).sqrt();
283                debug_assert_eq!(dst.len(), indices.len());
284                dst.iter_mut()
285                    .zip(indices.iter())
286                    .for_each(|(d, s)| *d = tmp[*s as usize] * rescale);
287            }
288        }
289
290        Ok(())
291    }
292}
293
294impl<A> TryClone for DoubleHadamard<A>
295where
296    A: Allocator,
297{
298    fn try_clone(&self) -> Result<Self, AllocatorError> {
299        Ok(Self {
300            signs0: self.signs0.try_clone()?,
301            signs1: self.signs1.try_clone()?,
302            target_dim: self.target_dim,
303            subsample: self.subsample.try_clone()?,
304        })
305    }
306}
307
308#[derive(Debug, Clone, Copy, Error, PartialEq)]
309#[non_exhaustive]
310pub enum DoubleHadamardError {
311    #[error("first signs stage cannot be empty")]
312    Signs0Empty,
313    #[error("first signs stage has invalid coding")]
314    Signs0Invalid,
315
316    #[error("invalid sign representation for second stage")]
317    Signs1Invalid,
318    #[error("second sign stage must be at least as large as the first stage")]
319    Signs1TooSmall,
320
321    #[error("subsample length must equal `target_dim`")]
322    InvalidSubsampleLength,
323    #[error("subsample indices is not monotonic")]
324    SubsampleNotMonotonic,
325    #[error("last subsample index exceeded intermediate dim")]
326    LastSubsampleTooLarge,
327
328    #[error(transparent)]
329    AllocatorError(#[from] AllocatorError),
330}
331
332// Serialization
333#[cfg(feature = "flatbuffers")]
334impl<A> DoubleHadamard<A>
335where
336    A: Allocator,
337{
338    /// Pack into a [`crate::flatbuffers::transforms::DoubleHadamard`] serialized
339    /// represntation.
340    pub(crate) fn pack<'a, FA>(
341        &self,
342        buf: &mut FlatBufferBuilder<'a, FA>,
343    ) -> WIPOffset<fb::transforms::DoubleHadamard<'a>>
344    where
345        FA: flatbuffers::Allocator + 'a,
346    {
347        // Store the sign vectors.
348        let signs0 = buf.create_vector_from_iter(self.signs0.iter().copied().map(sign_to_bool));
349        let signs1 = buf.create_vector_from_iter(self.signs1.iter().copied().map(sign_to_bool));
350
351        // If subsample indices are present - save those as well.
352        let subsample = self
353            .subsample
354            .as_ref()
355            .map(|indices| buf.create_vector(indices));
356
357        fb::transforms::DoubleHadamard::create(
358            buf,
359            &fb::transforms::DoubleHadamardArgs {
360                signs0: Some(signs0),
361                signs1: Some(signs1),
362                subsample,
363            },
364        )
365    }
366
367    /// Attempt to unpack from a [`crate::flatbuffers::transforms::DoubleHadamard`]
368    /// serialized representation, returning any error if encountered.
369    pub(crate) fn try_unpack(
370        alloc: A,
371        proto: fb::transforms::DoubleHadamard<'_>,
372    ) -> Result<Self, DoubleHadamardError> {
373        let signs0 = Poly::from_iter(proto.signs0().iter().map(bool_to_sign), alloc.clone())?;
374        let signs1 = Poly::from_iter(proto.signs1().iter().map(bool_to_sign), alloc.clone())?;
375
376        let subsample = match proto.subsample() {
377            Some(subsample) => Some(Poly::from_iter(subsample.into_iter(), alloc)?),
378            None => None,
379        };
380
381        Self::try_from_parts(signs0, signs1, subsample)
382    }
383}
384
385///////////
386// Tests //
387///////////
388
389#[cfg(test)]
390mod tests {
391    use diskann_utils::lazy_format;
392    use rand::{rngs::StdRng, SeedableRng};
393
394    use super::*;
395    use crate::{
396        algorithms::transforms::{test_utils, Transform, TransformKind},
397        alloc::GlobalAllocator,
398    };
399
400    test_utils::delegate_transformer!(DoubleHadamard<GlobalAllocator>);
401
402    #[test]
403    fn test_double_hadamard() {
404        // Inner product computations are more susceptible to floating point error.
405        // Instead of using ULP here, we fall back to using absolute and relative error.
406        //
407        // These error bounds are for when we set the output dimenion to a power of 2 that
408        // is higher than input dimension.
409        let natural_errors = test_utils::ErrorSetup {
410            norm: test_utils::Check::ulp(5),
411            l2: test_utils::Check::ulp(5),
412            ip: test_utils::Check::absrel(2.5e-5, 2e-4),
413        };
414
415        // NOTE: Subsampling introduces high variance in the norm and L2, so our error
416        // bounds need to be looser.
417        //
418        // Subsampling results in poor preservation of inner products, so we skip it
419        // altogether.
420        let subsampled_errors = test_utils::ErrorSetup {
421            norm: test_utils::Check::absrel(0.0, 2e-2),
422            l2: test_utils::Check::absrel(0.0, 2e-2),
423            ip: test_utils::Check::skip(),
424        };
425
426        let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
427        let dim_combos = [
428            // Natural
429            (15, 15, true, TargetDim::Same, &natural_errors),
430            (15, 15, true, TargetDim::Natural, &natural_errors),
431            (16, 16, true, TargetDim::Same, &natural_errors),
432            (16, 16, true, TargetDim::Natural, &natural_errors),
433            (256, 256, true, TargetDim::Same, &natural_errors),
434            (1000, 1000, true, TargetDim::Same, &natural_errors),
435            // Larger
436            (15, 16, true, target_dim(16), &natural_errors),
437            (100, 128, true, target_dim(128), &natural_errors),
438            (15, 32, true, target_dim(32), &natural_errors),
439            (16, 64, true, target_dim(64), &natural_errors),
440            // Sub-Sampling.
441            (1024, 1023, false, target_dim(1023), &subsampled_errors),
442            (1000, 999, false, target_dim(999), &subsampled_errors),
443        ];
444
445        let trials_per_combo = 20;
446        let trials_per_dim = 100;
447
448        let mut rng = StdRng::seed_from_u64(0x6d1699abe066147);
449        for (input, output, preserves_norms, target, errors) in dim_combos {
450            let input_nz = NonZeroUsize::new(input).unwrap();
451            for trial in 0..trials_per_combo {
452                let ctx = &lazy_format!(
453                    "input dim = {}, output dim = {}, macro trial {} of {}",
454                    input,
455                    output,
456                    trial,
457                    trials_per_combo
458                );
459
460                let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
461                    let d = input.min(output);
462                    assert_ne!(&io.input0[..d], &io.output0[..d]);
463                    assert_ne!(&io.input1[..d], &io.output1[..d]);
464                    test_utils::check_errors(io, context, errors);
465                };
466
467                // Clone the Rng state so the abstract transform behaves the same.
468                let mut rng_clone = rng.clone();
469
470                // Test the underlying transformer.
471                {
472                    let transformer = DoubleHadamard::new(
473                        NonZeroUsize::new(input).unwrap(),
474                        target,
475                        &mut rng,
476                        GlobalAllocator,
477                    )
478                    .unwrap();
479
480                    assert_eq!(transformer.input_dim(), input);
481                    assert_eq!(transformer.output_dim(), output);
482                    assert_eq!(transformer.preserves_norms(), preserves_norms);
483
484                    test_utils::test_transform(
485                        &transformer,
486                        trials_per_dim,
487                        &mut checker,
488                        &mut rng,
489                        ctx,
490                    )
491                }
492
493                // Abstract Transformer
494                {
495                    let kind = TransformKind::DoubleHadamard { target_dim: target };
496                    let transformer =
497                        Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
498                            .unwrap();
499
500                    assert_eq!(transformer.input_dim(), input);
501                    assert_eq!(transformer.output_dim(), output);
502                    assert_eq!(transformer.preserves_norms(), preserves_norms);
503
504                    test_utils::test_transform(
505                        &transformer,
506                        trials_per_dim,
507                        &mut checker,
508                        &mut rng_clone,
509                        ctx,
510                    )
511                }
512            }
513        }
514    }
515
516    #[cfg(feature = "flatbuffers")]
517    mod serialization {
518        use super::*;
519        use crate::flatbuffers::to_flatbuffer;
520
521        #[test]
522        fn double_hadamard() {
523            let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
524            let alloc = GlobalAllocator;
525
526            // Test various dimension combinations
527            let test_cases = [
528                // No sub or super sampling
529                (5, TargetDim::Same),
530                (8, TargetDim::Same),
531                (10, TargetDim::Natural),
532                (16, TargetDim::Natural),
533                // Super sampling with both stages
534                (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
535                (10, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
536                // Super sample with one stage
537                (15, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
538                (16, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
539                (15, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
540                (16, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
541                // Sub sampling.
542                (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
543                (16, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
544            ];
545
546            for (dim, target_dim) in test_cases {
547                let transform = DoubleHadamard::new(
548                    NonZeroUsize::new(dim).unwrap(),
549                    target_dim,
550                    &mut rng,
551                    alloc,
552                )
553                .unwrap();
554                let data = to_flatbuffer(|buf| transform.pack(buf));
555
556                let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
557                let reloaded = DoubleHadamard::try_unpack(alloc, proto).unwrap();
558
559                assert_eq!(transform, reloaded);
560            }
561
562            let gen_err = |x: DoubleHadamard<_>| -> DoubleHadamardError {
563                let data = to_flatbuffer(|buf| x.pack(buf));
564                let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
565                DoubleHadamard::try_unpack(alloc, proto).unwrap_err()
566            };
567
568            type E = DoubleHadamardError;
569            let error_cases = [
570                // Signs1TooSmall: signs0.len() > signs1.len()
571                (
572                    vec![0, 0, 0, 0, 0], // 5 elements
573                    vec![0, 0, 0, 0],    // 4 elements
574                    4,
575                    None,
576                    E::Signs1TooSmall,
577                ),
578                // Signs0Empty
579                (
580                    vec![], // empty signs0
581                    vec![0, 0, 0, 0],
582                    4,
583                    None,
584                    E::Signs0Empty,
585                ),
586                // SubsampleNotMonotonic: subsample indices not in increasing order
587                (
588                    vec![0, 0, 0, 0],
589                    vec![0, 0, 0, 0],
590                    3,
591                    Some(vec![0, 2, 1]), // not monotonic
592                    E::SubsampleNotMonotonic,
593                ),
594                // SubsampleNotMonotonic: duplicate values
595                (
596                    vec![0, 0, 0, 0],
597                    vec![0, 0, 0, 0],
598                    3,
599                    Some(vec![0, 1, 1]), // duplicate values
600                    E::SubsampleNotMonotonic,
601                ),
602                // LastSubsampleTooLarge: exceeds intermediate dim with signs1
603                (
604                    vec![0, 0, 0], // 3 elements
605                    vec![0, 0, 0], // 3 elements
606                    2,
607                    Some(vec![0, 3]), // index 3 >= intermediate_dim(3,2) = 3
608                    E::LastSubsampleTooLarge,
609                ),
610                // LastSubsampleTooLarge: exceeds intermediate dim with signs1
611                (
612                    vec![0, 0, 0], // 3 elements
613                    vec![0, 0, 0], // 3 elements
614                    2,
615                    Some(vec![]), // empty
616                    E::InvalidSubsampleLength,
617                ),
618            ];
619
620            let poly = |v: &Vec<u32>| Poly::from_iter(v.iter().copied(), alloc).unwrap();
621
622            for (signs0, signs1, target_dim, subsample, expected) in error_cases.iter() {
623                println!(
624                    "on case ({:?}, {:?}, {}, {:?})",
625                    signs0, signs1, target_dim, subsample,
626                );
627                let err = gen_err(DoubleHadamard {
628                    signs0: poly(signs0),
629                    signs1: poly(signs1),
630                    target_dim: *target_dim,
631                    subsample: subsample.as_ref().map(poly),
632                });
633
634                assert_eq!(
635                    err, *expected,
636                    "failed for case ({:?}, {:?}, {}, {:?})",
637                    signs0, signs1, target_dim, subsample
638                );
639            }
640        }
641    }
642}