1use nalgebra::DMatrix;
2use rand::SeedableRng;
3use rand_chacha::ChaCha8Rng;
4use rand_distr::{Distribution, StandardNormal};
5use serde::{Deserialize, Serialize};
6
7use crate::{digest::json_digest, profile::MAX_ROTATION_MATRIX_VALUES, FibQuantError, Result};
8
9pub const ROTATION_SCHEMA: &str = "fib_rotation_v1";
11pub const ROTATION_ALGORITHM_VERSION: &str = "qr-gaussian-chacha8-sign-corrected-v1";
13
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
16pub struct StoredRotation {
17 dim: usize,
18 seed: u64,
19 matrix: Vec<f64>,
20}
21
22impl StoredRotation {
23 pub fn new(dim: usize, seed: u64) -> Result<Self> {
25 if dim == 0 {
26 return Err(FibQuantError::ZeroDimension);
27 }
28 let matrix_values = dim.checked_mul(dim).ok_or_else(|| {
29 FibQuantError::ResourceLimitExceeded("rotation matrix value count overflow".into())
30 })?;
31 if matrix_values > MAX_ROTATION_MATRIX_VALUES {
32 return Err(FibQuantError::ResourceLimitExceeded(format!(
33 "rotation matrix values {matrix_values} exceed MAX_ROTATION_MATRIX_VALUES {MAX_ROTATION_MATRIX_VALUES}"
34 )));
35 }
36 let mut rng = ChaCha8Rng::seed_from_u64(seed);
37 let data: Vec<f64> = (0..matrix_values)
38 .map(|_| StandardNormal.sample(&mut rng))
39 .collect();
40 let m = DMatrix::from_vec(dim, dim, data);
41 let qr = m.qr();
42 let mut q = qr.q();
43 let r = qr.r();
44 for j in 0..dim {
45 if r[(j, j)] < 0.0 {
46 for i in 0..dim {
47 q[(i, j)] *= -1.0;
48 }
49 }
50 }
51 let mut matrix = vec![0.0; matrix_values];
52 for row in 0..dim {
53 for col in 0..dim {
54 matrix[row * dim + col] = q[(row, col)];
55 }
56 }
57 Ok(Self { dim, seed, matrix })
58 }
59
60 pub fn dim(&self) -> usize {
62 self.dim
63 }
64
65 pub fn seed(&self) -> u64 {
67 self.seed
68 }
69
70 pub fn rotation_schema(&self) -> &'static str {
72 ROTATION_SCHEMA
73 }
74
75 pub fn algorithm_version(&self) -> &'static str {
77 ROTATION_ALGORITHM_VERSION
78 }
79
80 pub fn digest(&self) -> Result<String> {
82 #[derive(Serialize)]
83 struct RotationDigestView<'a> {
84 rotation_schema: &'a str,
85 algorithm_version: &'a str,
86 dim: usize,
87 seed: u64,
88 matrix: &'a [f64],
89 }
90
91 json_digest(
92 ROTATION_SCHEMA,
93 &RotationDigestView {
94 rotation_schema: ROTATION_SCHEMA,
95 algorithm_version: ROTATION_ALGORITHM_VERSION,
96 dim: self.dim,
97 seed: self.seed,
98 matrix: &self.matrix,
99 },
100 )
101 }
102
103 pub fn apply(&self, input: &[f64]) -> Result<Vec<f64>> {
105 self.check_dim(input.len())?;
106 let mut out = vec![0.0; self.dim];
107 for (row, output) in out.iter_mut().enumerate().take(self.dim) {
108 *output = self.matrix[row * self.dim..(row + 1) * self.dim]
109 .iter()
110 .zip(input)
111 .map(|(a, b)| a * b)
112 .sum();
113 }
114 Ok(out)
115 }
116
117 pub fn apply_inverse(&self, input: &[f64]) -> Result<Vec<f64>> {
119 self.check_dim(input.len())?;
120 let mut out = vec![0.0; self.dim];
121 for (col, output) in out.iter_mut().enumerate().take(self.dim) {
122 let mut sum = 0.0;
123 for (row, value) in input.iter().enumerate().take(self.dim) {
124 sum += self.matrix[row * self.dim + col] * value;
125 }
126 *output = sum;
127 }
128 Ok(out)
129 }
130
131 fn check_dim(&self, got: usize) -> Result<()> {
132 if got != self.dim {
133 return Err(FibQuantError::CorruptPayload(format!(
134 "rotation expected dimension {}, got {got}",
135 self.dim
136 )));
137 }
138 Ok(())
139 }
140}