Skip to main content

diskann_quantization/algorithms/transforms/
random_rotation.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use diskann_linalg::{self, Transpose};
11use rand::Rng;
12#[cfg(feature = "flatbuffers")]
13use thiserror::Error;
14
15use super::{
16    utils::{check_dims, TransformFailed},
17    TargetDim,
18};
19#[cfg(feature = "flatbuffers")]
20use crate::flatbuffers as fb;
21
22//////////////////////
23// Dense Transforms //
24//////////////////////
25
26/// A distance-perserving transformation from `N`-dimensions to `N`-dimensions.
27///
28/// This struct materializes a full `NxN` transformation matrix and mechanically applies
29/// transformations via matrix multiplication.
30#[derive(Debug, Clone)]
31#[cfg_attr(test, derive(PartialEq))]
32pub struct RandomRotation {
33    /// This data structure maintains the invariant that this **must** be a square matrix.
34    transform: diskann_utils::views::Matrix<f32>,
35}
36
37impl RandomRotation {
38    /// Construct a new `RandomRotation` that transforms input vectors of dimension `dim`.
39    ///
40    /// The parameter `rng` is used to randomly sample the transformation matrix.
41    ///
42    /// The following dimensionalities will be configured depending on the value of `target`:
43    ///
44    /// * [`TargetDim::Same`]
45    ///   - `self.input_dim() == dim.get()`
46    ///   - `self.output_dim() == dim.get()`
47    /// * [`TargetDim::Natural`]
48    ///   - `self.input_dim() == dim.get()`
49    ///   - `self.output_dim() == dim.get()`
50    /// * [`TargetDim::Override`]
51    ///   - `self.input_dim() == dim.get()`
52    ///   - `self.output_dim()`: The value provided by the override.
53    ///
54    /// Sub-sampling occurs if `self.output_dim()` is less than `self.input_dim()`.
55    pub fn new<R>(dim: NonZeroUsize, target_dim: TargetDim, rng: &mut R) -> Self
56    where
57        R: Rng + ?Sized,
58    {
59        let dim = dim.get();
60
61        // There are three cases we need to consider:
62        //
63        // 1. If the target dim is the **same** as `dim`, then our transformation matrix can
64        //    be square.
65        //
66        // 2. If the target dim is **less** than `dim`, then we generate a random `dim x dim`
67        //    matrix and remove the output rows that would not be included in the output,
68        //    resulting in a `target_dim x dim` matrix.
69        //
70        // 3. If the target dim is **greater** than `dim`, then we generate a
71        //    `target_dim x target_dim` random matrix and remove columns to end up with a
72        //    `target_dim x dim` matrix.
73        //
74        //    Removing columns is equivalent to zero padding the original vector up to
75        //    `target_dim` and multiplying by the full matrix.
76        let (target_dim, matrix_dim) = match target_dim {
77            TargetDim::Same | TargetDim::Natural => (dim, dim),
78            TargetDim::Override(target) => {
79                let target_dim = target.get();
80                if target_dim <= dim {
81                    (target_dim, dim)
82                } else {
83                    (target_dim, target_dim)
84                }
85            }
86        };
87
88        // Lint: By construction, the matrix returned from
89        // `diskann_linalg::random_distance_preserving_matrix` will by `matrix_dim x matrix_dim`.
90        #[allow(clippy::unwrap_used)]
91        let initial = diskann_utils::views::Matrix::try_from(
92            diskann_linalg::random_distance_preserving_matrix(matrix_dim, rng).into(),
93            matrix_dim,
94            matrix_dim,
95        )
96        .unwrap();
97
98        // Restructure the matrix as needed to apply the desired sub/super sampling.
99        let transform = match target_dim.cmp(&dim) {
100            std::cmp::Ordering::Equal => initial,
101            std::cmp::Ordering::Less => {
102                let indices = rand::seq::index::sample(rng, dim, target_dim);
103                let scaling = (dim as f32 / target_dim as f32).sqrt();
104
105                let mut transform = diskann_utils::views::Matrix::new(0.0f32, target_dim, dim);
106                std::iter::zip(transform.row_iter_mut(), indices.iter()).for_each(|(ro, ri)| {
107                    std::iter::zip(ro.iter_mut(), initial.row(ri).iter()).for_each(|(o, i)| {
108                        *o = scaling * (*i);
109                    })
110                });
111                transform
112            }
113            std::cmp::Ordering::Greater => {
114                let mut transform = diskann_utils::views::Matrix::new(0.0f32, target_dim, dim);
115                std::iter::zip(transform.row_iter_mut(), initial.row_iter())
116                    .for_each(|(o, i)| o.copy_from_slice(&i[..dim]));
117                transform
118            }
119        };
120
121        Self { transform }
122    }
123
124    /// Return the input dimension for the transformation.
125    pub fn input_dim(&self) -> usize {
126        self.transform.ncols()
127    }
128
129    /// Return the output dimension for the transformation.
130    pub fn output_dim(&self) -> usize {
131        self.transform.nrows()
132    }
133
134    /// Return whether or not the transform preserves norms.
135    ///
136    /// For this transform, norms are not preserved when the output dimensionality is less
137    /// than the input dimensionality.
138    pub fn preserves_norms(&self) -> bool {
139        self.output_dim() >= self.input_dim()
140    }
141
142    /// Perform the transformation of the `src` vector into the `dst` vector.
143    ///
144    /// # Errors
145    ///
146    /// Returns an error if
147    ///
148    /// * `src.len() != self.input_dim()`.
149    /// * `dst.len() != self.output_dim()`.
150    pub fn transform_into(&self, dst: &mut [f32], src: &[f32]) -> Result<(), TransformFailed> {
151        let input_dim = self.input_dim();
152        let output_dim = self.output_dim();
153        check_dims(dst, src, input_dim, output_dim)?;
154        diskann_linalg::sgemm(
155            Transpose::None,
156            Transpose::None,
157            output_dim,
158            1,
159            input_dim,
160            1.0,
161            self.transform.as_slice(),
162            src,
163            None,
164            dst,
165        );
166        Ok(())
167    }
168}
169
170#[cfg(feature = "flatbuffers")]
171#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
172#[derive(Debug, Clone, Copy, Error, PartialEq)]
173#[non_exhaustive]
174pub enum RandomRotationError {
175    #[error("buffer size not product of rows and columns")]
176    IncorrectDim,
177    #[error("number of rows cannot be zero")]
178    RowsZero,
179    #[error("number of cols cannot be zero")]
180    ColsZero,
181}
182
183// Serialization
184#[cfg(feature = "flatbuffers")]
185impl RandomRotation {
186    /// Pack into a [`crate::flatbuffers::transforms::RandomRotation`] serialized
187    /// representation.
188    pub(crate) fn pack<'a, A>(
189        &self,
190        buf: &mut FlatBufferBuilder<'a, A>,
191    ) -> WIPOffset<fb::transforms::RandomRotation<'a>>
192    where
193        A: flatbuffers::Allocator + 'a,
194    {
195        let data = buf.create_vector(self.transform.as_slice());
196
197        fb::transforms::RandomRotation::create(
198            buf,
199            &fb::transforms::RandomRotationArgs {
200                data: Some(data),
201                nrows: self.transform.nrows() as u32,
202                ncols: self.transform.ncols() as u32,
203            },
204        )
205    }
206
207    /// Attempt to unpack from a [`crate::flatbuffers::transforms::RandomRotation`]
208    /// serialized representation, returning any error if encountered.
209    pub(crate) fn try_unpack(
210        proto: fb::transforms::RandomRotation<'_>,
211    ) -> Result<Self, RandomRotationError> {
212        let nrows = proto.nrows();
213        let ncols = proto.ncols();
214        if nrows == 0 {
215            return Err(RandomRotationError::RowsZero);
216        }
217        if ncols == 0 {
218            return Err(RandomRotationError::ColsZero);
219        }
220
221        let data = proto.data().into_iter().collect();
222        let transform =
223            diskann_utils::views::Matrix::try_from(data, nrows as usize, ncols as usize)
224                .map_err(|_| RandomRotationError::IncorrectDim)?;
225
226        Ok(Self { transform })
227    }
228}
229
230///////////
231// Tests //
232///////////
233
234#[cfg(test)]
235mod tests {
236    use diskann_utils::lazy_format;
237    use rand::{rngs::StdRng, SeedableRng};
238
239    use super::*;
240    use crate::{
241        algorithms::transforms::{test_utils, Transform, TransformFailed, TransformKind},
242        alloc::GlobalAllocator,
243        test_util::Check,
244    };
245
246    impl test_utils::Transformer for RandomRotation {
247        fn input_dim_(&self) -> usize {
248            self.input_dim()
249        }
250        fn output_dim_(&self) -> usize {
251            self.output_dim()
252        }
253        fn transform_into_(&self, dst: &mut [f32], src: &[f32]) -> Result<(), TransformFailed> {
254            self.transform_into(dst, src)
255        }
256    }
257
258    #[test]
259    fn test_transform_matrix() {
260        let nonsubsampled_errors = test_utils::ErrorSetup {
261            norm: Check::ulp(10),
262            l2: Check::ulp(10),
263            ip: Check::absrel(2e-5, 1e-4),
264        };
265
266        // Because we're using relatively low dimensions, subsampling yields pretty large
267        // variances. We can't use higher dimensionality, though, because then the tests
268        // would never complete.
269        let subsampled_errors = test_utils::ErrorSetup {
270            norm: Check::absrel(0.0, 0.18),
271            l2: Check::absrel(0.0, 0.18),
272            ip: Check::skip(),
273        };
274
275        let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
276
277        // Combinations of input to output dimensions.
278        let dim_combos = [
279            // Same dimension
280            (15, 15, true, TargetDim::Same, &nonsubsampled_errors),
281            (15, 15, true, TargetDim::Natural, &nonsubsampled_errors),
282            (16, 16, true, TargetDim::Same, &nonsubsampled_errors),
283            (100, 100, true, TargetDim::Same, &nonsubsampled_errors),
284            (100, 100, true, TargetDim::Natural, &nonsubsampled_errors),
285            (256, 256, true, TargetDim::Same, &nonsubsampled_errors),
286            // Super Sampling
287            (15, 20, true, target_dim(20), &nonsubsampled_errors),
288            // Sub Sampling
289            (256, 200, false, target_dim(200), &subsampled_errors),
290        ];
291
292        let trials_per_combo = 20;
293        let trials_per_dim = 50;
294
295        let mut rng = StdRng::seed_from_u64(0x30e37c10c36cc64b);
296        for (input, output, preserves_norms, target, errors) in dim_combos {
297            let input_nz = NonZeroUsize::new(input).unwrap();
298            for trial in 0..trials_per_combo {
299                let ctx = &lazy_format!(
300                    "input dim = {}, output dim = {}, macro trial {} of {}",
301                    input,
302                    output,
303                    trial,
304                    trials_per_combo
305                );
306
307                let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
308                    test_utils::check_errors(io, context, errors);
309                };
310
311                // Clone the Rng state so the abstract transform behaves the same.
312                let mut rng_clone = rng.clone();
313
314                // Test the underlying transformer.
315                {
316                    let transformer =
317                        RandomRotation::new(NonZeroUsize::new(input).unwrap(), target, &mut rng);
318                    assert_eq!(transformer.input_dim(), input, "{}", ctx);
319                    assert_eq!(transformer.output_dim(), output, "{}", ctx);
320                    assert_eq!(transformer.preserves_norms(), preserves_norms, "{}", ctx);
321
322                    test_utils::test_transform(
323                        &transformer,
324                        trials_per_dim,
325                        &mut checker,
326                        &mut rng,
327                        ctx,
328                    );
329                }
330
331                // Abstract Transformer
332                {
333                    let kind = TransformKind::RandomRotation { target_dim: target };
334                    let transformer =
335                        Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
336                            .unwrap();
337
338                    assert_eq!(transformer.input_dim(), input);
339                    assert_eq!(transformer.output_dim(), output);
340                    assert_eq!(transformer.preserves_norms(), preserves_norms);
341
342                    test_utils::test_transform(
343                        &transformer,
344                        trials_per_dim,
345                        &mut checker,
346                        &mut rng_clone,
347                        ctx,
348                    )
349                }
350            }
351        }
352    }
353
354    #[cfg(feature = "flatbuffers")]
355    mod serialization {
356        use super::*;
357        use crate::flatbuffers::to_flatbuffer;
358
359        #[test]
360        fn random_rotation() {
361            let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
362
363            // Test various dimension combinations
364            let test_cases = [
365                (5, TargetDim::Same),
366                (10, TargetDim::Natural),
367                (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
368                (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
369            ];
370
371            for (dim, target_dim) in test_cases {
372                let transform =
373                    RandomRotation::new(NonZeroUsize::new(dim).unwrap(), target_dim, &mut rng);
374                let data = to_flatbuffer(|buf| transform.pack(buf));
375
376                let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
377                let reloaded = RandomRotation::try_unpack(proto).unwrap();
378
379                assert_eq!(transform, reloaded);
380            }
381
382            // Test error cases for invalid dimensions
383            {
384                let data = to_flatbuffer(|buf| {
385                    let data = buf.create_vector::<f32>(&[1.0, 0.0, 0.0, 1.0]); // 2x2 matrix
386                    fb::transforms::RandomRotation::create(
387                        buf,
388                        &fb::transforms::RandomRotationArgs {
389                            data: Some(data),
390                            nrows: 0, // Invalid: zero rows
391                            ncols: 2,
392                        },
393                    )
394                });
395
396                let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
397                let err = RandomRotation::try_unpack(proto).unwrap_err();
398                assert_eq!(err, RandomRotationError::RowsZero);
399            }
400
401            {
402                let data = to_flatbuffer(|buf| {
403                    let data = buf.create_vector::<f32>(&[1.0, 0.0, 0.0, 1.0]);
404                    fb::transforms::RandomRotation::create(
405                        buf,
406                        &fb::transforms::RandomRotationArgs {
407                            data: Some(data), // 2x2 matrix
408                            nrows: 2,
409                            ncols: 0, // Invalid: zero cols
410                        },
411                    )
412                });
413
414                let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
415                let err = RandomRotation::try_unpack(proto).unwrap_err();
416                assert_eq!(err, RandomRotationError::ColsZero);
417            }
418
419            {
420                let data = to_flatbuffer(|buf| {
421                    let data = buf.create_vector::<f32>(&[1.0, 0.0, 0.0]); // 3 elements
422                    fb::transforms::RandomRotation::create(
423                        buf,
424                        &fb::transforms::RandomRotationArgs {
425                            data: Some(data),
426                            nrows: 2,
427                            ncols: 2, // Should be 4 elements for 2x2 matrix
428                        },
429                    )
430                });
431
432                let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
433                let err = RandomRotation::try_unpack(proto).unwrap_err();
434                assert_eq!(err, RandomRotationError::IncorrectDim);
435            }
436        }
437    }
438}