1use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use diskann_linalg::{self, Transpose};
11use rand::Rng;
12#[cfg(feature = "flatbuffers")]
13use thiserror::Error;
14
15use super::{
16 utils::{check_dims, TransformFailed},
17 TargetDim,
18};
19#[cfg(feature = "flatbuffers")]
20use crate::flatbuffers as fb;
21
22#[derive(Debug, Clone)]
31#[cfg_attr(test, derive(PartialEq))]
32pub struct RandomRotation {
33 transform: diskann_utils::views::Matrix<f32>,
35}
36
37impl RandomRotation {
38 pub fn new<R>(dim: NonZeroUsize, target_dim: TargetDim, rng: &mut R) -> Self
56 where
57 R: Rng + ?Sized,
58 {
59 let dim = dim.get();
60
61 let (target_dim, matrix_dim) = match target_dim {
77 TargetDim::Same | TargetDim::Natural => (dim, dim),
78 TargetDim::Override(target) => {
79 let target_dim = target.get();
80 if target_dim <= dim {
81 (target_dim, dim)
82 } else {
83 (target_dim, target_dim)
84 }
85 }
86 };
87
88 #[allow(clippy::unwrap_used)]
91 let initial = diskann_utils::views::Matrix::try_from(
92 diskann_linalg::random_distance_preserving_matrix(matrix_dim, rng).into(),
93 matrix_dim,
94 matrix_dim,
95 )
96 .unwrap();
97
98 let transform = match target_dim.cmp(&dim) {
100 std::cmp::Ordering::Equal => initial,
101 std::cmp::Ordering::Less => {
102 let indices = rand::seq::index::sample(rng, dim, target_dim);
103 let scaling = (dim as f32 / target_dim as f32).sqrt();
104
105 let mut transform = diskann_utils::views::Matrix::new(0.0f32, target_dim, dim);
106 std::iter::zip(transform.row_iter_mut(), indices.iter()).for_each(|(ro, ri)| {
107 std::iter::zip(ro.iter_mut(), initial.row(ri).iter()).for_each(|(o, i)| {
108 *o = scaling * (*i);
109 })
110 });
111 transform
112 }
113 std::cmp::Ordering::Greater => {
114 let mut transform = diskann_utils::views::Matrix::new(0.0f32, target_dim, dim);
115 std::iter::zip(transform.row_iter_mut(), initial.row_iter())
116 .for_each(|(o, i)| o.copy_from_slice(&i[..dim]));
117 transform
118 }
119 };
120
121 Self { transform }
122 }
123
124 pub fn input_dim(&self) -> usize {
126 self.transform.ncols()
127 }
128
129 pub fn output_dim(&self) -> usize {
131 self.transform.nrows()
132 }
133
134 pub fn preserves_norms(&self) -> bool {
139 self.output_dim() >= self.input_dim()
140 }
141
142 pub fn transform_into(&self, dst: &mut [f32], src: &[f32]) -> Result<(), TransformFailed> {
151 let input_dim = self.input_dim();
152 let output_dim = self.output_dim();
153 check_dims(dst, src, input_dim, output_dim)?;
154 diskann_linalg::sgemm(
155 Transpose::None,
156 Transpose::None,
157 output_dim,
158 1,
159 input_dim,
160 1.0,
161 self.transform.as_slice(),
162 src,
163 None,
164 dst,
165 );
166 Ok(())
167 }
168}
169
170#[cfg(feature = "flatbuffers")]
171#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
172#[derive(Debug, Clone, Copy, Error, PartialEq)]
173#[non_exhaustive]
174pub enum RandomRotationError {
175 #[error("buffer size not product of rows and columns")]
176 IncorrectDim,
177 #[error("number of rows cannot be zero")]
178 RowsZero,
179 #[error("number of cols cannot be zero")]
180 ColsZero,
181}
182
183#[cfg(feature = "flatbuffers")]
185impl RandomRotation {
186 pub(crate) fn pack<'a, A>(
189 &self,
190 buf: &mut FlatBufferBuilder<'a, A>,
191 ) -> WIPOffset<fb::transforms::RandomRotation<'a>>
192 where
193 A: flatbuffers::Allocator + 'a,
194 {
195 let data = buf.create_vector(self.transform.as_slice());
196
197 fb::transforms::RandomRotation::create(
198 buf,
199 &fb::transforms::RandomRotationArgs {
200 data: Some(data),
201 nrows: self.transform.nrows() as u32,
202 ncols: self.transform.ncols() as u32,
203 },
204 )
205 }
206
207 pub(crate) fn try_unpack(
210 proto: fb::transforms::RandomRotation<'_>,
211 ) -> Result<Self, RandomRotationError> {
212 let nrows = proto.nrows();
213 let ncols = proto.ncols();
214 if nrows == 0 {
215 return Err(RandomRotationError::RowsZero);
216 }
217 if ncols == 0 {
218 return Err(RandomRotationError::ColsZero);
219 }
220
221 let data = proto.data().into_iter().collect();
222 let transform =
223 diskann_utils::views::Matrix::try_from(data, nrows as usize, ncols as usize)
224 .map_err(|_| RandomRotationError::IncorrectDim)?;
225
226 Ok(Self { transform })
227 }
228}
229
230#[cfg(test)]
235mod tests {
236 use diskann_utils::lazy_format;
237 use rand::{rngs::StdRng, SeedableRng};
238
239 use super::*;
240 use crate::{
241 algorithms::transforms::{test_utils, Transform, TransformFailed, TransformKind},
242 alloc::GlobalAllocator,
243 test_util::Check,
244 };
245
246 impl test_utils::Transformer for RandomRotation {
247 fn input_dim_(&self) -> usize {
248 self.input_dim()
249 }
250 fn output_dim_(&self) -> usize {
251 self.output_dim()
252 }
253 fn transform_into_(&self, dst: &mut [f32], src: &[f32]) -> Result<(), TransformFailed> {
254 self.transform_into(dst, src)
255 }
256 }
257
258 #[test]
259 fn test_transform_matrix() {
260 let nonsubsampled_errors = test_utils::ErrorSetup {
261 norm: Check::ulp(10),
262 l2: Check::ulp(10),
263 ip: Check::absrel(2e-5, 1e-4),
264 };
265
266 let subsampled_errors = test_utils::ErrorSetup {
270 norm: Check::absrel(0.0, 0.18),
271 l2: Check::absrel(0.0, 0.18),
272 ip: Check::skip(),
273 };
274
275 let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
276
277 let dim_combos = [
279 (15, 15, true, TargetDim::Same, &nonsubsampled_errors),
281 (15, 15, true, TargetDim::Natural, &nonsubsampled_errors),
282 (16, 16, true, TargetDim::Same, &nonsubsampled_errors),
283 (100, 100, true, TargetDim::Same, &nonsubsampled_errors),
284 (100, 100, true, TargetDim::Natural, &nonsubsampled_errors),
285 (256, 256, true, TargetDim::Same, &nonsubsampled_errors),
286 (15, 20, true, target_dim(20), &nonsubsampled_errors),
288 (256, 200, false, target_dim(200), &subsampled_errors),
290 ];
291
292 let trials_per_combo = 20;
293 let trials_per_dim = 50;
294
295 let mut rng = StdRng::seed_from_u64(0x30e37c10c36cc64b);
296 for (input, output, preserves_norms, target, errors) in dim_combos {
297 let input_nz = NonZeroUsize::new(input).unwrap();
298 for trial in 0..trials_per_combo {
299 let ctx = &lazy_format!(
300 "input dim = {}, output dim = {}, macro trial {} of {}",
301 input,
302 output,
303 trial,
304 trials_per_combo
305 );
306
307 let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
308 test_utils::check_errors(io, context, errors);
309 };
310
311 let mut rng_clone = rng.clone();
313
314 {
316 let transformer =
317 RandomRotation::new(NonZeroUsize::new(input).unwrap(), target, &mut rng);
318 assert_eq!(transformer.input_dim(), input, "{}", ctx);
319 assert_eq!(transformer.output_dim(), output, "{}", ctx);
320 assert_eq!(transformer.preserves_norms(), preserves_norms, "{}", ctx);
321
322 test_utils::test_transform(
323 &transformer,
324 trials_per_dim,
325 &mut checker,
326 &mut rng,
327 ctx,
328 );
329 }
330
331 {
333 let kind = TransformKind::RandomRotation { target_dim: target };
334 let transformer =
335 Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
336 .unwrap();
337
338 assert_eq!(transformer.input_dim(), input);
339 assert_eq!(transformer.output_dim(), output);
340 assert_eq!(transformer.preserves_norms(), preserves_norms);
341
342 test_utils::test_transform(
343 &transformer,
344 trials_per_dim,
345 &mut checker,
346 &mut rng_clone,
347 ctx,
348 )
349 }
350 }
351 }
352 }
353
354 #[cfg(feature = "flatbuffers")]
355 mod serialization {
356 use super::*;
357 use crate::flatbuffers::to_flatbuffer;
358
359 #[test]
360 fn random_rotation() {
361 let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
362
363 let test_cases = [
365 (5, TargetDim::Same),
366 (10, TargetDim::Natural),
367 (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
368 (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
369 ];
370
371 for (dim, target_dim) in test_cases {
372 let transform =
373 RandomRotation::new(NonZeroUsize::new(dim).unwrap(), target_dim, &mut rng);
374 let data = to_flatbuffer(|buf| transform.pack(buf));
375
376 let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
377 let reloaded = RandomRotation::try_unpack(proto).unwrap();
378
379 assert_eq!(transform, reloaded);
380 }
381
382 {
384 let data = to_flatbuffer(|buf| {
385 let data = buf.create_vector::<f32>(&[1.0, 0.0, 0.0, 1.0]); fb::transforms::RandomRotation::create(
387 buf,
388 &fb::transforms::RandomRotationArgs {
389 data: Some(data),
390 nrows: 0, ncols: 2,
392 },
393 )
394 });
395
396 let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
397 let err = RandomRotation::try_unpack(proto).unwrap_err();
398 assert_eq!(err, RandomRotationError::RowsZero);
399 }
400
401 {
402 let data = to_flatbuffer(|buf| {
403 let data = buf.create_vector::<f32>(&[1.0, 0.0, 0.0, 1.0]);
404 fb::transforms::RandomRotation::create(
405 buf,
406 &fb::transforms::RandomRotationArgs {
407 data: Some(data), nrows: 2,
409 ncols: 0, },
411 )
412 });
413
414 let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
415 let err = RandomRotation::try_unpack(proto).unwrap_err();
416 assert_eq!(err, RandomRotationError::ColsZero);
417 }
418
419 {
420 let data = to_flatbuffer(|buf| {
421 let data = buf.create_vector::<f32>(&[1.0, 0.0, 0.0]); fb::transforms::RandomRotation::create(
423 buf,
424 &fb::transforms::RandomRotationArgs {
425 data: Some(data),
426 nrows: 2,
427 ncols: 2, },
429 )
430 });
431
432 let proto = flatbuffers::root::<fb::transforms::RandomRotation>(&data).unwrap();
433 let err = RandomRotation::try_unpack(proto).unwrap_err();
434 assert_eq!(err, RandomRotationError::IncorrectDim);
435 }
436 }
437 }
438}