Skip to main content

fib_quant/
rotation.rs

1use nalgebra::DMatrix;
2use rand::SeedableRng;
3use rand_chacha::ChaCha8Rng;
4use rand_distr::{Distribution, StandardNormal};
5use serde::{Deserialize, Serialize};
6
7use crate::{digest::json_digest, profile::MAX_ROTATION_MATRIX_VALUES, FibQuantError, Result};
8
9/// Stable schema marker for deterministic stored rotations.
10pub const ROTATION_SCHEMA: &str = "fib_rotation_v1";
11/// Algorithm identity for the alpha QR/Gaussian rotation generator.
12pub const ROTATION_ALGORITHM_VERSION: &str = "qr-gaussian-chacha8-sign-corrected-v1";
13
14/// Stored deterministic orthogonal rotation.
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
16pub struct StoredRotation {
17    dim: usize,
18    seed: u64,
19    matrix: Vec<f64>,
20}
21
22impl StoredRotation {
23    /// Generate a Haar-like orthogonal matrix via QR decomposition.
24    pub fn new(dim: usize, seed: u64) -> Result<Self> {
25        if dim == 0 {
26            return Err(FibQuantError::ZeroDimension);
27        }
28        let matrix_values = dim.checked_mul(dim).ok_or_else(|| {
29            FibQuantError::ResourceLimitExceeded("rotation matrix value count overflow".into())
30        })?;
31        if matrix_values > MAX_ROTATION_MATRIX_VALUES {
32            return Err(FibQuantError::ResourceLimitExceeded(format!(
33                "rotation matrix values {matrix_values} exceed MAX_ROTATION_MATRIX_VALUES {MAX_ROTATION_MATRIX_VALUES}"
34            )));
35        }
36        let mut rng = ChaCha8Rng::seed_from_u64(seed);
37        let data: Vec<f64> = (0..matrix_values)
38            .map(|_| StandardNormal.sample(&mut rng))
39            .collect();
40        let m = DMatrix::from_vec(dim, dim, data);
41        let qr = m.qr();
42        let mut q = qr.q();
43        let r = qr.r();
44        for j in 0..dim {
45            if r[(j, j)] < 0.0 {
46                for i in 0..dim {
47                    q[(i, j)] *= -1.0;
48                }
49            }
50        }
51        let mut matrix = vec![0.0; matrix_values];
52        for row in 0..dim {
53            for col in 0..dim {
54                matrix[row * dim + col] = q[(row, col)];
55            }
56        }
57        Ok(Self { dim, seed, matrix })
58    }
59
60    /// Dimension of this rotation.
61    pub fn dim(&self) -> usize {
62        self.dim
63    }
64
65    /// Seed used for deterministic generation.
66    pub fn seed(&self) -> u64 {
67        self.seed
68    }
69
70    /// Stable rotation schema marker.
71    pub fn rotation_schema(&self) -> &'static str {
72        ROTATION_SCHEMA
73    }
74
75    /// Rotation algorithm identity.
76    pub fn algorithm_version(&self) -> &'static str {
77        ROTATION_ALGORITHM_VERSION
78    }
79
80    /// Deterministic digest over the rotation identity and matrix values.
81    pub fn digest(&self) -> Result<String> {
82        #[derive(Serialize)]
83        struct RotationDigestView<'a> {
84            rotation_schema: &'a str,
85            algorithm_version: &'a str,
86            dim: usize,
87            seed: u64,
88            matrix: &'a [f64],
89        }
90
91        json_digest(
92            ROTATION_SCHEMA,
93            &RotationDigestView {
94                rotation_schema: ROTATION_SCHEMA,
95                algorithm_version: ROTATION_ALGORITHM_VERSION,
96                dim: self.dim,
97                seed: self.seed,
98                matrix: &self.matrix,
99            },
100        )
101    }
102
103    /// Apply `y = Pi x`.
104    pub fn apply(&self, input: &[f64]) -> Result<Vec<f64>> {
105        self.check_dim(input.len())?;
106        let mut out = vec![0.0; self.dim];
107        for (row, output) in out.iter_mut().enumerate().take(self.dim) {
108            *output = self.matrix[row * self.dim..(row + 1) * self.dim]
109                .iter()
110                .zip(input)
111                .map(|(a, b)| a * b)
112                .sum();
113        }
114        Ok(out)
115    }
116
117    /// Apply inverse `x = Pi^T y`.
118    pub fn apply_inverse(&self, input: &[f64]) -> Result<Vec<f64>> {
119        self.check_dim(input.len())?;
120        let mut out = vec![0.0; self.dim];
121        for (col, output) in out.iter_mut().enumerate().take(self.dim) {
122            let mut sum = 0.0;
123            for (row, value) in input.iter().enumerate().take(self.dim) {
124                sum += self.matrix[row * self.dim + col] * value;
125            }
126            *output = sum;
127        }
128        Ok(out)
129    }
130
131    fn check_dim(&self, got: usize) -> Result<()> {
132        if got != self.dim {
133            return Err(FibQuantError::CorruptPayload(format!(
134                "rotation expected dimension {}, got {got}",
135                self.dim
136            )));
137        }
138        Ok(())
139    }
140}