use std::num::NonZeroUsize;
#[cfg(feature = "flatbuffers")]
use flatbuffers::{FlatBufferBuilder, WIPOffset};
use rand::{
Rng,
distr::{Distribution, StandardUniform},
};
use thiserror::Error;
#[cfg(feature = "flatbuffers")]
use super::utils::{bool_to_sign, sign_to_bool};
use super::{
TargetDim,
utils::{TransformFailed, check_dims, is_sign, subsample_indices},
};
#[cfg(feature = "flatbuffers")]
use crate::flatbuffers as fb;
use crate::{
algorithms::hadamard_transform,
alloc::{Allocator, AllocatorError, Poly, ScopedAllocator, TryClone},
utils,
};
#[derive(Debug)]
#[cfg_attr(test, derive(PartialEq))]
pub struct DoubleHadamard<A>
where
A: Allocator,
{
signs0: Poly<[u32], A>,
signs1: Poly<[u32], A>,
target_dim: usize,
subsample: Option<Poly<[u32], A>>,
}
impl<A> DoubleHadamard<A>
where
A: Allocator,
{
pub fn new<R>(
dim: NonZeroUsize,
target_dim: TargetDim,
rng: &mut R,
allocator: A,
) -> Result<Self, AllocatorError>
where
R: Rng + ?Sized,
{
let dim = dim.get();
let target_dim = match target_dim {
TargetDim::Override(target) => target.get(),
TargetDim::Same => dim,
TargetDim::Natural => dim,
};
let intermediate_dim = dim.max(target_dim);
let mut sample = |_: usize| {
let sign: bool = StandardUniform {}.sample(rng);
if sign { 0x8000_0000 } else { 0 }
};
let signs0 = Poly::from_iter((0..dim).map(&mut sample), allocator.clone())?;
let signs1 = Poly::from_iter((0..intermediate_dim).map(&mut sample), allocator.clone())?;
let subsample = if dim > target_dim {
Some(subsample_indices(rng, dim, target_dim, allocator)?)
} else {
None
};
Ok(Self {
signs0,
signs1,
target_dim,
subsample,
})
}
pub fn try_from_parts(
signs0: Poly<[u32], A>,
signs1: Poly<[u32], A>,
subsample: Option<Poly<[u32], A>>,
) -> Result<Self, DoubleHadamardError> {
type E = DoubleHadamardError;
if signs0.is_empty() {
return Err(E::Signs0Empty);
}
if signs1.len() < signs0.len() {
return Err(E::Signs1TooSmall);
}
if !signs0.iter().copied().all(is_sign) {
return Err(E::Signs0Invalid);
}
if !signs1.iter().copied().all(is_sign) {
return Err(E::Signs1Invalid);
}
let target_dim = if let Some(ref subsample) = subsample {
if !utils::is_strictly_monotonic(subsample.iter()) {
return Err(E::SubsampleNotMonotonic);
}
match subsample.last() {
Some(last) => {
if *last as usize >= signs1.len() {
return Err(E::LastSubsampleTooLarge);
}
}
None => {
return Err(E::InvalidSubsampleLength);
}
}
debug_assert!(
subsample.len() < signs1.len(),
"since we've verified monotonicity and the last element, this is implied"
);
subsample.len()
} else {
signs1.len()
};
Ok(Self {
signs0,
signs1,
target_dim,
subsample,
})
}
pub fn input_dim(&self) -> usize {
self.signs0.len()
}
pub fn output_dim(&self) -> usize {
self.target_dim
}
pub fn preserves_norms(&self) -> bool {
self.subsample.is_none()
}
fn intermediate_dim(&self) -> usize {
self.input_dim().max(self.output_dim())
}
pub fn transform_into(
&self,
dst: &mut [f32],
src: &[f32],
allocator: ScopedAllocator<'_>,
) -> Result<(), TransformFailed> {
check_dims(dst, src, self.input_dim(), self.output_dim())?;
let intermediate_dim = self.intermediate_dim();
let mut tmp = Poly::broadcast(0.0f32, intermediate_dim, allocator)?;
std::iter::zip(tmp.iter_mut(), src.iter())
.zip(self.signs0.iter())
.for_each(|((dst, src), sign)| *dst = f32::from_bits(src.to_bits() ^ sign));
let split = 1usize << (usize::BITS - intermediate_dim.leading_zeros() - 1);
#[allow(clippy::unwrap_used)]
hadamard_transform(&mut tmp[..split]).unwrap();
tmp.iter_mut()
.zip(self.signs1.iter())
.for_each(|(dst, sign)| *dst = f32::from_bits(dst.to_bits() ^ sign));
#[allow(clippy::unwrap_used)]
hadamard_transform(&mut tmp[intermediate_dim - split..]).unwrap();
match self.subsample.as_ref() {
None => {
dst.copy_from_slice(&tmp);
}
Some(indices) => {
let rescale = ((tmp.len() as f32) / (indices.len() as f32)).sqrt();
debug_assert_eq!(dst.len(), indices.len());
dst.iter_mut()
.zip(indices.iter())
.for_each(|(d, s)| *d = tmp[*s as usize] * rescale);
}
}
Ok(())
}
}
impl<A> TryClone for DoubleHadamard<A>
where
A: Allocator,
{
fn try_clone(&self) -> Result<Self, AllocatorError> {
Ok(Self {
signs0: self.signs0.try_clone()?,
signs1: self.signs1.try_clone()?,
target_dim: self.target_dim,
subsample: self.subsample.try_clone()?,
})
}
}
#[derive(Debug, Clone, Copy, Error, PartialEq)]
#[non_exhaustive]
pub enum DoubleHadamardError {
#[error("first signs stage cannot be empty")]
Signs0Empty,
#[error("first signs stage has invalid coding")]
Signs0Invalid,
#[error("invalid sign representation for second stage")]
Signs1Invalid,
#[error("second sign stage must be at least as large as the first stage")]
Signs1TooSmall,
#[error("subsample length must equal `target_dim`")]
InvalidSubsampleLength,
#[error("subsample indices is not monotonic")]
SubsampleNotMonotonic,
#[error("last subsample index exceeded intermediate dim")]
LastSubsampleTooLarge,
#[error(transparent)]
AllocatorError(#[from] AllocatorError),
}
#[cfg(feature = "flatbuffers")]
impl<A> DoubleHadamard<A>
where
A: Allocator,
{
pub(crate) fn pack<'a, FA>(
&self,
buf: &mut FlatBufferBuilder<'a, FA>,
) -> WIPOffset<fb::transforms::DoubleHadamard<'a>>
where
FA: flatbuffers::Allocator + 'a,
{
let signs0 = buf.create_vector_from_iter(self.signs0.iter().copied().map(sign_to_bool));
let signs1 = buf.create_vector_from_iter(self.signs1.iter().copied().map(sign_to_bool));
let subsample = self
.subsample
.as_ref()
.map(|indices| buf.create_vector(indices));
fb::transforms::DoubleHadamard::create(
buf,
&fb::transforms::DoubleHadamardArgs {
signs0: Some(signs0),
signs1: Some(signs1),
subsample,
},
)
}
pub(crate) fn try_unpack(
alloc: A,
proto: fb::transforms::DoubleHadamard<'_>,
) -> Result<Self, DoubleHadamardError> {
let signs0 = Poly::from_iter(proto.signs0().iter().map(bool_to_sign), alloc.clone())?;
let signs1 = Poly::from_iter(proto.signs1().iter().map(bool_to_sign), alloc.clone())?;
let subsample = match proto.subsample() {
Some(subsample) => Some(Poly::from_iter(subsample.into_iter(), alloc)?),
None => None,
};
Self::try_from_parts(signs0, signs1, subsample)
}
}
#[cfg(test)]
#[cfg(not(miri))]
mod tests {
use diskann_utils::lazy_format;
use rand::{SeedableRng, rngs::StdRng};
use super::*;
use crate::{
algorithms::transforms::{Transform, TransformKind, test_utils},
alloc::GlobalAllocator,
test_util::Check,
};
test_utils::delegate_transformer!(DoubleHadamard<GlobalAllocator>);
#[test]
fn test_double_hadamard() {
let natural_errors = test_utils::ErrorSetup {
norm: Check::ulp(5),
l2: Check::ulp(5),
ip: Check::absrel(2.5e-5, 2e-4),
};
let subsampled_errors = test_utils::ErrorSetup {
norm: Check::absrel(0.0, 2e-2),
l2: Check::absrel(0.0, 2e-2),
ip: Check::skip(),
};
let target_dim = |v| TargetDim::Override(NonZeroUsize::new(v).unwrap());
let dim_combos = [
(15, 15, true, TargetDim::Same, &natural_errors),
(15, 15, true, TargetDim::Natural, &natural_errors),
(16, 16, true, TargetDim::Same, &natural_errors),
(16, 16, true, TargetDim::Natural, &natural_errors),
(256, 256, true, TargetDim::Same, &natural_errors),
(1000, 1000, true, TargetDim::Same, &natural_errors),
(15, 16, true, target_dim(16), &natural_errors),
(100, 128, true, target_dim(128), &natural_errors),
(15, 32, true, target_dim(32), &natural_errors),
(16, 64, true, target_dim(64), &natural_errors),
(1024, 1023, false, target_dim(1023), &subsampled_errors),
(1000, 999, false, target_dim(999), &subsampled_errors),
];
let trials_per_combo = 20;
let trials_per_dim = 100;
let mut rng = StdRng::seed_from_u64(0x6d1699abe066147);
for (input, output, preserves_norms, target, errors) in dim_combos {
let input_nz = NonZeroUsize::new(input).unwrap();
for trial in 0..trials_per_combo {
let ctx = &lazy_format!(
"input dim = {}, output dim = {}, macro trial {} of {}",
input,
output,
trial,
trials_per_combo
);
let mut checker = |io: test_utils::IO<'_>, context: &dyn std::fmt::Display| {
let d = input.min(output);
assert_ne!(&io.input0[..d], &io.output0[..d]);
assert_ne!(&io.input1[..d], &io.output1[..d]);
test_utils::check_errors(io, context, errors);
};
let mut rng_clone = rng.clone();
{
let transformer = DoubleHadamard::new(
NonZeroUsize::new(input).unwrap(),
target,
&mut rng,
GlobalAllocator,
)
.unwrap();
assert_eq!(transformer.input_dim(), input);
assert_eq!(transformer.output_dim(), output);
assert_eq!(transformer.preserves_norms(), preserves_norms);
test_utils::test_transform(
&transformer,
trials_per_dim,
&mut checker,
&mut rng,
ctx,
)
}
{
let kind = TransformKind::DoubleHadamard { target_dim: target };
let transformer =
Transform::new(kind, input_nz, Some(&mut rng_clone), GlobalAllocator)
.unwrap();
assert_eq!(transformer.input_dim(), input);
assert_eq!(transformer.output_dim(), output);
assert_eq!(transformer.preserves_norms(), preserves_norms);
test_utils::test_transform(
&transformer,
trials_per_dim,
&mut checker,
&mut rng_clone,
ctx,
)
}
}
}
}
#[cfg(feature = "flatbuffers")]
mod serialization {
use super::*;
use crate::flatbuffers::to_flatbuffer;
#[test]
fn double_hadamard() {
let mut rng = StdRng::seed_from_u64(0x123456789abcdef0);
let alloc = GlobalAllocator;
let test_cases = [
(5, TargetDim::Same),
(8, TargetDim::Same),
(10, TargetDim::Natural),
(16, TargetDim::Natural),
(8, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
(10, TargetDim::Override(NonZeroUsize::new(12).unwrap())),
(15, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
(16, TargetDim::Override(NonZeroUsize::new(16).unwrap())),
(15, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
(16, TargetDim::Override(NonZeroUsize::new(32).unwrap())),
(15, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
(16, TargetDim::Override(NonZeroUsize::new(10).unwrap())),
];
for (dim, target_dim) in test_cases {
let transform = DoubleHadamard::new(
NonZeroUsize::new(dim).unwrap(),
target_dim,
&mut rng,
alloc,
)
.unwrap();
let data = to_flatbuffer(|buf| transform.pack(buf));
let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
let reloaded = DoubleHadamard::try_unpack(alloc, proto).unwrap();
assert_eq!(transform, reloaded);
}
let gen_err = |x: DoubleHadamard<_>| -> DoubleHadamardError {
let data = to_flatbuffer(|buf| x.pack(buf));
let proto = flatbuffers::root::<fb::transforms::DoubleHadamard>(&data).unwrap();
DoubleHadamard::try_unpack(alloc, proto).unwrap_err()
};
type E = DoubleHadamardError;
let error_cases = [
(
vec![0, 0, 0, 0, 0], vec![0, 0, 0, 0], 4,
None,
E::Signs1TooSmall,
),
(
vec![], vec![0, 0, 0, 0],
4,
None,
E::Signs0Empty,
),
(
vec![0, 0, 0, 0],
vec![0, 0, 0, 0],
3,
Some(vec![0, 2, 1]), E::SubsampleNotMonotonic,
),
(
vec![0, 0, 0, 0],
vec![0, 0, 0, 0],
3,
Some(vec![0, 1, 1]), E::SubsampleNotMonotonic,
),
(
vec![0, 0, 0], vec![0, 0, 0], 2,
Some(vec![0, 3]), E::LastSubsampleTooLarge,
),
(
vec![0, 0, 0], vec![0, 0, 0], 2,
Some(vec![]), E::InvalidSubsampleLength,
),
];
let poly = |v: &Vec<u32>| Poly::from_iter(v.iter().copied(), alloc).unwrap();
for (signs0, signs1, target_dim, subsample, expected) in error_cases.iter() {
println!(
"on case ({:?}, {:?}, {}, {:?})",
signs0, signs1, target_dim, subsample,
);
let err = gen_err(DoubleHadamard {
signs0: poly(signs0),
signs1: poly(signs1),
target_dim: *target_dim,
subsample: subsample.as_ref().map(poly),
});
assert_eq!(
err, *expected,
"failed for case ({:?}, {:?}, {}, {:?})",
signs0, signs1, target_dim, subsample
);
}
}
}
}