1use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use diskann_linalg::{self, Transpose};
11use rand::Rng;
12#[cfg(feature = "flatbuffers")]
13use thiserror::Error;
14
15use super::{
16 utils::{check_dims, TransformFailed},
17 TargetDim,
18};
19#[cfg(feature = "flatbuffers")]
20use crate::flatbuffers as fb;
21
22#[derive(Debug, Clone)]
31#[cfg_attr(test, derive(PartialEq))]
32pub struct RandomRotation {
33 transform: diskann_utils::views::Matrix<f32>,
35}
36
37impl RandomRotation {
38 pub fn new<R>(dim: NonZeroUsize, target_dim: TargetDim, rng: &mut R) -> Self
56 where
57 R: Rng + ?Sized,
58 {
59 let dim = dim.get();
60
61 let (target_dim, matrix_dim) = match target_dim {
77 TargetDim::Same | TargetDim::Natural => (dim, dim),
78 TargetDim::Override(target) => {
79 let target_dim = target.get();
80 if target_dim <= dim {
81 (target_dim, dim)
82 } else {
83 (target_dim, target_dim)
84 }
85 }
86 };
87
88 #[allow(clippy::unwrap_used)]
91 let initial = diskann_utils::views::Matrix::try_from(
92 diskann_linalg::random_distance_preserving_matrix(matrix_dim, rng).into(),
93 matrix_dim,
94 matrix_dim,
95 )
96 .unwrap();
97
98 let transform = match target_dim.cmp(&dim) {
100 std::cmp::Ordering::Equal => initial,
101 std::cmp::Ordering::Less => {
102 let indices = rand::seq::index::sample(rng, dim, target_dim);
103 let scaling = (dim as f32 / target_dim as f32).sqrt();
104
105 let mut transform = diskann_utils::views::Matrix::new(0.0f32, target_dim, dim);
106 std::iter::zip(transform.row_iter_mut(), indices.iter()).for_each(|(ro, ri)| {
107 std::iter::zip(ro.iter_mut(), initial.row(ri).iter()).for_each(|(o, i)| {
108 *o = scaling * (*i);
109 })
110 });
111 transform
112 }
113 std::cmp::Ordering::Greater => {
114 let mut transform = diskann_utils::views::Matrix::new(0.0f32, target_dim, dim);
115 std::iter::zip(transform.row_iter_mut(), initial.row_iter())
116 .for_each(|(o, i)| o.copy_from_slice(&i[..dim]));
117 transform
118 }
119 };
120
121 Self { transform }
122 }
123
124 pub fn input_dim(&self) -> usize {
126 self.transform.ncols()
127 }
128
129 pub fn output_dim(&self) -> usize {
131 self.transform.nrows()
132 }
133
134 pub fn preserves_norms(&self) -> bool {
139 self.output_dim() >= self.input_dim()
140 }
141
142 pub fn transform_into(&self, dst: &mut [f32], src: &[f32]) -> Result<(), TransformFailed> {
151 let input_dim = self.input_dim();
152 let output_dim = self.output_dim();
153 check_dims(dst, src, input_dim, output_dim)?;
154 diskann_linalg::sgemm(
155 Transpose::None,
156 Transpose::None,
157 output_dim,
158 1,
159 input_dim,
160 1.0,
161 self.transform.as_slice(),
162 src,
163 None,
164 dst,
165 );
166 Ok(())
167 }
168}
169
170#[cfg(feature = "flatbuffers")]
171#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
172#[derive(Debug, Clone, Copy, Error, PartialEq)]
173#[non_exhaustive]
174pub enum RandomRotationError {
175 #[error("buffer size not product of rows and columns")]
176 IncorrectDim,
177 #[error("number of rows cannot be zero")]
178 RowsZero,
179 #[error("number of cols cannot be zero")]
180 ColsZero,
181}
182
183#[cfg(feature = "flatbuffers")]
185impl RandomRotation {
186 pub(crate) fn pack<'a, A>(
189 &self,
190 buf: &mut FlatBufferBuilder<'a, A>,
191 ) -> WIPOffset<fb::transforms::RandomRotation<'a>>
192 where
193 A: flatbuffers::Allocator + 'a,
194 {
195 let data = buf.create_vector(self.transform.as_slice());
196
197 fb::transforms::RandomRotation::create(
198 buf,
199 &fb::transforms::RandomRotationArgs {
200 data: Some(data),
201 nrows: self.transform.nrows() as u32,
202 ncols: self.transform.ncols() as u32,
203 },
204 )
205 }
206
207 pub(crate) fn try_unpack(
210 proto: fb::transforms::RandomRotation<'_>,
211 ) -> Result<Self, RandomRotationError> {
212 let nrows = proto.nrows();
213 let ncols = proto.ncols();
214 if nrows == 0 {
215 return Err(RandomRotationError::RowsZero);
216 }
217 if ncols == 0 {
218 return Err(RandomRotationError::ColsZero);
219 }
220
221 let data = proto.data().into_iter().collect();
222 let transform =
223 diskann_utils::views::Matrix::try_from(data, nrows as usize, ncols as usize)
224 .map_err(|_| RandomRotationError::IncorrectDim)?;
225
226 Ok(Self { transform })
227 }
228}
229
230#[cfg(test)]
235mod tests {
236 use diskann_utils::lazy_format;
237 use rand::{rngs::StdRng, SeedableRng};
238
239 use super::*;
240 use crate::{
241 algorithms::transforms::{test_utils, Transform, TransformFailed, TransformKind},
242 alloc::GlobalAllocator,
243 };
244
245 impl test_utils::Transformer for RandomRotation {
246 fn input_dim_(&self) -> usize {
247 self.input_dim()
248 }
249 fn output_dim_(&self) -> usize {
250 self.output_dim()
251 }
252 fn transform_into_(&self, dst: &mut [f32], src: &[f32]) -> Result<(), TransformFailed> {
253 self.transform_into(dst, src)
254 }
255 }
256
257 #[test]
258 fn test_transform_matrix() {
259 let nonsubsampled_errors = test_utils::ErrorSetup {
260 norm: test_utils::Check::ulp(10),
261 l2: test_utils::Check::ulp(10),
262 ip: test_utils::Check::absrel(2e-5, 1e-4),
263 };
264
265 let subsampled_errors = test_utils::ErrorSetup {
269 norm: test_utils::Check::absrel(0.0, 0.18),
270 l2: test_utils::Check::absrel(0.0, 0.18),
271 ip: test_utils::Check::skip(),
272 };
273
274 let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
275
276 let dim_combos = [
278 (15, 15, true, TargetDim::Same, &nonsubsampled_errors),
280 (15, 15, true, TargetDim::Natural, &nonsubsampled_errors),
281 (16, 16, true, TargetDim::Same, &nonsubsampled_errors),
282 (100, 100, true, TargetDim::Same, &nonsubsampled_errors),
283 (100, 100, true, TargetDim::Natural, &nonsubsampled_errors),
284 (256, 256, true, TargetDim::Same, &nonsubsampled_errors),
285 (15, 20, true, target_dim(20), &nonsubsampled_errors),
287 (256, 200, false, target_dim(200), &subsampled_errors),
289 ];
290
291 let trials_per_combo = 20;
292 let trials_per_dim = 50;
293
294 let mut rng = StdRng::seed_from_u64(0x30e37c10c36cc64b);
295 for (input, output, preserves_norms, target, errors) in dim_combos {
296 let input_nz = NonZeroUsize::new(input).unwrap();
297 for trial in 0..trials_per_combo {
298 let ctx = &lazy_format!(
299 "input dim = {}, output dim = {}, macro trial {} of {}",
300 input,
301 output,
302 trial,
303 trials_per_combo
304 );
305
306 let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
307 test_utils::check_errors(io, context, errors);
308 };
309
310 let mut rng_clone = rng.clone();
312
313 {
315 let transformer =
316 RandomRotation::new(NonZeroUsize::new(input).unwrap(), target, &mut rng);
317 assert_eq!(transformer.input_dim(), input, "{}", ctx);
318 assert_eq!(transformer.output_dim(), output, "{}", ctx);
319 assert_eq!(transformer.preserves_norms(), preserves_norms, "{}", ctx);
320
321 test_utils::test_transform(
322 &transformer,
323 trials_per_dim,
324 &mut checker,
325 &mut rng,
326 ctx,
327 );
328 }
329
330 {
332 let kind = TransformKind::RandomRotation { target_dim: target };
333 let transformer =
334 Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
335 .unwrap();
336
337 assert_eq!(transformer.input_dim(), input);
338 assert_eq!(transformer.output_dim(), output);
339 assert_eq!(transformer.preserves_norms(), preserves_norms);
340
341 test_utils::test_transform(
342 &transformer,
343 trials_per_dim,
344 &mut checker,
345 &mut rng_clone,
346 ctx,
347 )
348 }
349 }
350 }
351 }
352
353 #[cfg(feature = "flatbuffers")]
354 mod serialization {
355 use super::*;
356 use crate::flatbuffers::to_flatbuffer;
357
358 #[test]
359 fn random_rotation() {
360 let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
361
362 let test_cases = [
364 (5, TargetDim::Same),
365 (10, TargetDim::Natural),
366 (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
367 (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
368 ];
369
370 for (dim, target_dim) in test_cases {
371 let transform =
372 RandomRotation::new(NonZeroUsize::new(dim).unwrap(), target_dim, &mut rng);
373 let data = to_flatbuffer(|buf| transform.pack(buf));
374
375 let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
376 let reloaded = RandomRotation::try_unpack(proto).unwrap();
377
378 assert_eq!(transform, reloaded);
379 }
380
381 {
383 let data = to_flatbuffer(|buf| {
384 let data = buf.create_vector::<f32>(&[1.0, 0.0, 0.0, 1.0]); fb::transforms::RandomRotation::create(
386 buf,
387 &fb::transforms::RandomRotationArgs {
388 data: Some(data),
389 nrows: 0, ncols: 2,
391 },
392 )
393 });
394
395 let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
396 let err = RandomRotation::try_unpack(proto).unwrap_err();
397 assert_eq!(err, RandomRotationError::RowsZero);
398 }
399
400 {
401 let data = to_flatbuffer(|buf| {
402 let data = buf.create_vector::<f32>(&[1.0, 0.0, 0.0, 1.0]);
403 fb::transforms::RandomRotation::create(
404 buf,
405 &fb::transforms::RandomRotationArgs {
406 data: Some(data), nrows: 2,
408 ncols: 0, },
410 )
411 });
412
413 let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
414 let err = RandomRotation::try_unpack(proto).unwrap_err();
415 assert_eq!(err, RandomRotationError::ColsZero);
416 }
417
418 {
419 let data = to_flatbuffer(|buf| {
420 let data = buf.create_vector::<f32>(&[1.0, 0.0, 0.0]); fb::transforms::RandomRotation::create(
422 buf,
423 &fb::transforms::RandomRotationArgs {
424 data: Some(data),
425 nrows: 2,
426 ncols: 2, },
428 )
429 });
430
431 let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
432 let err = RandomRotation::try_unpack(proto).unwrap_err();
433 assert_eq!(err, RandomRotationError::IncorrectDim);
434 }
435 }
436 }
437}