use nalgebra::{DMatrix, DVector};
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
use serde::{Deserialize, Serialize};
use crate::error::{Result, TurboQuantError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RandomRotation {
matrix: DMatrix<f64>,
matrix_t: DMatrix<f64>,
pub dim: usize,
}
impl RandomRotation {
pub fn new(dim: usize, seed: u64) -> Result<Self> {
if dim == 0 {
return Err(TurboQuantError::InvalidDimension(dim));
}
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).unwrap();
let data: Vec<f64> = (0..dim * dim).map(|_| normal.sample(&mut rng)).collect();
let g = DMatrix::from_vec(dim, dim, data);
let qr = g.qr();
let q = qr.q();
let r = qr.r();
let mut signs = vec![1.0f64; dim];
for i in 0..dim {
if r[(i, i)] < 0.0 {
signs[i] = -1.0;
}
}
let sign_diag = DMatrix::from_diagonal(&DVector::from_vec(signs));
let matrix = q * sign_diag;
let matrix_t = matrix.transpose();
Ok(Self {
matrix,
matrix_t,
dim,
})
}
pub fn rotate(&self, x: &[f64]) -> Result<Vec<f64>> {
if x.len() != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: x.len(),
});
}
let xv = DVector::from_vec(x.to_vec());
let yv = &self.matrix * xv;
Ok(yv.data.into())
}
pub fn rotate_inverse(&self, y: &[f64]) -> Result<Vec<f64>> {
if y.len() != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: y.len(),
});
}
let yv = DVector::from_vec(y.to_vec());
let xv = &self.matrix_t * yv;
Ok(xv.data.into())
}
pub fn matrix(&self) -> &DMatrix<f64> {
&self.matrix
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_rotation_orthogonality() {
let dim = 8;
let rot = RandomRotation::new(dim, 42).unwrap();
let prod = &rot.matrix * &rot.matrix_t;
for i in 0..dim {
for j in 0..dim {
let expected = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(prod[(i, j)], expected, epsilon = 1e-10);
}
}
}
#[test]
fn test_rotation_inverse() {
let dim = 16;
let rot = RandomRotation::new(dim, 123).unwrap();
let x: Vec<f64> = (0..dim).map(|i| i as f64 / dim as f64).collect();
let y = rot.rotate(&x).unwrap();
let x_recovered = rot.rotate_inverse(&y).unwrap();
for (a, b) in x.iter().zip(x_recovered.iter()) {
assert_abs_diff_eq!(a, b, epsilon = 1e-10);
}
}
#[test]
fn test_rotation_preserves_norm() {
use crate::utils::norm;
let dim = 32;
let rot = RandomRotation::new(dim, 7).unwrap();
let x: Vec<f64> = (0..dim).map(|i| (i as f64).sin()).collect();
let nx = norm(&x);
let y = rot.rotate(&x).unwrap();
let ny = norm(&y);
assert_abs_diff_eq!(nx, ny, epsilon = 1e-8);
}
#[test]
fn test_dimension_mismatch() {
let rot = RandomRotation::new(4, 1).unwrap();
let x = vec![1.0, 2.0]; assert!(rot.rotate(&x).is_err());
}
#[test]
fn test_invalid_dimension_zero() {
assert!(RandomRotation::new(0, 1).is_err());
}
#[test]
fn test_rotate_inverse_dimension_mismatch() {
let rot = RandomRotation::new(4, 1).unwrap();
let y = vec![1.0, 2.0, 3.0]; assert!(rot.rotate_inverse(&y).is_err());
}
}