use rand::RngExt;
use rand::SeedableRng;
use rand::rngs::StdRng;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
const F32_SIGN_BIT: u32 = 0x8000_0000;
pub struct RotationMatrix {
sign_masks: Vec<u32>,
num_rounds: usize,
padded_dim: usize,
norm_factor: f32,
}
impl RotationMatrix {
pub fn try_new(seed: u64, dimension: usize, num_rounds: usize) -> VortexResult<Self> {
vortex_ensure!(num_rounds >= 1, "num_rounds must be >= 1, got {num_rounds}");
let padded_dim = dimension.next_power_of_two();
let mut rng = StdRng::seed_from_u64(seed);
let mut sign_masks = Vec::with_capacity(num_rounds * padded_dim);
for _ in 0..num_rounds {
sign_masks.extend(gen_random_sign_masks(&mut rng, padded_dim));
}
#[expect(
clippy::cast_possible_truncation,
reason = "Intentional f64 -> f32 truncation for normalization factor."
)]
let norm_factor = (padded_dim as f64).powf(-(num_rounds as f64) / 2.0) as f32;
Ok(Self {
sign_masks,
num_rounds,
padded_dim,
norm_factor,
})
}
pub fn rotate(&self, input: &[f32], output: &mut [f32]) {
debug_assert_eq!(input.len(), self.padded_dim);
debug_assert_eq!(output.len(), self.padded_dim);
output.copy_from_slice(input);
self.apply_srht(output);
}
pub fn inverse_rotate(&self, input: &[f32], output: &mut [f32]) {
debug_assert_eq!(input.len(), self.padded_dim);
debug_assert_eq!(output.len(), self.padded_dim);
output.copy_from_slice(input);
self.apply_inverse_srht(output);
}
pub fn num_rounds(&self) -> usize {
self.num_rounds
}
pub fn padded_dim(&self) -> usize {
self.padded_dim
}
fn apply_srht(&self, buf: &mut [f32]) {
for round in 0..self.num_rounds {
let offset = round * self.padded_dim;
apply_signs_xor(buf, &self.sign_masks[offset..offset + self.padded_dim]);
walsh_hadamard_transform(buf);
}
let norm = self.norm_factor;
buf.iter_mut().for_each(|val| *val *= norm);
}
fn apply_inverse_srht(&self, buf: &mut [f32]) {
for round in (0..self.num_rounds).rev() {
walsh_hadamard_transform(buf);
let offset = round * self.padded_dim;
apply_signs_xor(buf, &self.sign_masks[offset..offset + self.padded_dim]);
}
let norm = self.norm_factor;
buf.iter_mut().for_each(|val| *val *= norm);
}
pub fn export_inverse_signs_u8(&self) -> Vec<u8> {
let total = self.num_rounds * self.padded_dim;
let mut out = Vec::with_capacity(total);
for round in (0..self.num_rounds).rev() {
let offset = round * self.padded_dim;
for &mask in &self.sign_masks[offset..offset + self.padded_dim] {
out.push(if mask == 0 { 1u8 } else { 0u8 });
}
}
out
}
pub fn from_u8_slice(
signs_u8: &[u8],
dimension: usize,
num_rounds: usize,
) -> VortexResult<Self> {
vortex_ensure!(num_rounds >= 1, "num_rounds must be >= 1, got {num_rounds}");
let padded_dim = dimension.next_power_of_two();
vortex_ensure!(
signs_u8.len() == num_rounds * padded_dim,
"Expected {} sign bytes, got {}",
num_rounds * padded_dim,
signs_u8.len()
);
let mut sign_masks = vec![0u32; num_rounds * padded_dim];
for storage_idx in 0..num_rounds {
let round = num_rounds - 1 - storage_idx;
let src_offset = storage_idx * padded_dim;
let dst_offset = round * padded_dim;
for i in 0..padded_dim {
sign_masks[dst_offset + i] = if signs_u8[src_offset + i] != 0 {
0u32
} else {
F32_SIGN_BIT
};
}
}
#[expect(
clippy::cast_possible_truncation,
reason = "Intentional f64 -> f32 truncation for normalization factor."
)]
let norm_factor = (padded_dim as f64).powf(-(num_rounds as f64) / 2.0) as f32;
Ok(Self {
sign_masks,
num_rounds,
padded_dim,
norm_factor,
})
}
}
fn gen_random_sign_masks(rng: &mut StdRng, len: usize) -> Vec<u32> {
(0..len)
.map(|_| {
if rng.random_bool(0.5) {
0u32 } else {
F32_SIGN_BIT }
})
.collect()
}
#[inline]
fn apply_signs_xor(buf: &mut [f32], masks: &[u32]) {
for (val, &mask) in buf.iter_mut().zip(masks.iter()) {
*val = f32::from_bits(val.to_bits() ^ mask);
}
}
fn walsh_hadamard_transform(buf: &mut [f32]) {
let len = buf.len();
debug_assert!(len.is_power_of_two());
let mut half = 1;
while half < len {
let stride = half * 2;
for chunk in buf.chunks_exact_mut(stride) {
let (lo, hi) = chunk.split_at_mut(half);
butterfly(lo, hi);
}
half *= 2;
}
}
#[inline(always)]
fn butterfly(lo: &mut [f32], hi: &mut [f32]) {
debug_assert_eq!(lo.len(), hi.len());
for (a, b) in lo.iter_mut().zip(hi.iter_mut()) {
let sum = *a + *b;
let diff = *a - *b;
*a = sum;
*b = diff;
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_error::VortexResult;
use super::*;
#[test]
fn deterministic_from_seed() -> VortexResult<()> {
let r1 = RotationMatrix::try_new(42, 64, 3)?;
let r2 = RotationMatrix::try_new(42, 64, 3)?;
let pd = r1.padded_dim();
let mut input = vec![0.0f32; pd];
for i in 0..64 {
input[i] = i as f32;
}
let mut out1 = vec![0.0f32; pd];
let mut out2 = vec![0.0f32; pd];
r1.rotate(&input, &mut out1);
r2.rotate(&input, &mut out2);
assert_eq!(out1, out2);
Ok(())
}
#[rstest]
#[case(32, 3)]
#[case(64, 3)]
#[case(100, 3)]
#[case(128, 1)]
#[case(128, 2)]
#[case(128, 3)]
#[case(128, 5)]
#[case(256, 3)]
#[case(512, 3)]
#[case(768, 3)]
#[case(1024, 3)]
fn roundtrip_exact(#[case] dim: usize, #[case] num_rounds: usize) -> VortexResult<()> {
let rot = RotationMatrix::try_new(42, dim, num_rounds)?;
let padded_dim = rot.padded_dim();
let mut input = vec![0.0f32; padded_dim];
for i in 0..dim {
input[i] = (i as f32 + 1.0) * 0.01;
}
let mut rotated = vec![0.0f32; padded_dim];
let mut recovered = vec![0.0f32; padded_dim];
rot.rotate(&input, &mut rotated);
rot.inverse_rotate(&rotated, &mut recovered);
let max_err: f32 = input
.iter()
.zip(recovered.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let max_val: f32 = input.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let rel_err = max_err / max_val;
assert!(
rel_err < 1e-5,
"roundtrip relative error too large for dim={dim}, rounds={num_rounds}: {rel_err:.2e}"
);
Ok(())
}
#[rstest]
#[case(128, 1)]
#[case(128, 3)]
#[case(128, 5)]
#[case(768, 3)]
fn preserves_norm(#[case] dim: usize, #[case] num_rounds: usize) -> VortexResult<()> {
let rot = RotationMatrix::try_new(7, dim, num_rounds)?;
let padded_dim = rot.padded_dim();
let mut input = vec![0.0f32; padded_dim];
for i in 0..dim {
input[i] = (i as f32) * 0.01;
}
let input_norm: f32 = input.iter().map(|x| x * x).sum::<f32>().sqrt();
let mut rotated = vec![0.0f32; padded_dim];
rot.rotate(&input, &mut rotated);
let rotated_norm: f32 = rotated.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(input_norm - rotated_norm).abs() / input_norm < 1e-5,
"norm not preserved for dim={dim}: {} vs {} (rel err: {:.2e})",
input_norm,
rotated_norm,
(input_norm - rotated_norm).abs() / input_norm
);
Ok(())
}
#[rstest]
#[case(64, 3)]
#[case(128, 1)]
#[case(128, 3)]
#[case(128, 5)]
#[case(768, 3)]
fn sign_export_import_roundtrip(
#[case] dim: usize,
#[case] num_rounds: usize,
) -> VortexResult<()> {
let rot = RotationMatrix::try_new(42, dim, num_rounds)?;
let padded_dim = rot.padded_dim();
let signs_u8 = rot.export_inverse_signs_u8();
let rot2 = RotationMatrix::from_u8_slice(&signs_u8, dim, num_rounds)?;
let mut input = vec![0.0f32; padded_dim];
for i in 0..dim {
input[i] = (i as f32 + 1.0) * 0.01;
}
let mut out1 = vec![0.0f32; padded_dim];
let mut out2 = vec![0.0f32; padded_dim];
rot.rotate(&input, &mut out1);
rot2.rotate(&input, &mut out2);
assert_eq!(out1, out2, "Forward rotation mismatch after export/import");
rot.inverse_rotate(&out1, &mut out2);
let mut out3 = vec![0.0f32; padded_dim];
rot2.inverse_rotate(&out1, &mut out3);
assert_eq!(out2, out3, "Inverse rotation mismatch after export/import");
Ok(())
}
#[test]
fn wht_basic() {
let mut buf = vec![1.0f32, 0.0, 0.0, 0.0];
walsh_hadamard_transform(&mut buf);
assert_eq!(buf, vec![1.0, 1.0, 1.0, 1.0]);
walsh_hadamard_transform(&mut buf);
assert_eq!(buf, vec![4.0, 0.0, 0.0, 0.0]);
}
}