use rand::{Rng, SeedableRng};
use rand_distr::{Distribution, StandardNormal};
#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum RandomRotationKind {
HaarDense,
HadamardSigned,
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
enum Mode {
HaarDense { matrix: Vec<f32> },
HadamardSigned {
signs: [Vec<f32>; 3],
padded_dim: usize,
},
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
pub struct RandomRotation {
mode: Mode,
pub dim: usize,
#[serde(default)]
pub matrix: Vec<f32>,
}
impl RandomRotation {
pub fn random(dim: usize, seed: u64) -> Self {
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mut m: Vec<Vec<f32>> = (0..dim)
.map(|_| {
(0..dim)
.map(|_| {
<StandardNormal as Distribution<f64>>::sample(&StandardNormal, &mut rng)
as f32
})
.collect()
})
.collect();
for i in 0..dim {
for j in 0..i {
let dot: f32 = (0..dim).map(|k| m[i][k] * m[j][k]).sum();
for k in 0..dim {
let v = m[j][k];
m[i][k] -= dot * v;
}
}
let norm: f32 = m[i].iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
m[i].iter_mut().for_each(|x| *x /= norm);
}
}
let matrix: Vec<f32> = m.into_iter().flatten().collect();
Self {
mode: Mode::HaarDense {
matrix: matrix.clone(),
},
dim,
matrix,
}
}
pub fn hadamard(dim: usize, seed: u64) -> Self {
assert!(dim > 0, "RandomRotation::hadamard: dim must be > 0");
let padded_dim = dim.next_power_of_two();
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let make_signs = |rng: &mut rand::rngs::StdRng| -> Vec<f32> {
(0..padded_dim)
.map(|_| if rng.gen::<bool>() { 1.0_f32 } else { -1.0_f32 })
.collect()
};
let signs = [
make_signs(&mut rng),
make_signs(&mut rng),
make_signs(&mut rng),
];
Self {
mode: Mode::HadamardSigned { signs, padded_dim },
dim,
matrix: Vec::new(),
}
}
#[inline]
pub fn kind(&self) -> RandomRotationKind {
match &self.mode {
Mode::HaarDense { .. } => RandomRotationKind::HaarDense,
Mode::HadamardSigned { .. } => RandomRotationKind::HadamardSigned,
}
}
#[inline]
pub fn apply(&self, v: &[f32]) -> Vec<f32> {
debug_assert_eq!(v.len(), self.dim);
let mut out = vec![0.0f32; self.dim];
self.apply_into(v, &mut out);
out
}
#[inline]
pub fn apply_into(&self, v: &[f32], out: &mut [f32]) {
debug_assert_eq!(v.len(), self.dim);
debug_assert_eq!(out.len(), self.dim);
match &self.mode {
Mode::HaarDense { matrix } => {
let d = self.dim;
for (i, out_i) in out.iter_mut().enumerate() {
let row = &matrix[i * d..(i + 1) * d];
*out_i = row.iter().zip(v.iter()).map(|(&r, &x)| r * x).sum();
}
}
Mode::HadamardSigned { signs, padded_dim } => {
let mut buf = vec![0.0_f32; *padded_dim];
buf[..self.dim].copy_from_slice(v);
for (b, s) in buf.iter_mut().zip(signs[2].iter()) {
*b *= *s;
}
fwht_inplace(&mut buf);
for (b, s) in buf.iter_mut().zip(signs[1].iter()) {
*b *= *s;
}
fwht_inplace(&mut buf);
for (b, s) in buf.iter_mut().zip(signs[0].iter()) {
*b *= *s;
}
let scale = 1.0_f32 / (*padded_dim as f32);
for (o, b) in out.iter_mut().zip(buf.iter().take(self.dim)) {
*o = b * scale;
}
}
}
}
pub fn bytes(&self) -> usize {
match &self.mode {
Mode::HaarDense { matrix } => matrix.len() * 4,
Mode::HadamardSigned { signs, .. } => signs.iter().map(|s| s.len() * 4).sum::<usize>(),
}
}
}
pub fn normalize_inplace(v: &mut [f32]) {
let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
v.iter_mut().for_each(|x| *x /= norm);
}
}
#[inline]
fn fwht_inplace(buf: &mut [f32]) {
let n = buf.len();
debug_assert!(n.is_power_of_two(), "FWHT requires power-of-two length");
let mut h = 1;
while h < n {
let mut i = 0;
while i < n {
for j in i..(i + h) {
let x = buf[j];
let y = buf[j + h];
buf[j] = x + y;
buf[j + h] = x - y;
}
i += h * 2;
}
h *= 2;
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand_distr::StandardNormal;
#[test]
fn orthogonality_all_pairs_d64() {
check_orthonormal(64, 42, 1e-4);
}
#[test]
fn orthogonality_all_pairs_d128() {
check_orthonormal(128, 7, 1e-4);
}
#[test]
fn orthogonality_all_pairs_d256() {
check_orthonormal(256, 11, 1e-3);
}
fn check_orthonormal(dim: usize, seed: u64, tol: f32) {
let rot = RandomRotation::random(dim, seed);
let d = rot.dim;
for i in 0..d {
let ri = &rot.matrix[i * d..(i + 1) * d];
let ni: f32 = ri.iter().map(|&x| x * x).sum::<f32>().sqrt();
assert!((ni - 1.0).abs() < tol, "row {i} norm = {ni}, D={d}");
for j in (i + 1)..d {
let rj = &rot.matrix[j * d..(j + 1) * d];
let dot: f32 = ri.iter().zip(rj.iter()).map(|(&a, &b)| a * b).sum();
assert!(dot.abs() < tol, "rows {i},{j} dot={dot}, D={d}");
}
}
}
#[test]
fn apply_preserves_norm() {
let rot = RandomRotation::random(128, 7);
let v: Vec<f32> = (0..128_u32).map(|i| (i as f32).sin()).collect();
let rv = rot.apply(&v);
let norm_in: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
let norm_out: f32 = rv.iter().map(|&x| x * x).sum::<f32>().sqrt();
assert!((norm_in - norm_out).abs() / norm_in < 1e-3);
}
#[test]
fn seed_reproducibility() {
let a = RandomRotation::random(64, 1234);
let b = RandomRotation::random(64, 1234);
assert_eq!(a.matrix, b.matrix);
}
fn random_unit_vecs(dim: usize, n: usize, seed: u64) -> Vec<Vec<f32>> {
let mut rng = StdRng::seed_from_u64(seed);
(0..n)
.map(|_| {
let mut v: Vec<f32> = (0..dim)
.map(|_| {
<StandardNormal as Distribution<f64>>::sample(&StandardNormal, &mut rng)
as f32
})
.collect();
normalize_inplace(&mut v);
v
})
.collect()
}
fn hadamard_norm_check(dim: usize, seed: u64) {
let rot = RandomRotation::hadamard(dim, seed);
assert_eq!(rot.kind(), RandomRotationKind::HadamardSigned);
let vecs = random_unit_vecs(dim, 100, seed ^ 0xDEAD_BEEF);
for v in &vecs {
let rv = rot.apply(v);
let n: f32 = rv.iter().map(|&x| x * x).sum::<f32>().sqrt();
assert!(
(0.95..=1.05).contains(&n),
"D={dim}: rotated unit vector has norm {n}",
);
}
}
#[test]
fn hadamard_apply_preserves_norm_power_of_two() {
hadamard_norm_check(128, 7);
hadamard_norm_check(256, 11);
}
#[test]
fn hadamard_apply_preserves_norm_non_power_of_two() {
hadamard_norm_check(1000, 3);
}
#[test]
fn hadamard_is_deterministic() {
let a = RandomRotation::hadamard(128, 0xC0FFEE);
let b = RandomRotation::hadamard(128, 0xC0FFEE);
let v: Vec<f32> = (0..128_u32).map(|i| (i as f32).cos()).collect();
assert_eq!(a.apply(&v), b.apply(&v));
let c = RandomRotation::hadamard(128, 0xC0FFEE + 1);
assert_ne!(a.apply(&v), c.apply(&v));
}
#[test]
fn hadamard_is_fast() {
let mut buf = vec![1.0_f32; 8];
fwht_inplace(&mut buf);
assert!((buf[0] - 8.0).abs() < 1e-6);
for v in &buf[1..] {
assert!(v.abs() < 1e-6);
}
let had = RandomRotation::hadamard(128, 1);
let haar = RandomRotation::random(128, 1);
assert!(had.bytes() < haar.bytes() / 10);
assert_eq!(had.kind(), RandomRotationKind::HadamardSigned);
assert_eq!(haar.kind(), RandomRotationKind::HaarDense);
}
}