use crate::error::{require, values_match, Result, TurboQuantError};
use crate::packed::is_valid_dim;
pub fn wht_inplace(data: &mut [f32]) {
let n = data.len();
let mut h = 1;
while h < n {
for i in (0..n).step_by(h * 2) {
for j in i..i + h {
let x = data[j];
let y = data[j + h];
data[j] = x + y;
data[j + h] = x - y;
}
}
h *= 2;
}
let norm = 1.0 / (n as f32).sqrt();
for v in data.iter_mut() {
*v *= norm;
}
}
const GOLDEN_RATIO: u64 = 0x9E37_79B9_7F4A_7C15;
fn golden_ratio_hash(seed: u64, index: usize) -> u64 {
let combined = seed.wrapping_add(index as u64);
combined.wrapping_mul(GOLDEN_RATIO)
}
pub fn generate_sign_pattern(dim: usize, seed: u64) -> Vec<f32> {
(0..dim)
.map(|i| {
if golden_ratio_hash(seed, i) & 1 == 0 {
1.0_f32
} else {
-1.0_f32
}
})
.collect()
}
fn apply_sign_flip(data: &mut [f32], signs: &[f32]) {
for (v, &s) in data.iter_mut().zip(signs.iter()) {
*v *= s;
}
}
fn validate_rotation_inputs(data_len: usize, sign_len: usize) -> Result<()> {
require(
is_valid_dim(data_len),
TurboQuantError::InvalidDimension(data_len),
)?;
require(
values_match(data_len, sign_len),
TurboQuantError::DimensionMismatch {
expected: data_len,
actual: sign_len,
},
)
}
pub enum RotationOrder {
Forward,
Inverse,
}
pub fn rotate(data: &mut [f32], sign_pattern: &[f32], order: RotationOrder) -> Result<()> {
validate_rotation_inputs(data.len(), sign_pattern.len())?;
match order {
RotationOrder::Forward => {
apply_sign_flip(data, sign_pattern);
wht_inplace(data);
}
RotationOrder::Inverse => {
wht_inplace(data);
apply_sign_flip(data, sign_pattern);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_SMALL_DIM: usize = 4;
const TEST_SEED: u64 = 42;
const TEST_SIGN_PATTERN_DIM: usize = 128;
const TEST_SIGN_PATTERN_SEED: u64 = 99;
const TEST_HASH_INDEX: usize = 7;
#[test]
fn is_valid_dim_accepts_powers_of_two() {
assert!(is_valid_dim(1));
assert!(is_valid_dim(2));
assert!(is_valid_dim(64));
assert!(is_valid_dim(256));
}
#[test]
fn is_valid_dim_rejects_invalid() {
assert!(!is_valid_dim(0));
assert!(!is_valid_dim(3));
assert!(!is_valid_dim(100));
}
#[test]
fn golden_ratio_hash_is_deterministic() {
let a = golden_ratio_hash(TEST_SEED, TEST_HASH_INDEX);
let b = golden_ratio_hash(TEST_SEED, TEST_HASH_INDEX);
assert_eq!(a, b);
}
#[test]
fn sign_pattern_elements_are_plus_or_minus_one() {
let pattern = generate_sign_pattern(TEST_SIGN_PATTERN_DIM, TEST_SIGN_PATTERN_SEED);
assert_sign_values_valid(&pattern);
}
fn assert_sign_values_valid(pattern: &[f32]) {
for &v in pattern {
assert!(v == 1.0 || v == -1.0);
}
}
#[test]
fn rotate_accepts_matching_pow2_dims() {
let mut data = vec![1.0, 2.0, 3.0, 4.0];
let signs = generate_sign_pattern(TEST_SMALL_DIM, TEST_SEED);
assert!(rotate(&mut data, &signs, RotationOrder::Forward).is_ok());
}
#[test]
fn rotate_rejects_non_pow2() {
let mut data = vec![1.0, 2.0, 3.0];
let signs = vec![1.0, -1.0, 1.0];
assert!(rotate(&mut data, &signs, RotationOrder::Forward).is_err());
}
#[test]
fn rotate_rejects_mismatched_lengths() {
let mut data = vec![1.0, 2.0, 3.0, 4.0];
let signs = vec![1.0, -1.0];
assert!(rotate(&mut data, &signs, RotationOrder::Forward).is_err());
}
#[test]
fn wht_inplace_known_vector() {
let mut data = vec![1.0, 1.0, 1.0, 1.0];
wht_inplace(&mut data);
assert!((data[0] - 2.0).abs() < 1e-6);
assert!((data[1] - 0.0).abs() < 1e-6);
assert!((data[2] - 0.0).abs() < 1e-6);
assert!((data[3] - 0.0).abs() < 1e-6);
}
#[test]
fn wht_inplace_self_inverse() {
let original = vec![1.0, 2.0, 3.0, 4.0];
let mut data = original.clone();
wht_inplace(&mut data);
wht_inplace(&mut data);
for (a, b) in data.iter().zip(original.iter()) {
assert!((a - b).abs() < 1e-5);
}
}
#[test]
fn apply_sign_flip_basic() {
let mut data = vec![2.0, -3.0, 4.0, -5.0];
let signs = vec![1.0, -1.0, -1.0, 1.0];
apply_sign_flip(&mut data, &signs);
assert_eq!(data, vec![2.0, 3.0, -4.0, -5.0]);
}
#[test]
fn values_match_equal() {
assert!(values_match(4, 4));
assert!(values_match(128, 128));
}
#[test]
fn values_match_unequal() {
assert!(!values_match(4, 8));
assert!(!values_match(0, 1));
}
#[test]
fn rotate_inverse_rotate_roundtrip() {
let original = vec![1.0, 2.0, 3.0, 4.0];
let mut data = original.clone();
let signs = generate_sign_pattern(TEST_SMALL_DIM, TEST_SEED);
rotate(&mut data, &signs, RotationOrder::Forward).unwrap();
rotate(&mut data, &signs, RotationOrder::Inverse).unwrap();
for (a, b) in data.iter().zip(original.iter()) {
assert!((a - b).abs() < 1e-5);
}
}
}