use alloc::sync::Arc;
use alloc::vec;
use faer::Mat;
use libm::fabs;
use crate::codec::codec_config::CodecConfig;
use crate::codec::gaussian::ChaChaGaussianStream;
use crate::errors::CodecError;
#[derive(Clone, Debug)]
pub struct RotationMatrix {
matrix: Arc<[f64]>,
seed: u64,
dimension: u32,
}
impl RotationMatrix {
#[inline]
pub fn from_config(config: &CodecConfig) -> Self {
Self::build(config.seed(), config.dimension())
}
#[allow(clippy::indexing_slicing)] pub fn build(seed: u64, dimension: u32) -> Self {
assert!(
dimension > 0,
"RotationMatrix::build requires dimension > 0"
);
let dim = dimension as usize;
let mut data = vec![0.0f64; dim * dim];
let mut stream = ChaChaGaussianStream::new(seed);
for slot in &mut data {
*slot = stream.next_f64();
}
let a = Mat::<f64>::from_fn(dim, dim, |i, j| data[i * dim + j]);
let qr = a.qr();
let q = qr.compute_q();
let r = qr.compute_r();
let mut row_major = vec![0.0f64; dim * dim];
for j in 0..dim {
let diag = r[(j, j)];
let sign = if diag >= 0.0 { 1.0 } else { -1.0 };
for i in 0..dim {
row_major[i * dim + j] = q[(i, j)] * sign;
}
}
Self {
matrix: Arc::from(row_major.into_boxed_slice()),
seed,
dimension,
}
}
#[inline]
pub fn matrix(&self) -> &[f64] {
&self.matrix
}
#[inline]
pub const fn dimension(&self) -> u32 {
self.dimension
}
#[inline]
pub const fn seed(&self) -> u64 {
self.seed
}
pub fn apply_into(&self, input: &[f32], output: &mut [f32]) -> Result<(), CodecError> {
let dim = self.dimension as usize;
if input.len() != dim || output.len() != dim {
return Err(CodecError::LengthMismatch {
left: input.len(),
right: output.len(),
});
}
for (row, out_slot) in self.matrix.chunks_exact(dim).zip(output.iter_mut()) {
let acc: f64 = row
.iter()
.zip(input.iter())
.map(|(m, x)| m * f64::from(*x))
.sum();
#[allow(clippy::cast_possible_truncation)]
{
*out_slot = acc as f32;
}
}
Ok(())
}
pub fn apply_inverse_into(&self, input: &[f32], output: &mut [f32]) -> Result<(), CodecError> {
let dim = self.dimension as usize;
if input.len() != dim || output.len() != dim {
return Err(CodecError::LengthMismatch {
left: input.len(),
right: output.len(),
});
}
let mut scratch = alloc::vec![0.0f64; dim];
for (row, x) in self.matrix.chunks_exact(dim).zip(input.iter()) {
let xf = f64::from(*x);
for (scratch_slot, m) in scratch.iter_mut().zip(row.iter()) {
*scratch_slot += m * xf;
}
}
for (out_slot, value) in output.iter_mut().zip(scratch.iter()) {
#[allow(clippy::cast_possible_truncation)]
{
*out_slot = *value as f32;
}
}
Ok(())
}
pub fn verify_orthogonality(&self, tol: f64) -> bool {
let dim = self.dimension as usize;
for (i, row_i) in self.matrix.chunks_exact(dim).enumerate() {
for (j, row_j) in self.matrix.chunks_exact(dim).enumerate() {
let acc: f64 = row_i.iter().zip(row_j.iter()).map(|(a, b)| a * b).sum();
let expected = if i == j { 1.0 } else { 0.0 };
if fabs(acc - expected) > tol {
return false;
}
}
}
true
}
}