1use std::num::NonZeroUsize;
7
8#[cfg(feature = "flatbuffers")]
9use flatbuffers::{FlatBufferBuilder, WIPOffset};
10use rand::{
11 Rng,
12 distr::{Distribution, StandardUniform},
13};
14use thiserror::Error;
15
16#[cfg(feature = "flatbuffers")]
17use super::utils::{bool_to_sign, sign_to_bool};
18use super::{
19 TargetDim,
20 utils::{TransformFailed, check_dims, is_sign, subsample_indices},
21};
22#[cfg(feature = "flatbuffers")]
23use crate::flatbuffers as fb;
24use crate::{
25 algorithms::hadamard_transform,
26 alloc::{Allocator, AllocatorError, Poly, ScopedAllocator, TryClone},
27 utils,
28};
29
30#[derive(Debug)]
55#[cfg_attr(test, derive(PartialEq))]
56pub struct DoubleHadamard<A>
57where
58 A: Allocator,
59{
60 signs0: Poly<[u32], A>,
67 signs1: Poly<[u32], A>,
68
69 target_dim: usize,
71
72 subsample: Option<Poly<[u32], A>>,
74}
75
76impl<A> DoubleHadamard<A>
77where
78 A: Allocator,
79{
80 pub fn new<R>(
99 dim: NonZeroUsize,
100 target_dim: TargetDim,
101 rng: &mut R,
102 allocator: A,
103 ) -> Result<Self, AllocatorError>
104 where
105 R: Rng + ?Sized,
106 {
107 let dim = dim.get();
108
109 let target_dim = match target_dim {
110 TargetDim::Override(target) => target.get(),
111 TargetDim::Same => dim,
112 TargetDim::Natural => dim,
113 };
114
115 let intermediate_dim = dim.max(target_dim);
120
121 let mut sample = |_: usize| {
123 let sign: bool = StandardUniform {}.sample(rng);
124 if sign { 0x8000_0000 } else { 0 }
125 };
126
127 let signs0 = Poly::from_iter((0..dim).map(&mut sample), allocator.clone())?;
130 let signs1 = Poly::from_iter((0..intermediate_dim).map(&mut sample), allocator.clone())?;
131
132 let subsample = if dim > target_dim {
133 Some(subsample_indices(rng, dim, target_dim, allocator)?)
134 } else {
135 None
136 };
137
138 Ok(Self {
139 signs0,
140 signs1,
141 target_dim,
142 subsample,
143 })
144 }
145
146 pub fn try_from_parts(
147 signs0: Poly<[u32], A>,
148 signs1: Poly<[u32], A>,
149 subsample: Option<Poly<[u32], A>>,
150 ) -> Result<Self, DoubleHadamardError> {
151 type E = DoubleHadamardError;
152 if signs0.is_empty() {
153 return Err(E::Signs0Empty);
154 }
155 if signs1.len() < signs0.len() {
156 return Err(E::Signs1TooSmall);
157 }
158 if !signs0.iter().copied().all(is_sign) {
159 return Err(E::Signs0Invalid);
160 }
161 if !signs1.iter().copied().all(is_sign) {
162 return Err(E::Signs1Invalid);
163 }
164
165 let target_dim = if let Some(ref subsample) = subsample {
167 if !utils::is_strictly_monotonic(subsample.iter()) {
168 return Err(E::SubsampleNotMonotonic);
169 }
170
171 match subsample.last() {
172 Some(last) => {
173 if *last as usize >= signs1.len() {
174 return Err(E::LastSubsampleTooLarge);
181 }
182 }
183 None => {
184 return Err(E::InvalidSubsampleLength);
186 }
187 }
188
189 debug_assert!(
190 subsample.len() < signs1.len(),
191 "since we've verified monotonicity and the last element, this is implied"
192 );
193
194 subsample.len()
195 } else {
196 signs1.len()
198 };
199
200 Ok(Self {
201 signs0,
202 signs1,
203 target_dim,
204 subsample,
205 })
206 }
207
208 pub fn input_dim(&self) -> usize {
210 self.signs0.len()
211 }
212
213 pub fn output_dim(&self) -> usize {
215 self.target_dim
216 }
217
218 pub fn preserves_norms(&self) -> bool {
223 self.subsample.is_none()
224 }
225
226 fn intermediate_dim(&self) -> usize {
227 self.input_dim().max(self.output_dim())
228 }
229
230 pub fn transform_into(
239 &self,
240 dst: &mut [f32],
241 src: &[f32],
242 allocator: ScopedAllocator<'_>,
243 ) -> Result<(), TransformFailed> {
244 check_dims(dst, src, self.input_dim(), self.output_dim())?;
245
246 let intermediate_dim = self.intermediate_dim();
248 let mut tmp = Poly::broadcast(0.0f32, intermediate_dim, allocator)?;
249
250 std::iter::zip(tmp.iter_mut(), src.iter())
251 .zip(self.signs0.iter())
252 .for_each(|((dst, src), sign)| *dst = f32::from_bits(src.to_bits() ^ sign));
253
254 let split = 1usize << (usize::BITS - intermediate_dim.leading_zeros() - 1);
255
256 #[allow(clippy::unwrap_used)]
261 hadamard_transform(&mut tmp[..split]).unwrap();
262
263 tmp.iter_mut()
267 .zip(self.signs1.iter())
268 .for_each(|(dst, sign)| *dst = f32::from_bits(dst.to_bits() ^ sign));
269
270 #[allow(clippy::unwrap_used)]
271 hadamard_transform(&mut tmp[intermediate_dim - split..]).unwrap();
272
273 match self.subsample.as_ref() {
274 None => {
275 dst.copy_from_slice(&tmp);
276 }
277 Some(indices) => {
278 let rescale = ((tmp.len() as f32) / (indices.len() as f32)).sqrt();
279 debug_assert_eq!(dst.len(), indices.len());
280 dst.iter_mut()
281 .zip(indices.iter())
282 .for_each(|(d, s)| *d = tmp[*s as usize] * rescale);
283 }
284 }
285
286 Ok(())
287 }
288}
289
290impl<A> TryClone for DoubleHadamard<A>
291where
292 A: Allocator,
293{
294 fn try_clone(&self) -> Result<Self, AllocatorError> {
295 Ok(Self {
296 signs0: self.signs0.try_clone()?,
297 signs1: self.signs1.try_clone()?,
298 target_dim: self.target_dim,
299 subsample: self.subsample.try_clone()?,
300 })
301 }
302}
303
304#[derive(Debug, Clone, Copy, Error, PartialEq)]
305#[non_exhaustive]
306pub enum DoubleHadamardError {
307 #[error("first signs stage cannot be empty")]
308 Signs0Empty,
309 #[error("first signs stage has invalid coding")]
310 Signs0Invalid,
311
312 #[error("invalid sign representation for second stage")]
313 Signs1Invalid,
314 #[error("second sign stage must be at least as large as the first stage")]
315 Signs1TooSmall,
316
317 #[error("subsample length must equal `target_dim`")]
318 InvalidSubsampleLength,
319 #[error("subsample indices is not monotonic")]
320 SubsampleNotMonotonic,
321 #[error("last subsample index exceeded intermediate dim")]
322 LastSubsampleTooLarge,
323
324 #[error(transparent)]
325 AllocatorError(#[from] AllocatorError),
326}
327
328#[cfg(feature = "flatbuffers")]
330impl<A> DoubleHadamard<A>
331where
332 A: Allocator,
333{
334 pub(crate) fn pack<'a, FA>(
337 &self,
338 buf: &mut FlatBufferBuilder<'a, FA>,
339 ) -> WIPOffset<fb::transforms::DoubleHadamard<'a>>
340 where
341 FA: flatbuffers::Allocator + 'a,
342 {
343 let signs0 = buf.create_vector_from_iter(self.signs0.iter().copied().map(sign_to_bool));
345 let signs1 = buf.create_vector_from_iter(self.signs1.iter().copied().map(sign_to_bool));
346
347 let subsample = self
349 .subsample
350 .as_ref()
351 .map(|indices| buf.create_vector(indices));
352
353 fb::transforms::DoubleHadamard::create(
354 buf,
355 &fb::transforms::DoubleHadamardArgs {
356 signs0: Some(signs0),
357 signs1: Some(signs1),
358 subsample,
359 },
360 )
361 }
362
363 pub(crate) fn try_unpack(
366 alloc: A,
367 proto: fb::transforms::DoubleHadamard<'_>,
368 ) -> Result<Self, DoubleHadamardError> {
369 let signs0 = Poly::from_iter(proto.signs0().iter().map(bool_to_sign), alloc.clone())?;
370 let signs1 = Poly::from_iter(proto.signs1().iter().map(bool_to_sign), alloc.clone())?;
371
372 let subsample = match proto.subsample() {
373 Some(subsample) => Some(Poly::from_iter(subsample.into_iter(), alloc)?),
374 None => None,
375 };
376
377 Self::try_from_parts(signs0, signs1, subsample)
378 }
379}
380
381#[cfg(test)]
386#[cfg(not(miri))]
387mod tests {
388 use diskann_utils::lazy_format;
389 use rand::{SeedableRng, rngs::StdRng};
390
391 use super::*;
392 use crate::{
393 algorithms::transforms::{Transform, TransformKind, test_utils},
394 alloc::GlobalAllocator,
395 };
396
397 test_utils::delegate_transformer!(DoubleHadamard<GlobalAllocator>);
398
399 #[test]
400 fn test_double_hadamard() {
401 let natural_errors = test_utils::ErrorSetup {
407 norm: test_utils::Check::ulp(5),
408 l2: test_utils::Check::ulp(5),
409 ip: test_utils::Check::absrel(2.5e-5, 2e-4),
410 };
411
412 let subsampled_errors = test_utils::ErrorSetup {
418 norm: test_utils::Check::absrel(0.0, 2e-2),
419 l2: test_utils::Check::absrel(0.0, 2e-2),
420 ip: test_utils::Check::skip(),
421 };
422
423 let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
424 let dim_combos = [
425 (15, 15, true, TargetDim::Same, &natural_errors),
427 (15, 15, true, TargetDim::Natural, &natural_errors),
428 (16, 16, true, TargetDim::Same, &natural_errors),
429 (16, 16, true, TargetDim::Natural, &natural_errors),
430 (256, 256, true, TargetDim::Same, &natural_errors),
431 (1000, 1000, true, TargetDim::Same, &natural_errors),
432 (15, 16, true, target_dim(16), &natural_errors),
434 (100, 128, true, target_dim(128), &natural_errors),
435 (15, 32, true, target_dim(32), &natural_errors),
436 (16, 64, true, target_dim(64), &natural_errors),
437 (1024, 1023, false, target_dim(1023), &subsampled_errors),
439 (1000, 999, false, target_dim(999), &subsampled_errors),
440 ];
441
442 let trials_per_combo = 20;
443 let trials_per_dim = 100;
444
445 let mut rng = StdRng::seed_from_u64(0x6d1699abe066147);
446 for (input, output, preserves_norms, target, errors) in dim_combos {
447 let input_nz = NonZeroUsize::new(input).unwrap();
448 for trial in 0..trials_per_combo {
449 let ctx = &lazy_format!(
450 "input dim = {}, output dim = {}, macro trial {} of {}",
451 input,
452 output,
453 trial,
454 trials_per_combo
455 );
456
457 let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
458 let d = input.min(output);
459 assert_ne!(&io.input0[..d], &io.output0[..d]);
460 assert_ne!(&io.input1[..d], &io.output1[..d]);
461 test_utils::check_errors(io, context, errors);
462 };
463
464 let mut rng_clone = rng.clone();
466
467 {
469 let transformer = DoubleHadamard::new(
470 NonZeroUsize::new(input).unwrap(),
471 target,
472 &mut rng,
473 GlobalAllocator,
474 )
475 .unwrap();
476
477 assert_eq!(transformer.input_dim(), input);
478 assert_eq!(transformer.output_dim(), output);
479 assert_eq!(transformer.preserves_norms(), preserves_norms);
480
481 test_utils::test_transform(
482 &transformer,
483 trials_per_dim,
484 &mut checker,
485 &mut rng,
486 ctx,
487 )
488 }
489
490 {
492 let kind = TransformKind::DoubleHadamard { target_dim: target };
493 let transformer =
494 Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
495 .unwrap();
496
497 assert_eq!(transformer.input_dim(), input);
498 assert_eq!(transformer.output_dim(), output);
499 assert_eq!(transformer.preserves_norms(), preserves_norms);
500
501 test_utils::test_transform(
502 &transformer,
503 trials_per_dim,
504 &mut checker,
505 &mut rng_clone,
506 ctx,
507 )
508 }
509 }
510 }
511 }
512
513 #[cfg(feature = "flatbuffers")]
514 mod serialization {
515 use super::*;
516 use crate::flatbuffers::to_flatbuffer;
517
518 #[test]
519 fn double_hadamard() {
520 let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
521 let alloc = GlobalAllocator;
522
523 let test_cases = [
525 (5, TargetDim::Same),
527 (8, TargetDim::Same),
528 (10, TargetDim::Natural),
529 (16, TargetDim::Natural),
530 (8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
532 (10, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
533 (15, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
535 (16, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
536 (15, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
537 (16, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
538 (15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
540 (16, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
541 ];
542
543 for (dim, target_dim) in test_cases {
544 let transform = DoubleHadamard::new(
545 NonZeroUsize::new(dim).unwrap(),
546 target_dim,
547 &mut rng,
548 alloc,
549 )
550 .unwrap();
551 let data = to_flatbuffer(|buf| transform.pack(buf));
552
553 let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
554 let reloaded = DoubleHadamard::try_unpack(alloc, proto).unwrap();
555
556 assert_eq!(transform, reloaded);
557 }
558
559 let gen_err = |x: DoubleHadamard<_>| -> DoubleHadamardError {
560 let data = to_flatbuffer(|buf| x.pack(buf));
561 let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
562 DoubleHadamard::try_unpack(alloc, proto).unwrap_err()
563 };
564
565 type E = DoubleHadamardError;
566 let error_cases = [
567 (
569 vec![0, 0, 0, 0, 0], vec![0, 0, 0, 0], 4,
572 None,
573 E::Signs1TooSmall,
574 ),
575 (
577 vec![], vec![0, 0, 0, 0],
579 4,
580 None,
581 E::Signs0Empty,
582 ),
583 (
585 vec![0, 0, 0, 0],
586 vec![0, 0, 0, 0],
587 3,
588 Some(vec![0, 2, 1]), E::SubsampleNotMonotonic,
590 ),
591 (
593 vec![0, 0, 0, 0],
594 vec![0, 0, 0, 0],
595 3,
596 Some(vec![0, 1, 1]), E::SubsampleNotMonotonic,
598 ),
599 (
601 vec![0, 0, 0], vec![0, 0, 0], 2,
604 Some(vec![0, 3]), E::LastSubsampleTooLarge,
606 ),
607 (
609 vec![0, 0, 0], vec![0, 0, 0], 2,
612 Some(vec![]), E::InvalidSubsampleLength,
614 ),
615 ];
616
617 let poly = |v: &Vec<u32>| Poly::from_iter(v.iter().copied(), alloc).unwrap();
618
619 for (signs0, signs1, target_dim, subsample, expected) in error_cases.iter() {
620 println!(
621 "on case ({:?}, {:?}, {}, {:?})",
622 signs0, signs1, target_dim, subsample,
623 );
624 let err = gen_err(DoubleHadamard {
625 signs0: poly(signs0),
626 signs1: poly(signs1),
627 target_dim: *target_dim,
628 subsample: subsample.as_ref().map(poly),
629 });
630
631 assert_eq!(
632 err, *expected,
633 "failed for case ({:?}, {:?}, {}, {:?})",
634 signs0, signs1, target_dim, subsample
635 );
636 }
637 }
638 }
639}