1use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use rand::{
11 distr::{Distribution, StandardUniform},
12 Rng,
13};
14use thiserror::Error;
15
16#[cfg(feature = "flatbuffers")]
17use super::utils::{bool_to_sign, sign_to_bool};
18use super::{
19 utils::{check_dims, is_sign, subsample_indices, TransformFailed},
20 TargetDim,
21};
22#[cfg(feature = "flatbuffers")]
23use crate::flatbuffers as fb;
24use crate::{
25 algorithms::hadamard_transform,
26 alloc::{Allocator, AllocatorError, Poly, ScopedAllocator, TryClone},
27 utils,
28};
29
30#[derive(Debug)]
43#[cfg_attr(test, derive(PartialEq))]
44pub struct PaddingHadamard<A>
45where
46 A: Allocator,
47{
48 signs: Poly<[u32], A>,
61
62 padded_dim: usize,
65
66 subsample: Option<Poly<[u32], A>>,
69}
70
71impl<A> PaddingHadamard<A>
72where
73 A: Allocator,
74{
75 pub fn new<R>(
95 dim: NonZeroUsize,
96 target: TargetDim,
97 rng: &mut R,
98 allocator: A,
99 ) -> Result<Self, AllocatorError>
100 where
101 R: Rng + ?Sized,
102 {
103 let signs = Poly::from_iter(
104 (0..dim.get()).map(|_| {
105 let sign: bool = StandardUniform {}.sample(rng);
106 if sign {
107 0x8000_0000
108 } else {
109 0
110 }
111 }),
112 allocator.clone(),
113 )?;
114
115 let (padded_dim, target_dim) = match target {
116 TargetDim::Same => (dim.get().next_power_of_two(), dim.get()),
117 TargetDim::Natural => {
118 let next = dim.get().next_power_of_two();
119 (next, next)
120 }
121 TargetDim::Override(target) => {
122 (target.max(dim).get().next_power_of_two(), target.get())
123 }
124 };
125
126 let subsample = if padded_dim > target_dim {
127 Some(subsample_indices(rng, padded_dim, target_dim, allocator)?)
128 } else {
129 None
130 };
131
132 Ok(Self {
133 signs,
134 padded_dim,
135 subsample,
136 })
137 }
138
139 pub fn try_from_parts(
142 signs: Poly<[u32], A>,
143 padded_dim: usize,
144 subsample: Option<Poly<[u32], A>>,
145 ) -> Result<Self, PaddingHadamardError> {
146 if !signs.iter().copied().all(is_sign) {
147 return Err(PaddingHadamardError::InvalidSignRepresentation);
148 }
149
150 if signs.len() > padded_dim {
151 return Err(PaddingHadamardError::SignsTooLong);
152 }
153
154 if !padded_dim.is_power_of_two() {
155 return Err(PaddingHadamardError::DimNotPowerOfTwo);
156 }
157
158 if let Some(ref subsample) = subsample {
159 if !utils::is_strictly_monotonic(subsample.iter()) {
160 return Err(PaddingHadamardError::SubsampleNotMonotonic);
161 }
162
163 if let Some(last) = subsample.last() {
164 if *last as usize >= padded_dim {
165 return Err(PaddingHadamardError::LastSubsampleTooLarge);
166 }
167 } else {
168 return Err(PaddingHadamardError::SubsampleEmpty);
169 }
170 }
171
172 Ok(Self {
173 signs,
174 padded_dim,
175 subsample,
176 })
177 }
178
179 pub fn input_dim(&self) -> usize {
181 self.signs.len()
182 }
183
184 pub fn output_dim(&self) -> usize {
186 match &self.subsample {
187 None => self.padded_dim,
188 Some(v) => v.len(),
189 }
190 }
191
192 pub fn preserves_norms(&self) -> bool {
197 self.subsample.is_none()
198 }
199
200 fn copy_and_flip_signs(&self, dst: &mut [f32], src: &[f32]) {
209 debug_assert_eq!(dst.len(), self.padded_dim);
210 debug_assert_eq!(src.len(), self.input_dim());
211
212 std::iter::zip(dst.iter_mut(), src.iter())
214 .zip(self.signs.iter())
215 .for_each(|((dst, src), sign)| *dst = f32::from_bits(src.to_bits() ^ sign));
216
217 dst.iter_mut()
219 .skip(self.input_dim())
220 .for_each(|dst| *dst = 0.0);
221 }
222
223 pub fn transform_into(
232 &self,
233 dst: &mut [f32],
234 src: &[f32],
235 allocator: ScopedAllocator<'_>,
236 ) -> Result<(), TransformFailed> {
237 let input_dim = self.input_dim();
238 let output_dim = self.output_dim();
239 check_dims(dst, src, input_dim, output_dim)?;
240
241 match &self.subsample {
243 None => {
244 self.copy_and_flip_signs(dst, src);
246
247 #[allow(clippy::unwrap_used)]
252 hadamard_transform(dst).unwrap();
253 }
254 Some(indices) => {
255 let mut tmp = Poly::broadcast(0.0f32, self.padded_dim, allocator)?;
256
257 self.copy_and_flip_signs(&mut tmp, src);
258
259 #[allow(clippy::unwrap_used)]
264 hadamard_transform(&mut tmp).unwrap();
265
266 let rescale = ((tmp.len() as f32) / (indices.len() as f32)).sqrt();
267 debug_assert_eq!(dst.len(), indices.len());
268 std::iter::zip(dst.iter_mut(), indices.iter()).for_each(
269 |(d, i): (&mut f32, &u32)| {
270 *d = tmp[*i as usize] * rescale;
271 },
272 );
273 }
274 }
275
276 Ok(())
277 }
278}
279
280impl<A> TryClone for PaddingHadamard<A>
281where
282 A: Allocator,
283{
284 fn try_clone(&self) -> Result<Self, AllocatorError> {
285 Ok(Self {
286 signs: self.signs.try_clone()?,
287 padded_dim: self.padded_dim,
288 subsample: self.subsample.try_clone()?,
289 })
290 }
291}
292
293#[derive(Debug, Clone, Copy, Error, PartialEq)]
295#[non_exhaustive]
296pub enum PaddingHadamardError {
297 #[error("an invalid sign representation was discovered")]
298 InvalidSignRepresentation,
299 #[error("`signs` length exceeds `padded_dim`")]
300 SignsTooLong,
301 #[error("padded dim is not a power of two")]
302 DimNotPowerOfTwo,
303 #[error("subsample indices cannot be empty")]
304 SubsampleEmpty,
305 #[error("subsample indices is not monotonic")]
306 SubsampleNotMonotonic,
307 #[error("last subsample index exceeded `padded_dim`")]
308 LastSubsampleTooLarge,
309 #[error(transparent)]
310 AllocatorError(#[from] AllocatorError),
311}
312
313#[cfg(feature = "flatbuffers")]
314impl<A> PaddingHadamard<A>
315where
316 A: Allocator,
317{
318 pub(crate) fn pack<'a, FA>(
320 &self,
321 buf: &mut FlatBufferBuilder<'a, FA>,
322 ) -> WIPOffset<fb::transforms::PaddingHadamard<'a>>
323 where
324 FA: flatbuffers::Allocator + 'a,
325 {
326 let signs = buf.create_vector_from_iter(self.signs.iter().copied().map(sign_to_bool));
328
329 let subsample = self
331 .subsample
332 .as_ref()
333 .map(|indices| buf.create_vector(indices));
334
335 fb::transforms::PaddingHadamard::create(
337 buf,
338 &fb::transforms::PaddingHadamardArgs {
339 signs: Some(signs),
340 padded_dim: self.padded_dim as u32,
341 subsample,
342 },
343 )
344 }
345
346 pub(crate) fn try_unpack(
349 alloc: A,
350 proto: fb::transforms::PaddingHadamard<'_>,
351 ) -> Result<Self, PaddingHadamardError> {
352 let signs = Poly::from_iter(proto.signs().iter().map(bool_to_sign), alloc.clone())?;
353
354 let subsample = match proto.subsample() {
355 Some(subsample) => Some(Poly::from_iter(subsample.into_iter(), alloc)?),
356 None => None,
357 };
358
359 Self::try_from_parts(signs, proto.padded_dim() as usize, subsample)
360 }
361}
362
363#[cfg(test)]
368mod tests {
369 use diskann_utils::lazy_format;
370 use rand::{rngs::StdRng, SeedableRng};
371
372 use super::*;
373 use crate::{
374 algorithms::transforms::{test_utils, Transform, TransformKind},
375 alloc::GlobalAllocator,
376 };
377
378 #[test]
381 fn test_sign_flipping() {
382 let mut rng = StdRng::seed_from_u64(0xf8ee12b1e9f33dbd);
383 let dim = 14;
384
385 let transform = PaddingHadamard::new(
386 NonZeroUsize::new(dim).unwrap(),
387 TargetDim::Same,
388 &mut rng,
389 GlobalAllocator,
390 )
391 .unwrap();
392
393 assert_eq!(transform.input_dim(), dim);
394 assert_eq!(transform.output_dim(), dim);
395
396 let positive = vec![1.0f32; dim];
397 let negative = vec![-1.0f32; dim];
398
399 let mut output = vec![f32::INFINITY; 16];
400
401 transform.copy_and_flip_signs(&mut output, &positive);
403
404 let mut unflipped = 0;
405 let mut flipped = 0;
406 std::iter::zip(output.iter(), transform.signs.iter())
407 .enumerate()
408 .for_each(|(i, (o, s))| {
409 if *s == 0x8000_0000 {
410 flipped += 1;
411 assert_eq!(*o, -1.0, "expected entry {} to be flipped", i);
412 } else {
413 unflipped += 1;
414 assert_eq!(*o, 1.0, "expected entry {} to be unchanged", i);
415 }
416 });
417
418 assert!(unflipped > 0);
420 assert!(flipped > 0);
421
422 assert_eq!(output[14], 0.0f32);
424 assert_eq!(output[15], 0.0f32);
425
426 output.fill(f32::INFINITY);
428 transform.copy_and_flip_signs(&mut output, &negative);
429 std::iter::zip(output.iter(), transform.signs.iter())
430 .enumerate()
431 .for_each(|(i, (o, s))| {
432 if *s == 0x8000_0000 {
433 assert_eq!(*o, 1.0, "expected entry {} to be flipped", i);
434 } else {
435 assert_eq!(*o, -1.0, "expected entry {} to be unchanged", i);
436 }
437 });
438
439 assert_eq!(output[14], 0.0f32);
441 assert_eq!(output[15], 0.0f32);
442 }
443
444 test_utils::delegate_transformer!(PaddingHadamard<GlobalAllocator>);
445
446 #[test]
449 fn test_padding_hadamard() {
450 let natural_errors = test_utils::ErrorSetup {
456 norm: test_utils::Check::ulp(4),
457 l2: test_utils::Check::ulp(4),
458 ip: test_utils::Check::absrel(5.0e-6, 2e-4),
459 };
460
461 let subsampled_errors = test_utils::ErrorSetup {
467 norm: test_utils::Check::absrel(0.0, 1e-1),
468 l2: test_utils::Check::absrel(0.0, 1e-1),
469 ip: test_utils::Check::skip(),
470 };
471
472 let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
473
474 let dim_combos = [
475 (15, 16, true, target_dim(16), &natural_errors),
477 (15, 16, true, TargetDim::Natural, &natural_errors),
478 (16, 16, true, TargetDim::Same, &natural_errors),
479 (16, 16, true, TargetDim::Natural, &natural_errors),
480 (16, 32, true, target_dim(32), &natural_errors),
481 (16, 64, true, target_dim(64), &natural_errors),
482 (100, 128, true, target_dim(128), &natural_errors),
483 (100, 128, true, TargetDim::Natural, &natural_errors),
484 (256, 256, true, target_dim(256), &natural_errors),
485 (1000, 1000, false, TargetDim::Same, &subsampled_errors),
487 (500, 1000, false, target_dim(1000), &subsampled_errors),
488 ];
489
490 let trials_per_combo = 20;
491 let trials_per_dim = 100;
492
493 let mut rng = StdRng::seed_from_u64(0x6d1699abe0626147);
494 for (input, output, preserves_norms, target, errors) in dim_combos {
495 let input_nz = NonZeroUsize::new(input).unwrap();
496 for trial in 0..trials_per_combo {
497 let ctx = lazy_format!(
498 "input dim = {}, output dim = {}, macro trial {} of {}",
499 input,
500 output,
501 trial,
502 trials_per_combo
503 );
504
505 let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
506 assert_ne!(io.input0, &io.output0[..input]);
507 assert_ne!(io.input1, &io.output1[..input]);
508 test_utils::check_errors(io, context, errors);
509 };
510
511 let mut rng_clone = rng.clone();
513
514 {
516 let transformer = PaddingHadamard::new(
517 NonZeroUsize::new(input).unwrap(),
518 target,
519 &mut rng,
520 GlobalAllocator,
521 )
522 .unwrap();
523
524 assert_eq!(transformer.input_dim(), input);
525 assert_eq!(transformer.output_dim(), output);
526 assert_eq!(transformer.preserves_norms(), preserves_norms);
527
528 test_utils::test_transform(
529 &transformer,
530 trials_per_dim,
531 &mut checker,
532 &mut rng,
533 &ctx,
534 )
535 }
536
537 {
539 let kind = TransformKind::PaddingHadamard { target_dim: target };
540 let transformer =
541 Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
542 .unwrap();
543
544 assert_eq!(transformer.input_dim(), input);
545 assert_eq!(transformer.output_dim(), output);
546 assert_eq!(transformer.preserves_norms(), preserves_norms);
547
548 test_utils::test_transform(
549 &transformer,
550 trials_per_dim,
551 &mut checker,
552 &mut rng_clone,
553 &ctx,
554 )
555 }
556 }
557 }
558 }
559
560 #[cfg(feature = "flatbuffers")]
561 mod serialization {
562 use super::*;
563 use crate::{flatbuffers::to_flatbuffer, poly};
564
565 #[test]
566 fn padding_hadamard() {
567 let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
568 let alloc = GlobalAllocator;
569
570 let test_cases = [
572 (5, TargetDim::Same),
573 (10, TargetDim::Natural),
574 (16, TargetDim::Natural),
575 (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
576 (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
577 ];
578
579 for (dim, target_dim) in test_cases {
580 let transform = PaddingHadamard::new(
581 NonZeroUsize::new(dim).unwrap(),
582 target_dim,
583 &mut rng,
584 alloc,
585 )
586 .unwrap();
587 let data = to_flatbuffer(|buf| transform.pack(buf));
588
589 let proto = flatbuffers::root::<fb::transforms::PaddingHadamard>(&data).unwrap();
590 let reloaded = PaddingHadamard::try_unpack(alloc, proto).unwrap();
591
592 assert_eq!(transform, reloaded);
593 }
594
595 let gen_err = |x: PaddingHadamard<_>| -> PaddingHadamardError {
596 let data = to_flatbuffer(|buf| x.pack(buf));
597 let proto = flatbuffers::root::<fb::transforms::PaddingHadamard>(&data).unwrap();
598 PaddingHadamard::try_unpack(alloc, proto).unwrap_err()
599 };
600
601 {
603 let err = gen_err(PaddingHadamard {
604 signs: poly!([0, 0, 0, 0, 0], alloc).unwrap(), padded_dim: 4,
606 subsample: None,
607 });
608
609 assert_eq!(err, PaddingHadamardError::SignsTooLong);
610 }
611
612 {
614 let err = gen_err(PaddingHadamard {
615 signs: poly!([0, 0, 0, 0, 0], alloc).unwrap(),
616 padded_dim: 5, subsample: None,
618 });
619
620 assert_eq!(err, PaddingHadamardError::DimNotPowerOfTwo);
621 }
622
623 {
625 let err = gen_err(PaddingHadamard {
626 signs: poly!([0, 0, 0, 0], alloc).unwrap(),
627 padded_dim: 4,
628 subsample: Some(poly!([], alloc).unwrap()), });
630
631 assert_eq!(err, PaddingHadamardError::SubsampleEmpty);
632 }
633
634 {
636 let err = gen_err(PaddingHadamard {
637 signs: poly!([0, 0, 0, 0], alloc).unwrap(),
638 padded_dim: 4,
639 subsample: Some(poly!([0, 2, 2], alloc).unwrap()), });
641 assert_eq!(err, PaddingHadamardError::SubsampleNotMonotonic);
642 }
643
644 {
646 let err = gen_err(PaddingHadamard {
647 signs: poly!([0, 0, 0, 0], alloc).unwrap(),
648 padded_dim: 4,
649 subsample: Some(poly!([0, 1, 2, 3, 4], alloc).unwrap()),
650 });
651
652 assert_eq!(err, PaddingHadamardError::LastSubsampleTooLarge);
653 }
654
655 {
657 let err = gen_err(PaddingHadamard {
658 signs: poly!([0, 0, 0, 0], alloc).unwrap(),
659 padded_dim: 4,
660 subsample: Some(poly!([0, 4], alloc).unwrap()),
661 });
662
663 assert_eq!(err, PaddingHadamardError::LastSubsampleTooLarge);
664 }
665 }
666 }
667}