Skip to main content

diskann_quantization/algorithms/transforms/
padding_hadamard.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use rand::{
11    Rng,
12    distr::{Distribution, StandardUniform},
13};
14use thiserror::Error;
15
16#[cfg(feature = "flatbuffers")]
17use super::utils::{bool_to_sign, sign_to_bool};
18use super::{
19    TargetDim,
20    utils::{TransformFailed, check_dims, is_sign, subsample_indices},
21};
22#[cfg(feature = "flatbuffers")]
23use crate::flatbuffers as fb;
24use crate::{
25    algorithms::hadamard_transform,
26    alloc::{Allocator, AllocatorError, Poly, ScopedAllocator, TryClone},
27    utils,
28};
29
30/// A Hadamard transform that zero pads non-power-of-two dimensions to the next power of two.
31///
32/// This struct performs the transformation
33/// ```math
34/// HDx / sqrt(n)
35/// ```
36/// where
37///
38/// * `H` is a Hadamard Matrix
39/// * `D` is a diagonal matrix with diagonal entries in `{-1, +1}`.
40/// * `x` is the vector to transform, zero padded to have a length that is a multiple of two.
41/// * `n` is the output-dimension.
42#[derive(Debug)]
43#[cfg_attr(test, derive(PartialEq))]
44pub struct PaddingHadamard<A>
45where
46    A: Allocator,
47{
48    /// A vector of `+/-1` used to add randomness to the Hadamard transform.
49    ///
50    /// This is stored as a `Vec<u32>` instead of something more representative like a
51    /// `Vec<bool>` because we store the sign-bits for the `f32` representation explicitly
52    /// to turn sign flipping into a cheap `xor` operation.
53    ///
54    /// An internal invariant is that each value is either `0` or `0x8000_0000`.
55    ///
56    /// # Details
57    ///
58    /// On `x86` - a vectorized `xor` has a latency of 1 and a clocks-per-instruction (CPI)
59    /// of 0.333 where-as a `f32` multiply has a latency of 4 and a CPI 0.5.
60    signs: Poly<[u32], A>,
61
62    /// The padded-up dimension pre-rotation. This should always be a power of two and
63    /// greater than `signs`.
64    padded_dim: usize,
65
66    /// Indices of sub-sampled points. This should be sorted to provide more efficient
67    /// memory access. If `None`, then no subsampling is performed.
68    subsample: Option<Poly<[u32], A>>,
69}
70
71impl<A> PaddingHadamard<A>
72where
73    A: Allocator,
74{
75    /// Construct a new `PaddingHadamard` that transforms input vectors of dimension `dim`.
76    ///
77    /// The parameter `rng` is used to randomly initialize the diagonal matrix portion of
78    /// the transform.
79    ///
80    /// The following dimensionalities will be configured depending on the value of `target`:
81    ///
82    /// * [`TargetDim::Same`]
83    ///   - `self.input_dim() == dim.get()`
84    ///   - `self.output_dim() == dim.get()`
85    /// * [`TargetDim::Natural`]
86    ///   - `self.input_dim() == dim.get()`
87    ///   - `self.output_dim() == dim.get().next_power_of_two()`
88    /// * [`TargetDim::Override`]
89    ///   - `self.input_dim() == dim.get()`
90    ///   - `self.output_dim()`: The value provided by the override.
91    ///
92    /// Subsampling occurs if `self.output_dim()` is not a power of two and greater-than or
93    /// equal to `self.input_dim()`.
94    pub fn new<R>(
95        dim: NonZeroUsize,
96        target: TargetDim,
97        rng: &mut R,
98        allocator: A,
99    ) -> Result<Self, AllocatorError>
100    where
101        R: Rng + ?Sized,
102    {
103        let signs = Poly::from_iter(
104            (0..dim.get()).map(|_| {
105                let sign: bool = StandardUniform {}.sample(rng);
106                if sign { 0x8000_0000 } else { 0 }
107            }),
108            allocator.clone(),
109        )?;
110
111        let (padded_dim, target_dim) = match target {
112            TargetDim::Same => (dim.get().next_power_of_two(), dim.get()),
113            TargetDim::Natural => {
114                let next = dim.get().next_power_of_two();
115                (next, next)
116            }
117            TargetDim::Override(target) => {
118                (target.max(dim).get().next_power_of_two(), target.get())
119            }
120        };
121
122        let subsample = if padded_dim > target_dim {
123            Some(subsample_indices(rng, padded_dim, target_dim, allocator)?)
124        } else {
125            None
126        };
127
128        Ok(Self {
129            signs,
130            padded_dim,
131            subsample,
132        })
133    }
134
135    /// Construct `Self` from constituent parts. This validates that the necessary
136    /// invariants hold for the constituent parts, returning an error if they do not.
137    pub fn try_from_parts(
138        signs: Poly<[u32], A>,
139        padded_dim: usize,
140        subsample: Option<Poly<[u32], A>>,
141    ) -> Result<Self, PaddingHadamardError> {
142        if !signs.iter().copied().all(is_sign) {
143            return Err(PaddingHadamardError::InvalidSignRepresentation);
144        }
145
146        if signs.len() > padded_dim {
147            return Err(PaddingHadamardError::SignsTooLong);
148        }
149
150        if !padded_dim.is_power_of_two() {
151            return Err(PaddingHadamardError::DimNotPowerOfTwo);
152        }
153
154        if let Some(ref subsample) = subsample {
155            if !utils::is_strictly_monotonic(subsample.iter()) {
156                return Err(PaddingHadamardError::SubsampleNotMonotonic);
157            }
158
159            if let Some(last) = subsample.last() {
160                if *last as usize >= padded_dim {
161                    return Err(PaddingHadamardError::LastSubsampleTooLarge);
162                }
163            } else {
164                return Err(PaddingHadamardError::SubsampleEmpty);
165            }
166        }
167
168        Ok(Self {
169            signs,
170            padded_dim,
171            subsample,
172        })
173    }
174
175    /// Return the input dimension for the transformation.
176    pub fn input_dim(&self) -> usize {
177        self.signs.len()
178    }
179
180    /// Return the output dimension for the transformation.
181    pub fn output_dim(&self) -> usize {
182        match &self.subsample {
183            None => self.padded_dim,
184            Some(v) => v.len(),
185        }
186    }
187
188    /// Return whether or not the transform preserves norms.
189    ///
190    /// For this transform, norms are not preserved when the output dimensionality is not a
191    /// power of two greater than or equal to the input dimensionality.
192    pub fn preserves_norms(&self) -> bool {
193        self.subsample.is_none()
194    }
195
196    /// An internal helper for performing the sign flipping operation.
197    //A
198    /// # Preconditions
199    ///
200    /// This function requires (but only checks in debug build) the following pre-conditions
201    ///
202    /// * `src.len() == self.input_dim()`.
203    /// * `dst.len() == self.output_dim()`.
204    fn copy_and_flip_signs(&self, dst: &mut [f32], src: &[f32]) {
205        debug_assert_eq!(dst.len(), self.padded_dim);
206        debug_assert_eq!(src.len(), self.input_dim());
207
208        // Copy the sign bits.
209        std::iter::zip(dst.iter_mut(), src.iter())
210            .zip(self.signs.iter())
211            .for_each(|((dst, src), sign)| *dst = f32::from_bits(src.to_bits() ^ sign));
212
213        // Pad the rest to zero.
214        dst.iter_mut()
215            .skip(self.input_dim())
216            .for_each(|dst| *dst = 0.0);
217    }
218
219    /// Perform the transformation of the `src` vector into the `dst` vector.
220    ///
221    /// # Errors
222    ///
223    /// Returns an error if
224    ///
225    /// * `src.len() != self.input_dim()`.
226    /// * `dst.len() != self.output_dim()`.
227    pub fn transform_into(
228        &self,
229        dst: &mut [f32],
230        src: &[f32],
231        allocator: ScopedAllocator<'_>,
232    ) -> Result<(), TransformFailed> {
233        let input_dim = self.input_dim();
234        let output_dim = self.output_dim();
235        check_dims(dst, src, input_dim, output_dim)?;
236
237        // If we are not sub-sampling, then we can transform directly into the destination.
238        match &self.subsample {
239            None => {
240                // Copy over values from `src`, applying the sign flipping.
241                self.copy_and_flip_signs(dst, src);
242
243                // Lint: We satisfy the pre-condidions for `hadamard_transform` because:
244                //
245                // 1. `output_dim` is a power of 2 by construction.
246                // 2. We've checked that `dst.len() == output_dim`.
247                #[allow(clippy::unwrap_used)]
248                hadamard_transform(dst).unwrap();
249            }
250            Some(indices) => {
251                let mut tmp = Poly::broadcast(0.0f32, self.padded_dim, allocator)?;
252
253                self.copy_and_flip_signs(&mut tmp, src);
254
255                // Lint: We satisfy the pre-condidions for `hadamard_transform` because:
256                //
257                // 1. `padded_dim` is a power of 2 by construction.
258                // 2. We've checked that `tmp.len() == padded_dim`.
259                #[allow(clippy::unwrap_used)]
260                hadamard_transform(&mut tmp).unwrap();
261
262                let rescale = ((tmp.len() as f32) / (indices.len() as f32)).sqrt();
263                debug_assert_eq!(dst.len(), indices.len());
264                std::iter::zip(dst.iter_mut(), indices.iter()).for_each(
265                    |(d, i): (&mut f32, &u32)| {
266                        *d = tmp[*i as usize] * rescale;
267                    },
268                );
269            }
270        }
271
272        Ok(())
273    }
274}
275
276impl<A> TryClone for PaddingHadamard<A>
277where
278    A: Allocator,
279{
280    fn try_clone(&self) -> Result<Self, AllocatorError> {
281        Ok(Self {
282            signs: self.signs.try_clone()?,
283            padded_dim: self.padded_dim,
284            subsample: self.subsample.try_clone()?,
285        })
286    }
287}
288
289/// Errors that may occur while constructing a [`PaddingHadamard`] from constituent parts.
290#[derive(Debug, Clone, Copy, Error, PartialEq)]
291#[non_exhaustive]
292pub enum PaddingHadamardError {
293    #[error("an invalid sign representation was discovered")]
294    InvalidSignRepresentation,
295    #[error("`signs` length exceeds `padded_dim`")]
296    SignsTooLong,
297    #[error("padded dim is not a power of two")]
298    DimNotPowerOfTwo,
299    #[error("subsample indices cannot be empty")]
300    SubsampleEmpty,
301    #[error("subsample indices is not monotonic")]
302    SubsampleNotMonotonic,
303    #[error("last subsample index exceeded `padded_dim`")]
304    LastSubsampleTooLarge,
305    #[error(transparent)]
306    AllocatorError(#[from] AllocatorError),
307}
308
309#[cfg(feature = "flatbuffers")]
310impl<A> PaddingHadamard<A>
311where
312    A: Allocator,
313{
314    /// Pack into a [`crate::flatbuffers::transforms::PaddingHadamard`] serialized representation.
315    pub(crate) fn pack<'a, FA>(
316        &self,
317        buf: &mut FlatBufferBuilder<'a, FA>,
318    ) -> WIPOffset<fb::transforms::PaddingHadamard<'a>>
319    where
320        FA: flatbuffers::Allocator + 'a,
321    {
322        // First, pack the sign bits as boolean values.
323        let signs = buf.create_vector_from_iter(self.signs.iter().copied().map(sign_to_bool));
324
325        // If subsample indices are present - save those as well.
326        let subsample = self
327            .subsample
328            .as_ref()
329            .map(|indices| buf.create_vector(indices));
330
331        // Finish up.
332        fb::transforms::PaddingHadamard::create(
333            buf,
334            &fb::transforms::PaddingHadamardArgs {
335                signs: Some(signs),
336                padded_dim: self.padded_dim as u32,
337                subsample,
338            },
339        )
340    }
341
342    /// Attempt to unpack from a [`crate::flatbuffers::transforms::PaddingHadamard`]
343    /// serialized representation, returning any error if encountered.
344    pub(crate) fn try_unpack(
345        alloc: A,
346        proto: fb::transforms::PaddingHadamard<'_>,
347    ) -> Result<Self, PaddingHadamardError> {
348        let signs = Poly::from_iter(proto.signs().iter().map(bool_to_sign), alloc.clone())?;
349
350        let subsample = match proto.subsample() {
351            Some(subsample) => Some(Poly::from_iter(subsample.into_iter(), alloc)?),
352            None => None,
353        };
354
355        Self::try_from_parts(signs, proto.padded_dim() as usize, subsample)
356    }
357}
358
359///////////
360// Tests //
361///////////
362
363#[cfg(test)]
364mod tests {
365    use diskann_utils::lazy_format;
366    use rand::{SeedableRng, rngs::StdRng};
367
368    use super::*;
369    use crate::{
370        algorithms::transforms::{Transform, TransformKind, test_utils},
371        alloc::GlobalAllocator,
372    };
373
374    // Since we use a slightly non-obvious strategy for applying the +/-1 permutation, we
375    // test its behavior explicitly.
376    #[test]
377    fn test_sign_flipping() {
378        let mut rng = StdRng::seed_from_u64(0xf8ee12b1e9f33dbd);
379        let dim = 14;
380
381        let transform = PaddingHadamard::new(
382            NonZeroUsize::new(dim).unwrap(),
383            TargetDim::Same,
384            &mut rng,
385            GlobalAllocator,
386        )
387        .unwrap();
388
389        assert_eq!(transform.input_dim(), dim);
390        assert_eq!(transform.output_dim(), dim);
391
392        let positive = vec![1.0f32; dim];
393        let negative = vec![-1.0f32; dim];
394
395        let mut output = vec![f32::INFINITY; 16];
396
397        // Transform positive numbers
398        transform.copy_and_flip_signs(&mut output, &positive);
399
400        let mut unflipped = 0;
401        let mut flipped = 0;
402        std::iter::zip(output.iter(), transform.signs.iter())
403            .enumerate()
404            .for_each(|(i, (o, s))| {
405                if *s == 0x8000_0000 {
406                    flipped += 1;
407                    assert_eq!(*o, -1.0, "expected entry {} to be flipped", i);
408                } else {
409                    unflipped += 1;
410                    assert_eq!(*o, 1.0, "expected entry {} to be unchanged", i);
411                }
412            });
413
414        // Check that we have a mixture of flipped and unflipped signs.
415        assert!(unflipped > 0);
416        assert!(flipped > 0);
417
418        // Assert that everything else was zero padded.
419        assert_eq!(output[14], 0.0f32);
420        assert_eq!(output[15], 0.0f32);
421
422        // Transform negative numbers
423        output.fill(f32::INFINITY);
424        transform.copy_and_flip_signs(&mut output, &negative);
425        std::iter::zip(output.iter(), transform.signs.iter())
426            .enumerate()
427            .for_each(|(i, (o, s))| {
428                if *s == 0x8000_0000 {
429                    assert_eq!(*o, 1.0, "expected entry {} to be flipped", i);
430                } else {
431                    assert_eq!(*o, -1.0, "expected entry {} to be unchanged", i);
432                }
433            });
434
435        // Assert that everything else was zero padded.
436        assert_eq!(output[14], 0.0f32);
437        assert_eq!(output[15], 0.0f32);
438    }
439
440    test_utils::delegate_transformer!(PaddingHadamard<GlobalAllocator>);
441
442    // This tests the natural hadamard transform where the output dimension is upgraded
443    // to the next power of 2.
444    #[test]
445    fn test_padding_hadamard() {
446        // Inner product computations are more susceptible to floating point error.
447        // Instead of using ULP here, we fall back to using absolute and relative error.
448        //
449        // These error bounds are for when we set the output dimenion to a power of 2 that
450        // is higher than input dimension.
451        let natural_errors = test_utils::ErrorSetup {
452            norm: test_utils::Check::ulp(4),
453            l2: test_utils::Check::ulp(4),
454            ip: test_utils::Check::absrel(5.0e-6, 2e-4),
455        };
456
457        // NOTE: Subsampling introduces high variance in the norm and L2, so our error
458        // bounds need to be looser.
459        //
460        // Subsampling results in poor preservation of inner products, so we skip it
461        // altogether.
462        let subsampled_errors = test_utils::ErrorSetup {
463            norm: test_utils::Check::absrel(0.0, 1e-1),
464            l2: test_utils::Check::absrel(0.0, 1e-1),
465            ip: test_utils::Check::skip(),
466        };
467
468        let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
469
470        let dim_combos = [
471            // Natural
472            (15, 16, true, target_dim(16), &natural_errors),
473            (15, 16, true, TargetDim::Natural, &natural_errors),
474            (16, 16, true, TargetDim::Same, &natural_errors),
475            (16, 16, true, TargetDim::Natural, &natural_errors),
476            (16, 32, true, target_dim(32), &natural_errors),
477            (16, 64, true, target_dim(64), &natural_errors),
478            (100, 128, true, target_dim(128), &natural_errors),
479            (100, 128, true, TargetDim::Natural, &natural_errors),
480            (256, 256, true, target_dim(256), &natural_errors),
481            // Subsampled,
482            (1000, 1000, false, TargetDim::Same, &subsampled_errors),
483            (500, 1000, false, target_dim(1000), &subsampled_errors),
484        ];
485
486        let trials_per_combo = 20;
487        let trials_per_dim = 100;
488
489        let mut rng = StdRng::seed_from_u64(0x6d1699abe0626147);
490        for (input, output, preserves_norms, target, errors) in dim_combos {
491            let input_nz = NonZeroUsize::new(input).unwrap();
492            for trial in 0..trials_per_combo {
493                let ctx = lazy_format!(
494                    "input dim = {}, output dim = {}, macro trial {} of {}",
495                    input,
496                    output,
497                    trial,
498                    trials_per_combo
499                );
500
501                let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
502                    assert_ne!(io.input0, &io.output0[..input]);
503                    assert_ne!(io.input1, &io.output1[..input]);
504                    test_utils::check_errors(io, context, errors);
505                };
506
507                // Clone the Rng state so the abstract transform behaves the same.
508                let mut rng_clone = rng.clone();
509
510                // Base Transformer
511                {
512                    let transformer = PaddingHadamard::new(
513                        NonZeroUsize::new(input).unwrap(),
514                        target,
515                        &mut rng,
516                        GlobalAllocator,
517                    )
518                    .unwrap();
519
520                    assert_eq!(transformer.input_dim(), input);
521                    assert_eq!(transformer.output_dim(), output);
522                    assert_eq!(transformer.preserves_norms(), preserves_norms);
523
524                    test_utils::test_transform(
525                        &transformer,
526                        trials_per_dim,
527                        &mut checker,
528                        &mut rng,
529                        &ctx,
530                    )
531                }
532
533                // Abstract Transformer
534                {
535                    let kind = TransformKind::PaddingHadamard { target_dim: target };
536                    let transformer =
537                        Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
538                            .unwrap();
539
540                    assert_eq!(transformer.input_dim(), input);
541                    assert_eq!(transformer.output_dim(), output);
542                    assert_eq!(transformer.preserves_norms(), preserves_norms);
543
544                    test_utils::test_transform(
545                        &transformer,
546                        trials_per_dim,
547                        &mut checker,
548                        &mut rng_clone,
549                        &ctx,
550                    )
551                }
552            }
553        }
554    }
555
556    #[cfg(feature = "flatbuffers")]
557    mod serialization {
558        use super::*;
559        use crate::{flatbuffers::to_flatbuffer, poly};
560
561        #[test]
562        fn padding_hadamard() {
563            let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
564            let alloc = GlobalAllocator;
565
566            // Test various dimension combinations
567            let test_cases = [
568                (5, TargetDim::Same),
569                (10, TargetDim::Natural),
570                (16, TargetDim::Natural),
571                (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
572                (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
573            ];
574
575            for (dim, target_dim) in test_cases {
576                let transform = PaddingHadamard::new(
577                    NonZeroUsize::new(dim).unwrap(),
578                    target_dim,
579                    &mut rng,
580                    alloc,
581                )
582                .unwrap();
583                let data = to_flatbuffer(|buf| transform.pack(buf));
584
585                let proto = flatbuffers::root::<fb::transforms::PaddingHadamard>(&data).unwrap();
586                let reloaded = PaddingHadamard::try_unpack(alloc, proto).unwrap();
587
588                assert_eq!(transform, reloaded);
589            }
590
591            let gen_err = |x: PaddingHadamard<_>| -> PaddingHadamardError {
592                let data = to_flatbuffer(|buf| x.pack(buf));
593                let proto = flatbuffers::root::<fb::transforms::PaddingHadamard>(&data).unwrap();
594                PaddingHadamard::try_unpack(alloc, proto).unwrap_err()
595            };
596
597            // Signs too longs.
598            {
599                let err = gen_err(PaddingHadamard {
600                    signs: poly!([0, 0, 0, 0, 0], alloc).unwrap(), // longer than `padded_dim`.
601                    padded_dim: 4,
602                    subsample: None,
603                });
604
605                assert_eq!(err, PaddingHadamardError::SignsTooLong);
606            }
607
608            // Dim Not a power of 2.
609            {
610                let err = gen_err(PaddingHadamard {
611                    signs: poly!([0, 0, 0, 0, 0], alloc).unwrap(),
612                    padded_dim: 5, // not a power of 2
613                    subsample: None,
614                });
615
616                assert_eq!(err, PaddingHadamardError::DimNotPowerOfTwo);
617            }
618
619            // Subsample empty
620            {
621                let err = gen_err(PaddingHadamard {
622                    signs: poly!([0, 0, 0, 0], alloc).unwrap(),
623                    padded_dim: 4,
624                    subsample: Some(poly!([], alloc).unwrap()), // empty
625                });
626
627                assert_eq!(err, PaddingHadamardError::SubsampleEmpty);
628            }
629
630            // Not monotonic
631            {
632                let err = gen_err(PaddingHadamard {
633                    signs: poly!([0, 0, 0, 0], alloc).unwrap(),
634                    padded_dim: 4,
635                    subsample: Some(poly!([0, 2, 2], alloc).unwrap()), // not monotonic
636                });
637                assert_eq!(err, PaddingHadamardError::SubsampleNotMonotonic);
638            }
639
640            // Subsample too long.
641            {
642                let err = gen_err(PaddingHadamard {
643                    signs: poly!([0, 0, 0, 0], alloc).unwrap(),
644                    padded_dim: 4,
645                    subsample: Some(poly!([0, 1, 2, 3, 4], alloc).unwrap()),
646                });
647
648                assert_eq!(err, PaddingHadamardError::LastSubsampleTooLarge);
649            }
650
651            // Subsample too large
652            {
653                let err = gen_err(PaddingHadamard {
654                    signs: poly!([0, 0, 0, 0], alloc).unwrap(),
655                    padded_dim: 4,
656                    subsample: Some(poly!([0, 4], alloc).unwrap()),
657                });
658
659                assert_eq!(err, PaddingHadamardError::LastSubsampleTooLarge);
660            }
661        }
662    }
663}