1use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use rand::{
11 distr::{Distribution, StandardUniform},
12 Rng,
13};
14use thiserror::Error;
15
16#[cfg(feature = "flatbuffers")]
17use super::utils::{bool_to_sign, sign_to_bool};
18use super::{
19 utils::{check_dims, is_sign, subsample_indices, TransformFailed},
20 TargetDim,
21};
22#[cfg(feature = "flatbuffers")]
23use crate::flatbuffers as fb;
24use crate::{
25 algorithms::hadamard_transform,
26 alloc::{Allocator, AllocatorError, Poly, ScopedAllocator, TryClone},
27 utils,
28};
29
30#[derive(Debug)]
55#[cfg_attr(test, derive(PartialEq))]
56pub struct DoubleHadamard<A>
57where
58 A: Allocator,
59{
60 signs0: Poly<[u32], A>,
67 signs1: Poly<[u32], A>,
68
69 target_dim: usize,
71
72 subsample: Option<Poly<[u32], A>>,
74}
75
76impl<A> DoubleHadamard<A>
77where
78 A: Allocator,
79{
80 pub fn new<R>(
99 dim: NonZeroUsize,
100 target_dim: TargetDim,
101 rng: &mut R,
102 allocator: A,
103 ) -> Result<Self, AllocatorError>
104 where
105 R: Rng + ?Sized,
106 {
107 let dim = dim.get();
108
109 let target_dim = match target_dim {
110 TargetDim::Override(target) => target.get(),
111 TargetDim::Same => dim,
112 TargetDim::Natural => dim,
113 };
114
115 let intermediate_dim = dim.max(target_dim);
120
121 let mut sample = |_: usize| {
123 let sign: bool = StandardUniform {}.sample(rng);
124 if sign {
125 0x8000_0000
126 } else {
127 0
128 }
129 };
130
131 let signs0 = Poly::from_iter((0..dim).map(&mut sample), allocator.clone())?;
134 let signs1 = Poly::from_iter((0..intermediate_dim).map(&mut sample), allocator.clone())?;
135
136 let subsample = if dim > target_dim {
137 Some(subsample_indices(rng, dim, target_dim, allocator)?)
138 } else {
139 None
140 };
141
142 Ok(Self {
143 signs0,
144 signs1,
145 target_dim,
146 subsample,
147 })
148 }
149
150 pub fn try_from_parts(
151 signs0: Poly<[u32], A>,
152 signs1: Poly<[u32], A>,
153 subsample: Option<Poly<[u32], A>>,
154 ) -> Result<Self, DoubleHadamardError> {
155 type E = DoubleHadamardError;
156 if signs0.is_empty() {
157 return Err(E::Signs0Empty);
158 }
159 if signs1.len() < signs0.len() {
160 return Err(E::Signs1TooSmall);
161 }
162 if !signs0.iter().copied().all(is_sign) {
163 return Err(E::Signs0Invalid);
164 }
165 if !signs1.iter().copied().all(is_sign) {
166 return Err(E::Signs1Invalid);
167 }
168
169 let target_dim = if let Some(ref subsample) = subsample {
171 if !utils::is_strictly_monotonic(subsample.iter()) {
172 return Err(E::SubsampleNotMonotonic);
173 }
174
175 match subsample.last() {
176 Some(last) => {
177 if *last as usize >= signs1.len() {
178 return Err(E::LastSubsampleTooLarge);
185 }
186 }
187 None => {
188 return Err(E::InvalidSubsampleLength);
190 }
191 }
192
193 debug_assert!(
194 subsample.len() < signs1.len(),
195 "since we've verified monotonicity and the last element, this is implied"
196 );
197
198 subsample.len()
199 } else {
200 signs1.len()
202 };
203
204 Ok(Self {
205 signs0,
206 signs1,
207 target_dim,
208 subsample,
209 })
210 }
211
212 pub fn input_dim(&self) -> usize {
214 self.signs0.len()
215 }
216
217 pub fn output_dim(&self) -> usize {
219 self.target_dim
220 }
221
222 pub fn preserves_norms(&self) -> bool {
227 self.subsample.is_none()
228 }
229
230 fn intermediate_dim(&self) -> usize {
231 self.input_dim().max(self.output_dim())
232 }
233
234 pub fn transform_into(
243 &self,
244 dst: &mut [f32],
245 src: &[f32],
246 allocator: ScopedAllocator<'_>,
247 ) -> Result<(), TransformFailed> {
248 check_dims(dst, src, self.input_dim(), self.output_dim())?;
249
250 let intermediate_dim = self.intermediate_dim();
252 let mut tmp = Poly::broadcast(0.0f32, intermediate_dim, allocator)?;
253
254 std::iter::zip(tmp.iter_mut(), src.iter())
255 .zip(self.signs0.iter())
256 .for_each(|((dst, src), sign)| *dst = f32::from_bits(src.to_bits() ^ sign));
257
258 let split = 1usize << (usize::BITS - intermediate_dim.leading_zeros() - 1);
259
260 #[allow(clippy::unwrap_used)]
265 hadamard_transform(&mut tmp[..split]).unwrap();
266
267 tmp.iter_mut()
271 .zip(self.signs1.iter())
272 .for_each(|(dst, sign)| *dst = f32::from_bits(dst.to_bits() ^ sign));
273
274 #[allow(clippy::unwrap_used)]
275 hadamard_transform(&mut tmp[intermediate_dim - split..]).unwrap();
276
277 match self.subsample.as_ref() {
278 None => {
279 dst.copy_from_slice(&tmp);
280 }
281 Some(indices) => {
282 let rescale = ((tmp.len() as f32) / (indices.len() as f32)).sqrt();
283 debug_assert_eq!(dst.len(), indices.len());
284 dst.iter_mut()
285 .zip(indices.iter())
286 .for_each(|(d, s)| *d = tmp[*s as usize] * rescale);
287 }
288 }
289
290 Ok(())
291 }
292}
293
294impl<A> TryClone for DoubleHadamard<A>
295where
296 A: Allocator,
297{
298 fn try_clone(&self) -> Result<Self, AllocatorError> {
299 Ok(Self {
300 signs0: self.signs0.try_clone()?,
301 signs1: self.signs1.try_clone()?,
302 target_dim: self.target_dim,
303 subsample: self.subsample.try_clone()?,
304 })
305 }
306}
307
308#[derive(Debug, Clone, Copy, Error, PartialEq)]
309#[non_exhaustive]
310pub enum DoubleHadamardError {
311 #[error("first signs stage cannot be empty")]
312 Signs0Empty,
313 #[error("first signs stage has invalid coding")]
314 Signs0Invalid,
315
316 #[error("invalid sign representation for second stage")]
317 Signs1Invalid,
318 #[error("second sign stage must be at least as large as the first stage")]
319 Signs1TooSmall,
320
321 #[error("subsample length must equal `target_dim`")]
322 InvalidSubsampleLength,
323 #[error("subsample indices is not monotonic")]
324 SubsampleNotMonotonic,
325 #[error("last subsample index exceeded intermediate dim")]
326 LastSubsampleTooLarge,
327
328 #[error(transparent)]
329 AllocatorError(#[from] AllocatorError),
330}
331
332#[cfg(feature = "flatbuffers")]
334impl<A> DoubleHadamard<A>
335where
336 A: Allocator,
337{
338 pub(crate) fn pack<'a, FA>(
341 &self,
342 buf: &mut FlatBufferBuilder<'a, FA>,
343 ) -> WIPOffset<fb::transforms::DoubleHadamard<'a>>
344 where
345 FA: flatbuffers::Allocator + 'a,
346 {
347 let signs0 = buf.create_vector_from_iter(self.signs0.iter().copied().map(sign_to_bool));
349 let signs1 = buf.create_vector_from_iter(self.signs1.iter().copied().map(sign_to_bool));
350
351 let subsample = self
353 .subsample
354 .as_ref()
355 .map(|indices| buf.create_vector(indices));
356
357 fb::transforms::DoubleHadamard::create(
358 buf,
359 &fb::transforms::DoubleHadamardArgs {
360 signs0: Some(signs0),
361 signs1: Some(signs1),
362 subsample,
363 },
364 )
365 }
366
367 pub(crate) fn try_unpack(
370 alloc: A,
371 proto: fb::transforms::DoubleHadamard<'_>,
372 ) -> Result<Self, DoubleHadamardError> {
373 let signs0 = Poly::from_iter(proto.signs0().iter().map(bool_to_sign), alloc.clone())?;
374 let signs1 = Poly::from_iter(proto.signs1().iter().map(bool_to_sign), alloc.clone())?;
375
376 let subsample = match proto.subsample() {
377 Some(subsample) => Some(Poly::from_iter(subsample.into_iter(), alloc)?),
378 None => None,
379 };
380
381 Self::try_from_parts(signs0, signs1, subsample)
382 }
383}
384
385#[cfg(test)]
390mod tests {
391 use diskann_utils::lazy_format;
392 use rand::{rngs::StdRng, SeedableRng};
393
394 use super::*;
395 use crate::{
396 algorithms::transforms::{test_utils, Transform, TransformKind},
397 alloc::GlobalAllocator,
398 };
399
400 test_utils::delegate_transformer!(DoubleHadamard<GlobalAllocator>);
401
402 #[test]
403 fn test_double_hadamard() {
404 let natural_errors = test_utils::ErrorSetup {
410 norm: test_utils::Check::ulp(5),
411 l2: test_utils::Check::ulp(5),
412 ip: test_utils::Check::absrel(2.5e-5, 2e-4),
413 };
414
415 let subsampled_errors = test_utils::ErrorSetup {
421 norm: test_utils::Check::absrel(0.0, 2e-2),
422 l2: test_utils::Check::absrel(0.0, 2e-2),
423 ip: test_utils::Check::skip(),
424 };
425
426 let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
427 let dim_combos = [
428 (15, 15, true, TargetDim::Same, &natural_errors),
430 (15, 15, true, TargetDim::Natural, &natural_errors),
431 (16, 16, true, TargetDim::Same, &natural_errors),
432 (16, 16, true, TargetDim::Natural, &natural_errors),
433 (256, 256, true, TargetDim::Same, &natural_errors),
434 (1000, 1000, true, TargetDim::Same, &natural_errors),
435 (15, 16, true, target_dim(16), &natural_errors),
437 (100, 128, true, target_dim(128), &natural_errors),
438 (15, 32, true, target_dim(32), &natural_errors),
439 (16, 64, true, target_dim(64), &natural_errors),
440 (1024, 1023, false, target_dim(1023), &subsampled_errors),
442 (1000, 999, false, target_dim(999), &subsampled_errors),
443 ];
444
445 let trials_per_combo = 20;
446 let trials_per_dim = 100;
447
448 let mut rng = StdRng::seed_from_u64(0x6d1699abe066147);
449 for (input, output, preserves_norms, target, errors) in dim_combos {
450 let input_nz = NonZeroUsize::new(input).unwrap();
451 for trial in 0..trials_per_combo {
452 let ctx = &lazy_format!(
453 "input dim = {}, output dim = {}, macro trial {} of {}",
454 input,
455 output,
456 trial,
457 trials_per_combo
458 );
459
460 let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
461 let d = input.min(output);
462 assert_ne!(&io.input0[..d], &io.output0[..d]);
463 assert_ne!(&io.input1[..d], &io.output1[..d]);
464 test_utils::check_errors(io, context, errors);
465 };
466
467 let mut rng_clone = rng.clone();
469
470 {
472 let transformer = DoubleHadamard::new(
473 NonZeroUsize::new(input).unwrap(),
474 target,
475 &mut rng,
476 GlobalAllocator,
477 )
478 .unwrap();
479
480 assert_eq!(transformer.input_dim(), input);
481 assert_eq!(transformer.output_dim(), output);
482 assert_eq!(transformer.preserves_norms(), preserves_norms);
483
484 test_utils::test_transform(
485 &transformer,
486 trials_per_dim,
487 &mut checker,
488 &mut rng,
489 ctx,
490 )
491 }
492
493 {
495 let kind = TransformKind::DoubleHadamard { target_dim: target };
496 let transformer =
497 Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
498 .unwrap();
499
500 assert_eq!(transformer.input_dim(), input);
501 assert_eq!(transformer.output_dim(), output);
502 assert_eq!(transformer.preserves_norms(), preserves_norms);
503
504 test_utils::test_transform(
505 &transformer,
506 trials_per_dim,
507 &mut checker,
508 &mut rng_clone,
509 ctx,
510 )
511 }
512 }
513 }
514 }
515
516 #[cfg(feature = "flatbuffers")]
517 mod serialization {
518 use super::*;
519 use crate::flatbuffers::to_flatbuffer;
520
521 #[test]
522 fn double_hadamard() {
523 let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
524 let alloc = GlobalAllocator;
525
526 let test_cases = [
528 (5, TargetDim::Same),
530 (8, TargetDim::Same),
531 (10, TargetDim::Natural),
532 (16, TargetDim::Natural),
533 (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
535 (10, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
536 (15, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
538 (16, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
539 (15, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
540 (16, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
541 (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
543 (16, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
544 ];
545
546 for (dim, target_dim) in test_cases {
547 let transform = DoubleHadamard::new(
548 NonZeroUsize::new(dim).unwrap(),
549 target_dim,
550 &mut rng,
551 alloc,
552 )
553 .unwrap();
554 let data = to_flatbuffer(|buf| transform.pack(buf));
555
556 let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
557 let reloaded = DoubleHadamard::try_unpack(alloc, proto).unwrap();
558
559 assert_eq!(transform, reloaded);
560 }
561
562 let gen_err = |x: DoubleHadamard<_>| -> DoubleHadamardError {
563 let data = to_flatbuffer(|buf| x.pack(buf));
564 let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
565 DoubleHadamard::try_unpack(alloc, proto).unwrap_err()
566 };
567
568 type E = DoubleHadamardError;
569 let error_cases = [
570 (
572 vec![0, 0, 0, 0, 0], vec![0, 0, 0, 0], 4,
575 None,
576 E::Signs1TooSmall,
577 ),
578 (
580 vec![], vec![0, 0, 0, 0],
582 4,
583 None,
584 E::Signs0Empty,
585 ),
586 (
588 vec![0, 0, 0, 0],
589 vec![0, 0, 0, 0],
590 3,
591 Some(vec![0, 2, 1]), E::SubsampleNotMonotonic,
593 ),
594 (
596 vec![0, 0, 0, 0],
597 vec![0, 0, 0, 0],
598 3,
599 Some(vec![0, 1, 1]), E::SubsampleNotMonotonic,
601 ),
602 (
604 vec![0, 0, 0], vec![0, 0, 0], 2,
607 Some(vec![0, 3]), E::LastSubsampleTooLarge,
609 ),
610 (
612 vec![0, 0, 0], vec![0, 0, 0], 2,
615 Some(vec![]), E::InvalidSubsampleLength,
617 ),
618 ];
619
620 let poly = |v: &Vec<u32>| Poly::from_iter(v.iter().copied(), alloc).unwrap();
621
622 for (signs0, signs1, target_dim, subsample, expected) in error_cases.iter() {
623 println!(
624 "on case ({:?}, {:?}, {}, {:?})",
625 signs0, signs1, target_dim, subsample,
626 );
627 let err = gen_err(DoubleHadamard {
628 signs0: poly(signs0),
629 signs1: poly(signs1),
630 target_dim: *target_dim,
631 subsample: subsample.as_ref().map(poly),
632 });
633
634 assert_eq!(
635 err, *expected,
636 "failed for case ({:?}, {:?}, {}, {:?})",
637 signs0, signs1, target_dim, subsample
638 );
639 }
640 }
641 }
642}