1use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use rand::{
11 Rng,
12 distr::{Distribution, StandardUniform},
13};
14use thiserror::Error;
15
16#[cfg(feature = "flatbuffers")]
17use super::utils::{bool_to_sign, sign_to_bool};
18use super::{
19 TargetDim,
20 utils::{TransformFailed, check_dims, is_sign, subsample_indices},
21};
22#[cfg(feature = "flatbuffers")]
23use crate::flatbuffers as fb;
24use crate::{
25 algorithms::hadamard_transform,
26 alloc::{Allocator, AllocatorError, Poly, ScopedAllocator, TryClone},
27 utils,
28};
29
30#[derive(Debug)]
55#[cfg_attr(test, derive(PartialEq))]
56pub struct DoubleHadamard<A>
57where
58 A: Allocator,
59{
60 signs0: Poly<[u32], A>,
67 signs1: Poly<[u32], A>,
68
69 target_dim: usize,
71
72 subsample: Option<Poly<[u32], A>>,
74}
75
76impl<A> DoubleHadamard<A>
77where
78 A: Allocator,
79{
80 pub fn new<R>(
99 dim: NonZeroUsize,
100 target_dim: TargetDim,
101 rng: &mut R,
102 allocator: A,
103 ) -> Result<Self, AllocatorError>
104 where
105 R: Rng + ?Sized,
106 {
107 let dim = dim.get();
108
109 let target_dim = match target_dim {
110 TargetDim::Override(target) => target.get(),
111 TargetDim::Same => dim,
112 TargetDim::Natural => dim,
113 };
114
115 let intermediate_dim = dim.max(target_dim);
120
121 let mut sample = |_: usize| {
123 let sign: bool = StandardUniform {}.sample(rng);
124 if sign { 0x8000_0000 } else { 0 }
125 };
126
127 let signs0 = Poly::from_iter((0..dim).map(&mut sample), allocator.clone())?;
130 let signs1 = Poly::from_iter((0..intermediate_dim).map(&mut sample), allocator.clone())?;
131
132 let subsample = if dim > target_dim {
133 Some(subsample_indices(rng, dim, target_dim, allocator)?)
134 } else {
135 None
136 };
137
138 Ok(Self {
139 signs0,
140 signs1,
141 target_dim,
142 subsample,
143 })
144 }
145
146 pub fn try_from_parts(
147 signs0: Poly<[u32], A>,
148 signs1: Poly<[u32], A>,
149 subsample: Option<Poly<[u32], A>>,
150 ) -> Result<Self, DoubleHadamardError> {
151 type E = DoubleHadamardError;
152 if signs0.is_empty() {
153 return Err(E::Signs0Empty);
154 }
155 if signs1.len() < signs0.len() {
156 return Err(E::Signs1TooSmall);
157 }
158 if !signs0.iter().copied().all(is_sign) {
159 return Err(E::Signs0Invalid);
160 }
161 if !signs1.iter().copied().all(is_sign) {
162 return Err(E::Signs1Invalid);
163 }
164
165 let target_dim = if let Some(ref subsample) = subsample {
167 if !utils::is_strictly_monotonic(subsample.iter()) {
168 return Err(E::SubsampleNotMonotonic);
169 }
170
171 match subsample.last() {
172 Some(last) => {
173 if *last as usize >= signs1.len() {
174 return Err(E::LastSubsampleTooLarge);
181 }
182 }
183 None => {
184 return Err(E::InvalidSubsampleLength);
186 }
187 }
188
189 debug_assert!(
190 subsample.len() < signs1.len(),
191 "since we've verified monotonicity and the last element, this is implied"
192 );
193
194 subsample.len()
195 } else {
196 signs1.len()
198 };
199
200 Ok(Self {
201 signs0,
202 signs1,
203 target_dim,
204 subsample,
205 })
206 }
207
208 pub fn input_dim(&self) -> usize {
210 self.signs0.len()
211 }
212
213 pub fn output_dim(&self) -> usize {
215 self.target_dim
216 }
217
218 pub fn preserves_norms(&self) -> bool {
223 self.subsample.is_none()
224 }
225
226 fn intermediate_dim(&self) -> usize {
227 self.input_dim().max(self.output_dim())
228 }
229
230 pub fn transform_into(
239 &self,
240 dst: &mut [f32],
241 src: &[f32],
242 allocator: ScopedAllocator<'_>,
243 ) -> Result<(), TransformFailed> {
244 check_dims(dst, src, self.input_dim(), self.output_dim())?;
245
246 let intermediate_dim = self.intermediate_dim();
248 let mut tmp = Poly::broadcast(0.0f32, intermediate_dim, allocator)?;
249
250 std::iter::zip(tmp.iter_mut(), src.iter())
251 .zip(self.signs0.iter())
252 .for_each(|((dst, src), sign)| *dst = f32::from_bits(src.to_bits() ^ sign));
253
254 let split = 1usize << (usize::BITS - intermediate_dim.leading_zeros() - 1);
255
256 #[allow(clippy::unwrap_used)]
261 hadamard_transform(&mut tmp[..split]).unwrap();
262
263 tmp.iter_mut()
267 .zip(self.signs1.iter())
268 .for_each(|(dst, sign)| *dst = f32::from_bits(dst.to_bits() ^ sign));
269
270 #[allow(clippy::unwrap_used)]
271 hadamard_transform(&mut tmp[intermediate_dim - split..]).unwrap();
272
273 match self.subsample.as_ref() {
274 None => {
275 dst.copy_from_slice(&tmp);
276 }
277 Some(indices) => {
278 let rescale = ((tmp.len() as f32) / (indices.len() as f32)).sqrt();
279 debug_assert_eq!(dst.len(), indices.len());
280 dst.iter_mut()
281 .zip(indices.iter())
282 .for_each(|(d, s)| *d = tmp[*s as usize] * rescale);
283 }
284 }
285
286 Ok(())
287 }
288}
289
290impl<A> TryClone for DoubleHadamard<A>
291where
292 A: Allocator,
293{
294 fn try_clone(&self) -> Result<Self, AllocatorError> {
295 Ok(Self {
296 signs0: self.signs0.try_clone()?,
297 signs1: self.signs1.try_clone()?,
298 target_dim: self.target_dim,
299 subsample: self.subsample.try_clone()?,
300 })
301 }
302}
303
304#[derive(Debug, Clone, Copy, Error, PartialEq)]
305#[non_exhaustive]
306pub enum DoubleHadamardError {
307 #[error("first signs stage cannot be empty")]
308 Signs0Empty,
309 #[error("first signs stage has invalid coding")]
310 Signs0Invalid,
311
312 #[error("invalid sign representation for second stage")]
313 Signs1Invalid,
314 #[error("second sign stage must be at least as large as the first stage")]
315 Signs1TooSmall,
316
317 #[error("subsample length must equal `target_dim`")]
318 InvalidSubsampleLength,
319 #[error("subsample indices is not monotonic")]
320 SubsampleNotMonotonic,
321 #[error("last subsample index exceeded intermediate dim")]
322 LastSubsampleTooLarge,
323
324 #[error(transparent)]
325 AllocatorError(#[from] AllocatorError),
326}
327
328#[cfg(feature = "flatbuffers")]
330impl<A> DoubleHadamard<A>
331where
332 A: Allocator,
333{
334 pub(crate) fn pack<'a, FA>(
337 &self,
338 buf: &mut FlatBufferBuilder<'a, FA>,
339 ) -> WIPOffset<fb::transforms::DoubleHadamard<'a>>
340 where
341 FA: flatbuffers::Allocator + 'a,
342 {
343 let signs0 = buf.create_vector_from_iter(self.signs0.iter().copied().map(sign_to_bool));
345 let signs1 = buf.create_vector_from_iter(self.signs1.iter().copied().map(sign_to_bool));
346
347 let subsample = self
349 .subsample
350 .as_ref()
351 .map(|indices| buf.create_vector(indices));
352
353 fb::transforms::DoubleHadamard::create(
354 buf,
355 &fb::transforms::DoubleHadamardArgs {
356 signs0: Some(signs0),
357 signs1: Some(signs1),
358 subsample,
359 },
360 )
361 }
362
363 pub(crate) fn try_unpack(
366 alloc: A,
367 proto: fb::transforms::DoubleHadamard<'_>,
368 ) -> Result<Self, DoubleHadamardError> {
369 let signs0 = Poly::from_iter(proto.signs0().iter().map(bool_to_sign), alloc.clone())?;
370 let signs1 = Poly::from_iter(proto.signs1().iter().map(bool_to_sign), alloc.clone())?;
371
372 let subsample = match proto.subsample() {
373 Some(subsample) => Some(Poly::from_iter(subsample.into_iter(), alloc)?),
374 None => None,
375 };
376
377 Self::try_from_parts(signs0, signs1, subsample)
378 }
379}
380
381#[cfg(test)]
386#[cfg(not(miri))]
387mod tests {
388 use diskann_utils::lazy_format;
389 use rand::{SeedableRng, rngs::StdRng};
390
391 use super::*;
392 use crate::{
393 algorithms::transforms::{Transform, TransformKind, test_utils},
394 alloc::GlobalAllocator,
395 test_util::Check,
396 };
397
398 test_utils::delegate_transformer!(DoubleHadamard<GlobalAllocator>);
399
400 #[test]
401 fn test_double_hadamard() {
402 let natural_errors = test_utils::ErrorSetup {
408 norm: Check::ulp(5),
409 l2: Check::ulp(5),
410 ip: Check::absrel(2.5e-5, 2e-4),
411 };
412
413 let subsampled_errors = test_utils::ErrorSetup {
419 norm: Check::absrel(0.0, 2e-2),
420 l2: Check::absrel(0.0, 2e-2),
421 ip: Check::skip(),
422 };
423
424 let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
425 let dim_combos = [
426 (15, 15, true, TargetDim::Same, &natural_errors),
428 (15, 15, true, TargetDim::Natural, &natural_errors),
429 (16, 16, true, TargetDim::Same, &natural_errors),
430 (16, 16, true, TargetDim::Natural, &natural_errors),
431 (256, 256, true, TargetDim::Same, &natural_errors),
432 (1000, 1000, true, TargetDim::Same, &natural_errors),
433 (15, 16, true, target_dim(16), &natural_errors),
435 (100, 128, true, target_dim(128), &natural_errors),
436 (15, 32, true, target_dim(32), &natural_errors),
437 (16, 64, true, target_dim(64), &natural_errors),
438 (1024, 1023, false, target_dim(1023), &subsampled_errors),
440 (1000, 999, false, target_dim(999), &subsampled_errors),
441 ];
442
443 let trials_per_combo = 20;
444 let trials_per_dim = 100;
445
446 let mut rng = StdRng::seed_from_u64(0x6d1699abe066147);
447 for (input, output, preserves_norms, target, errors) in dim_combos {
448 let input_nz = NonZeroUsize::new(input).unwrap();
449 for trial in 0..trials_per_combo {
450 let ctx = &lazy_format!(
451 "input dim = {}, output dim = {}, macro trial {} of {}",
452 input,
453 output,
454 trial,
455 trials_per_combo
456 );
457
458 let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
459 let d = input.min(output);
460 assert_ne!(&io.input0[..d], &io.output0[..d]);
461 assert_ne!(&io.input1[..d], &io.output1[..d]);
462 test_utils::check_errors(io, context, errors);
463 };
464
465 let mut rng_clone = rng.clone();
467
468 {
470 let transformer = DoubleHadamard::new(
471 NonZeroUsize::new(input).unwrap(),
472 target,
473 &mut rng,
474 GlobalAllocator,
475 )
476 .unwrap();
477
478 assert_eq!(transformer.input_dim(), input);
479 assert_eq!(transformer.output_dim(), output);
480 assert_eq!(transformer.preserves_norms(), preserves_norms);
481
482 test_utils::test_transform(
483 &transformer,
484 trials_per_dim,
485 &mut checker,
486 &mut rng,
487 ctx,
488 )
489 }
490
491 {
493 let kind = TransformKind::DoubleHadamard { target_dim: target };
494 let transformer =
495 Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
496 .unwrap();
497
498 assert_eq!(transformer.input_dim(), input);
499 assert_eq!(transformer.output_dim(), output);
500 assert_eq!(transformer.preserves_norms(), preserves_norms);
501
502 test_utils::test_transform(
503 &transformer,
504 trials_per_dim,
505 &mut checker,
506 &mut rng_clone,
507 ctx,
508 )
509 }
510 }
511 }
512 }
513
514 #[cfg(feature = "flatbuffers")]
515 mod serialization {
516 use super::*;
517 use crate::flatbuffers::to_flatbuffer;
518
519 #[test]
520 fn double_hadamard() {
521 let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
522 let alloc = GlobalAllocator;
523
524 let test_cases = [
526 (5, TargetDim::Same),
528 (8, TargetDim::Same),
529 (10, TargetDim::Natural),
530 (16, TargetDim::Natural),
531 (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
533 (10, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
534 (15, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
536 (16, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
537 (15, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
538 (16, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
539 (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
541 (16, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
542 ];
543
544 for (dim, target_dim) in test_cases {
545 let transform = DoubleHadamard::new(
546 NonZeroUsize::new(dim).unwrap(),
547 target_dim,
548 &mut rng,
549 alloc,
550 )
551 .unwrap();
552 let data = to_flatbuffer(|buf| transform.pack(buf));
553
554 let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
555 let reloaded = DoubleHadamard::try_unpack(alloc, proto).unwrap();
556
557 assert_eq!(transform, reloaded);
558 }
559
560 let gen_err = |x: DoubleHadamard<_>| -> DoubleHadamardError {
561 let data = to_flatbuffer(|buf| x.pack(buf));
562 let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
563 DoubleHadamard::try_unpack(alloc, proto).unwrap_err()
564 };
565
566 type E = DoubleHadamardError;
567 let error_cases = [
568 (
570 vec![0, 0, 0, 0, 0], vec![0, 0, 0, 0], 4,
573 None,
574 E::Signs1TooSmall,
575 ),
576 (
578 vec![], vec![0, 0, 0, 0],
580 4,
581 None,
582 E::Signs0Empty,
583 ),
584 (
586 vec![0, 0, 0, 0],
587 vec![0, 0, 0, 0],
588 3,
589 Some(vec![0, 2, 1]), E::SubsampleNotMonotonic,
591 ),
592 (
594 vec![0, 0, 0, 0],
595 vec![0, 0, 0, 0],
596 3,
597 Some(vec![0, 1, 1]), E::SubsampleNotMonotonic,
599 ),
600 (
602 vec![0, 0, 0], vec![0, 0, 0], 2,
605 Some(vec![0, 3]), E::LastSubsampleTooLarge,
607 ),
608 (
610 vec![0, 0, 0], vec![0, 0, 0], 2,
613 Some(vec![]), E::InvalidSubsampleLength,
615 ),
616 ];
617
618 let poly = |v: &Vec<u32>| Poly::from_iter(v.iter().copied(), alloc).unwrap();
619
620 for (signs0, signs1, target_dim, subsample, expected) in error_cases.iter() {
621 println!(
622 "on case ({:?}, {:?}, {}, {:?})",
623 signs0, signs1, target_dim, subsample,
624 );
625 let err = gen_err(DoubleHadamard {
626 signs0: poly(signs0),
627 signs1: poly(signs1),
628 target_dim: *target_dim,
629 subsample: subsample.as_ref().map(poly),
630 });
631
632 assert_eq!(
633 err, *expected,
634 "failed for case ({:?}, {:?}, {}, {:?})",
635 signs0, signs1, target_dim, subsample
636 );
637 }
638 }
639 }
640}