1use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use rand::{
11 Rng,
12 distr::{Distribution, StandardUniform},
13};
14use thiserror::Error;
15
16#[cfg(feature = "flatbuffers")]
17use super::utils::{bool_to_sign, sign_to_bool};
18use super::{
19 TargetDim,
20 utils::{TransformFailed, check_dims, is_sign, subsample_indices},
21};
22#[cfg(feature = "flatbuffers")]
23use crate::flatbuffers as fb;
24use crate::{
25 algorithms::hadamard_transform,
26 alloc::{Allocator, AllocatorError, Poly, ScopedAllocator, TryClone},
27 utils,
28};
29
30#[derive(Debug)]
43#[cfg_attr(test, derive(PartialEq))]
44pub struct PaddingHadamard<A>
45where
46 A: Allocator,
47{
48 signs: Poly<[u32], A>,
61
62 padded_dim: usize,
65
66 subsample: Option<Poly<[u32], A>>,
69}
70
71impl<A> PaddingHadamard<A>
72where
73 A: Allocator,
74{
75 pub fn new<R>(
95 dim: NonZeroUsize,
96 target: TargetDim,
97 rng: &mut R,
98 allocator: A,
99 ) -> Result<Self, AllocatorError>
100 where
101 R: Rng + ?Sized,
102 {
103 let signs = Poly::from_iter(
104 (0..dim.get()).map(|_| {
105 let sign: bool = StandardUniform {}.sample(rng);
106 if sign { 0x8000_0000 } else { 0 }
107 }),
108 allocator.clone(),
109 )?;
110
111 let (padded_dim, target_dim) = match target {
112 TargetDim::Same => (dim.get().next_power_of_two(), dim.get()),
113 TargetDim::Natural => {
114 let next = dim.get().next_power_of_two();
115 (next, next)
116 }
117 TargetDim::Override(target) => {
118 (target.max(dim).get().next_power_of_two(), target.get())
119 }
120 };
121
122 let subsample = if padded_dim > target_dim {
123 Some(subsample_indices(rng, padded_dim, target_dim, allocator)?)
124 } else {
125 None
126 };
127
128 Ok(Self {
129 signs,
130 padded_dim,
131 subsample,
132 })
133 }
134
135 pub fn try_from_parts(
138 signs: Poly<[u32], A>,
139 padded_dim: usize,
140 subsample: Option<Poly<[u32], A>>,
141 ) -> Result<Self, PaddingHadamardError> {
142 if !signs.iter().copied().all(is_sign) {
143 return Err(PaddingHadamardError::InvalidSignRepresentation);
144 }
145
146 if signs.len() > padded_dim {
147 return Err(PaddingHadamardError::SignsTooLong);
148 }
149
150 if !padded_dim.is_power_of_two() {
151 return Err(PaddingHadamardError::DimNotPowerOfTwo);
152 }
153
154 if let Some(ref subsample) = subsample {
155 if !utils::is_strictly_monotonic(subsample.iter()) {
156 return Err(PaddingHadamardError::SubsampleNotMonotonic);
157 }
158
159 if let Some(last) = subsample.last() {
160 if *last as usize >= padded_dim {
161 return Err(PaddingHadamardError::LastSubsampleTooLarge);
162 }
163 } else {
164 return Err(PaddingHadamardError::SubsampleEmpty);
165 }
166 }
167
168 Ok(Self {
169 signs,
170 padded_dim,
171 subsample,
172 })
173 }
174
175 pub fn input_dim(&self) -> usize {
177 self.signs.len()
178 }
179
180 pub fn output_dim(&self) -> usize {
182 match &self.subsample {
183 None => self.padded_dim,
184 Some(v) => v.len(),
185 }
186 }
187
188 pub fn preserves_norms(&self) -> bool {
193 self.subsample.is_none()
194 }
195
196 fn copy_and_flip_signs(&self, dst: &mut [f32], src: &[f32]) {
205 debug_assert_eq!(dst.len(), self.padded_dim);
206 debug_assert_eq!(src.len(), self.input_dim());
207
208 std::iter::zip(dst.iter_mut(), src.iter())
210 .zip(self.signs.iter())
211 .for_each(|((dst, src), sign)| *dst = f32::from_bits(src.to_bits() ^ sign));
212
213 dst.iter_mut()
215 .skip(self.input_dim())
216 .for_each(|dst| *dst = 0.0);
217 }
218
219 pub fn transform_into(
228 &self,
229 dst: &mut [f32],
230 src: &[f32],
231 allocator: ScopedAllocator<'_>,
232 ) -> Result<(), TransformFailed> {
233 let input_dim = self.input_dim();
234 let output_dim = self.output_dim();
235 check_dims(dst, src, input_dim, output_dim)?;
236
237 match &self.subsample {
239 None => {
240 self.copy_and_flip_signs(dst, src);
242
243 #[allow(clippy::unwrap_used)]
248 hadamard_transform(dst).unwrap();
249 }
250 Some(indices) => {
251 let mut tmp = Poly::broadcast(0.0f32, self.padded_dim, allocator)?;
252
253 self.copy_and_flip_signs(&mut tmp, src);
254
255 #[allow(clippy::unwrap_used)]
260 hadamard_transform(&mut tmp).unwrap();
261
262 let rescale = ((tmp.len() as f32) / (indices.len() as f32)).sqrt();
263 debug_assert_eq!(dst.len(), indices.len());
264 std::iter::zip(dst.iter_mut(), indices.iter()).for_each(
265 |(d, i): (&mut f32, &u32)| {
266 *d = tmp[*i as usize] * rescale;
267 },
268 );
269 }
270 }
271
272 Ok(())
273 }
274}
275
276impl<A> TryClone for PaddingHadamard<A>
277where
278 A: Allocator,
279{
280 fn try_clone(&self) -> Result<Self, AllocatorError> {
281 Ok(Self {
282 signs: self.signs.try_clone()?,
283 padded_dim: self.padded_dim,
284 subsample: self.subsample.try_clone()?,
285 })
286 }
287}
288
289#[derive(Debug, Clone, Copy, Error, PartialEq)]
291#[non_exhaustive]
292pub enum PaddingHadamardError {
293 #[error("an invalid sign representation was discovered")]
294 InvalidSignRepresentation,
295 #[error("`signs` length exceeds `padded_dim`")]
296 SignsTooLong,
297 #[error("padded dim is not a power of two")]
298 DimNotPowerOfTwo,
299 #[error("subsample indices cannot be empty")]
300 SubsampleEmpty,
301 #[error("subsample indices is not monotonic")]
302 SubsampleNotMonotonic,
303 #[error("last subsample index exceeded `padded_dim`")]
304 LastSubsampleTooLarge,
305 #[error(transparent)]
306 AllocatorError(#[from] AllocatorError),
307}
308
309#[cfg(feature = "flatbuffers")]
310impl<A> PaddingHadamard<A>
311where
312 A: Allocator,
313{
314 pub(crate) fn pack<'a, FA>(
316 &self,
317 buf: &mut FlatBufferBuilder<'a, FA>,
318 ) -> WIPOffset<fb::transforms::PaddingHadamard<'a>>
319 where
320 FA: flatbuffers::Allocator + 'a,
321 {
322 let signs = buf.create_vector_from_iter(self.signs.iter().copied().map(sign_to_bool));
324
325 let subsample = self
327 .subsample
328 .as_ref()
329 .map(|indices| buf.create_vector(indices));
330
331 fb::transforms::PaddingHadamard::create(
333 buf,
334 &fb::transforms::PaddingHadamardArgs {
335 signs: Some(signs),
336 padded_dim: self.padded_dim as u32,
337 subsample,
338 },
339 )
340 }
341
342 pub(crate) fn try_unpack(
345 alloc: A,
346 proto: fb::transforms::PaddingHadamard<'_>,
347 ) -> Result<Self, PaddingHadamardError> {
348 let signs = Poly::from_iter(proto.signs().iter().map(bool_to_sign), alloc.clone())?;
349
350 let subsample = match proto.subsample() {
351 Some(subsample) => Some(Poly::from_iter(subsample.into_iter(), alloc)?),
352 None => None,
353 };
354
355 Self::try_from_parts(signs, proto.padded_dim() as usize, subsample)
356 }
357}
358
359#[cfg(test)]
364mod tests {
365 #[cfg(not(miri))]
366 use crate::algorithms::transforms::{Transform, TransformKind, test_utils};
367 #[cfg(not(miri))]
368 use diskann_utils::lazy_format;
369 use rand::{SeedableRng, rngs::StdRng};
370
371 use super::*;
372 use crate::{alloc::GlobalAllocator, test_util::Check};
373
374 #[test]
377 fn test_sign_flipping() {
378 let mut rng = StdRng::seed_from_u64(0xf8ee12b1e9f33dbd);
379 let dim = 14;
380
381 let transform = PaddingHadamard::new(
382 NonZeroUsize::new(dim).unwrap(),
383 TargetDim::Same,
384 &mut rng,
385 GlobalAllocator,
386 )
387 .unwrap();
388
389 assert_eq!(transform.input_dim(), dim);
390 assert_eq!(transform.output_dim(), dim);
391
392 let positive = vec![1.0f32; dim];
393 let negative = vec![-1.0f32; dim];
394
395 let mut output = vec![f32::INFINITY; 16];
396
397 transform.copy_and_flip_signs(&mut output, &positive);
399
400 let mut unflipped = 0;
401 let mut flipped = 0;
402 std::iter::zip(output.iter(), transform.signs.iter())
403 .enumerate()
404 .for_each(|(i, (o, s))| {
405 if *s == 0x8000_0000 {
406 flipped += 1;
407 assert_eq!(*o, -1.0, "expected entry {} to be flipped", i);
408 } else {
409 unflipped += 1;
410 assert_eq!(*o, 1.0, "expected entry {} to be unchanged", i);
411 }
412 });
413
414 assert!(unflipped > 0);
416 assert!(flipped > 0);
417
418 assert_eq!(output[14], 0.0f32);
420 assert_eq!(output[15], 0.0f32);
421
422 output.fill(f32::INFINITY);
424 transform.copy_and_flip_signs(&mut output, &negative);
425 std::iter::zip(output.iter(), transform.signs.iter())
426 .enumerate()
427 .for_each(|(i, (o, s))| {
428 if *s == 0x8000_0000 {
429 assert_eq!(*o, 1.0, "expected entry {} to be flipped", i);
430 } else {
431 assert_eq!(*o, -1.0, "expected entry {} to be unchanged", i);
432 }
433 });
434
435 assert_eq!(output[14], 0.0f32);
437 assert_eq!(output[15], 0.0f32);
438 }
439
440 #[cfg(not(miri))]
441 test_utils::delegate_transformer!(PaddingHadamard<GlobalAllocator>);
442
443 #[test]
446 #[cfg(not(miri))]
447 fn test_padding_hadamard() {
448 let natural_errors = test_utils::ErrorSetup {
454 norm: Check::ulp(4),
455 l2: Check::ulp(4),
456 ip: Check::absrel(5.0e-6, 2e-4),
457 };
458
459 let subsampled_errors = test_utils::ErrorSetup {
465 norm: Check::absrel(0.0, 1e-1),
466 l2: Check::absrel(0.0, 1e-1),
467 ip: Check::skip(),
468 };
469
470 let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
471
472 let dim_combos = [
473 (15, 16, true, target_dim(16), &natural_errors),
475 (15, 16, true, TargetDim::Natural, &natural_errors),
476 (16, 16, true, TargetDim::Same, &natural_errors),
477 (16, 16, true, TargetDim::Natural, &natural_errors),
478 (16, 32, true, target_dim(32), &natural_errors),
479 (16, 64, true, target_dim(64), &natural_errors),
480 (100, 128, true, target_dim(128), &natural_errors),
481 (100, 128, true, TargetDim::Natural, &natural_errors),
482 (256, 256, true, target_dim(256), &natural_errors),
483 (1000, 1000, false, TargetDim::Same, &subsampled_errors),
485 (500, 1000, false, target_dim(1000), &subsampled_errors),
486 ];
487
488 let trials_per_combo = 20;
489 let trials_per_dim = 100;
490
491 let mut rng = StdRng::seed_from_u64(0x6d1699abe0626147);
492 for (input, output, preserves_norms, target, errors) in dim_combos {
493 let input_nz = NonZeroUsize::new(input).unwrap();
494 for trial in 0..trials_per_combo {
495 let ctx = lazy_format!(
496 "input dim = {}, output dim = {}, macro trial {} of {}",
497 input,
498 output,
499 trial,
500 trials_per_combo
501 );
502
503 let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
504 assert_ne!(io.input0, &io.output0[..input]);
505 assert_ne!(io.input1, &io.output1[..input]);
506 test_utils::check_errors(io, context, errors);
507 };
508
509 let mut rng_clone = rng.clone();
511
512 {
514 let transformer = PaddingHadamard::new(
515 NonZeroUsize::new(input).unwrap(),
516 target,
517 &mut rng,
518 GlobalAllocator,
519 )
520 .unwrap();
521
522 assert_eq!(transformer.input_dim(), input);
523 assert_eq!(transformer.output_dim(), output);
524 assert_eq!(transformer.preserves_norms(), preserves_norms);
525
526 test_utils::test_transform(
527 &transformer,
528 trials_per_dim,
529 &mut checker,
530 &mut rng,
531 &ctx,
532 )
533 }
534
535 {
537 let kind = TransformKind::PaddingHadamard { target_dim: target };
538 let transformer =
539 Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
540 .unwrap();
541
542 assert_eq!(transformer.input_dim(), input);
543 assert_eq!(transformer.output_dim(), output);
544 assert_eq!(transformer.preserves_norms(), preserves_norms);
545
546 test_utils::test_transform(
547 &transformer,
548 trials_per_dim,
549 &mut checker,
550 &mut rng_clone,
551 &ctx,
552 )
553 }
554 }
555 }
556 }
557
558 #[cfg(feature = "flatbuffers")]
559 mod serialization {
560 use super::*;
561 use crate::{flatbuffers::to_flatbuffer, poly};
562
563 #[test]
564 fn padding_hadamard() {
565 let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
566 let alloc = GlobalAllocator;
567
568 let test_cases = [
570 (5, TargetDim::Same),
571 (10, TargetDim::Natural),
572 (16, TargetDim::Natural),
573 (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
574 (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
575 ];
576
577 for (dim, target_dim) in test_cases {
578 let transform = PaddingHadamard::new(
579 NonZeroUsize::new(dim).unwrap(),
580 target_dim,
581 &mut rng,
582 alloc,
583 )
584 .unwrap();
585 let data = to_flatbuffer(|buf| transform.pack(buf));
586
587 let proto = flatbuffers::root::<fb::transforms::PaddingHadamard>(&data).unwrap();
588 let reloaded = PaddingHadamard::try_unpack(alloc, proto).unwrap();
589
590 assert_eq!(transform, reloaded);
591 }
592
593 let gen_err = |x: PaddingHadamard<_>| -> PaddingHadamardError {
594 let data = to_flatbuffer(|buf| x.pack(buf));
595 let proto = flatbuffers::root::<fb::transforms::PaddingHadamard>(&data).unwrap();
596 PaddingHadamard::try_unpack(alloc, proto).unwrap_err()
597 };
598
599 {
601 let err = gen_err(PaddingHadamard {
602 signs: poly!([0, 0, 0, 0, 0], alloc).unwrap(), padded_dim: 4,
604 subsample: None,
605 });
606
607 assert_eq!(err, PaddingHadamardError::SignsTooLong);
608 }
609
610 {
612 let err = gen_err(PaddingHadamard {
613 signs: poly!([0, 0, 0, 0, 0], alloc).unwrap(),
614 padded_dim: 5, subsample: None,
616 });
617
618 assert_eq!(err, PaddingHadamardError::DimNotPowerOfTwo);
619 }
620
621 {
623 let err = gen_err(PaddingHadamard {
624 signs: poly!([0, 0, 0, 0], alloc).unwrap(),
625 padded_dim: 4,
626 subsample: Some(poly!([], alloc).unwrap()), });
628
629 assert_eq!(err, PaddingHadamardError::SubsampleEmpty);
630 }
631
632 {
634 let err = gen_err(PaddingHadamard {
635 signs: poly!([0, 0, 0, 0], alloc).unwrap(),
636 padded_dim: 4,
637 subsample: Some(poly!([0, 2, 2], alloc).unwrap()), });
639 assert_eq!(err, PaddingHadamardError::SubsampleNotMonotonic);
640 }
641
642 {
644 let err = gen_err(PaddingHadamard {
645 signs: poly!([0, 0, 0, 0], alloc).unwrap(),
646 padded_dim: 4,
647 subsample: Some(poly!([0, 1, 2, 3, 4], alloc).unwrap()),
648 });
649
650 assert_eq!(err, PaddingHadamardError::LastSubsampleTooLarge);
651 }
652
653 {
655 let err = gen_err(PaddingHadamard {
656 signs: poly!([0, 0, 0, 0], alloc).unwrap(),
657 padded_dim: 4,
658 subsample: Some(poly!([0, 4], alloc).unwrap()),
659 });
660
661 assert_eq!(err, PaddingHadamardError::LastSubsampleTooLarge);
662 }
663 }
664 }
665}