Skip to main content

diskann_quantization/algorithms/transforms/
random_rotation.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use diskann_linalg::{self, Transpose};
11use rand::Rng;
12#[cfg(feature = "flatbuffers")]
13use thiserror::Error;
14
15use super::{
16    utils::{check_dims, TransformFailed},
17    TargetDim,
18};
19#[cfg(feature = "flatbuffers")]
20use crate::flatbuffers as fb;
21
22//////////////////////
23// Dense Transforms //
24//////////////////////
25
26/// A distance-perserving transformation from `N`-dimensions to `N`-dimensions.
27///
28/// This struct materializes a full `NxN` transformation matrix and mechanically applies
29/// transformations via matrix multiplication.
30#[derive(Debug, Clone)]
31#[cfg_attr(test, derive(PartialEq))]
32pub struct RandomRotation {
33    /// This data structure maintains the invariant that this **must** be a square matrix.
34    transform: diskann_utils::views::Matrix<f32>,
35}
36
37impl RandomRotation {
38    /// Construct a new `RandomRotation` that transforms input vectors of dimension `dim`.
39    ///
40    /// The parameter `rng` is used to randomly sample the transformation matrix.
41    ///
42    /// The following dimensionalities will be configured depending on the value of `target`:
43    ///
44    /// * [`TargetDim::Same`]
45    ///   - `self.input_dim() == dim.get()`
46    ///   - `self.output_dim() == dim.get()`
47    /// * [`TargetDim::Natural`]
48    ///   - `self.input_dim() == dim.get()`
49    ///   - `self.output_dim() == dim.get()`
50    /// * [`TargetDim::Override`]
51    ///   - `self.input_dim() == dim.get()`
52    ///   - `self.output_dim()`: The value provided by the override.
53    ///
54    /// Sub-sampling occurs if `self.output_dim()` is less than `self.input_dim()`.
55    pub fn new<R>(dim: NonZeroUsize, target_dim: TargetDim, rng: &mut R) -> Self
56    where
57        R: Rng + ?Sized,
58    {
59        let dim = dim.get();
60
61        // There are three cases we need to consider:
62        //
63        // 1. If the target dim is the **same** as `dim`, then our transformation matrix can
64        //    be square.
65        //
66        // 2. If the target dim is **less** than `dim`, then we generate a random `dim x dim`
67        //    matrix and remove the output rows that would not be included in the output,
68        //    resulting in a `target_dim x dim` matrix.
69        //
70        // 3. If the target dim is **greater** than `dim`, then we generate a
71        //    `target_dim x target_dim` random matrix and remove columns to end up with a
72        //    `target_dim x dim` matrix.
73        //
74        //    Removing columns is equivalent to zero padding the original vector up to
75        //    `target_dim` and multiplying by the full matrix.
76        let (target_dim, matrix_dim) = match target_dim {
77            TargetDim::Same | TargetDim::Natural => (dim, dim),
78            TargetDim::Override(target) => {
79                let target_dim = target.get();
80                if target_dim <= dim {
81                    (target_dim, dim)
82                } else {
83                    (target_dim, target_dim)
84                }
85            }
86        };
87
88        // Lint: By construction, the matrix returned from
89        // `diskann_linalg::random_distance_preserving_matrix` will by `matrix_dim x matrix_dim`.
90        #[allow(clippy::unwrap_used)]
91        let initial = diskann_utils::views::Matrix::try_from(
92            diskann_linalg::random_distance_preserving_matrix(matrix_dim, rng).into(),
93            matrix_dim,
94            matrix_dim,
95        )
96        .unwrap();
97
98        // Restructure the matrix as needed to apply the desired sub/super sampling.
99        let transform = match target_dim.cmp(&dim) {
100            std::cmp::Ordering::Equal => initial,
101            std::cmp::Ordering::Less => {
102                let indices = rand::seq::index::sample(rng, dim, target_dim);
103                let scaling = (dim as f32 / target_dim as f32).sqrt();
104
105                let mut transform = diskann_utils::views::Matrix::new(0.0f32, target_dim, dim);
106                std::iter::zip(transform.row_iter_mut(), indices.iter()).for_each(|(ro, ri)| {
107                    std::iter::zip(ro.iter_mut(), initial.row(ri).iter()).for_each(|(o, i)| {
108                        *o = scaling * (*i);
109                    })
110                });
111                transform
112            }
113            std::cmp::Ordering::Greater => {
114                let mut transform = diskann_utils::views::Matrix::new(0.0f32, target_dim, dim);
115                std::iter::zip(transform.row_iter_mut(), initial.row_iter())
116                    .for_each(|(o, i)| o.copy_from_slice(&i[..dim]));
117                transform
118            }
119        };
120
121        Self { transform }
122    }
123
124    /// Return the input dimension for the transformation.
125    pub fn input_dim(&self) -> usize {
126        self.transform.ncols()
127    }
128
129    /// Return the output dimension for the transformation.
130    pub fn output_dim(&self) -> usize {
131        self.transform.nrows()
132    }
133
134    /// Return whether or not the transform preserves norms.
135    ///
136    /// For this transform, norms are not preserved when the output dimensionality is less
137    /// than the input dimensionality.
138    pub fn preserves_norms(&self) -> bool {
139        self.output_dim() >= self.input_dim()
140    }
141
142    /// Perform the transformation of the `src` vector into the `dst` vector.
143    ///
144    /// # Errors
145    ///
146    /// Returns an error if
147    ///
148    /// * `src.len() != self.input_dim()`.
149    /// * `dst.len() != self.output_dim()`.
150    pub fn transform_into(&self, dst: &mut [f32], src: &[f32]) -> Result<(), TransformFailed> {
151        let input_dim = self.input_dim();
152        let output_dim = self.output_dim();
153        check_dims(dst, src, input_dim, output_dim)?;
154        diskann_linalg::sgemm(
155            Transpose::None,
156            Transpose::None,
157            output_dim,
158            1,
159            input_dim,
160            1.0,
161            self.transform.as_slice(),
162            src,
163            None,
164            dst,
165        );
166        Ok(())
167    }
168}
169
170#[cfg(feature = "flatbuffers")]
171#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
172#[derive(Debug, Clone, Copy, Error, PartialEq)]
173#[non_exhaustive]
174pub enum RandomRotationError {
175    #[error("buffer size not product of rows and columns")]
176    IncorrectDim,
177    #[error("number of rows cannot be zero")]
178    RowsZero,
179    #[error("number of cols cannot be zero")]
180    ColsZero,
181}
182
183// Serialization
184#[cfg(feature = "flatbuffers")]
185impl RandomRotation {
186    /// Pack into a [`crate::flatbuffers::transforms::RandomRotation`] serialized
187    /// representation.
188    pub(crate) fn pack<'a, A>(
189        &self,
190        buf: &mut FlatBufferBuilder<'a, A>,
191    ) -> WIPOffset<fb::transforms::RandomRotation<'a>>
192    where
193        A: flatbuffers::Allocator + 'a,
194    {
195        let data = buf.create_vector(self.transform.as_slice());
196
197        fb::transforms::RandomRotation::create(
198            buf,
199            &fb::transforms::RandomRotationArgs {
200                data: Some(data),
201                nrows: self.transform.nrows() as u32,
202                ncols: self.transform.ncols() as u32,
203            },
204        )
205    }
206
207    /// Attempt to unpack from a [`crate::flatbuffers::transforms::RandomRotation`]
208    /// serialized representation, returning any error if encountered.
209    pub(crate) fn try_unpack(
210        proto: fb::transforms::RandomRotation<'_>,
211    ) -> Result<Self, RandomRotationError> {
212        let nrows = proto.nrows();
213        let ncols = proto.ncols();
214        if nrows == 0 {
215            return Err(RandomRotationError::RowsZero);
216        }
217        if ncols == 0 {
218            return Err(RandomRotationError::ColsZero);
219        }
220
221        let data = proto.data().into_iter().collect();
222        let transform =
223            diskann_utils::views::Matrix::try_from(data, nrows as usize, ncols as usize)
224                .map_err(|_| RandomRotationError::IncorrectDim)?;
225
226        Ok(Self { transform })
227    }
228}
229
230///////////
231// Tests //
232///////////
233
234#[cfg(test)]
235mod tests {
236    use diskann_utils::lazy_format;
237    use rand::{rngs::StdRng, SeedableRng};
238
239    use super::*;
240    use crate::{
241        algorithms::transforms::{test_utils, Transform, TransformFailed, TransformKind},
242        alloc::GlobalAllocator,
243    };
244
245    impl test_utils::Transformer for RandomRotation {
246        fn input_dim_(&self) -> usize {
247            self.input_dim()
248        }
249        fn output_dim_(&self) -> usize {
250            self.output_dim()
251        }
252        fn transform_into_(&self, dst: &mut [f32], src: &[f32]) -> Result<(), TransformFailed> {
253            self.transform_into(dst, src)
254        }
255    }
256
257    #[test]
258    fn test_transform_matrix() {
259        let nonsubsampled_errors = test_utils::ErrorSetup {
260            norm: test_utils::Check::ulp(10),
261            l2: test_utils::Check::ulp(10),
262            ip: test_utils::Check::absrel(2e-5, 1e-4),
263        };
264
265        // Because we're using relatively low dimensions, subsampling yields pretty large
266        // variances. We can't use higher dimensionality, though, because then the tests
267        // would never complete.
268        let subsampled_errors = test_utils::ErrorSetup {
269            norm: test_utils::Check::absrel(0.0, 0.18),
270            l2: test_utils::Check::absrel(0.0, 0.18),
271            ip: test_utils::Check::skip(),
272        };
273
274        let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
275
276        // Combinations of input to output dimensions.
277        let dim_combos = [
278            // Same dimension
279            (15, 15, true, TargetDim::Same, &nonsubsampled_errors),
280            (15, 15, true, TargetDim::Natural, &nonsubsampled_errors),
281            (16, 16, true, TargetDim::Same, &nonsubsampled_errors),
282            (100, 100, true, TargetDim::Same, &nonsubsampled_errors),
283            (100, 100, true, TargetDim::Natural, &nonsubsampled_errors),
284            (256, 256, true, TargetDim::Same, &nonsubsampled_errors),
285            // Super Sampling
286            (15, 20, true, target_dim(20), &nonsubsampled_errors),
287            // Sub Sampling
288            (256, 200, false, target_dim(200), &subsampled_errors),
289        ];
290
291        let trials_per_combo = 20;
292        let trials_per_dim = 50;
293
294        let mut rng = StdRng::seed_from_u64(0x30e37c10c36cc64b);
295        for (input, output, preserves_norms, target, errors) in dim_combos {
296            let input_nz = NonZeroUsize::new(input).unwrap();
297            for trial in 0..trials_per_combo {
298                let ctx = &lazy_format!(
299                    "input dim = {}, output dim = {}, macro trial {} of {}",
300                    input,
301                    output,
302                    trial,
303                    trials_per_combo
304                );
305
306                let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
307                    test_utils::check_errors(io, context, errors);
308                };
309
310                // Clone the Rng state so the abstract transform behaves the same.
311                let mut rng_clone = rng.clone();
312
313                // Test the underlying transformer.
314                {
315                    let transformer =
316                        RandomRotation::new(NonZeroUsize::new(input).unwrap(), target, &mut rng);
317                    assert_eq!(transformer.input_dim(), input, "{}", ctx);
318                    assert_eq!(transformer.output_dim(), output, "{}", ctx);
319                    assert_eq!(transformer.preserves_norms(), preserves_norms, "{}", ctx);
320
321                    test_utils::test_transform(
322                        &transformer,
323                        trials_per_dim,
324                        &mut checker,
325                        &mut rng,
326                        ctx,
327                    );
328                }
329
330                // Abstract Transformer
331                {
332                    let kind = TransformKind::RandomRotation { target_dim: target };
333                    let transformer =
334                        Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
335                            .unwrap();
336
337                    assert_eq!(transformer.input_dim(), input);
338                    assert_eq!(transformer.output_dim(), output);
339                    assert_eq!(transformer.preserves_norms(), preserves_norms);
340
341                    test_utils::test_transform(
342                        &transformer,
343                        trials_per_dim,
344                        &mut checker,
345                        &mut rng_clone,
346                        ctx,
347                    )
348                }
349            }
350        }
351    }
352
353    #[cfg(feature = "flatbuffers")]
354    mod serialization {
355        use super::*;
356        use crate::flatbuffers::to_flatbuffer;
357
358        #[test]
359        fn random_rotation() {
360            let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
361
362            // Test various dimension combinations
363            let test_cases = [
364                (5, TargetDim::Same),
365                (10, TargetDim::Natural),
366                (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
367                (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
368            ];
369
370            for (dim, target_dim) in test_cases {
371                let transform =
372                    RandomRotation::new(NonZeroUsize::new(dim).unwrap(), target_dim, &mut rng);
373                let data = to_flatbuffer(|buf| transform.pack(buf));
374
375                let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
376                let reloaded = RandomRotation::try_unpack(proto).unwrap();
377
378                assert_eq!(transform, reloaded);
379            }
380
381            // Test error cases for invalid dimensions
382            {
383                let data = to_flatbuffer(|buf| {
384                    let data = buf.create_vector::<f32>(&[1.0, 0.0, 0.0, 1.0]); // 2x2 matrix
385                    fb::transforms::RandomRotation::create(
386                        buf,
387                        &fb::transforms::RandomRotationArgs {
388                            data: Some(data),
389                            nrows: 0, // Invalid: zero rows
390                            ncols: 2,
391                        },
392                    )
393                });
394
395                let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
396                let err = RandomRotation::try_unpack(proto).unwrap_err();
397                assert_eq!(err, RandomRotationError::RowsZero);
398            }
399
400            {
401                let data = to_flatbuffer(|buf| {
402                    let data = buf.create_vector::<f32>(&[1.0, 0.0, 0.0, 1.0]);
403                    fb::transforms::RandomRotation::create(
404                        buf,
405                        &fb::transforms::RandomRotationArgs {
406                            data: Some(data), // 2x2 matrix
407                            nrows: 2,
408                            ncols: 0, // Invalid: zero cols
409                        },
410                    )
411                });
412
413                let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
414                let err = RandomRotation::try_unpack(proto).unwrap_err();
415                assert_eq!(err, RandomRotationError::ColsZero);
416            }
417
418            {
419                let data = to_flatbuffer(|buf| {
420                    let data = buf.create_vector::<f32>(&[1.0, 0.0, 0.0]); // 3 elements
421                    fb::transforms::RandomRotation::create(
422                        buf,
423                        &fb::transforms::RandomRotationArgs {
424                            data: Some(data),
425                            nrows: 2,
426                            ncols: 2, // Should be 4 elements for 2x2 matrix
427                        },
428                    )
429                });
430
431                let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
432                let err = RandomRotation::try_unpack(proto).unwrap_err();
433                assert_eq!(err, RandomRotationError::IncorrectDim);
434            }
435        }
436    }
437}