Skip to main content

diskann_quantization/algorithms/transforms/
mod.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6// imports
7use std::num::NonZeroUsize;
8
9#[cfg(feature = "flatbuffers")]
10use flatbuffers::{FlatBufferBuilder, WIPOffset};
11use rand::RngCore;
12use thiserror::Error;
13
14use crate::alloc::{Allocator, AllocatorError, ScopedAllocator, TryClone};
15#[cfg(feature = "flatbuffers")]
16use crate::flatbuffers as fb;
17
18// modules
19mod double_hadamard;
20mod null;
21mod padding_hadamard;
22
23crate::utils::features! {
24    #![feature = "linalg"]
25    mod random_rotation;
26}
27
28mod utils;
29
30#[cfg(test)]
31#[cfg(not(miri))]
32mod test_utils;
33
34// reexports
35pub use double_hadamard::{DoubleHadamard, DoubleHadamardError};
36pub use null::NullTransform;
37pub use padding_hadamard::{PaddingHadamard, PaddingHadamardError};
38pub use utils::TransformFailed;
39
40crate::utils::features! {
41    #![feature = "linalg"]
42    pub use random_rotation::RandomRotation;
43}
44
45crate::utils::features! {
46    #![all(feature = "linalg", feature = "flatbuffers")]
47    pub use random_rotation::RandomRotationError;
48}
49
50crate::utils::features! {
51    #![feature = "flatbuffers"]
52    pub use null::NullTransformError;
53}
54
55///////////////
56// Transform //
57///////////////
58
59#[derive(Debug, Clone, Copy)]
60#[non_exhaustive]
61pub enum TransformKind {
62    /// Use a Hadamard transform
63    /// ```math
64    /// HDx / sqrt(n)
65    /// ```
66    /// where
67    ///
68    /// * `H` is an (implicit) [Hadamard matrix](https://en.wikipedia.org/wiki/Hadamard_matrix)
69    /// * `D` is a diagonal matrix with `+/-1` on the diagonal.
70    /// * `x` is the input vector.
71    /// * `n` is the number of rows in `x`.
72    ///
73    /// Unlike [`Self::RandomRotation`], this method does not require matrix-vector
74    /// multiplication and is therefore much faster for high-dimensional vectors.
75    ///
76    /// The Hadamard multiplication requires dimensions to be a power of two. Internally,
77    /// this method will pad `x` with zeros up to the next power of two and transform the
78    /// result.
79    PaddingHadamard { target_dim: TargetDim },
80
81    /// Use a Double Hadamard transform, which applies two Hadamard transformations
82    /// in sequence; first to the head of the vector and then to the tail.
83    ///
84    /// This approach does not have any requirement on the input dimension to apply
85    /// the distance preserving transformation using Hadamard multiplication,
86    /// unlike [`PaddingHadamard`].
87    ///
88    /// Empirically, this approach seems to give better recall performance than
89    /// applying [`PaddingHadamard`] and sampling down when `self.output_dim() == self.dim()`
90    /// and the dimension is not a power of two.
91    ///
92    /// See [`DoubleHadamard`] for the implementation details.
93    DoubleHadamard { target_dim: TargetDim },
94
95    /// A naive transform that copies source into destination.
96    Null,
97
98    /// Use a full-dimensional, randomly sampled orthogonal matrix to transform vectors.
99    ///
100    /// Transformation involves matrix multiplication and may be slow for high-dimensional
101    /// vectors.
102    #[cfg(feature = "linalg")]
103    #[cfg_attr(docsrs, doc(cfg(feature = "linalg")))]
104    RandomRotation { target_dim: TargetDim },
105}
106
107#[derive(Debug, Clone, Error)]
108pub enum NewTransformError {
109    #[error("random number generator is required for {0:?}")]
110    RngMissing(TransformKind),
111    #[error(transparent)]
112    AllocatorError(#[from] AllocatorError),
113}
114
115#[derive(Debug)]
116#[cfg_attr(test, derive(PartialEq))]
117pub enum Transform<A>
118where
119    A: Allocator,
120{
121    PaddingHadamard(PaddingHadamard<A>),
122    DoubleHadamard(DoubleHadamard<A>),
123    Null(NullTransform),
124
125    #[cfg(feature = "linalg")]
126    #[cfg_attr(docsrs, doc(cfg(feature = "linalg")))]
127    RandomRotation(RandomRotation),
128}
129
130impl<A> Transform<A>
131where
132    A: Allocator,
133{
134    /// Construct a new `Transform` from a `TransformKind`, input dimension
135    /// and an optional rng (if needed).
136    ///
137    /// Currently, `rng` should be supplied for the following transforms:
138    /// - [`RandomRotation`]
139    /// - [`PaddingHadamard`]
140    /// - [`DoubleHadamard`]
141    ///
142    /// The [`NullTransform`] can be initialized without `rng`.
143    pub fn new(
144        transform_kind: TransformKind,
145        dim: NonZeroUsize,
146        rng: Option<&mut dyn RngCore>,
147        allocator: A,
148    ) -> Result<Self, NewTransformError> {
149        match transform_kind {
150            TransformKind::PaddingHadamard { target_dim } => {
151                let rng = rng.ok_or(NewTransformError::RngMissing(transform_kind))?;
152                Ok(Transform::PaddingHadamard(PaddingHadamard::new(
153                    dim, target_dim, rng, allocator,
154                )?))
155            }
156            TransformKind::DoubleHadamard { target_dim } => {
157                let rng = rng.ok_or(NewTransformError::RngMissing(transform_kind))?;
158                Ok(Transform::DoubleHadamard(DoubleHadamard::new(
159                    dim, target_dim, rng, allocator,
160                )?))
161            }
162            TransformKind::Null => Ok(Transform::Null(NullTransform::new(dim))),
163            #[cfg(feature = "linalg")]
164            TransformKind::RandomRotation { target_dim } => {
165                let rng = rng.ok_or(NewTransformError::RngMissing(transform_kind))?;
166                Ok(Transform::RandomRotation(RandomRotation::new(
167                    dim, target_dim, rng,
168                )))
169            }
170        }
171    }
172
173    pub(crate) fn input_dim(&self) -> usize {
174        match self {
175            Self::PaddingHadamard(t) => t.input_dim(),
176            Self::DoubleHadamard(t) => t.input_dim(),
177            Self::Null(t) => t.dim(),
178            #[cfg(feature = "linalg")]
179            Self::RandomRotation(t) => t.input_dim(),
180        }
181    }
182    pub(crate) fn output_dim(&self) -> usize {
183        match self {
184            Self::PaddingHadamard(t) => t.output_dim(),
185            Self::DoubleHadamard(t) => t.output_dim(),
186            Self::Null(t) => t.dim(),
187            #[cfg(feature = "linalg")]
188            Self::RandomRotation(t) => t.output_dim(),
189        }
190    }
191
192    pub(crate) fn preserves_norms(&self) -> bool {
193        match self {
194            Self::PaddingHadamard(t) => t.preserves_norms(),
195            Self::DoubleHadamard(t) => t.preserves_norms(),
196            Self::Null(t) => t.preserves_norms(),
197            #[cfg(feature = "linalg")]
198            Self::RandomRotation(t) => t.preserves_norms(),
199        }
200    }
201
202    pub(crate) fn transform_into(
203        &self,
204        dst: &mut [f32],
205        src: &[f32],
206        allocator: ScopedAllocator<'_>,
207    ) -> Result<(), TransformFailed> {
208        match self {
209            Self::PaddingHadamard(t) => t.transform_into(dst, src, allocator),
210            Self::DoubleHadamard(t) => t.transform_into(dst, src, allocator),
211            Self::Null(t) => t.transform_into(dst, src),
212            #[cfg(feature = "linalg")]
213            Self::RandomRotation(t) => t.transform_into(dst, src),
214        }
215    }
216}
217
218impl<A> TryClone for Transform<A>
219where
220    A: Allocator,
221{
222    fn try_clone(&self) -> Result<Self, AllocatorError> {
223        match self {
224            Self::PaddingHadamard(t) => Ok(Self::PaddingHadamard(t.try_clone()?)),
225            Self::DoubleHadamard(t) => Ok(Self::DoubleHadamard(t.try_clone()?)),
226            Self::Null(t) => Ok(Self::Null(t.clone())),
227            #[cfg(feature = "linalg")]
228            Self::RandomRotation(t) => Ok(Self::RandomRotation(t.clone())),
229        }
230    }
231}
232
233#[cfg(feature = "flatbuffers")]
234#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
235#[derive(Debug, Clone, Copy, Error, PartialEq)]
236#[non_exhaustive]
237pub enum TransformError {
238    #[error(transparent)]
239    PaddingHadamardError(#[from] PaddingHadamardError),
240    #[error(transparent)]
241    DoubleHadamardError(#[from] DoubleHadamardError),
242    #[error(transparent)]
243    NullTransformError(#[from] NullTransformError),
244    #[cfg(feature = "linalg")]
245    #[cfg_attr(docsrs, doc(cfg(feature = "linalg")))]
246    #[error(transparent)]
247    RandomRotationError(#[from] RandomRotationError),
248    #[error("invalid transform kind")]
249    InvalidTransformKind,
250}
251
252#[cfg(feature = "flatbuffers")]
253impl<A> Transform<A>
254where
255    A: Allocator,
256{
257    /// Pack into a [`crate::flatbuffers::transforms::Transform`] serialized representation.
258    pub(crate) fn pack<'a, FA>(
259        &self,
260        buf: &mut FlatBufferBuilder<'a, FA>,
261    ) -> WIPOffset<fb::transforms::Transform<'a>>
262    where
263        FA: flatbuffers::Allocator + 'a,
264    {
265        let (kind, offset) = match self {
266            Self::PaddingHadamard(t) => (
267                fb::transforms::TransformKind::PaddingHadamard,
268                t.pack(buf).as_union_value(),
269            ),
270            Self::DoubleHadamard(t) => (
271                fb::transforms::TransformKind::DoubleHadamard,
272                t.pack(buf).as_union_value(),
273            ),
274            Self::Null(t) => (
275                fb::transforms::TransformKind::NullTransform,
276                t.pack(buf).as_union_value(),
277            ),
278            #[cfg(feature = "linalg")]
279            Self::RandomRotation(t) => (
280                fb::transforms::TransformKind::RandomRotation,
281                t.pack(buf).as_union_value(),
282            ),
283        };
284
285        fb::transforms::Transform::create(
286            buf,
287            &fb::transforms::TransformArgs {
288                transform_type: kind,
289                transform: Some(offset),
290            },
291        )
292    }
293
294    /// Attempt to unpack from a [`crate::flatbuffers::transforms::Transform`] serialized
295    /// representation, returning any error if encountered.
296    pub(crate) fn try_unpack(
297        alloc: A,
298        proto: fb::transforms::Transform<'_>,
299    ) -> Result<Self, TransformError> {
300        if let Some(transform) = proto.transform_as_padding_hadamard() {
301            return Ok(Self::PaddingHadamard(PaddingHadamard::try_unpack(
302                alloc, transform,
303            )?));
304        }
305
306        #[cfg(feature = "linalg")]
307        if let Some(transform) = proto.transform_as_random_rotation() {
308            return Ok(Self::RandomRotation(RandomRotation::try_unpack(transform)?));
309        }
310
311        if let Some(transform) = proto.transform_as_double_hadamard() {
312            return Ok(Self::DoubleHadamard(DoubleHadamard::try_unpack(
313                alloc, transform,
314            )?));
315        }
316
317        if let Some(transform) = proto.transform_as_null_transform() {
318            return Ok(Self::Null(NullTransform::try_unpack(transform)?));
319        }
320
321        Err(TransformError::InvalidTransformKind)
322    }
323}
324
325/// Transformations possess the ability to keep dimensionality the same, increase it, or
326/// decrease it.
327///
328/// This struct enables the caller to communicate the desired behavior upon transform
329/// construction.
330#[derive(Debug, Clone, Copy)]
331pub enum TargetDim {
332    /// Keep the output dimensionality the same as the input dimensionality.
333    ///
334    /// # Note
335    ///
336    /// When the input dimensionality is less than the "natural" dimensionality (
337    /// see [`Self::Natural`], post-transformed sampling may be invoked where only a subset
338    /// of the transformed vector's dimensions are retained.
339    ///
340    /// For low dimensional embeddings, this sampling may result in high norm variance and
341    /// poor recall.
342    Same,
343
344    /// Use the "natural" dimensionality for the output.
345    ///
346    /// This allows transformations like [`PaddingHadamard`] to increase the dimensionality
347    /// to the next power of two if needed. This will usually provide better accuracy than
348    /// [`Self::Same`] but may result in a worse compression ratio.
349    Natural,
350
351    /// Set a hard value for the output dimensionality.
352    ///
353    /// This may result in arbitrary subsampling (see the note in [`TargetDim::Same`] or
354    /// supersampling (zero padding the pretransformed vector). Use with care.
355    Override(NonZeroUsize),
356}
357
358#[cfg(test)]
359#[cfg(not(miri))]
360test_utils::delegate_transformer!(Transform<crate::alloc::GlobalAllocator>);