Skip to main content

diskann_quantization/algorithms/transforms/
padding_hadamard.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use rand::{
11    distr::{Distribution, StandardUniform},
12    Rng,
13};
14use thiserror::Error;
15
16#[cfg(feature = "flatbuffers")]
17use super::utils::{bool_to_sign, sign_to_bool};
18use super::{
19    utils::{check_dims, is_sign, subsample_indices, TransformFailed},
20    TargetDim,
21};
22#[cfg(feature = "flatbuffers")]
23use crate::flatbuffers as fb;
24use crate::{
25    algorithms::hadamard_transform,
26    alloc::{Allocator, AllocatorError, Poly, ScopedAllocator, TryClone},
27    utils,
28};
29
30/// A Hadamard transform that zero pads non-power-of-two dimensions to the next power of two.
31///
32/// This struct performs the transformation
33/// ```math
34/// HDx / sqrt(n)
35/// ```
36/// where
37///
38/// * `H` is a Hadamard Matrix
39/// * `D` is a diagonal matrix with diagonal entries in `{-1, +1}`.
40/// * `x` is the vector to transform, zero padded to have a length that is a multiple of two.
41/// * `n` is the output-dimension.
42#[derive(Debug)]
43#[cfg_attr(test, derive(PartialEq))]
44pub struct PaddingHadamard<A>
45where
46    A: Allocator,
47{
48    /// A vector of `+/-1` used to add randomness to the Hadamard transform.
49    ///
50    /// This is stored as a `Vec<u32>` instead of something more representative like a
51    /// `Vec<bool>` because we store the sign-bits for the `f32` representation explicitly
52    /// to turn sign flipping into a cheap `xor` operation.
53    ///
54    /// An internal invariant is that each value is either `0` or `0x8000_0000`.
55    ///
56    /// # Details
57    ///
58    /// On `x86` - a vectorized `xor` has a latency of 1 and a clocks-per-instruction (CPI)
59    /// of 0.333 where-as a `f32` multiply has a latency of 4 and a CPI 0.5.
60    signs: Poly<[u32], A>,
61
62    /// The padded-up dimension pre-rotation. This should always be a power of two and
63    /// greater than `signs`.
64    padded_dim: usize,
65
66    /// Indices of sub-sampled points. This should be sorted to provide more efficient
67    /// memory access. If `None`, then no subsampling is performed.
68    subsample: Option<Poly<[u32], A>>,
69}
70
71impl<A> PaddingHadamard<A>
72where
73    A: Allocator,
74{
75    /// Construct a new `PaddingHadamard` that transforms input vectors of dimension `dim`.
76    ///
77    /// The parameter `rng` is used to randomly initialize the diagonal matrix portion of
78    /// the transform.
79    ///
80    /// The following dimensionalities will be configured depending on the value of `target`:
81    ///
82    /// * [`TargetDim::Same`]
83    ///   - `self.input_dim() == dim.get()`
84    ///   - `self.output_dim() == dim.get()`
85    /// * [`TargetDim::Natural`]
86    ///   - `self.input_dim() == dim.get()`
87    ///   - `self.output_dim() == dim.get().next_power_of_two()`
88    /// * [`TargetDim::Override`]
89    ///   - `self.input_dim() == dim.get()`
90    ///   - `self.output_dim()`: The value provided by the override.
91    ///
92    /// Subsampling occurs if `self.output_dim()` is not a power of two and greater-than or
93    /// equal to `self.input_dim()`.
94    pub fn new<R>(
95        dim: NonZeroUsize,
96        target: TargetDim,
97        rng: &mut R,
98        allocator: A,
99    ) -> Result<Self, AllocatorError>
100    where
101        R: Rng + ?Sized,
102    {
103        let signs = Poly::from_iter(
104            (0..dim.get()).map(|_| {
105                let sign: bool = StandardUniform {}.sample(rng);
106                if sign {
107                    0x8000_0000
108                } else {
109                    0
110                }
111            }),
112            allocator.clone(),
113        )?;
114
115        let (padded_dim, target_dim) = match target {
116            TargetDim::Same => (dim.get().next_power_of_two(), dim.get()),
117            TargetDim::Natural => {
118                let next = dim.get().next_power_of_two();
119                (next, next)
120            }
121            TargetDim::Override(target) => {
122                (target.max(dim).get().next_power_of_two(), target.get())
123            }
124        };
125
126        let subsample = if padded_dim > target_dim {
127            Some(subsample_indices(rng, padded_dim, target_dim, allocator)?)
128        } else {
129            None
130        };
131
132        Ok(Self {
133            signs,
134            padded_dim,
135            subsample,
136        })
137    }
138
139    /// Construct `Self` from constituent parts. This validates that the necessary
140    /// invariants hold for the constituent parts, returning an error if they do not.
141    pub fn try_from_parts(
142        signs: Poly<[u32], A>,
143        padded_dim: usize,
144        subsample: Option<Poly<[u32], A>>,
145    ) -> Result<Self, PaddingHadamardError> {
146        if !signs.iter().copied().all(is_sign) {
147            return Err(PaddingHadamardError::InvalidSignRepresentation);
148        }
149
150        if signs.len() > padded_dim {
151            return Err(PaddingHadamardError::SignsTooLong);
152        }
153
154        if !padded_dim.is_power_of_two() {
155            return Err(PaddingHadamardError::DimNotPowerOfTwo);
156        }
157
158        if let Some(ref subsample) = subsample {
159            if !utils::is_strictly_monotonic(subsample.iter()) {
160                return Err(PaddingHadamardError::SubsampleNotMonotonic);
161            }
162
163            if let Some(last) = subsample.last() {
164                if *last as usize >= padded_dim {
165                    return Err(PaddingHadamardError::LastSubsampleTooLarge);
166                }
167            } else {
168                return Err(PaddingHadamardError::SubsampleEmpty);
169            }
170        }
171
172        Ok(Self {
173            signs,
174            padded_dim,
175            subsample,
176        })
177    }
178
179    /// Return the input dimension for the transformation.
180    pub fn input_dim(&self) -> usize {
181        self.signs.len()
182    }
183
184    /// Return the output dimension for the transformation.
185    pub fn output_dim(&self) -> usize {
186        match &self.subsample {
187            None => self.padded_dim,
188            Some(v) => v.len(),
189        }
190    }
191
192    /// Return whether or not the transform preserves norms.
193    ///
194    /// For this transform, norms are not preserved when the output dimensionality is not a
195    /// power of two greater than or equal to the input dimensionality.
196    pub fn preserves_norms(&self) -> bool {
197        self.subsample.is_none()
198    }
199
200    /// An internal helper for performing the sign flipping operation.
201    //A
202    /// # Preconditions
203    ///
204    /// This function requires (but only checks in debug build) the following pre-conditions
205    ///
206    /// * `src.len() == self.input_dim()`.
207    /// * `dst.len() == self.output_dim()`.
208    fn copy_and_flip_signs(&self, dst: &mut [f32], src: &[f32]) {
209        debug_assert_eq!(dst.len(), self.padded_dim);
210        debug_assert_eq!(src.len(), self.input_dim());
211
212        // Copy the sign bits.
213        std::iter::zip(dst.iter_mut(), src.iter())
214            .zip(self.signs.iter())
215            .for_each(|((dst, src), sign)| *dst = f32::from_bits(src.to_bits() ^ sign));
216
217        // Pad the rest to zero.
218        dst.iter_mut()
219            .skip(self.input_dim())
220            .for_each(|dst| *dst = 0.0);
221    }
222
223    /// Perform the transformation of the `src` vector into the `dst` vector.
224    ///
225    /// # Errors
226    ///
227    /// Returns an error if
228    ///
229    /// * `src.len() != self.input_dim()`.
230    /// * `dst.len() != self.output_dim()`.
231    pub fn transform_into(
232        &self,
233        dst: &mut [f32],
234        src: &[f32],
235        allocator: ScopedAllocator<'_>,
236    ) -> Result<(), TransformFailed> {
237        let input_dim = self.input_dim();
238        let output_dim = self.output_dim();
239        check_dims(dst, src, input_dim, output_dim)?;
240
241        // If we are not sub-sampling, then we can transform directly into the destination.
242        match &self.subsample {
243            None => {
244                // Copy over values from `src`, applying the sign flipping.
245                self.copy_and_flip_signs(dst, src);
246
247                // Lint: We satisfy the pre-condidions for `hadamard_transform` because:
248                //
249                // 1. `output_dim` is a power of 2 by construction.
250                // 2. We've checked that `dst.len() == output_dim`.
251                #[allow(clippy::unwrap_used)]
252                hadamard_transform(dst).unwrap();
253            }
254            Some(indices) => {
255                let mut tmp = Poly::broadcast(0.0f32, self.padded_dim, allocator)?;
256
257                self.copy_and_flip_signs(&mut tmp, src);
258
259                // Lint: We satisfy the pre-condidions for `hadamard_transform` because:
260                //
261                // 1. `padded_dim` is a power of 2 by construction.
262                // 2. We've checked that `tmp.len() == padded_dim`.
263                #[allow(clippy::unwrap_used)]
264                hadamard_transform(&mut tmp).unwrap();
265
266                let rescale = ((tmp.len() as f32) / (indices.len() as f32)).sqrt();
267                debug_assert_eq!(dst.len(), indices.len());
268                std::iter::zip(dst.iter_mut(), indices.iter()).for_each(
269                    |(d, i): (&mut f32, &u32)| {
270                        *d = tmp[*i as usize] * rescale;
271                    },
272                );
273            }
274        }
275
276        Ok(())
277    }
278}
279
280impl<A> TryClone for PaddingHadamard<A>
281where
282    A: Allocator,
283{
284    fn try_clone(&self) -> Result<Self, AllocatorError> {
285        Ok(Self {
286            signs: self.signs.try_clone()?,
287            padded_dim: self.padded_dim,
288            subsample: self.subsample.try_clone()?,
289        })
290    }
291}
292
293/// Errors that may occur while constructing a [`PaddingHadamard`] from constituent parts.
294#[derive(Debug, Clone, Copy, Error, PartialEq)]
295#[non_exhaustive]
296pub enum PaddingHadamardError {
297    #[error("an invalid sign representation was discovered")]
298    InvalidSignRepresentation,
299    #[error("`signs` length exceeds `padded_dim`")]
300    SignsTooLong,
301    #[error("padded dim is not a power of two")]
302    DimNotPowerOfTwo,
303    #[error("subsample indices cannot be empty")]
304    SubsampleEmpty,
305    #[error("subsample indices is not monotonic")]
306    SubsampleNotMonotonic,
307    #[error("last subsample index exceeded `padded_dim`")]
308    LastSubsampleTooLarge,
309    #[error(transparent)]
310    AllocatorError(#[from] AllocatorError),
311}
312
313#[cfg(feature = "flatbuffers")]
314impl<A> PaddingHadamard<A>
315where
316    A: Allocator,
317{
318    /// Pack into a [`crate::flatbuffers::transforms::PaddingHadamard`] serialized representation.
319    pub(crate) fn pack<'a, FA>(
320        &self,
321        buf: &mut FlatBufferBuilder<'a, FA>,
322    ) -> WIPOffset<fb::transforms::PaddingHadamard<'a>>
323    where
324        FA: flatbuffers::Allocator + 'a,
325    {
326        // First, pack the sign bits as boolean values.
327        let signs = buf.create_vector_from_iter(self.signs.iter().copied().map(sign_to_bool));
328
329        // If subsample indices are present - save those as well.
330        let subsample = self
331            .subsample
332            .as_ref()
333            .map(|indices| buf.create_vector(indices));
334
335        // Finish up.
336        fb::transforms::PaddingHadamard::create(
337            buf,
338            &fb::transforms::PaddingHadamardArgs {
339                signs: Some(signs),
340                padded_dim: self.padded_dim as u32,
341                subsample,
342            },
343        )
344    }
345
346    /// Attempt to unpack from a [`crate::flatbuffers::transforms::PaddingHadamard`]
347    /// serialized representation, returning any error if encountered.
348    pub(crate) fn try_unpack(
349        alloc: A,
350        proto: fb::transforms::PaddingHadamard<'_>,
351    ) -> Result<Self, PaddingHadamardError> {
352        let signs = Poly::from_iter(proto.signs().iter().map(bool_to_sign), alloc.clone())?;
353
354        let subsample = match proto.subsample() {
355            Some(subsample) => Some(Poly::from_iter(subsample.into_iter(), alloc)?),
356            None => None,
357        };
358
359        Self::try_from_parts(signs, proto.padded_dim() as usize, subsample)
360    }
361}
362
363///////////
364// Tests //
365///////////
366
367#[cfg(test)]
368mod tests {
369    use diskann_utils::lazy_format;
370    use rand::{rngs::StdRng, SeedableRng};
371
372    use super::*;
373    use crate::{
374        algorithms::transforms::{test_utils, Transform, TransformKind},
375        alloc::GlobalAllocator,
376    };
377
378    // Since we use a slightly non-obvious strategy for applying the +/-1 permutation, we
379    // test its behavior explicitly.
380    #[test]
381    fn test_sign_flipping() {
382        let mut rng = StdRng::seed_from_u64(0xf8ee12b1e9f33dbd);
383        let dim = 14;
384
385        let transform = PaddingHadamard::new(
386            NonZeroUsize::new(dim).unwrap(),
387            TargetDim::Same,
388            &mut rng,
389            GlobalAllocator,
390        )
391        .unwrap();
392
393        assert_eq!(transform.input_dim(), dim);
394        assert_eq!(transform.output_dim(), dim);
395
396        let positive = vec![1.0f32; dim];
397        let negative = vec![-1.0f32; dim];
398
399        let mut output = vec![f32::INFINITY; 16];
400
401        // Transform positive numbers
402        transform.copy_and_flip_signs(&mut output, &positive);
403
404        let mut unflipped = 0;
405        let mut flipped = 0;
406        std::iter::zip(output.iter(), transform.signs.iter())
407            .enumerate()
408            .for_each(|(i, (o, s))| {
409                if *s == 0x8000_0000 {
410                    flipped += 1;
411                    assert_eq!(*o, -1.0, "expected entry {} to be flipped", i);
412                } else {
413                    unflipped += 1;
414                    assert_eq!(*o, 1.0, "expected entry {} to be unchanged", i);
415                }
416            });
417
418        // Check that we have a mixture of flipped and unflipped signs.
419        assert!(unflipped > 0);
420        assert!(flipped > 0);
421
422        // Assert that everything else was zero padded.
423        assert_eq!(output[14], 0.0f32);
424        assert_eq!(output[15], 0.0f32);
425
426        // Transform negative numbers
427        output.fill(f32::INFINITY);
428        transform.copy_and_flip_signs(&mut output, &negative);
429        std::iter::zip(output.iter(), transform.signs.iter())
430            .enumerate()
431            .for_each(|(i, (o, s))| {
432                if *s == 0x8000_0000 {
433                    assert_eq!(*o, 1.0, "expected entry {} to be flipped", i);
434                } else {
435                    assert_eq!(*o, -1.0, "expected entry {} to be unchanged", i);
436                }
437            });
438
439        // Assert that everything else was zero padded.
440        assert_eq!(output[14], 0.0f32);
441        assert_eq!(output[15], 0.0f32);
442    }
443
444    test_utils::delegate_transformer!(PaddingHadamard<GlobalAllocator>);
445
446    // This tests the natural hadamard transform where the output dimension is upgraded
447    // to the next power of 2.
448    #[test]
449    fn test_padding_hadamard() {
450        // Inner product computations are more susceptible to floating point error.
451        // Instead of using ULP here, we fall back to using absolute and relative error.
452        //
453        // These error bounds are for when we set the output dimenion to a power of 2 that
454        // is higher than input dimension.
455        let natural_errors = test_utils::ErrorSetup {
456            norm: test_utils::Check::ulp(4),
457            l2: test_utils::Check::ulp(4),
458            ip: test_utils::Check::absrel(5.0e-6, 2e-4),
459        };
460
461        // NOTE: Subsampling introduces high variance in the norm and L2, so our error
462        // bounds need to be looser.
463        //
464        // Subsampling results in poor preservation of inner products, so we skip it
465        // altogether.
466        let subsampled_errors = test_utils::ErrorSetup {
467            norm: test_utils::Check::absrel(0.0, 1e-1),
468            l2: test_utils::Check::absrel(0.0, 1e-1),
469            ip: test_utils::Check::skip(),
470        };
471
472        let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
473
474        let dim_combos = [
475            // Natural
476            (15, 16, true, target_dim(16), &natural_errors),
477            (15, 16, true, TargetDim::Natural, &natural_errors),
478            (16, 16, true, TargetDim::Same, &natural_errors),
479            (16, 16, true, TargetDim::Natural, &natural_errors),
480            (16, 32, true, target_dim(32), &natural_errors),
481            (16, 64, true, target_dim(64), &natural_errors),
482            (100, 128, true, target_dim(128), &natural_errors),
483            (100, 128, true, TargetDim::Natural, &natural_errors),
484            (256, 256, true, target_dim(256), &natural_errors),
485            // Subsampled,
486            (1000, 1000, false, TargetDim::Same, &subsampled_errors),
487            (500, 1000, false, target_dim(1000), &subsampled_errors),
488        ];
489
490        let trials_per_combo = 20;
491        let trials_per_dim = 100;
492
493        let mut rng = StdRng::seed_from_u64(0x6d1699abe0626147);
494        for (input, output, preserves_norms, target, errors) in dim_combos {
495            let input_nz = NonZeroUsize::new(input).unwrap();
496            for trial in 0..trials_per_combo {
497                let ctx = lazy_format!(
498                    "input dim = {}, output dim = {}, macro trial {} of {}",
499                    input,
500                    output,
501                    trial,
502                    trials_per_combo
503                );
504
505                let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
506                    assert_ne!(io.input0, &io.output0[..input]);
507                    assert_ne!(io.input1, &io.output1[..input]);
508                    test_utils::check_errors(io, context, errors);
509                };
510
511                // Clone the Rng state so the abstract transform behaves the same.
512                let mut rng_clone = rng.clone();
513
514                // Base Transformer
515                {
516                    let transformer = PaddingHadamard::new(
517                        NonZeroUsize::new(input).unwrap(),
518                        target,
519                        &mut rng,
520                        GlobalAllocator,
521                    )
522                    .unwrap();
523
524                    assert_eq!(transformer.input_dim(), input);
525                    assert_eq!(transformer.output_dim(), output);
526                    assert_eq!(transformer.preserves_norms(), preserves_norms);
527
528                    test_utils::test_transform(
529                        &transformer,
530                        trials_per_dim,
531                        &mut checker,
532                        &mut rng,
533                        &ctx,
534                    )
535                }
536
537                // Abstract Transformer
538                {
539                    let kind = TransformKind::PaddingHadamard { target_dim: target };
540                    let transformer =
541                        Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
542                            .unwrap();
543
544                    assert_eq!(transformer.input_dim(), input);
545                    assert_eq!(transformer.output_dim(), output);
546                    assert_eq!(transformer.preserves_norms(), preserves_norms);
547
548                    test_utils::test_transform(
549                        &transformer,
550                        trials_per_dim,
551                        &mut checker,
552                        &mut rng_clone,
553                        &ctx,
554                    )
555                }
556            }
557        }
558    }
559
560    #[cfg(feature = "flatbuffers")]
561    mod serialization {
562        use super::*;
563        use crate::{flatbuffers::to_flatbuffer, poly};
564
565        #[test]
566        fn padding_hadamard() {
567            let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
568            let alloc = GlobalAllocator;
569
570            // Test various dimension combinations
571            let test_cases = [
572                (5, TargetDim::Same),
573                (10, TargetDim::Natural),
574                (16, TargetDim::Natural),
575                (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
576                (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
577            ];
578
579            for (dim, target_dim) in test_cases {
580                let transform = PaddingHadamard::new(
581                    NonZeroUsize::new(dim).unwrap(),
582                    target_dim,
583                    &mut rng,
584                    alloc,
585                )
586                .unwrap();
587                let data = to_flatbuffer(|buf| transform.pack(buf));
588
589                let proto = flatbuffers::root::<fb::transforms::PaddingHadamard>(&data).unwrap();
590                let reloaded = PaddingHadamard::try_unpack(alloc, proto).unwrap();
591
592                assert_eq!(transform, reloaded);
593            }
594
595            let gen_err = |x: PaddingHadamard<_>| -> PaddingHadamardError {
596                let data = to_flatbuffer(|buf| x.pack(buf));
597                let proto = flatbuffers::root::<fb::transforms::PaddingHadamard>(&data).unwrap();
598                PaddingHadamard::try_unpack(alloc, proto).unwrap_err()
599            };
600
601            // Signs too longs.
602            {
603                let err = gen_err(PaddingHadamard {
604                    signs: poly!([0, 0, 0, 0, 0], alloc).unwrap(), // longer than `padded_dim`.
605                    padded_dim: 4,
606                    subsample: None,
607                });
608
609                assert_eq!(err, PaddingHadamardError::SignsTooLong);
610            }
611
612            // Dim Not a power of 2.
613            {
614                let err = gen_err(PaddingHadamard {
615                    signs: poly!([0, 0, 0, 0, 0], alloc).unwrap(),
616                    padded_dim: 5, // not a power of 2
617                    subsample: None,
618                });
619
620                assert_eq!(err, PaddingHadamardError::DimNotPowerOfTwo);
621            }
622
623            // Subsample empty
624            {
625                let err = gen_err(PaddingHadamard {
626                    signs: poly!([0, 0, 0, 0], alloc).unwrap(),
627                    padded_dim: 4,
628                    subsample: Some(poly!([], alloc).unwrap()), // empty
629                });
630
631                assert_eq!(err, PaddingHadamardError::SubsampleEmpty);
632            }
633
634            // Not monotonic
635            {
636                let err = gen_err(PaddingHadamard {
637                    signs: poly!([0, 0, 0, 0], alloc).unwrap(),
638                    padded_dim: 4,
639                    subsample: Some(poly!([0, 2, 2], alloc).unwrap()), // not monotonic
640                });
641                assert_eq!(err, PaddingHadamardError::SubsampleNotMonotonic);
642            }
643
644            // Subsample too long.
645            {
646                let err = gen_err(PaddingHadamard {
647                    signs: poly!([0, 0, 0, 0], alloc).unwrap(),
648                    padded_dim: 4,
649                    subsample: Some(poly!([0, 1, 2, 3, 4], alloc).unwrap()),
650                });
651
652                assert_eq!(err, PaddingHadamardError::LastSubsampleTooLarge);
653            }
654
655            // Subsample too large
656            {
657                let err = gen_err(PaddingHadamard {
658                    signs: poly!([0, 0, 0, 0], alloc).unwrap(),
659                    padded_dim: 4,
660                    subsample: Some(poly!([0, 4], alloc).unwrap()),
661                });
662
663                assert_eq!(err, PaddingHadamardError::LastSubsampleTooLarge);
664            }
665        }
666    }
667}