Skip to main content

diskann_quantization/algorithms/transforms/
mod.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6// imports
7use std::num::NonZeroUsize;
8
9#[cfg(feature = "flatbuffers")]
10use flatbuffers::{FlatBufferBuilder, WIPOffset};
11use rand::RngCore;
12use thiserror::Error;
13
14use crate::alloc::{Allocator, AllocatorError, ScopedAllocator, TryClone};
15#[cfg(feature = "flatbuffers")]
16use crate::flatbuffers as fb;
17
18// modules
19mod double_hadamard;
20mod null;
21mod padding_hadamard;
22
23crate::utils::features! {
24    #![feature = "linalg"]
25    mod random_rotation;
26}
27
28mod utils;
29
30#[cfg(test)]
31mod test_utils;
32
33// reexports
34pub use double_hadamard::{DoubleHadamard, DoubleHadamardError};
35pub use null::NullTransform;
36pub use padding_hadamard::{PaddingHadamard, PaddingHadamardError};
37pub use utils::TransformFailed;
38
39crate::utils::features! {
40    #![feature = "linalg"]
41    pub use random_rotation::RandomRotation;
42}
43
44crate::utils::features! {
45    #![all(feature = "linalg", feature = "flatbuffers")]
46    pub use random_rotation::RandomRotationError;
47}
48
49crate::utils::features! {
50    #![feature = "flatbuffers"]
51    pub use null::NullTransformError;
52}
53
54///////////////
55// Transform //
56///////////////
57
58#[derive(Debug, Clone, Copy)]
59#[non_exhaustive]
60pub enum TransformKind {
61    /// Use a Hadamard transform
62    /// ```math
63    /// HDx / sqrt(n)
64    /// ```
65    /// where
66    ///
67    /// * `H` is an (implicit) [Hadamard matrix](https://en.wikipedia.org/wiki/Hadamard_matrix)
68    /// * `D` is a diagonal matrix with `+/-1` on the diagonal.
69    /// * `x` is the input vector.
70    /// * `n` is the number of rows in `x`.
71    ///
72    /// Unlike [`Self::RandomRotation`], this method does not require matrix-vector
73    /// multiplication and is therefore much faster for high-dimensional vectors.
74    ///
75    /// The Hadamard multiplication requires dimensions to be a power of two. Internally,
76    /// this method will pad `x` with zeros up to the next power of two and transform the
77    /// result.
78    PaddingHadamard { target_dim: TargetDim },
79
80    /// Use a Double Hadamard transform, which applies two Hadamard transformations
81    /// in sequence; first to the head of the vector and then to the tail.
82    ///
83    /// This approach does not have any requirement on the input dimension to apply
84    /// the distance preserving transformation using Hadamard multiplication,
85    /// unlike [`PaddingHadamard`].
86    ///
87    /// Empirically, this approach seems to give better recall performance than
88    /// applying [`PaddingHadamard`] and sampling down when `self.output_dim() == self.dim()`
89    /// and the dimension is not a power of two.
90    ///
91    /// See [`DoubleHadamard`] for the implementation details.
92    DoubleHadamard { target_dim: TargetDim },
93
94    /// A naive transform that copies source into destination.
95    Null,
96
97    /// Use a full-dimensional, randomly sampled orthogonal matrix to transform vectors.
98    ///
99    /// Transformation involves matrix multiplication and may be slow for high-dimensional
100    /// vectors.
101    #[cfg(feature = "linalg")]
102    #[cfg_attr(docsrs, doc(cfg(feature = "linalg")))]
103    RandomRotation { target_dim: TargetDim },
104}
105
106#[derive(Debug, Clone, Error)]
107pub enum NewTransformError {
108    #[error("random number generator is required for {0:?}")]
109    RngMissing(TransformKind),
110    #[error(transparent)]
111    AllocatorError(#[from] AllocatorError),
112}
113
114#[derive(Debug)]
115#[cfg_attr(test, derive(PartialEq))]
116pub enum Transform<A>
117where
118    A: Allocator,
119{
120    PaddingHadamard(PaddingHadamard<A>),
121    DoubleHadamard(DoubleHadamard<A>),
122    Null(NullTransform),
123
124    #[cfg(feature = "linalg")]
125    #[cfg_attr(docsrs, doc(cfg(feature = "linalg")))]
126    RandomRotation(RandomRotation),
127}
128
129impl<A> Transform<A>
130where
131    A: Allocator,
132{
133    /// Construct a new `Transform` from a `TransformKind`, input dimension
134    /// and an optional rng (if needed).
135    ///
136    /// Currently, `rng` should be supplied for the following transforms:
137    /// - [`RandomRotation`]
138    /// - [`PaddingHadamard`]
139    /// - [`DoubleHadamard`]
140    ///
141    /// The [`NullTransform`] can be initialized without `rng`.
142    pub fn new(
143        transform_kind: TransformKind,
144        dim: NonZeroUsize,
145        rng: Option<&mut dyn RngCore>,
146        allocator: A,
147    ) -> Result<Self, NewTransformError> {
148        match transform_kind {
149            TransformKind::PaddingHadamard { target_dim } => {
150                let rng = rng.ok_or(NewTransformError::RngMissing(transform_kind))?;
151                Ok(Transform::PaddingHadamard(PaddingHadamard::new(
152                    dim, target_dim, rng, allocator,
153                )?))
154            }
155            TransformKind::DoubleHadamard { target_dim } => {
156                let rng = rng.ok_or(NewTransformError::RngMissing(transform_kind))?;
157                Ok(Transform::DoubleHadamard(DoubleHadamard::new(
158                    dim, target_dim, rng, allocator,
159                )?))
160            }
161            TransformKind::Null => Ok(Transform::Null(NullTransform::new(dim))),
162            #[cfg(feature = "linalg")]
163            TransformKind::RandomRotation { target_dim } => {
164                let rng = rng.ok_or(NewTransformError::RngMissing(transform_kind))?;
165                Ok(Transform::RandomRotation(RandomRotation::new(
166                    dim, target_dim, rng,
167                )))
168            }
169        }
170    }
171
172    pub(crate) fn input_dim(&self) -> usize {
173        match self {
174            Self::PaddingHadamard(t) => t.input_dim(),
175            Self::DoubleHadamard(t) => t.input_dim(),
176            Self::Null(t) => t.dim(),
177            #[cfg(feature = "linalg")]
178            Self::RandomRotation(t) => t.input_dim(),
179        }
180    }
181    pub(crate) fn output_dim(&self) -> usize {
182        match self {
183            Self::PaddingHadamard(t) => t.output_dim(),
184            Self::DoubleHadamard(t) => t.output_dim(),
185            Self::Null(t) => t.dim(),
186            #[cfg(feature = "linalg")]
187            Self::RandomRotation(t) => t.output_dim(),
188        }
189    }
190
191    pub(crate) fn preserves_norms(&self) -> bool {
192        match self {
193            Self::PaddingHadamard(t) => t.preserves_norms(),
194            Self::DoubleHadamard(t) => t.preserves_norms(),
195            Self::Null(t) => t.preserves_norms(),
196            #[cfg(feature = "linalg")]
197            Self::RandomRotation(t) => t.preserves_norms(),
198        }
199    }
200
201    pub(crate) fn transform_into(
202        &self,
203        dst: &mut [f32],
204        src: &[f32],
205        allocator: ScopedAllocator<'_>,
206    ) -> Result<(), TransformFailed> {
207        match self {
208            Self::PaddingHadamard(t) => t.transform_into(dst, src, allocator),
209            Self::DoubleHadamard(t) => t.transform_into(dst, src, allocator),
210            Self::Null(t) => t.transform_into(dst, src),
211            #[cfg(feature = "linalg")]
212            Self::RandomRotation(t) => t.transform_into(dst, src),
213        }
214    }
215}
216
217impl<A> TryClone for Transform<A>
218where
219    A: Allocator,
220{
221    fn try_clone(&self) -> Result<Self, AllocatorError> {
222        match self {
223            Self::PaddingHadamard(t) => Ok(Self::PaddingHadamard(t.try_clone()?)),
224            Self::DoubleHadamard(t) => Ok(Self::DoubleHadamard(t.try_clone()?)),
225            Self::Null(t) => Ok(Self::Null(t.clone())),
226            #[cfg(feature = "linalg")]
227            Self::RandomRotation(t) => Ok(Self::RandomRotation(t.clone())),
228        }
229    }
230}
231
232#[cfg(feature = "flatbuffers")]
233#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
234#[derive(Debug, Clone, Copy, Error, PartialEq)]
235#[non_exhaustive]
236pub enum TransformError {
237    #[error(transparent)]
238    PaddingHadamardError(#[from] PaddingHadamardError),
239    #[error(transparent)]
240    DoubleHadamardError(#[from] DoubleHadamardError),
241    #[error(transparent)]
242    NullTransformError(#[from] NullTransformError),
243    #[cfg(feature = "linalg")]
244    #[cfg_attr(docsrs, doc(cfg(feature = "linalg")))]
245    #[error(transparent)]
246    RandomRotationError(#[from] RandomRotationError),
247    #[error("invalid transform kind")]
248    InvalidTransformKind,
249}
250
251#[cfg(feature = "flatbuffers")]
252impl<A> Transform<A>
253where
254    A: Allocator,
255{
256    /// Pack into a [`crate::flatbuffers::transforms::Transform`] serialized representation.
257    pub(crate) fn pack<'a, FA>(
258        &self,
259        buf: &mut FlatBufferBuilder<'a, FA>,
260    ) -> WIPOffset<fb::transforms::Transform<'a>>
261    where
262        FA: flatbuffers::Allocator + 'a,
263    {
264        let (kind, offset) = match self {
265            Self::PaddingHadamard(t) => (
266                fb::transforms::TransformKind::PaddingHadamard,
267                t.pack(buf).as_union_value(),
268            ),
269            Self::DoubleHadamard(t) => (
270                fb::transforms::TransformKind::DoubleHadamard,
271                t.pack(buf).as_union_value(),
272            ),
273            Self::Null(t) => (
274                fb::transforms::TransformKind::NullTransform,
275                t.pack(buf).as_union_value(),
276            ),
277            #[cfg(feature = "linalg")]
278            Self::RandomRotation(t) => (
279                fb::transforms::TransformKind::RandomRotation,
280                t.pack(buf).as_union_value(),
281            ),
282        };
283
284        fb::transforms::Transform::create(
285            buf,
286            &fb::transforms::TransformArgs {
287                transform_type: kind,
288                transform: Some(offset),
289            },
290        )
291    }
292
293    /// Attempt to unpack from a [`crate::flatbuffers::transforms::Transform`] serialized
294    /// representation, returning any error if encountered.
295    pub(crate) fn try_unpack(
296        alloc: A,
297        proto: fb::transforms::Transform<'_>,
298    ) -> Result<Self, TransformError> {
299        if let Some(transform) = proto.transform_as_padding_hadamard() {
300            return Ok(Self::PaddingHadamard(PaddingHadamard::try_unpack(
301                alloc, transform,
302            )?));
303        }
304
305        #[cfg(feature = "linalg")]
306        if let Some(transform) = proto.transform_as_random_rotation() {
307            return Ok(Self::RandomRotation(RandomRotation::try_unpack(transform)?));
308        }
309
310        if let Some(transform) = proto.transform_as_double_hadamard() {
311            return Ok(Self::DoubleHadamard(DoubleHadamard::try_unpack(
312                alloc, transform,
313            )?));
314        }
315
316        if let Some(transform) = proto.transform_as_null_transform() {
317            return Ok(Self::Null(NullTransform::try_unpack(transform)?));
318        }
319
320        Err(TransformError::InvalidTransformKind)
321    }
322}
323
324/// Transformations possess the ability to keep dimensionality the same, increase it, or
325/// decrease it.
326///
327/// This struct enables the caller to communicate the desired behavior upon transform
328/// construction.
329#[derive(Debug, Clone, Copy)]
330pub enum TargetDim {
331    /// Keep the output dimensionality the same as the input dimensionality.
332    ///
333    /// # Note
334    ///
335    /// When the input dimensionality is less than the "natural" dimensionality (
336    /// see [`Self::Natural`], post-transformed sampling may be invoked where only a subset
337    /// of the transformed vector's dimensions are retained.
338    ///
339    /// For low dimensional embeddings, this sampling may result in high norm variance and
340    /// poor recall.
341    Same,
342
343    /// Use the "natural" dimensionality for the output.
344    ///
345    /// This allows transformations like [`PaddingHadamard`] to increase the dimensionality
346    /// to the next power of two if needed. This will usually provide better accuracy than
347    /// [`Self::Same`] but may result in a worse compression ratio.
348    Natural,
349
350    /// Set a hard value for the output dimensionality.
351    ///
352    /// This may result in arbitrary subsampling (see the note in [`TargetDim::Same`] or
353    /// supersampling (zero padding the pretransformed vector). Use with care.
354    Override(NonZeroUsize),
355}
356
357#[cfg(test)]
358test_utils::delegate_transformer!(Transform<crate::alloc::GlobalAllocator>);