use nalgebra::DMatrix;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
use serde::{Deserialize, Serialize};
use crate::{digest::json_digest, profile::MAX_ROTATION_MATRIX_VALUES, FibQuantError, Result};
pub const ROTATION_SCHEMA: &str = "fib_rotation_v1";
pub const ROTATION_ALGORITHM_VERSION: &str = "qr-gaussian-chacha8-sign-corrected-v1";
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct StoredRotation {
dim: usize,
seed: u64,
matrix: Vec<f64>,
}
impl StoredRotation {
pub fn new(dim: usize, seed: u64) -> Result<Self> {
if dim == 0 {
return Err(FibQuantError::ZeroDimension);
}
let matrix_values = dim.checked_mul(dim).ok_or_else(|| {
FibQuantError::ResourceLimitExceeded("rotation matrix value count overflow".into())
})?;
if matrix_values > MAX_ROTATION_MATRIX_VALUES {
return Err(FibQuantError::ResourceLimitExceeded(format!(
"rotation matrix values {matrix_values} exceed MAX_ROTATION_MATRIX_VALUES {MAX_ROTATION_MATRIX_VALUES}"
)));
}
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let data: Vec<f64> = (0..matrix_values)
.map(|_| StandardNormal.sample(&mut rng))
.collect();
let m = DMatrix::from_vec(dim, dim, data);
let qr = m.qr();
let mut q = qr.q();
let r = qr.r();
for j in 0..dim {
if r[(j, j)] < 0.0 {
for i in 0..dim {
q[(i, j)] *= -1.0;
}
}
}
let mut matrix = vec![0.0; matrix_values];
for row in 0..dim {
for col in 0..dim {
matrix[row * dim + col] = q[(row, col)];
}
}
Ok(Self { dim, seed, matrix })
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn seed(&self) -> u64 {
self.seed
}
pub fn rotation_schema(&self) -> &'static str {
ROTATION_SCHEMA
}
pub fn algorithm_version(&self) -> &'static str {
ROTATION_ALGORITHM_VERSION
}
pub fn digest(&self) -> Result<String> {
#[derive(Serialize)]
struct RotationDigestView<'a> {
rotation_schema: &'a str,
algorithm_version: &'a str,
dim: usize,
seed: u64,
matrix: &'a [f64],
}
json_digest(
ROTATION_SCHEMA,
&RotationDigestView {
rotation_schema: ROTATION_SCHEMA,
algorithm_version: ROTATION_ALGORITHM_VERSION,
dim: self.dim,
seed: self.seed,
matrix: &self.matrix,
},
)
}
pub fn apply(&self, input: &[f64]) -> Result<Vec<f64>> {
self.check_dim(input.len())?;
let mut out = vec![0.0; self.dim];
for (row, output) in out.iter_mut().enumerate().take(self.dim) {
*output = self.matrix[row * self.dim..(row + 1) * self.dim]
.iter()
.zip(input)
.map(|(a, b)| a * b)
.sum();
}
Ok(out)
}
pub fn apply_inverse(&self, input: &[f64]) -> Result<Vec<f64>> {
self.check_dim(input.len())?;
let mut out = vec![0.0; self.dim];
for (col, output) in out.iter_mut().enumerate().take(self.dim) {
let mut sum = 0.0;
for (row, value) in input.iter().enumerate().take(self.dim) {
sum += self.matrix[row * self.dim + col] * value;
}
*output = sum;
}
Ok(out)
}
fn check_dim(&self, got: usize) -> Result<()> {
if got != self.dim {
return Err(FibQuantError::CorruptPayload(format!(
"rotation expected dimension {}, got {got}",
self.dim
)));
}
Ok(())
}
}