diskann_quantization/algorithms/transforms/
utils.rs1use rand::Rng;
7use thiserror::Error;
8
9use crate::alloc::{Allocator, AllocatorError, Poly};
10
11#[derive(Debug, Clone, Copy, Error, PartialEq)]
12pub enum TransformFailed {
13 #[error("incorrect transform input vector - expected length {expected} but got {found}")]
14 SourceMismatch { expected: usize, found: usize },
15 #[error("incorrect transform output vector - expected length {expected} but got {found}")]
16 DestinationMismatch { expected: usize, found: usize },
17 #[error(transparent)]
18 AllocatorError(#[from] AllocatorError),
19}
20
21pub(super) fn check_dims(
22 dst: &[f32],
23 src: &[f32],
24 input_dim: usize,
25 output_dim: usize,
26) -> Result<(), TransformFailed> {
27 if src.len() != input_dim {
28 return Err(TransformFailed::SourceMismatch {
29 expected: input_dim,
30 found: src.len(),
31 });
32 }
33
34 if dst.len() != output_dim {
35 return Err(TransformFailed::DestinationMismatch {
36 expected: output_dim,
37 found: dst.len(),
38 });
39 }
40 Ok(())
41}
42
43pub(super) fn is_sign(x: u32) -> bool {
44 x == 0 || x == 0x8000_0000
45}
46
47#[cfg(feature = "flatbuffers")]
48pub(super) fn sign_to_bool(x: u32) -> bool {
49 x == 0x8000_0000
50}
51
52#[cfg(feature = "flatbuffers")]
53pub(super) fn bool_to_sign(x: bool) -> u32 {
54 if x { 0x8000_0000 } else { 0 }
55}
56
57pub(super) fn subsample_indices<R, A>(
58 rng: &mut R,
59 length: usize,
60 amount: usize,
61 allocator: A,
62) -> Result<Poly<[u32], A>, AllocatorError>
63where
64 R: Rng + ?Sized,
65 A: Allocator,
66{
67 let mut subsample = Poly::from_iter(
68 rand::seq::index::sample(rng, length, amount)
69 .into_iter()
70 .map(|i| i as u32),
71 allocator,
72 )?;
73 subsample.sort();
74 Ok(subsample)
75}