use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use super::splitmix64::SplitMix64;
const F32_SIGN_BIT: u32 = 0x8000_0000;
pub struct SorfMatrix {
sign_masks: Vec<u32>,
num_rounds: usize,
padded_dim: usize,
norm_factor: f32,
}
impl SorfMatrix {
pub fn try_new(seed: u64, dimensions: usize, num_rounds: usize) -> VortexResult<Self> {
Self::try_new_padded(dimensions.next_power_of_two(), num_rounds, seed)
}
pub(crate) fn try_new_padded(
padded_dimensions: usize,
num_rounds: usize,
seed: u64,
) -> VortexResult<Self> {
vortex_ensure!(num_rounds >= 1, "num_rounds must be >= 1, got {num_rounds}");
vortex_ensure!(
padded_dimensions.is_power_of_two(),
"padded_dimensions must be a power of two, got {padded_dimensions}"
);
let padded_dim = padded_dimensions;
let sign_masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds);
#[expect(
clippy::cast_possible_truncation,
reason = "the norm factor is in (0, 1] so the f64 -> f32 cast cannot overflow"
)]
let norm_factor = (padded_dim as f64).powf(-(num_rounds as f64) / 2.0) as f32;
Ok(Self {
sign_masks,
num_rounds,
padded_dim,
norm_factor,
})
}
pub fn padded_dim(&self) -> usize {
self.padded_dim
}
pub fn rotate(&self, input: &[f32], output: &mut [f32]) {
debug_assert_eq!(input.len(), self.padded_dim);
debug_assert_eq!(output.len(), self.padded_dim);
output.copy_from_slice(input);
self.apply_srht(output);
}
pub fn inverse_rotate(&self, input: &[f32], output: &mut [f32]) {
debug_assert_eq!(input.len(), self.padded_dim);
debug_assert_eq!(output.len(), self.padded_dim);
output.copy_from_slice(input);
self.apply_inverse_srht(output);
}
fn apply_srht(&self, buf: &mut [f32]) {
for round in 0..self.num_rounds {
self.apply_signs_xor(buf, round);
walsh_hadamard_transform(buf);
}
let norm = self.norm_factor;
buf.iter_mut().for_each(|val| *val *= norm);
}
fn apply_inverse_srht(&self, buf: &mut [f32]) {
for round in (0..self.num_rounds).rev() {
walsh_hadamard_transform(buf);
self.apply_signs_xor(buf, round);
}
let norm = self.norm_factor;
buf.iter_mut().for_each(|val| *val *= norm);
}
fn apply_signs_xor(&self, buf: &mut [f32], round: usize) {
let masks = &self.sign_masks[round * self.padded_dim..][..self.padded_dim];
for (val, &mask) in buf.iter_mut().zip(masks.iter()) {
*val = f32::from_bits(val.to_bits() ^ mask);
}
}
#[cfg(test)]
pub fn export_inverse_signs_u8(&self) -> Vec<u8> {
let total = self.num_rounds * self.padded_dim;
let mut out = Vec::with_capacity(total);
for round in (0..self.num_rounds).rev() {
let offset = round * self.padded_dim;
for &mask in &self.sign_masks[offset..offset + self.padded_dim] {
out.push(if mask == 0 { 1u8 } else { 0u8 });
}
}
out
}
#[cfg(test)]
pub fn from_u8_slice(
signs_u8: &[u8],
dimension: usize,
num_rounds: usize,
) -> VortexResult<Self> {
vortex_ensure!(num_rounds >= 1, "num_rounds must be >= 1, got {num_rounds}");
let padded_dim = dimension.next_power_of_two();
vortex_ensure!(
signs_u8.len() == num_rounds * padded_dim,
"Expected {} sign bytes, got {}",
num_rounds * padded_dim,
signs_u8.len()
);
let mut sign_masks = vec![0u32; num_rounds * padded_dim];
for storage_idx in 0..num_rounds {
let round = num_rounds - 1 - storage_idx;
let src_offset = storage_idx * padded_dim;
let dst_offset = round * padded_dim;
for i in 0..padded_dim {
sign_masks[dst_offset + i] = if signs_u8[src_offset + i] != 0 {
0u32
} else {
F32_SIGN_BIT
};
}
}
#[expect(
clippy::cast_possible_truncation,
reason = "the norm factor is in (0, 1] so the f64 -> f32 cast cannot overflow"
)]
let norm_factor = (padded_dim as f64).powf(-(num_rounds as f64) / 2.0) as f32;
Ok(Self {
sign_masks,
num_rounds,
padded_dim,
norm_factor,
})
}
}
fn gen_sign_masks_from_seed(seed: u64, padded_dim: usize, num_rounds: usize) -> Vec<u32> {
let mut rng = SplitMix64::new(seed);
let mut sign_masks = Vec::with_capacity(num_rounds * padded_dim);
for _round in 0..num_rounds {
for base_idx in (0..padded_dim).step_by(64) {
let word = rng.next_u64();
let bits_in_block = (padded_dim - base_idx).min(64);
sign_masks.extend((0..bits_in_block).map(|bit_idx| sign_mask_from_word(word, bit_idx)));
}
}
sign_masks
}
fn sign_mask_from_word(word: u64, bit_idx: usize) -> u32 {
if ((word >> bit_idx) & 1) != 0 {
0u32
} else {
F32_SIGN_BIT
}
}
fn walsh_hadamard_transform(buf: &mut [f32]) {
let len = buf.len();
debug_assert!(len.is_power_of_two());
let mut half = 1;
while half < len {
let stride = half * 2;
for chunk in buf.chunks_exact_mut(stride) {
let (lo, hi) = chunk.split_at_mut(half);
butterfly(lo, hi);
}
half *= 2;
}
}
fn butterfly(lo: &mut [f32], hi: &mut [f32]) {
debug_assert_eq!(lo.len(), hi.len());
for (a, b) in lo.iter_mut().zip(hi.iter_mut()) {
let sum = *a + *b;
let diff = *a - *b;
*a = sum;
*b = diff;
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_error::VortexResult;
use super::*;
use crate::scalar_fns::sorf_transform::splitmix64::SplitMix64;
fn unpack_sign_bits(word: u64, count: usize) -> Vec<u8> {
(0..count)
.map(|bit_idx| u8::from(((word >> bit_idx) & 1) != 0))
.collect()
}
fn dim_to_usize(dim: u32) -> usize {
usize::try_from(dim).unwrap()
}
fn rounds_to_usize(num_rounds: u8) -> usize {
usize::from(num_rounds)
}
#[test]
fn deterministic_from_seed() -> VortexResult<()> {
let dim = dim_to_usize(64u32);
let num_rounds = rounds_to_usize(3u8);
let r1 = SorfMatrix::try_new(42u64, dim, num_rounds)?;
let r2 = SorfMatrix::try_new(42u64, dim, num_rounds)?;
let pd = r1.padded_dim();
let mut input = vec![0.0f32; pd];
for i in 0..dim {
input[i] = i as f32;
}
let mut out1 = vec![0.0f32; pd];
let mut out2 = vec![0.0f32; pd];
r1.rotate(&input, &mut out1);
r2.rotate(&input, &mut out2);
assert_eq!(out1, out2);
Ok(())
}
#[test]
fn export_inverse_signs_matches_golden_words() -> VortexResult<()> {
let dim = dim_to_usize(64u32);
let num_rounds = rounds_to_usize(2u8);
let seed = 42u64;
let rot = SorfMatrix::try_new(seed, dim, num_rounds)?;
let padded_dim = rot.padded_dim();
let actual = rot.export_inverse_signs_u8();
let mut rng = SplitMix64::new(seed);
let round0_word = rng.next_u64();
let round1_word = rng.next_u64();
let mut expected = Vec::with_capacity(num_rounds * padded_dim);
expected.extend(unpack_sign_bits(round1_word, padded_dim));
expected.extend(unpack_sign_bits(round0_word, padded_dim));
assert_eq!(actual, expected);
Ok(())
}
#[test]
fn one_word_generates_64_signs_lsb_first() {
let seed = 42u64;
let padded_dim = dim_to_usize(64u32);
let num_rounds = rounds_to_usize(1u8);
let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds);
assert_eq!(masks.len(), padded_dim);
let mut rng = SplitMix64::new(seed);
let word = rng.next_u64();
let expected: Vec<_> = (0..padded_dim)
.map(|bit_idx| sign_mask_from_word(word, bit_idx))
.collect();
assert_eq!(masks, expected);
}
#[test]
fn accepts_non_power_of_two_dimensions() -> VortexResult<()> {
let rot = SorfMatrix::try_new(42u64, dim_to_usize(100u32), rounds_to_usize(3u8))?;
assert_eq!(rot.padded_dim(), 128);
Ok(())
}
#[test]
fn tail_block_uses_only_required_bits() {
let seed = 42u64;
let padded_dim = dim_to_usize(32u32);
let num_rounds = rounds_to_usize(1u8);
let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds);
assert_eq!(masks.len(), padded_dim);
let mut rng = SplitMix64::new(seed);
let word = rng.next_u64();
let expected: Vec<_> = (0..padded_dim)
.map(|bit_idx| sign_mask_from_word(word, bit_idx))
.collect();
assert_eq!(masks, expected);
}
#[rstest]
#[case(32u32, 3u8)]
#[case(64u32, 3u8)]
#[case(100u32, 3u8)]
#[case(128u32, 1u8)]
#[case(128u32, 2u8)]
#[case(128u32, 3u8)]
#[case(128u32, 5u8)]
#[case(256u32, 3u8)]
#[case(512u32, 3u8)]
#[case(768u32, 3u8)]
#[case(1024u32, 3u8)]
fn roundtrip_exact(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> {
let dim = dim_to_usize(dim);
let num_rounds = rounds_to_usize(num_rounds);
let rot = SorfMatrix::try_new(42u64, dim, num_rounds)?;
let padded_dim = rot.padded_dim();
let mut input = vec![0.0f32; padded_dim];
for i in 0..dim {
input[i] = (i as f32 + 1.0) * 0.01;
}
let mut rotated = vec![0.0f32; padded_dim];
let mut recovered = vec![0.0f32; padded_dim];
rot.rotate(&input, &mut rotated);
rot.inverse_rotate(&rotated, &mut recovered);
let max_err: f32 = input
.iter()
.zip(recovered.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let max_val: f32 = input.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let rel_err = max_err / max_val;
assert!(
rel_err < 1e-5,
"roundtrip relative error too large for dim={dim}, rounds={num_rounds}: {rel_err:.2e}"
);
Ok(())
}
#[rstest]
#[case(128u32, 1u8)]
#[case(128u32, 3u8)]
#[case(128u32, 5u8)]
#[case(768u32, 3u8)]
fn preserves_norm(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> {
let dim = dim_to_usize(dim);
let num_rounds = rounds_to_usize(num_rounds);
let rot = SorfMatrix::try_new(42u64, dim, num_rounds)?;
let padded_dim = rot.padded_dim();
let mut input = vec![0.0f32; padded_dim];
for i in 0..dim {
input[i] = (i as f32) * 0.01;
}
let input_norm: f32 = input.iter().map(|x| x * x).sum::<f32>().sqrt();
let mut rotated = vec![0.0f32; padded_dim];
rot.rotate(&input, &mut rotated);
let rotated_norm: f32 = rotated.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(input_norm - rotated_norm).abs() / input_norm < 1e-5,
"norm not preserved for dim={dim}: {} vs {} (rel err: {:.2e})",
input_norm,
rotated_norm,
(input_norm - rotated_norm).abs() / input_norm
);
Ok(())
}
#[rstest]
#[case(64u32, 3u8)]
#[case(128u32, 1u8)]
#[case(128u32, 3u8)]
#[case(128u32, 5u8)]
#[case(768u32, 3u8)]
fn sign_export_import_roundtrip(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> {
let dim = dim_to_usize(dim);
let num_rounds = rounds_to_usize(num_rounds);
let rot = SorfMatrix::try_new(42u64, dim, num_rounds)?;
let padded_dim = rot.padded_dim();
let signs_u8 = rot.export_inverse_signs_u8();
let rot2 = SorfMatrix::from_u8_slice(&signs_u8, dim, num_rounds)?;
let mut input = vec![0.0f32; padded_dim];
for i in 0..dim {
input[i] = (i as f32 + 1.0) * 0.01;
}
let mut out1 = vec![0.0f32; padded_dim];
let mut out2 = vec![0.0f32; padded_dim];
rot.rotate(&input, &mut out1);
rot2.rotate(&input, &mut out2);
assert_eq!(out1, out2, "Forward transform mismatch after export/import");
rot.inverse_rotate(&out1, &mut out2);
let mut out3 = vec![0.0f32; padded_dim];
rot2.inverse_rotate(&out1, &mut out3);
assert_eq!(out2, out3, "Inverse transform mismatch after export/import");
Ok(())
}
#[test]
fn wht_basic() {
let mut buf = vec![1.0f32, 0.0, 0.0, 0.0];
walsh_hadamard_transform(&mut buf);
assert_eq!(buf, vec![1.0, 1.0, 1.0, 1.0]);
walsh_hadamard_transform(&mut buf);
assert_eq!(buf, vec![4.0, 0.0, 0.0, 0.0]);
}
}