use rand::SeedableRng;
use rand::rngs::StdRng;
use rand_distr::{Distribution, StandardNormal};
#[derive(Debug, Clone)]
pub struct Rotation {
dim: usize,
matrix: Vec<f32>,
seed: u64,
}
impl Rotation {
pub fn new(dim: usize, seed: u64) -> Self {
let matrix = generate_orthogonal(dim, seed);
Self { dim, matrix, seed }
}
pub fn dimension(&self) -> usize {
self.dim
}
pub fn seed(&self) -> u64 {
self.seed
}
pub fn forward(&self, x: &mut [f32]) {
debug_assert_eq!(x.len(), self.dim);
let mut result = vec![0.0f32; self.dim];
for (i, item) in result.iter_mut().enumerate().take(self.dim) {
let row = &self.matrix[i * self.dim..(i + 1) * self.dim];
*item = dot(row, x);
}
x.copy_from_slice(&result);
}
pub fn inverse(&self, y: &mut [f32]) {
debug_assert_eq!(y.len(), self.dim);
let mut result = vec![0.0f32; self.dim];
for (i, &yi) in y.iter().enumerate().take(self.dim) {
let row = &self.matrix[i * self.dim..(i + 1) * self.dim];
for j in 0..self.dim {
result[j] += row[j] * yi;
}
}
y.copy_from_slice(&result);
}
}
#[inline]
fn dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
fn generate_orthogonal(dim: usize, seed: u64) -> Vec<f32> {
let mut rng = StdRng::seed_from_u64(seed);
let normal = StandardNormal;
let mut q = vec![0.0f32; dim * dim];
for val in &mut q {
*val = normal.sample(&mut rng);
}
for i in 0..dim {
let row_start = i * dim;
for j in 0..i {
let prev_start = j * dim;
let mut proj = 0.0f32;
for k in 0..dim {
proj += q[row_start + k] * q[prev_start + k];
}
for k in 0..dim {
q[row_start + k] -= proj * q[prev_start + k];
}
}
let mut norm = 0.0f32;
for k in 0..dim {
norm += q[row_start + k] * q[row_start + k];
}
let norm = norm.sqrt();
if norm > 0.0 {
for k in 0..dim {
q[row_start + k] /= norm;
}
}
}
q
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn orthogonality() {
let dim = 32;
let rot = Rotation::new(dim, 42);
for i in 0..dim {
for j in 0..dim {
let row_i = &rot.matrix[i * dim..(i + 1) * dim];
let row_j = &rot.matrix[j * dim..(j + 1) * dim];
let d = dot(row_i, row_j);
if i == j {
assert!(
(d - 1.0).abs() < 1e-4,
"diagonal [{i},{j}] = {d}, expected 1.0"
);
} else {
assert!(d.abs() < 1e-4, "off-diagonal [{i},{j}] = {d}, expected 0");
}
}
}
}
#[test]
fn deterministic_with_seed() {
let r1 = Rotation::new(16, 123);
let r2 = Rotation::new(16, 123);
assert_eq!(r1.matrix, r2.matrix);
}
#[test]
fn different_seeds_differ() {
let r1 = Rotation::new(16, 1);
let r2 = Rotation::new(16, 2);
assert_ne!(r1.matrix, r2.matrix);
}
#[test]
fn forward_inverse_roundtrip() {
let dim = 64;
let rot = Rotation::new(dim, 99);
let original: Vec<f32> = (0..dim).map(|i| (i as f32 + 0.5) / dim as f32).collect();
let mut v = original.clone();
rot.forward(&mut v);
assert!(
v.iter()
.zip(original.iter())
.any(|(a, b)| (a - b).abs() > 1e-4),
"rotation had no effect"
);
rot.inverse(&mut v);
for (a, b) in v.iter().zip(original.iter()) {
assert!(
(a - b).abs() < 1e-3,
"roundtrip failed: {a} vs {b} (diff={})",
(a - b).abs()
);
}
}
#[test]
fn preserves_norm() {
let dim = 64;
let rot = Rotation::new(dim, 7);
let mut v: Vec<f32> = (0..dim).map(|i| (i as f32).sin()).collect();
let norm_before: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
rot.forward(&mut v);
let norm_after: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm_before - norm_after).abs() < 1e-3,
"norm changed: {norm_before} → {norm_after}"
);
}
}