1use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use rand::{
11 Rng,
12 distr::{Distribution, StandardUniform},
13};
14use thiserror::Error;
15
16#[cfg(feature = "flatbuffers")]
17use super::utils::{bool_to_sign, sign_to_bool};
18use super::{
19 TargetDim,
20 utils::{TransformFailed, check_dims, is_sign, subsample_indices},
21};
22#[cfg(feature = "flatbuffers")]
23use crate::flatbuffers as fb;
24use crate::{
25 algorithms::hadamard_transform,
26 alloc::{Allocator, AllocatorError, Poly, ScopedAllocator, TryClone},
27 utils,
28};
29
30#[derive(Debug)]
55#[cfg_attr(test, derive(PartialEq))]
56pub struct DoubleHadamard<A>
57where
58 A: Allocator,
59{
60 signs0: Poly<[u32], A>,
67 signs1: Poly<[u32], A>,
68
69 target_dim: usize,
71
72 subsample: Option<Poly<[u32], A>>,
74}
75
76impl<A> DoubleHadamard<A>
77where
78 A: Allocator,
79{
80 pub fn new<R>(
99 dim: NonZeroUsize,
100 target_dim: TargetDim,
101 rng: &mut R,
102 allocator: A,
103 ) -> Result<Self, AllocatorError>
104 where
105 R: Rng + ?Sized,
106 {
107 let dim = dim.get();
108
109 let target_dim = match target_dim {
110 TargetDim::Override(target) => target.get(),
111 TargetDim::Same => dim,
112 TargetDim::Natural => dim,
113 };
114
115 let intermediate_dim = dim.max(target_dim);
120
121 let mut sample = |_: usize| {
123 let sign: bool = StandardUniform {}.sample(rng);
124 if sign { 0x8000_0000 } else { 0 }
125 };
126
127 let signs0 = Poly::from_iter((0..dim).map(&mut sample), allocator.clone())?;
130 let signs1 = Poly::from_iter((0..intermediate_dim).map(&mut sample), allocator.clone())?;
131
132 let subsample = if dim > target_dim {
133 Some(subsample_indices(rng, dim, target_dim, allocator)?)
134 } else {
135 None
136 };
137
138 Ok(Self {
139 signs0,
140 signs1,
141 target_dim,
142 subsample,
143 })
144 }
145
146 pub fn try_from_parts(
147 signs0: Poly<[u32], A>,
148 signs1: Poly<[u32], A>,
149 subsample: Option<Poly<[u32], A>>,
150 ) -> Result<Self, DoubleHadamardError> {
151 type E = DoubleHadamardError;
152 if signs0.is_empty() {
153 return Err(E::Signs0Empty);
154 }
155 if signs1.len() < signs0.len() {
156 return Err(E::Signs1TooSmall);
157 }
158 if !signs0.iter().copied().all(is_sign) {
159 return Err(E::Signs0Invalid);
160 }
161 if !signs1.iter().copied().all(is_sign) {
162 return Err(E::Signs1Invalid);
163 }
164
165 let target_dim = if let Some(ref subsample) = subsample {
167 if !utils::is_strictly_monotonic(subsample.iter()) {
168 return Err(E::SubsampleNotMonotonic);
169 }
170
171 match subsample.last() {
172 Some(last) => {
173 if *last as usize >= signs1.len() {
174 return Err(E::LastSubsampleTooLarge);
181 }
182 }
183 None => {
184 return Err(E::InvalidSubsampleLength);
186 }
187 }
188
189 debug_assert!(
190 subsample.len() < signs1.len(),
191 "since we've verified monotonicity and the last element, this is implied"
192 );
193
194 subsample.len()
195 } else {
196 signs1.len()
198 };
199
200 Ok(Self {
201 signs0,
202 signs1,
203 target_dim,
204 subsample,
205 })
206 }
207
208 pub fn input_dim(&self) -> usize {
210 self.signs0.len()
211 }
212
213 pub fn output_dim(&self) -> usize {
215 self.target_dim
216 }
217
218 pub fn preserves_norms(&self) -> bool {
223 self.subsample.is_none()
224 }
225
226 fn intermediate_dim(&self) -> usize {
227 self.input_dim().max(self.output_dim())
228 }
229
230 pub fn transform_into(
239 &self,
240 dst: &mut [f32],
241 src: &[f32],
242 allocator: ScopedAllocator<'_>,
243 ) -> Result<(), TransformFailed> {
244 check_dims(dst, src, self.input_dim(), self.output_dim())?;
245
246 let intermediate_dim = self.intermediate_dim();
248 let mut tmp = Poly::broadcast(0.0f32, intermediate_dim, allocator)?;
249
250 std::iter::zip(tmp.iter_mut(), src.iter())
251 .zip(self.signs0.iter())
252 .for_each(|((dst, src), sign)| *dst = f32::from_bits(src.to_bits() ^ sign));
253
254 let split = 1usize << (usize::BITS - intermediate_dim.leading_zeros() - 1);
255
256 #[allow(clippy::unwrap_used)]
261 hadamard_transform(&mut tmp[..split]).unwrap();
262
263 tmp.iter_mut()
267 .zip(self.signs1.iter())
268 .for_each(|(dst, sign)| *dst = f32::from_bits(dst.to_bits() ^ sign));
269
270 #[allow(clippy::unwrap_used)]
271 hadamard_transform(&mut tmp[intermediate_dim - split..]).unwrap();
272
273 match self.subsample.as_ref() {
274 None => {
275 dst.copy_from_slice(&tmp);
276 }
277 Some(indices) => {
278 let rescale = ((tmp.len() as f32) / (indices.len() as f32)).sqrt();
279 debug_assert_eq!(dst.len(), indices.len());
280 dst.iter_mut()
281 .zip(indices.iter())
282 .for_each(|(d, s)| *d = tmp[*s as usize] * rescale);
283 }
284 }
285
286 Ok(())
287 }
288}
289
290impl<A> TryClone for DoubleHadamard<A>
291where
292 A: Allocator,
293{
294 fn try_clone(&self) -> Result<Self, AllocatorError> {
295 Ok(Self {
296 signs0: self.signs0.try_clone()?,
297 signs1: self.signs1.try_clone()?,
298 target_dim: self.target_dim,
299 subsample: self.subsample.try_clone()?,
300 })
301 }
302}
303
304#[derive(Debug, Clone, Copy, Error, PartialEq)]
305#[non_exhaustive]
306pub enum DoubleHadamardError {
307 #[error("first signs stage cannot be empty")]
308 Signs0Empty,
309 #[error("first signs stage has invalid coding")]
310 Signs0Invalid,
311
312 #[error("invalid sign representation for second stage")]
313 Signs1Invalid,
314 #[error("second sign stage must be at least as large as the first stage")]
315 Signs1TooSmall,
316
317 #[error("subsample length must equal `target_dim`")]
318 InvalidSubsampleLength,
319 #[error("subsample indices is not monotonic")]
320 SubsampleNotMonotonic,
321 #[error("last subsample index exceeded intermediate dim")]
322 LastSubsampleTooLarge,
323
324 #[error(transparent)]
325 AllocatorError(#[from] AllocatorError),
326}
327
328#[cfg(feature = "flatbuffers")]
330impl<A> DoubleHadamard<A>
331where
332 A: Allocator,
333{
334 pub(crate) fn pack<'a, FA>(
337 &self,
338 buf: &mut FlatBufferBuilder<'a, FA>,
339 ) -> WIPOffset<fb::transforms::DoubleHadamard<'a>>
340 where
341 FA: flatbuffers::Allocator + 'a,
342 {
343 let signs0 = buf.create_vector_from_iter(self.signs0.iter().copied().map(sign_to_bool));
345 let signs1 = buf.create_vector_from_iter(self.signs1.iter().copied().map(sign_to_bool));
346
347 let subsample = self
349 .subsample
350 .as_ref()
351 .map(|indices| buf.create_vector(indices));
352
353 fb::transforms::DoubleHadamard::create(
354 buf,
355 &fb::transforms::DoubleHadamardArgs {
356 signs0: Some(signs0),
357 signs1: Some(signs1),
358 subsample,
359 },
360 )
361 }
362
363 pub(crate) fn try_unpack(
366 alloc: A,
367 proto: fb::transforms::DoubleHadamard<'_>,
368 ) -> Result<Self, DoubleHadamardError> {
369 let signs0 = Poly::from_iter(proto.signs0().iter().map(bool_to_sign), alloc.clone())?;
370 let signs1 = Poly::from_iter(proto.signs1().iter().map(bool_to_sign), alloc.clone())?;
371
372 let subsample = match proto.subsample() {
373 Some(subsample) => Some(Poly::from_iter(subsample.into_iter(), alloc)?),
374 None => None,
375 };
376
377 Self::try_from_parts(signs0, signs1, subsample)
378 }
379}
380
381#[cfg(test)]
386mod tests {
387 use diskann_utils::lazy_format;
388 use rand::{SeedableRng, rngs::StdRng};
389
390 use super::*;
391 use crate::{
392 algorithms::transforms::{Transform, TransformKind, test_utils},
393 alloc::GlobalAllocator,
394 };
395
396 test_utils::delegate_transformer!(DoubleHadamard<GlobalAllocator>);
397
398 #[test]
399 fn test_double_hadamard() {
400 let natural_errors = test_utils::ErrorSetup {
406 norm: test_utils::Check::ulp(5),
407 l2: test_utils::Check::ulp(5),
408 ip: test_utils::Check::absrel(2.5e-5, 2e-4),
409 };
410
411 let subsampled_errors = test_utils::ErrorSetup {
417 norm: test_utils::Check::absrel(0.0, 2e-2),
418 l2: test_utils::Check::absrel(0.0, 2e-2),
419 ip: test_utils::Check::skip(),
420 };
421
422 let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
423 let dim_combos = [
424 (15, 15, true, TargetDim::Same, &natural_errors),
426 (15, 15, true, TargetDim::Natural, &natural_errors),
427 (16, 16, true, TargetDim::Same, &natural_errors),
428 (16, 16, true, TargetDim::Natural, &natural_errors),
429 (256, 256, true, TargetDim::Same, &natural_errors),
430 (1000, 1000, true, TargetDim::Same, &natural_errors),
431 (15, 16, true, target_dim(16), &natural_errors),
433 (100, 128, true, target_dim(128), &natural_errors),
434 (15, 32, true, target_dim(32), &natural_errors),
435 (16, 64, true, target_dim(64), &natural_errors),
436 (1024, 1023, false, target_dim(1023), &subsampled_errors),
438 (1000, 999, false, target_dim(999), &subsampled_errors),
439 ];
440
441 let trials_per_combo = 20;
442 let trials_per_dim = 100;
443
444 let mut rng = StdRng::seed_from_u64(0x6d1699abe066147);
445 for (input, output, preserves_norms, target, errors) in dim_combos {
446 let input_nz = NonZeroUsize::new(input).unwrap();
447 for trial in 0..trials_per_combo {
448 let ctx = &lazy_format!(
449 "input dim = {}, output dim = {}, macro trial {} of {}",
450 input,
451 output,
452 trial,
453 trials_per_combo
454 );
455
456 let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
457 let d = input.min(output);
458 assert_ne!(&io.input0[..d], &io.output0[..d]);
459 assert_ne!(&io.input1[..d], &io.output1[..d]);
460 test_utils::check_errors(io, context, errors);
461 };
462
463 let mut rng_clone = rng.clone();
465
466 {
468 let transformer = DoubleHadamard::new(
469 NonZeroUsize::new(input).unwrap(),
470 target,
471 &mut rng,
472 GlobalAllocator,
473 )
474 .unwrap();
475
476 assert_eq!(transformer.input_dim(), input);
477 assert_eq!(transformer.output_dim(), output);
478 assert_eq!(transformer.preserves_norms(), preserves_norms);
479
480 test_utils::test_transform(
481 &transformer,
482 trials_per_dim,
483 &mut checker,
484 &mut rng,
485 ctx,
486 )
487 }
488
489 {
491 let kind = TransformKind::DoubleHadamard { target_dim: target };
492 let transformer =
493 Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
494 .unwrap();
495
496 assert_eq!(transformer.input_dim(), input);
497 assert_eq!(transformer.output_dim(), output);
498 assert_eq!(transformer.preserves_norms(), preserves_norms);
499
500 test_utils::test_transform(
501 &transformer,
502 trials_per_dim,
503 &mut checker,
504 &mut rng_clone,
505 ctx,
506 )
507 }
508 }
509 }
510 }
511
512 #[cfg(feature = "flatbuffers")]
513 mod serialization {
514 use super::*;
515 use crate::flatbuffers::to_flatbuffer;
516
517 #[test]
518 fn double_hadamard() {
519 let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
520 let alloc = GlobalAllocator;
521
522 let test_cases = [
524 (5, TargetDim::Same),
526 (8, TargetDim::Same),
527 (10, TargetDim::Natural),
528 (16, TargetDim::Natural),
529 (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
531 (10, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
532 (15, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
534 (16, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
535 (15, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
536 (16, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
537 (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
539 (16, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
540 ];
541
542 for (dim, target_dim) in test_cases {
543 let transform = DoubleHadamard::new(
544 NonZeroUsize::new(dim).unwrap(),
545 target_dim,
546 &mut rng,
547 alloc,
548 )
549 .unwrap();
550 let data = to_flatbuffer(|buf| transform.pack(buf));
551
552 let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
553 let reloaded = DoubleHadamard::try_unpack(alloc, proto).unwrap();
554
555 assert_eq!(transform, reloaded);
556 }
557
558 let gen_err = |x: DoubleHadamard<_>| -> DoubleHadamardError {
559 let data = to_flatbuffer(|buf| x.pack(buf));
560 let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
561 DoubleHadamard::try_unpack(alloc, proto).unwrap_err()
562 };
563
564 type E = DoubleHadamardError;
565 let error_cases = [
566 (
568 vec![0, 0, 0, 0, 0], vec![0, 0, 0, 0], 4,
571 None,
572 E::Signs1TooSmall,
573 ),
574 (
576 vec![], vec![0, 0, 0, 0],
578 4,
579 None,
580 E::Signs0Empty,
581 ),
582 (
584 vec![0, 0, 0, 0],
585 vec![0, 0, 0, 0],
586 3,
587 Some(vec![0, 2, 1]), E::SubsampleNotMonotonic,
589 ),
590 (
592 vec![0, 0, 0, 0],
593 vec![0, 0, 0, 0],
594 3,
595 Some(vec![0, 1, 1]), E::SubsampleNotMonotonic,
597 ),
598 (
600 vec![0, 0, 0], vec![0, 0, 0], 2,
603 Some(vec![0, 3]), E::LastSubsampleTooLarge,
605 ),
606 (
608 vec![0, 0, 0], vec![0, 0, 0], 2,
611 Some(vec![]), E::InvalidSubsampleLength,
613 ),
614 ];
615
616 let poly = |v: &Vec<u32>| Poly::from_iter(v.iter().copied(), alloc).unwrap();
617
618 for (signs0, signs1, target_dim, subsample, expected) in error_cases.iter() {
619 println!(
620 "on case ({:?}, {:?}, {}, {:?})",
621 signs0, signs1, target_dim, subsample,
622 );
623 let err = gen_err(DoubleHadamard {
624 signs0: poly(signs0),
625 signs1: poly(signs1),
626 target_dim: *target_dim,
627 subsample: subsample.as_ref().map(poly),
628 });
629
630 assert_eq!(
631 err, *expected,
632 "failed for case ({:?}, {:?}, {}, {:?})",
633 signs0, signs1, target_dim, subsample
634 );
635 }
636 }
637 }
638}