use crate::error::{Result, TurboQuantError};
use nalgebra::DMatrix;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
pub enum RotationKind {
Auto,
FastHadamard,
StoredQr,
}
impl RotationKind {
pub fn label(self) -> &'static str {
match self {
Self::Auto => "auto",
Self::FastHadamard => "fast_hadamard",
Self::StoredQr => "stored_qr_reference",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RotationBackend {
FastHadamard(FastHadamardRotation),
StoredQr(StoredRotation),
}
impl RotationBackend {
pub fn new(dim: usize, seed: u64, kind: RotationKind) -> Result<Self> {
match kind {
RotationKind::Auto if dim.is_power_of_two() => {
FastHadamardRotation::new(dim, seed).map(Self::FastHadamard)
}
RotationKind::Auto => StoredRotation::new(dim, seed).map(Self::StoredQr),
RotationKind::FastHadamard => {
FastHadamardRotation::new(dim, seed).map(Self::FastHadamard)
}
RotationKind::StoredQr => StoredRotation::new(dim, seed).map(Self::StoredQr),
}
}
pub fn kind(&self) -> RotationKind {
match self {
Self::FastHadamard(_) => RotationKind::FastHadamard,
Self::StoredQr(_) => RotationKind::StoredQr,
}
}
pub fn kind_label(&self) -> &'static str {
self.kind().label()
}
pub fn seed(&self) -> u64 {
match self {
Self::FastHadamard(rotation) => rotation.seed(),
Self::StoredQr(rotation) => rotation.seed(),
}
}
}
impl Rotation for RotationBackend {
fn dim(&self) -> usize {
match self {
Self::FastHadamard(rotation) => rotation.dim(),
Self::StoredQr(rotation) => rotation.dim(),
}
}
fn apply(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
match self {
Self::FastHadamard(rotation) => rotation.apply(input, output),
Self::StoredQr(rotation) => rotation.apply(input, output),
}
}
fn apply_inverse(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
match self {
Self::FastHadamard(rotation) => rotation.apply_inverse(input, output),
Self::StoredQr(rotation) => rotation.apply_inverse(input, output),
}
}
}
pub trait Rotation: Send + Sync {
fn dim(&self) -> usize;
fn apply(&self, input: &[f32], output: &mut [f32]) -> Result<()>;
fn apply_inverse(&self, input: &[f32], output: &mut [f32]) -> Result<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FastHadamardRotation {
dim: usize,
seed: u64,
signs: Vec<f32>,
}
impl FastHadamardRotation {
pub fn new(dim: usize, seed: u64) -> Result<Self> {
if dim == 0 {
return Err(TurboQuantError::ZeroDimension);
}
if !dim.is_power_of_two() {
return Err(TurboQuantError::RotationFailed {
reason: format!("Hadamard rotation requires a power-of-two dimension, got {dim}"),
});
}
let mut rng = ChaCha8Rng::seed_from_u64(seed.wrapping_add(0xA11C_E55E_D5A5_EED5));
let signs = (0..dim)
.map(|_| if rng.gen::<bool>() { 1.0 } else { -1.0 })
.collect();
Ok(Self { dim, seed, signs })
}
pub fn seed(&self) -> u64 {
self.seed
}
}
impl Rotation for FastHadamardRotation {
fn dim(&self) -> usize {
self.dim
}
fn apply(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
check_dim(input.len(), self.dim)?;
check_dim(output.len(), self.dim)?;
for ((out, value), sign) in output.iter_mut().zip(input.iter()).zip(self.signs.iter()) {
*out = value * sign;
}
fwht_normalized(output);
Ok(())
}
fn apply_inverse(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
check_dim(input.len(), self.dim)?;
check_dim(output.len(), self.dim)?;
output.copy_from_slice(input);
fwht_normalized(output);
for (out, sign) in output.iter_mut().zip(self.signs.iter()) {
*out *= sign;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredRotation {
dim: usize,
seed: u64,
#[serde(with = "matrix_serde")]
matrix: DMatrix<f32>,
}
impl StoredRotation {
pub fn new(dim: usize, seed: u64) -> Result<Self> {
if dim == 0 {
return Err(TurboQuantError::ZeroDimension);
}
let matrix = generate_orthogonal(dim, seed)?;
Ok(Self { dim, seed, matrix })
}
pub fn seed(&self) -> u64 {
self.seed
}
pub fn memory_bytes(&self) -> usize {
self.dim * self.dim * std::mem::size_of::<f32>()
}
}
impl Rotation for StoredRotation {
fn dim(&self) -> usize {
self.dim
}
fn apply(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
check_dim(input.len(), self.dim)?;
check_dim(output.len(), self.dim)?;
for (i, out) in output.iter_mut().enumerate() {
*out = self
.matrix
.row(i)
.iter()
.zip(input)
.map(|(r, x)| r * x)
.sum();
}
Ok(())
}
fn apply_inverse(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
check_dim(input.len(), self.dim)?;
check_dim(output.len(), self.dim)?;
for (i, out) in output.iter_mut().enumerate() {
*out = self
.matrix
.column(i)
.iter()
.zip(input)
.map(|(r, y)| r * y)
.sum();
}
Ok(())
}
}
fn generate_orthogonal(dim: usize, seed: u64) -> Result<DMatrix<f32>> {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let dist = StandardNormal;
let data: Vec<f32> = (0..dim * dim).map(|_| dist.sample(&mut rng)).collect();
let m = DMatrix::from_vec(dim, dim, data);
let qr = m.qr();
let q = qr.q();
let r = qr.r();
let signs: Vec<f32> = (0..dim)
.map(|i| if r[(i, i)] >= 0.0 { 1.0 } else { -1.0 })
.collect();
let mut corrected = q;
for (j, &s) in signs.iter().enumerate() {
if s < 0.0 {
for i in 0..dim {
corrected[(i, j)] *= -1.0;
}
}
}
Ok(corrected)
}
fn check_dim(got: usize, expected: usize) -> Result<()> {
if got != expected {
return Err(TurboQuantError::DimensionMismatch { expected, got });
}
Ok(())
}
fn fwht_normalized(values: &mut [f32]) {
let n = values.len();
let mut step = 1;
while step < n {
let block = step * 2;
for start in (0..n).step_by(block) {
for offset in 0..step {
let a = values[start + offset];
let b = values[start + offset + step];
values[start + offset] = a + b;
values[start + offset + step] = a - b;
}
}
step = block;
}
let scale = (n as f32).sqrt().recip();
for value in values {
*value *= scale;
}
}
mod matrix_serde {
use nalgebra::DMatrix;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Serialize, Deserialize)]
struct MatrixProxy {
rows: usize,
cols: usize,
data: Vec<f32>,
}
pub fn serialize<S: Serializer>(
m: &DMatrix<f32>,
s: S,
) -> std::result::Result<S::Ok, S::Error> {
MatrixProxy {
rows: m.nrows(),
cols: m.ncols(),
data: m.as_slice().to_vec(),
}
.serialize(s)
}
pub fn deserialize<'de, D: Deserializer<'de>>(
d: D,
) -> std::result::Result<DMatrix<f32>, D::Error> {
let p = MatrixProxy::deserialize(d)?;
Ok(DMatrix::from_vec(p.rows, p.cols, p.data))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rotation_is_deterministic_for_same_seed() {
let r1 = StoredRotation::new(8, 42).unwrap();
let r2 = StoredRotation::new(8, 42).unwrap();
assert_eq!(r1.matrix.as_slice(), r2.matrix.as_slice());
}
#[test]
fn rotation_differs_across_seeds() {
let r1 = StoredRotation::new(8, 1).unwrap();
let r2 = StoredRotation::new(8, 2).unwrap();
assert_ne!(r1.matrix.as_slice(), r2.matrix.as_slice());
}
#[test]
fn rotation_is_orthogonal_rrt_equals_identity() {
let r = StoredRotation::new(16, 7).unwrap();
let m = &r.matrix;
let product = m.transpose() * m;
for i in 0..16 {
for j in 0..16 {
let expected = if i == j { 1.0f32 } else { 0.0f32 };
let got = product[(i, j)];
assert!(
(got - expected).abs() < 1e-5,
"RᵀR[{i},{j}] = {got}, expected {expected}"
);
}
}
}
#[test]
fn apply_inverse_recovers_input() {
let r = StoredRotation::new(8, 99).unwrap();
let x = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut y = vec![0.0f32; 8];
let mut recovered = vec![0.0f32; 8];
r.apply(&x, &mut y).unwrap();
r.apply_inverse(&y, &mut recovered).unwrap();
for (orig, rec) in x.iter().zip(recovered.iter()) {
assert!((orig - rec).abs() < 1e-5, "orig={orig}, recovered={rec}");
}
}
#[test]
fn rotation_preserves_inner_products() {
let r = StoredRotation::new(8, 13).unwrap();
let x = vec![1.0f32, 0.5, -1.0, 2.0, 0.1, -0.3, 1.5, 0.8];
let y = vec![0.2f32, -1.0, 0.5, 1.0, -0.5, 0.3, 0.9, -0.7];
let mut rx = vec![0.0f32; 8];
let mut ry = vec![0.0f32; 8];
r.apply(&x, &mut rx).unwrap();
r.apply(&y, &mut ry).unwrap();
let ip_original: f32 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
let ip_rotated: f32 = rx.iter().zip(ry.iter()).map(|(a, b)| a * b).sum();
assert!((ip_original - ip_rotated).abs() < 1e-4);
}
#[test]
fn zero_dimension_is_rejected() {
assert!(StoredRotation::new(0, 0).is_err());
}
#[test]
fn serialization_roundtrip() {
let r = StoredRotation::new(8, 55).unwrap();
let json = serde_json::to_string(&r).unwrap();
let restored: StoredRotation = serde_json::from_str(&json).unwrap();
assert_eq!(r.matrix.as_slice(), restored.matrix.as_slice());
}
}