diskann_quantization/algorithms/transforms/
utils.rs1use rand::Rng;
7use thiserror::Error;
8
9use crate::alloc::{Allocator, AllocatorError, Poly};
10
11#[derive(Debug, Clone, Error, PartialEq)]
12pub enum TransformFailed {
13 #[error("incorrect transform input vector - expected length {expected} but got {found}")]
14 SourceMismatch { expected: usize, found: usize },
15 #[error("incorrect transform output vector - expected length {expected} but got {found}")]
16 DestinationMismatch { expected: usize, found: usize },
17 #[error(transparent)]
18 AllocatorError(#[from] AllocatorError),
19 #[cfg(feature = "linalg")]
20 #[error(transparent)]
21 SgemmError(#[from] diskann_linalg::SgemmError),
22}
23
24pub(super) fn check_dims(
25 dst: &[f32],
26 src: &[f32],
27 input_dim: usize,
28 output_dim: usize,
29) -> Result<(), TransformFailed> {
30 if src.len() != input_dim {
31 return Err(TransformFailed::SourceMismatch {
32 expected: input_dim,
33 found: src.len(),
34 });
35 }
36
37 if dst.len() != output_dim {
38 return Err(TransformFailed::DestinationMismatch {
39 expected: output_dim,
40 found: dst.len(),
41 });
42 }
43 Ok(())
44}
45
46pub(super) fn is_sign(x: u32) -> bool {
47 x == 0 || x == 0x8000_0000
48}
49
50#[cfg(feature = "flatbuffers")]
51pub(super) fn sign_to_bool(x: u32) -> bool {
52 x == 0x8000_0000
53}
54
55#[cfg(feature = "flatbuffers")]
56pub(super) fn bool_to_sign(x: bool) -> u32 {
57 if x { 0x8000_0000 } else { 0 }
58}
59
60pub(super) fn subsample_indices<R, A>(
61 rng: &mut R,
62 length: usize,
63 amount: usize,
64 allocator: A,
65) -> Result<Poly<[u32], A>, AllocatorError>
66where
67 R: Rng + ?Sized,
68 A: Allocator,
69{
70 let mut subsample = Poly::from_iter(
71 rand::seq::index::sample(rng, length, amount)
72 .into_iter()
73 .map(|i| i as u32),
74 allocator,
75 )?;
76 subsample.sort();
77 Ok(subsample)
78}