use crate::error::{Result, TurboQuantError};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand::Rng;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-support", derive(serde::Serialize, serde::Deserialize))]
pub struct WhtRotation {
dim: usize,
padded_dim: usize,
signs: Vec<f32>,
seed: u64,
}
impl WhtRotation {
pub fn new(dim: usize, seed: u64) -> Result<Self> {
if dim == 0 {
return Err(TurboQuantError::ZeroDimension);
}
let padded_dim = dim.next_power_of_two();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let signs: Vec<f32> = (0..padded_dim)
.map(|_| if rng.gen::<bool>() { 1.0_f32 } else { -1.0_f32 })
.collect();
Ok(Self { dim, padded_dim, signs, seed })
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn seed(&self) -> u64 {
self.seed
}
pub fn size_bytes(&self) -> usize {
self.signs.len() * core::mem::size_of::<f32>()
}
pub fn forward(&self, vector: &[f32]) -> Vec<f32> {
assert_eq!(vector.len(), self.dim, "WhtRotation::forward: expected {} dims, got {}", self.dim, vector.len());
let mut buf = vec![0.0_f32; self.padded_dim];
buf[..self.dim].copy_from_slice(vector);
for (x, &s) in buf.iter_mut().zip(self.signs.iter()) {
*x *= s;
}
fwht_normalized(&mut buf);
buf.truncate(self.dim);
buf
}
pub fn inverse(&self, vector: &[f32]) -> Vec<f32> {
assert_eq!(vector.len(), self.dim, "WhtRotation::inverse: expected {} dims, got {}", self.dim, vector.len());
let mut buf = vec![0.0_f32; self.padded_dim];
buf[..self.dim].copy_from_slice(vector);
fwht_normalized(&mut buf);
for (x, &s) in buf.iter_mut().zip(self.signs.iter()) {
*x *= s;
}
buf.truncate(self.dim);
buf
}
#[inline]
pub fn apply_slice(&self, slice: &[f32], out: &mut Vec<f32>) {
let result = self.forward(slice);
*out = result;
}
#[inline]
pub fn apply_inverse_slice(&self, slice: &[f32], out: &mut Vec<f32>) {
let result = self.inverse(slice);
*out = result;
}
}
pub(crate) fn fwht_in_place(data: &mut [f32]) {
let n = data.len();
assert!(n.is_power_of_two(), "fwht_in_place: length {} must be power of 2", n);
let mut half = 1;
while half < n {
let mut i = 0;
while i < n {
for j in i..i + half {
let a = data[j];
let b = data[j + half];
data[j] = a + b; data[j + half] = a - b; }
i += half * 2;
}
half *= 2;
}
}
pub(crate) fn fwht_normalized(data: &mut [f32]) {
fwht_in_place(data);
let scale = 1.0 / crate::compat::math::sqrtf(data.len() as f32);
for x in data.iter_mut() {
*x *= scale;
}
}
impl crate::traits::RotationStrategy for WhtRotation {
fn rotate(&self, vector: &[f32]) -> Vec<f32> {
self.forward(vector)
}
fn rotate_inverse(&self, vector: &[f32]) -> Vec<f32> {
self.inverse(vector)
}
fn dim(&self) -> usize {
self.dim
}
}
impl PartialEq for WhtRotation {
fn eq(&self, other: &Self) -> bool {
self.dim == other.dim && self.seed == other.seed
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fwht_power_of_two() {
let mut data = [1.0_f32, 0.0, 0.0, 0.0];
fwht_in_place(&mut data);
assert_eq!(data, [1.0, 1.0, 1.0, 1.0]);
}
#[test]
fn test_fwht_length_one() {
let mut data = [42.0_f32];
fwht_in_place(&mut data);
assert_eq!(data, [42.0]);
}
#[test]
fn test_fwht_all_ones() {
let mut data = [1.0_f32, 1.0, 1.0, 1.0];
fwht_in_place(&mut data);
assert_eq!(data, [4.0, 0.0, 0.0, 0.0]);
}
#[test]
fn test_normalized_self_inverse() {
let original = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut data = original.clone();
fwht_normalized(&mut data);
fwht_normalized(&mut data);
for (a, b) in original.iter().zip(data.iter()) {
assert!((a - b).abs() < 1e-5, "self-inverse failed: {} vs {}", a, b);
}
}
#[test]
fn test_normalized_preserves_norm() {
let data = vec![1.0_f32, 2.0, 3.0, 4.0];
let original_norm: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
let mut transformed = data.clone();
fwht_normalized(&mut transformed);
let transformed_norm: f32 = transformed.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(original_norm - transformed_norm).abs() < 1e-5,
"norm not preserved: {} vs {}",
original_norm,
transformed_norm
);
}
#[test]
fn test_wht_rotation_round_trip() {
let wht = WhtRotation::new(128, 42).unwrap();
let v: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
let rotated = wht.forward(&v);
let recovered = wht.inverse(&rotated);
for (a, b) in v.iter().zip(recovered.iter()) {
assert!((a - b).abs() < 1e-4, "round-trip failed: {} vs {}", a, b);
}
}
#[test]
fn test_wht_rotation_non_power_of_two_construction() {
let wht = WhtRotation::new(100, 42).unwrap();
assert_eq!(wht.dim(), 100);
assert_eq!(wht.padded_dim, 128);
let v: Vec<f32> = (0..100).map(|i| i as f32 * 0.01).collect();
let rotated = wht.forward(&v);
assert_eq!(rotated.len(), 100);
}
#[test]
fn test_wht_power_of_two_exact_round_trip() {
for dim in [2, 4, 8, 16, 32, 64, 128, 256] {
let wht = WhtRotation::new(dim, 42).unwrap();
let v: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
let rotated = wht.forward(&v);
let recovered = wht.inverse(&rotated);
for (a, b) in v.iter().zip(recovered.iter()) {
assert!(
(a - b).abs() < 1e-4,
"pow2 round-trip failed at dim={}: {} vs {}",
dim, a, b
);
}
}
}
#[test]
fn test_wht_deterministic() {
let wht1 = WhtRotation::new(64, 99).unwrap();
let wht2 = WhtRotation::new(64, 99).unwrap();
let v: Vec<f32> = (0..64).map(|i| i as f32).collect();
assert_eq!(wht1.forward(&v), wht2.forward(&v));
}
#[test]
fn test_wht_different_seeds_differ() {
let wht1 = WhtRotation::new(64, 1).unwrap();
let wht2 = WhtRotation::new(64, 2).unwrap();
let v: Vec<f32> = (0..64).map(|i| i as f32).collect();
assert_ne!(wht1.forward(&v), wht2.forward(&v));
}
#[test]
fn test_zero_dimension_error() {
assert!(matches!(WhtRotation::new(0, 42), Err(TurboQuantError::ZeroDimension)));
}
#[test]
fn test_memory_much_smaller_than_haar() {
let wht = WhtRotation::new(768, 42).unwrap();
let haar = crate::rotation::StoredRotation::new(768, 42).unwrap();
assert!(wht.size_bytes() < haar.size_bytes() / 100,
"WHT ({}) should be >100x smaller than Haar ({})",
wht.size_bytes(), haar.size_bytes()
);
}
#[test]
fn test_rotation_strategy_trait() {
use crate::traits::RotationStrategy;
let wht = WhtRotation::new(64, 42).unwrap();
let v: Vec<f32> = (0..64).map(|i| i as f32).collect();
let rotated = wht.rotate(&v);
let recovered = wht.rotate_inverse(&rotated);
for (a, b) in v.iter().zip(recovered.iter()) {
assert!((a - b).abs() < 1e-4, "trait round-trip: {} vs {}", a, b);
}
}
#[test]
fn test_apply_slice_interface() {
let wht = WhtRotation::new(32, 42).unwrap();
let v: Vec<f32> = (0..32).map(|i| i as f32).collect();
let mut out = Vec::new();
wht.apply_slice(&v, &mut out);
assert_eq!(out.len(), 32);
let mut recovered = Vec::new();
wht.apply_inverse_slice(&out, &mut recovered);
for (a, b) in v.iter().zip(recovered.iter()) {
assert!((a - b).abs() < 1e-4, "apply_slice round-trip: {} vs {}", a, b);
}
}
}