use std::cell::RefCell;
use rand::{SeedableRng, rngs::StdRng};
use rand_distr::{Distribution, Normal};
use wide::f32x8;
const ROTATION_STAGES: usize = 3;
const F32X8_LANES: usize = 8;
#[derive(Debug)]
pub struct RandomRotation {
pub dim: usize,
padded_dim: usize,
signs: Vec<Vec<f32>>,
}
impl RandomRotation {
pub fn new(dim: usize, seed: u64) -> Self {
let padded_dim = dim.max(1).next_power_of_two();
let mut rng = StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0f32, 1.0).expect("valid stddev");
let signs = (0..ROTATION_STAGES)
.map(|_| {
(0..padded_dim)
.map(|_| {
if normal.sample(&mut rng) >= 0.0 {
1.0f32
} else {
-1.0f32
}
})
.collect()
})
.collect();
RandomRotation {
dim,
padded_dim,
signs,
}
}
#[inline]
pub fn apply(&self, x: &[f32], out: &mut [f32]) {
debug_assert_eq!(x.len(), self.dim);
debug_assert_eq!(out.len(), self.dim);
let m = self.padded_dim;
let scale = 1.0 / (m as f32).sqrt();
SCRATCH.with(|cell| {
let mut buf = cell.borrow_mut();
buf.clear();
buf.resize(m, 0.0);
buf[..self.dim].copy_from_slice(x);
for stage_signs in &self.signs {
apply_signs(&mut buf, stage_signs);
walsh_hadamard(&mut buf);
scale_in_place(&mut buf, scale);
}
out.copy_from_slice(&buf[..self.dim]);
});
}
}
thread_local! {
static SCRATCH: RefCell<Vec<f32>> = const { RefCell::new(Vec::new()) };
}
#[inline]
fn apply_signs(buf: &mut [f32], signs: &[f32]) {
debug_assert_eq!(buf.len(), signs.len());
let n = buf.len();
let mut i = 0;
while i + F32X8_LANES <= n {
let b =
f32x8::from(<[f32; F32X8_LANES]>::try_from(&buf[i..i + F32X8_LANES]).expect("len-8"));
let s =
f32x8::from(<[f32; F32X8_LANES]>::try_from(&signs[i..i + F32X8_LANES]).expect("len-8"));
buf[i..i + F32X8_LANES].copy_from_slice(&(b * s).to_array());
i += F32X8_LANES;
}
while i < n {
buf[i] *= signs[i];
i += 1;
}
}
#[inline]
fn scale_in_place(buf: &mut [f32], scale: f32) {
let v = f32x8::splat(scale);
let n = buf.len();
let mut i = 0;
while i + F32X8_LANES <= n {
let b =
f32x8::from(<[f32; F32X8_LANES]>::try_from(&buf[i..i + F32X8_LANES]).expect("len-8"));
buf[i..i + F32X8_LANES].copy_from_slice(&(b * v).to_array());
i += F32X8_LANES;
}
while i < n {
buf[i] *= scale;
i += 1;
}
}
#[inline]
fn walsh_hadamard(a: &mut [f32]) {
let n = a.len();
debug_assert!(n.is_power_of_two());
let mut h = 1;
while h < n {
let mut i = 0;
while i < n {
if h >= F32X8_LANES {
let mut j = i;
while j < i + h {
let x = f32x8::from(
<[f32; F32X8_LANES]>::try_from(&a[j..j + F32X8_LANES]).expect("len-8"),
);
let y = f32x8::from(
<[f32; F32X8_LANES]>::try_from(&a[j + h..j + h + F32X8_LANES])
.expect("len-8"),
);
a[j..j + F32X8_LANES].copy_from_slice(&(x + y).to_array());
a[j + h..j + h + F32X8_LANES].copy_from_slice(&(x - y).to_array());
j += F32X8_LANES;
}
} else {
for j in i..i + h {
let x = a[j];
let y = a[j + h];
a[j] = x + y;
a[j + h] = x - y;
}
}
i += 2 * h;
}
h *= 2;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::superfile::vector::distance::dot;
fn approx(a: f32, b: f32, eps: f32) -> bool {
(a - b).abs() < eps
}
fn column(rot: &RandomRotation, i: usize) -> Vec<f32> {
let mut e = vec![0.0f32; rot.dim];
e[i] = 1.0;
let mut out = vec![0.0f32; rot.dim];
rot.apply(&e, &mut out);
out
}
#[test]
fn new_with_dim_8_succeeds() {
let r = RandomRotation::new(8, 42);
assert_eq!(r.dim, 8);
assert_eq!(r.padded_dim, 8);
assert_eq!(r.signs.len(), ROTATION_STAGES);
}
#[test]
fn new_with_realistic_dim_succeeds() {
for &dim in &[16, 64, 128, 384, 768, 1024] {
let r = RandomRotation::new(dim, 7);
assert_eq!(r.dim, dim);
assert!(r.padded_dim >= dim && r.padded_dim.is_power_of_two());
}
}
#[test]
fn columns_are_unit_vectors() {
let r = RandomRotation::new(64, 7);
for i in 0..r.dim {
let c = column(&r, i);
let mag_sq = dot(&c, &c);
assert!(approx(mag_sq, 1.0, 1e-4), "column {i} mag² = {mag_sq}");
}
}
#[test]
fn columns_are_pairwise_orthogonal() {
let r = RandomRotation::new(32, 11);
for i in 0..r.dim {
let ci = column(&r, i);
for j in (i + 1)..r.dim {
let cj = column(&r, j);
let p = dot(&ci, &cj);
assert!(approx(p, 0.0, 1e-4), "columns {i}, {j} dot = {p}");
}
}
}
#[test]
fn same_seed_yields_same_rotation() {
let r1 = RandomRotation::new(64, 12345);
let r2 = RandomRotation::new(64, 12345);
assert_eq!(r1.signs, r2.signs);
let x: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
let (mut a, mut b) = (vec![0.0; 64], vec![0.0; 64]);
r1.apply(&x, &mut a);
r2.apply(&x, &mut b);
assert_eq!(a, b);
}
#[test]
fn different_seed_yields_different_rotation() {
let r1 = RandomRotation::new(64, 1);
let r2 = RandomRotation::new(64, 2);
let x: Vec<f32> = (0..64).map(|i| i as f32 * 0.1 + 1.0).collect();
let (mut a, mut b) = (vec![0.0; 64], vec![0.0; 64]);
r1.apply(&x, &mut a);
r2.apply(&x, &mut b);
assert_ne!(a, b);
}
#[test]
fn apply_preserves_l2_norm() {
let r = RandomRotation::new(64, 42);
let mut x = vec![0.0f32; 64];
for (i, v) in x.iter_mut().enumerate() {
*v = (i as f32) * 0.1 - 1.5;
}
let mag_in = dot(&x, &x).sqrt();
let mut y = vec![0.0; 64];
r.apply(&x, &mut y);
let mag_out = dot(&y, &y).sqrt();
assert!(
approx(mag_in, mag_out, 1e-3),
"input |x| = {mag_in}, output |R·x| = {mag_out}"
);
}
#[test]
fn apply_zero_vector_yields_zero() {
let r = RandomRotation::new(32, 0xCAFE_F00D);
let x = vec![0.0; 32];
let mut y = vec![1.0; 32];
r.apply(&x, &mut y);
for &v in &y {
assert_eq!(v, 0.0);
}
}
#[test]
fn apply_preserves_inner_products() {
let r = RandomRotation::new(32, 7);
let x: Vec<f32> = (0..32).map(|i| (i as f32) * 0.3 - 4.0).collect();
let y: Vec<f32> = (0..32).map(|i| (i as f32) * -0.2 + 1.7).collect();
let mut rx = vec![0.0; 32];
let mut ry = vec![0.0; 32];
r.apply(&x, &mut rx);
r.apply(&y, &mut ry);
let inner_in = dot(&x, &y);
let inner_out = dot(&rx, &ry);
assert!(
approx(inner_in, inner_out, 1e-3),
"<x,y> = {inner_in}, <Rx,Ry> = {inner_out}"
);
}
#[test]
fn apply_is_linear() {
let r = RandomRotation::new(16, 99);
let x: Vec<f32> = (0..16).map(|i| i as f32).collect();
let y: Vec<f32> = (0..16).map(|i| (i as f32) * 0.5).collect();
let alpha = 2.5;
let mut rx = vec![0.0; 16];
let mut ry = vec![0.0; 16];
r.apply(&x, &mut rx);
r.apply(&y, &mut ry);
let combined: Vec<f32> = x.iter().zip(&y).map(|(a, b)| a + alpha * b).collect();
let mut r_combined = vec![0.0; 16];
r.apply(&combined, &mut r_combined);
for i in 0..16 {
let expected = rx[i] + alpha * ry[i];
assert!(
approx(r_combined[i], expected, 1e-3),
"linearity broken at i={i}: got {} expected {expected}",
r_combined[i]
);
}
}
}