diskann_quantization/algorithms/transforms/
mod.rs1use std::num::NonZeroUsize;
8
9#[cfg(feature = "flatbuffers")]
10use flatbuffers::{FlatBufferBuilder, WIPOffset};
11use rand::RngCore;
12use thiserror::Error;
13
14use crate::alloc::{Allocator, AllocatorError, ScopedAllocator, TryClone};
15#[cfg(feature = "flatbuffers")]
16use crate::flatbuffers as fb;
17
18mod double_hadamard;
20mod null;
21mod padding_hadamard;
22
23crate::utils::features! {
24 #![feature = "linalg"]
25 mod random_rotation;
26}
27
28mod utils;
29
30#[cfg(test)]
31mod test_utils;
32
33pub use double_hadamard::{DoubleHadamard, DoubleHadamardError};
35pub use null::NullTransform;
36pub use padding_hadamard::{PaddingHadamard, PaddingHadamardError};
37pub use utils::TransformFailed;
38
39crate::utils::features! {
40 #![feature = "linalg"]
41 pub use random_rotation::RandomRotation;
42}
43
44crate::utils::features! {
45 #![all(feature = "linalg", feature = "flatbuffers")]
46 pub use random_rotation::RandomRotationError;
47}
48
49crate::utils::features! {
50 #![feature = "flatbuffers"]
51 pub use null::NullTransformError;
52}
53
54#[derive(Debug, Clone, Copy)]
59#[non_exhaustive]
60pub enum TransformKind {
61 PaddingHadamard { target_dim: TargetDim },
79
80 DoubleHadamard { target_dim: TargetDim },
93
94 Null,
96
97 #[cfg(feature = "linalg")]
102 #[cfg_attr(docsrs, doc(cfg(feature = "linalg")))]
103 RandomRotation { target_dim: TargetDim },
104}
105
106#[derive(Debug, Clone, Error)]
107pub enum NewTransformError {
108 #[error("random number generator is required for {0:?}")]
109 RngMissing(TransformKind),
110 #[error(transparent)]
111 AllocatorError(#[from] AllocatorError),
112}
113
114#[derive(Debug)]
115#[cfg_attr(test, derive(PartialEq))]
116pub enum Transform<A>
117where
118 A: Allocator,
119{
120 PaddingHadamard(PaddingHadamard<A>),
121 DoubleHadamard(DoubleHadamard<A>),
122 Null(NullTransform),
123
124 #[cfg(feature = "linalg")]
125 #[cfg_attr(docsrs, doc(cfg(feature = "linalg")))]
126 RandomRotation(RandomRotation),
127}
128
129impl<A> Transform<A>
130where
131 A: Allocator,
132{
133 pub fn new(
143 transform_kind: TransformKind,
144 dim: NonZeroUsize,
145 rng: Option<&mut dyn RngCore>,
146 allocator: A,
147 ) -> Result<Self, NewTransformError> {
148 match transform_kind {
149 TransformKind::PaddingHadamard { target_dim } => {
150 let rng = rng.ok_or(NewTransformError::RngMissing(transform_kind))?;
151 Ok(Transform::PaddingHadamard(PaddingHadamard::new(
152 dim, target_dim, rng, allocator,
153 )?))
154 }
155 TransformKind::DoubleHadamard { target_dim } => {
156 let rng = rng.ok_or(NewTransformError::RngMissing(transform_kind))?;
157 Ok(Transform::DoubleHadamard(DoubleHadamard::new(
158 dim, target_dim, rng, allocator,
159 )?))
160 }
161 TransformKind::Null => Ok(Transform::Null(NullTransform::new(dim))),
162 #[cfg(feature = "linalg")]
163 TransformKind::RandomRotation { target_dim } => {
164 let rng = rng.ok_or(NewTransformError::RngMissing(transform_kind))?;
165 Ok(Transform::RandomRotation(RandomRotation::new(
166 dim, target_dim, rng,
167 )))
168 }
169 }
170 }
171
172 pub(crate) fn input_dim(&self) -> usize {
173 match self {
174 Self::PaddingHadamard(t) => t.input_dim(),
175 Self::DoubleHadamard(t) => t.input_dim(),
176 Self::Null(t) => t.dim(),
177 #[cfg(feature = "linalg")]
178 Self::RandomRotation(t) => t.input_dim(),
179 }
180 }
181 pub(crate) fn output_dim(&self) -> usize {
182 match self {
183 Self::PaddingHadamard(t) => t.output_dim(),
184 Self::DoubleHadamard(t) => t.output_dim(),
185 Self::Null(t) => t.dim(),
186 #[cfg(feature = "linalg")]
187 Self::RandomRotation(t) => t.output_dim(),
188 }
189 }
190
191 pub(crate) fn preserves_norms(&self) -> bool {
192 match self {
193 Self::PaddingHadamard(t) => t.preserves_norms(),
194 Self::DoubleHadamard(t) => t.preserves_norms(),
195 Self::Null(t) => t.preserves_norms(),
196 #[cfg(feature = "linalg")]
197 Self::RandomRotation(t) => t.preserves_norms(),
198 }
199 }
200
201 pub(crate) fn transform_into(
202 &self,
203 dst: &mut [f32],
204 src: &[f32],
205 allocator: ScopedAllocator<'_>,
206 ) -> Result<(), TransformFailed> {
207 match self {
208 Self::PaddingHadamard(t) => t.transform_into(dst, src, allocator),
209 Self::DoubleHadamard(t) => t.transform_into(dst, src, allocator),
210 Self::Null(t) => t.transform_into(dst, src),
211 #[cfg(feature = "linalg")]
212 Self::RandomRotation(t) => t.transform_into(dst, src),
213 }
214 }
215}
216
217impl<A> TryClone for Transform<A>
218where
219 A: Allocator,
220{
221 fn try_clone(&self) -> Result<Self, AllocatorError> {
222 match self {
223 Self::PaddingHadamard(t) => Ok(Self::PaddingHadamard(t.try_clone()?)),
224 Self::DoubleHadamard(t) => Ok(Self::DoubleHadamard(t.try_clone()?)),
225 Self::Null(t) => Ok(Self::Null(t.clone())),
226 #[cfg(feature = "linalg")]
227 Self::RandomRotation(t) => Ok(Self::RandomRotation(t.clone())),
228 }
229 }
230}
231
232#[cfg(feature = "flatbuffers")]
233#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
234#[derive(Debug, Clone, Copy, Error, PartialEq)]
235#[non_exhaustive]
236pub enum TransformError {
237 #[error(transparent)]
238 PaddingHadamardError(#[from] PaddingHadamardError),
239 #[error(transparent)]
240 DoubleHadamardError(#[from] DoubleHadamardError),
241 #[error(transparent)]
242 NullTransformError(#[from] NullTransformError),
243 #[cfg(feature = "linalg")]
244 #[cfg_attr(docsrs, doc(cfg(feature = "linalg")))]
245 #[error(transparent)]
246 RandomRotationError(#[from] RandomRotationError),
247 #[error("invalid transform kind")]
248 InvalidTransformKind,
249}
250
251#[cfg(feature = "flatbuffers")]
252impl<A> Transform<A>
253where
254 A: Allocator,
255{
256 pub(crate) fn pack<'a, FA>(
258 &self,
259 buf: &mut FlatBufferBuilder<'a, FA>,
260 ) -> WIPOffset<fb::transforms::Transform<'a>>
261 where
262 FA: flatbuffers::Allocator + 'a,
263 {
264 let (kind, offset) = match self {
265 Self::PaddingHadamard(t) => (
266 fb::transforms::TransformKind::PaddingHadamard,
267 t.pack(buf).as_union_value(),
268 ),
269 Self::DoubleHadamard(t) => (
270 fb::transforms::TransformKind::DoubleHadamard,
271 t.pack(buf).as_union_value(),
272 ),
273 Self::Null(t) => (
274 fb::transforms::TransformKind::NullTransform,
275 t.pack(buf).as_union_value(),
276 ),
277 #[cfg(feature = "linalg")]
278 Self::RandomRotation(t) => (
279 fb::transforms::TransformKind::RandomRotation,
280 t.pack(buf).as_union_value(),
281 ),
282 };
283
284 fb::transforms::Transform::create(
285 buf,
286 &fb::transforms::TransformArgs {
287 transform_type: kind,
288 transform: Some(offset),
289 },
290 )
291 }
292
293 pub(crate) fn try_unpack(
296 alloc: A,
297 proto: fb::transforms::Transform<'_>,
298 ) -> Result<Self, TransformError> {
299 if let Some(transform) = proto.transform_as_padding_hadamard() {
300 return Ok(Self::PaddingHadamard(PaddingHadamard::try_unpack(
301 alloc, transform,
302 )?));
303 }
304
305 #[cfg(feature = "linalg")]
306 if let Some(transform) = proto.transform_as_random_rotation() {
307 return Ok(Self::RandomRotation(RandomRotation::try_unpack(transform)?));
308 }
309
310 if let Some(transform) = proto.transform_as_double_hadamard() {
311 return Ok(Self::DoubleHadamard(DoubleHadamard::try_unpack(
312 alloc, transform,
313 )?));
314 }
315
316 if let Some(transform) = proto.transform_as_null_transform() {
317 return Ok(Self::Null(NullTransform::try_unpack(transform)?));
318 }
319
320 Err(TransformError::InvalidTransformKind)
321 }
322}
323
324#[derive(Debug, Clone, Copy)]
330pub enum TargetDim {
331 Same,
342
343 Natural,
349
350 Override(NonZeroUsize),
355}
356
357#[cfg(test)]
358test_utils::delegate_transformer!(Transform<crate::alloc::GlobalAllocator>);