Skip to main content

diskann_quantization/algorithms/transforms/
null.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10#[cfg(feature = "flatbuffers")]
11use thiserror::Error;
12
13use super::utils::{TransformFailed, check_dims};
14#[cfg(feature = "flatbuffers")]
15use crate::flatbuffers as fb;
16
17#[derive(Debug, Clone)]
18#[cfg_attr(test, derive(PartialEq))]
19pub struct NullTransform {
20    // no transform -> needed for quantizers that mandatorily use a transform
21    dim: usize,
22}
23
24impl NullTransform {
25    pub fn new(dim: NonZeroUsize) -> Self {
26        NullTransform { dim: dim.get() }
27    }
28
29    pub fn dim(&self) -> usize {
30        self.dim
31    }
32
33    /// The null transform always preserves norms because it leaves data unmodified.
34    pub const fn preserves_norms(&self) -> bool {
35        true
36    }
37
38    pub fn transform_into(&self, dst: &mut [f32], src: &[f32]) -> Result<(), TransformFailed> {
39        check_dims(dst, src, self.dim(), self.dim())?;
40        dst.copy_from_slice(src);
41        Ok(())
42    }
43}
44
45// Serialization
46#[cfg(feature = "flatbuffers")]
47#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
48#[derive(Debug, Clone, Copy, Error, PartialEq)]
49#[non_exhaustive]
50pub enum NullTransformError {
51    #[error("dim cannot be zero")]
52    DimCannotBeZero,
53}
54
55#[cfg(feature = "flatbuffers")]
56impl NullTransform {
57    /// Pack into a [`crate::flatbuffers::transforms::NullTransform`] serialized
58    /// representation.
59    pub(crate) fn pack<'a, A>(
60        &self,
61        buf: &mut FlatBufferBuilder<'a, A>,
62    ) -> WIPOffset<fb::transforms::NullTransform<'a>>
63    where
64        A: flatbuffers::Allocator + 'a,
65    {
66        fb::transforms::NullTransform::create(
67            buf,
68            &fb::transforms::NullTransformArgs {
69                dim: self.dim as u32,
70            },
71        )
72    }
73
74    /// Attempt to unpack from a [`crate::flatbuffers::transforms::NullTransform`]
75    /// serialized representation, returning any error if encountered.
76    pub(crate) fn try_unpack(
77        proto: fb::transforms::NullTransform<'_>,
78    ) -> Result<Self, NullTransformError> {
79        let dim =
80            NonZeroUsize::new(proto.dim() as usize).ok_or(NullTransformError::DimCannotBeZero)?;
81        Ok(Self::new(dim))
82    }
83}
84
85///////////
86// Tests //
87///////////
88
89#[cfg(all(test, feature = "flatbuffers"))]
90mod tests {
91    use super::*;
92    mod serialization {
93        use super::*;
94        use crate::flatbuffers::to_flatbuffer;
95
96        #[test]
97        fn null_transform() {
98            for dim in [1, 2, 10, 20, 1536] {
99                let transform = NullTransform::new(NonZeroUsize::new(dim).unwrap());
100                assert!(transform.preserves_norms());
101
102                let data = to_flatbuffer(|buf| transform.pack(buf));
103
104                let proto = flatbuffers::root::<fb::transforms::NullTransform>(&data).unwrap();
105                let reloaded = NullTransform::try_unpack(proto).unwrap();
106                assert_eq!(transform, reloaded);
107            }
108
109            // Ensure that invalid dims are rejected.
110            {
111                let data = to_flatbuffer(|buf| {
112                    fb::transforms::NullTransform::create(
113                        buf,
114                        &fb::transforms::NullTransformArgs::default(),
115                    )
116                });
117
118                let proto = flatbuffers::root::<fb::transforms::NullTransform>(&data).unwrap();
119                let err = NullTransform::try_unpack(proto).unwrap_err();
120                assert_eq!(err, NullTransformError::DimCannotBeZero);
121            }
122        }
123    }
124}