use crate::error::{Result, TurboQuantError};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-support", derive(serde::Serialize, serde::Deserialize))]
pub struct StoredRotation {
dim: usize,
seed: u64,
data: Vec<f32>,
}
impl StoredRotation {
pub fn new(dim: usize, seed: u64) -> Result<Self> {
if dim == 0 {
return Err(TurboQuantError::ZeroDimension);
}
let data = Self::generate_rotation(dim, seed);
Ok(Self { dim, seed, data })
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn seed(&self) -> u64 {
self.seed
}
#[inline]
pub fn apply_slice(&self, slice: &[f32], out: &mut Vec<f32>) {
let d = self.dim;
debug_assert_eq!(slice.len(), d, "apply_slice: slice.len() must equal self.dim");
debug_assert_eq!(self.data.len(), d * d, "apply_slice: data.len() must equal dim*dim");
out.clear();
out.resize(d, 0.0);
if slice.len() != d {
return;
}
for (i, o) in out.iter_mut().enumerate().take(d) {
*o = self.data[i * d..i * d + d]
.iter()
.zip(slice.iter())
.map(|(m, v)| m * v)
.sum();
}
}
#[inline]
pub fn apply_inverse_slice(&self, slice: &[f32], out: &mut Vec<f32>) {
let d = self.dim;
debug_assert_eq!(slice.len(), d, "apply_inverse_slice: slice.len() must equal self.dim");
debug_assert_eq!(self.data.len(), d * d, "apply_inverse_slice: data.len() must equal dim*dim");
out.clear();
out.resize(d, 0.0);
if slice.len() != d {
return;
}
for (i, o) in out.iter_mut().enumerate().take(d) {
*o = slice
.iter()
.enumerate()
.take(d)
.map(|(j, &s)| self.data[j * d + i] * s)
.sum();
}
}
pub fn size_bytes(&self) -> usize {
self.data.len() * core::mem::size_of::<f32>()
}
#[cfg(feature = "std")]
pub fn matrix(&self) -> nalgebra::DMatrix<f32> {
nalgebra::DMatrix::from_row_slice(self.dim, self.dim, &self.data)
}
fn generate_rotation(dim: usize, seed: u64) -> Vec<f32> {
#[cfg(feature = "std")]
{
use nalgebra::DMatrix;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::StandardNormal;
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let dist = StandardNormal;
let raw: Vec<f32> = (0..dim * dim)
.map(|_| <StandardNormal as rand_distr::Distribution<f64>>::sample(&dist, &mut rng) as f32)
.collect();
let m = DMatrix::from_row_slice(dim, dim, &raw);
let qr = m.qr();
let q = qr.q();
let det_sign = if q.determinant() >= 0.0 { 1.0_f32 } else { -1.0_f32 };
let mut data = vec![0.0_f32; dim * dim];
for i in 0..dim {
for j in 0..dim {
data[i * dim + j] = q[(i, j)] * det_sign;
}
}
data
}
#[cfg(not(feature = "std"))]
{
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand::Rng;
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut data = vec![0.0_f32; dim * dim];
for i in 0..dim {
let sign: f32 = if rng.gen::<bool>() { 1.0 } else { -1.0 };
data[i * dim + i] = sign;
}
data
}
}
}
impl PartialEq for StoredRotation {
fn eq(&self, other: &Self) -> bool {
self.dim == other.dim && self.seed == other.seed
}
}
impl crate::traits::RotationStrategy for StoredRotation {
fn rotate(&self, vector: &[f32]) -> Vec<f32> {
let mut out = Vec::with_capacity(self.dim);
self.apply_slice(vector, &mut out);
out
}
fn rotate_inverse(&self, vector: &[f32]) -> Vec<f32> {
let mut out = Vec::with_capacity(self.dim);
self.apply_inverse_slice(vector, &mut out);
out
}
fn dim(&self) -> usize {
self.dim
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zero_dimension_error() {
assert!(matches!(StoredRotation::new(0, 42), Err(TurboQuantError::ZeroDimension)));
}
#[test]
fn test_basic_construction() {
let r = StoredRotation::new(4, 42).unwrap();
assert_eq!(r.dim(), 4);
assert_eq!(r.seed(), 42);
}
#[test]
fn test_partial_eq() {
let a = StoredRotation::new(4, 42).unwrap();
let b = StoredRotation::new(4, 42).unwrap();
let c = StoredRotation::new(4, 99).unwrap();
assert_eq!(a, b);
assert_ne!(a, c);
}
#[test]
fn test_round_trip_identity() {
let r = StoredRotation::new(4, 7).unwrap();
let v = vec![1.0_f32, 2.0, 3.0, 4.0];
let mut rotated = Vec::new();
let mut reconstructed = Vec::new();
r.apply_slice(&v, &mut rotated);
r.apply_inverse_slice(&rotated, &mut reconstructed);
for (a, b) in v.iter().zip(reconstructed.iter()) {
assert!((a - b).abs() < 1e-4, "round-trip failed: {a} vs {b}");
}
}
}