1use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use rand::{
11 Rng,
12 distr::{Distribution, StandardUniform},
13};
14use thiserror::Error;
15
16#[cfg(feature = "flatbuffers")]
17use super::utils::{bool_to_sign, sign_to_bool};
18use super::{
19 TargetDim,
20 utils::{TransformFailed, check_dims, is_sign, subsample_indices},
21};
22#[cfg(feature = "flatbuffers")]
23use crate::flatbuffers as fb;
24use crate::{
25 algorithms::hadamard_transform,
26 alloc::{Allocator, AllocatorError, Poly, ScopedAllocator, TryClone},
27 utils,
28};
29
30#[derive(Debug)]
43#[cfg_attr(test, derive(PartialEq))]
44pub struct PaddingHadamard<A>
45where
46 A: Allocator,
47{
48 signs: Poly<[u32], A>,
61
62 padded_dim: usize,
65
66 subsample: Option<Poly<[u32], A>>,
69}
70
71impl<A> PaddingHadamard<A>
72where
73 A: Allocator,
74{
75 pub fn new<R>(
95 dim: NonZeroUsize,
96 target: TargetDim,
97 rng: &mut R,
98 allocator: A,
99 ) -> Result<Self, AllocatorError>
100 where
101 R: Rng + ?Sized,
102 {
103 let signs = Poly::from_iter(
104 (0..dim.get()).map(|_| {
105 let sign: bool = StandardUniform {}.sample(rng);
106 if sign { 0x8000_0000 } else { 0 }
107 }),
108 allocator.clone(),
109 )?;
110
111 let (padded_dim, target_dim) = match target {
112 TargetDim::Same => (dim.get().next_power_of_two(), dim.get()),
113 TargetDim::Natural => {
114 let next = dim.get().next_power_of_two();
115 (next, next)
116 }
117 TargetDim::Override(target) => {
118 (target.max(dim).get().next_power_of_two(), target.get())
119 }
120 };
121
122 let subsample = if padded_dim > target_dim {
123 Some(subsample_indices(rng, padded_dim, target_dim, allocator)?)
124 } else {
125 None
126 };
127
128 Ok(Self {
129 signs,
130 padded_dim,
131 subsample,
132 })
133 }
134
135 pub fn try_from_parts(
138 signs: Poly<[u32], A>,
139 padded_dim: usize,
140 subsample: Option<Poly<[u32], A>>,
141 ) -> Result<Self, PaddingHadamardError> {
142 if !signs.iter().copied().all(is_sign) {
143 return Err(PaddingHadamardError::InvalidSignRepresentation);
144 }
145
146 if signs.len() > padded_dim {
147 return Err(PaddingHadamardError::SignsTooLong);
148 }
149
150 if !padded_dim.is_power_of_two() {
151 return Err(PaddingHadamardError::DimNotPowerOfTwo);
152 }
153
154 if let Some(ref subsample) = subsample {
155 if !utils::is_strictly_monotonic(subsample.iter()) {
156 return Err(PaddingHadamardError::SubsampleNotMonotonic);
157 }
158
159 if let Some(last) = subsample.last() {
160 if *last as usize >= padded_dim {
161 return Err(PaddingHadamardError::LastSubsampleTooLarge);
162 }
163 } else {
164 return Err(PaddingHadamardError::SubsampleEmpty);
165 }
166 }
167
168 Ok(Self {
169 signs,
170 padded_dim,
171 subsample,
172 })
173 }
174
175 pub fn input_dim(&self) -> usize {
177 self.signs.len()
178 }
179
180 pub fn output_dim(&self) -> usize {
182 match &self.subsample {
183 None => self.padded_dim,
184 Some(v) => v.len(),
185 }
186 }
187
188 pub fn preserves_norms(&self) -> bool {
193 self.subsample.is_none()
194 }
195
196 fn copy_and_flip_signs(&self, dst: &mut [f32], src: &[f32]) {
205 debug_assert_eq!(dst.len(), self.padded_dim);
206 debug_assert_eq!(src.len(), self.input_dim());
207
208 std::iter::zip(dst.iter_mut(), src.iter())
210 .zip(self.signs.iter())
211 .for_each(|((dst, src), sign)| *dst = f32::from_bits(src.to_bits() ^ sign));
212
213 dst.iter_mut()
215 .skip(self.input_dim())
216 .for_each(|dst| *dst = 0.0);
217 }
218
219 pub fn transform_into(
228 &self,
229 dst: &mut [f32],
230 src: &[f32],
231 allocator: ScopedAllocator<'_>,
232 ) -> Result<(), TransformFailed> {
233 let input_dim = self.input_dim();
234 let output_dim = self.output_dim();
235 check_dims(dst, src, input_dim, output_dim)?;
236
237 match &self.subsample {
239 None => {
240 self.copy_and_flip_signs(dst, src);
242
243 #[allow(clippy::unwrap_used)]
248 hadamard_transform(dst).unwrap();
249 }
250 Some(indices) => {
251 let mut tmp = Poly::broadcast(0.0f32, self.padded_dim, allocator)?;
252
253 self.copy_and_flip_signs(&mut tmp, src);
254
255 #[allow(clippy::unwrap_used)]
260 hadamard_transform(&mut tmp).unwrap();
261
262 let rescale = ((tmp.len() as f32) / (indices.len() as f32)).sqrt();
263 debug_assert_eq!(dst.len(), indices.len());
264 std::iter::zip(dst.iter_mut(), indices.iter()).for_each(
265 |(d, i): (&mut f32, &u32)| {
266 *d = tmp[*i as usize] * rescale;
267 },
268 );
269 }
270 }
271
272 Ok(())
273 }
274}
275
276impl<A> TryClone for PaddingHadamard<A>
277where
278 A: Allocator,
279{
280 fn try_clone(&self) -> Result<Self, AllocatorError> {
281 Ok(Self {
282 signs: self.signs.try_clone()?,
283 padded_dim: self.padded_dim,
284 subsample: self.subsample.try_clone()?,
285 })
286 }
287}
288
289#[derive(Debug, Clone, Copy, Error, PartialEq)]
291#[non_exhaustive]
292pub enum PaddingHadamardError {
293 #[error("an invalid sign representation was discovered")]
294 InvalidSignRepresentation,
295 #[error("`signs` length exceeds `padded_dim`")]
296 SignsTooLong,
297 #[error("padded dim is not a power of two")]
298 DimNotPowerOfTwo,
299 #[error("subsample indices cannot be empty")]
300 SubsampleEmpty,
301 #[error("subsample indices is not monotonic")]
302 SubsampleNotMonotonic,
303 #[error("last subsample index exceeded `padded_dim`")]
304 LastSubsampleTooLarge,
305 #[error(transparent)]
306 AllocatorError(#[from] AllocatorError),
307}
308
309#[cfg(feature = "flatbuffers")]
310impl<A> PaddingHadamard<A>
311where
312 A: Allocator,
313{
314 pub(crate) fn pack<'a, FA>(
316 &self,
317 buf: &mut FlatBufferBuilder<'a, FA>,
318 ) -> WIPOffset<fb::transforms::PaddingHadamard<'a>>
319 where
320 FA: flatbuffers::Allocator + 'a,
321 {
322 let signs = buf.create_vector_from_iter(self.signs.iter().copied().map(sign_to_bool));
324
325 let subsample = self
327 .subsample
328 .as_ref()
329 .map(|indices| buf.create_vector(indices));
330
331 fb::transforms::PaddingHadamard::create(
333 buf,
334 &fb::transforms::PaddingHadamardArgs {
335 signs: Some(signs),
336 padded_dim: self.padded_dim as u32,
337 subsample,
338 },
339 )
340 }
341
342 pub(crate) fn try_unpack(
345 alloc: A,
346 proto: fb::transforms::PaddingHadamard<'_>,
347 ) -> Result<Self, PaddingHadamardError> {
348 let signs = Poly::from_iter(proto.signs().iter().map(bool_to_sign), alloc.clone())?;
349
350 let subsample = match proto.subsample() {
351 Some(subsample) => Some(Poly::from_iter(subsample.into_iter(), alloc)?),
352 None => None,
353 };
354
355 Self::try_from_parts(signs, proto.padded_dim() as usize, subsample)
356 }
357}
358
359#[cfg(test)]
364mod tests {
365 use diskann_utils::lazy_format;
366 use rand::{SeedableRng, rngs::StdRng};
367
368 use super::*;
369 use crate::{
370 algorithms::transforms::{Transform, TransformKind, test_utils},
371 alloc::GlobalAllocator,
372 };
373
374 #[test]
377 fn test_sign_flipping() {
378 let mut rng = StdRng::seed_from_u64(0xf8ee12b1e9f33dbd);
379 let dim = 14;
380
381 let transform = PaddingHadamard::new(
382 NonZeroUsize::new(dim).unwrap(),
383 TargetDim::Same,
384 &mut rng,
385 GlobalAllocator,
386 )
387 .unwrap();
388
389 assert_eq!(transform.input_dim(), dim);
390 assert_eq!(transform.output_dim(), dim);
391
392 let positive = vec![1.0f32; dim];
393 let negative = vec![-1.0f32; dim];
394
395 let mut output = vec![f32::INFINITY; 16];
396
397 transform.copy_and_flip_signs(&mut output, &positive);
399
400 let mut unflipped = 0;
401 let mut flipped = 0;
402 std::iter::zip(output.iter(), transform.signs.iter())
403 .enumerate()
404 .for_each(|(i, (o, s))| {
405 if *s == 0x8000_0000 {
406 flipped += 1;
407 assert_eq!(*o, -1.0, "expected entry {} to be flipped", i);
408 } else {
409 unflipped += 1;
410 assert_eq!(*o, 1.0, "expected entry {} to be unchanged", i);
411 }
412 });
413
414 assert!(unflipped > 0);
416 assert!(flipped > 0);
417
418 assert_eq!(output[14], 0.0f32);
420 assert_eq!(output[15], 0.0f32);
421
422 output.fill(f32::INFINITY);
424 transform.copy_and_flip_signs(&mut output, &negative);
425 std::iter::zip(output.iter(), transform.signs.iter())
426 .enumerate()
427 .for_each(|(i, (o, s))| {
428 if *s == 0x8000_0000 {
429 assert_eq!(*o, 1.0, "expected entry {} to be flipped", i);
430 } else {
431 assert_eq!(*o, -1.0, "expected entry {} to be unchanged", i);
432 }
433 });
434
435 assert_eq!(output[14], 0.0f32);
437 assert_eq!(output[15], 0.0f32);
438 }
439
440 test_utils::delegate_transformer!(PaddingHadamard<GlobalAllocator>);
441
442 #[test]
445 fn test_padding_hadamard() {
446 let natural_errors = test_utils::ErrorSetup {
452 norm: test_utils::Check::ulp(4),
453 l2: test_utils::Check::ulp(4),
454 ip: test_utils::Check::absrel(5.0e-6, 2e-4),
455 };
456
457 let subsampled_errors = test_utils::ErrorSetup {
463 norm: test_utils::Check::absrel(0.0, 1e-1),
464 l2: test_utils::Check::absrel(0.0, 1e-1),
465 ip: test_utils::Check::skip(),
466 };
467
468 let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
469
470 let dim_combos = [
471 (15, 16, true, target_dim(16), &natural_errors),
473 (15, 16, true, TargetDim::Natural, &natural_errors),
474 (16, 16, true, TargetDim::Same, &natural_errors),
475 (16, 16, true, TargetDim::Natural, &natural_errors),
476 (16, 32, true, target_dim(32), &natural_errors),
477 (16, 64, true, target_dim(64), &natural_errors),
478 (100, 128, true, target_dim(128), &natural_errors),
479 (100, 128, true, TargetDim::Natural, &natural_errors),
480 (256, 256, true, target_dim(256), &natural_errors),
481 (1000, 1000, false, TargetDim::Same, &subsampled_errors),
483 (500, 1000, false, target_dim(1000), &subsampled_errors),
484 ];
485
486 let trials_per_combo = 20;
487 let trials_per_dim = 100;
488
489 let mut rng = StdRng::seed_from_u64(0x6d1699abe0626147);
490 for (input, output, preserves_norms, target, errors) in dim_combos {
491 let input_nz = NonZeroUsize::new(input).unwrap();
492 for trial in 0..trials_per_combo {
493 let ctx = lazy_format!(
494 "input dim = {}, output dim = {}, macro trial {} of {}",
495 input,
496 output,
497 trial,
498 trials_per_combo
499 );
500
501 let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
502 assert_ne!(io.input0, &io.output0[..input]);
503 assert_ne!(io.input1, &io.output1[..input]);
504 test_utils::check_errors(io, context, errors);
505 };
506
507 let mut rng_clone = rng.clone();
509
510 {
512 let transformer = PaddingHadamard::new(
513 NonZeroUsize::new(input).unwrap(),
514 target,
515 &mut rng,
516 GlobalAllocator,
517 )
518 .unwrap();
519
520 assert_eq!(transformer.input_dim(), input);
521 assert_eq!(transformer.output_dim(), output);
522 assert_eq!(transformer.preserves_norms(), preserves_norms);
523
524 test_utils::test_transform(
525 &transformer,
526 trials_per_dim,
527 &mut checker,
528 &mut rng,
529 &ctx,
530 )
531 }
532
533 {
535 let kind = TransformKind::PaddingHadamard { target_dim: target };
536 let transformer =
537 Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
538 .unwrap();
539
540 assert_eq!(transformer.input_dim(), input);
541 assert_eq!(transformer.output_dim(), output);
542 assert_eq!(transformer.preserves_norms(), preserves_norms);
543
544 test_utils::test_transform(
545 &transformer,
546 trials_per_dim,
547 &mut checker,
548 &mut rng_clone,
549 &ctx,
550 )
551 }
552 }
553 }
554 }
555
556 #[cfg(feature = "flatbuffers")]
557 mod serialization {
558 use super::*;
559 use crate::{flatbuffers::to_flatbuffer, poly};
560
561 #[test]
562 fn padding_hadamard() {
563 let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
564 let alloc = GlobalAllocator;
565
566 let test_cases = [
568 (5, TargetDim::Same),
569 (10, TargetDim::Natural),
570 (16, TargetDim::Natural),
571 (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
572 (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
573 ];
574
575 for (dim, target_dim) in test_cases {
576 let transform = PaddingHadamard::new(
577 NonZeroUsize::new(dim).unwrap(),
578 target_dim,
579 &mut rng,
580 alloc,
581 )
582 .unwrap();
583 let data = to_flatbuffer(|buf| transform.pack(buf));
584
585 let proto = flatbuffers::root::<fb::transforms::PaddingHadamard>(&data).unwrap();
586 let reloaded = PaddingHadamard::try_unpack(alloc, proto).unwrap();
587
588 assert_eq!(transform, reloaded);
589 }
590
591 let gen_err = |x: PaddingHadamard<_>| -> PaddingHadamardError {
592 let data = to_flatbuffer(|buf| x.pack(buf));
593 let proto = flatbuffers::root::<fb::transforms::PaddingHadamard>(&data).unwrap();
594 PaddingHadamard::try_unpack(alloc, proto).unwrap_err()
595 };
596
597 {
599 let err = gen_err(PaddingHadamard {
600 signs: poly!([0, 0, 0, 0, 0], alloc).unwrap(), padded_dim: 4,
602 subsample: None,
603 });
604
605 assert_eq!(err, PaddingHadamardError::SignsTooLong);
606 }
607
608 {
610 let err = gen_err(PaddingHadamard {
611 signs: poly!([0, 0, 0, 0, 0], alloc).unwrap(),
612 padded_dim: 5, subsample: None,
614 });
615
616 assert_eq!(err, PaddingHadamardError::DimNotPowerOfTwo);
617 }
618
619 {
621 let err = gen_err(PaddingHadamard {
622 signs: poly!([0, 0, 0, 0], alloc).unwrap(),
623 padded_dim: 4,
624 subsample: Some(poly!([], alloc).unwrap()), });
626
627 assert_eq!(err, PaddingHadamardError::SubsampleEmpty);
628 }
629
630 {
632 let err = gen_err(PaddingHadamard {
633 signs: poly!([0, 0, 0, 0], alloc).unwrap(),
634 padded_dim: 4,
635 subsample: Some(poly!([0, 2, 2], alloc).unwrap()), });
637 assert_eq!(err, PaddingHadamardError::SubsampleNotMonotonic);
638 }
639
640 {
642 let err = gen_err(PaddingHadamard {
643 signs: poly!([0, 0, 0, 0], alloc).unwrap(),
644 padded_dim: 4,
645 subsample: Some(poly!([0, 1, 2, 3, 4], alloc).unwrap()),
646 });
647
648 assert_eq!(err, PaddingHadamardError::LastSubsampleTooLarge);
649 }
650
651 {
653 let err = gen_err(PaddingHadamard {
654 signs: poly!([0, 0, 0, 0], alloc).unwrap(),
655 padded_dim: 4,
656 subsample: Some(poly!([0, 4], alloc).unwrap()),
657 });
658
659 assert_eq!(err, PaddingHadamardError::LastSubsampleTooLarge);
660 }
661 }
662 }
663}