mnemonist_quant/
rotation.rs1use rand::SeedableRng;
13use rand::rngs::StdRng;
14use rand_distr::{Distribution, StandardNormal};
15
16#[derive(Debug, Clone)]
21pub struct Rotation {
22 dim: usize,
23 matrix: Vec<f32>,
25 seed: u64,
26}
27
28impl Rotation {
29 pub fn new(dim: usize, seed: u64) -> Self {
33 let matrix = generate_orthogonal(dim, seed);
34 Self { dim, matrix, seed }
35 }
36
37 pub fn dimension(&self) -> usize {
39 self.dim
40 }
41
42 pub fn seed(&self) -> u64 {
44 self.seed
45 }
46
47 pub fn forward(&self, x: &mut [f32]) {
49 debug_assert_eq!(x.len(), self.dim);
50 let mut result = vec![0.0f32; self.dim];
51 for (i, item) in result.iter_mut().enumerate().take(self.dim) {
52 let row = &self.matrix[i * self.dim..(i + 1) * self.dim];
53 *item = dot(row, x);
54 }
55 x.copy_from_slice(&result);
56 }
57
58 pub fn inverse(&self, y: &mut [f32]) {
62 debug_assert_eq!(y.len(), self.dim);
63 let mut result = vec![0.0f32; self.dim];
64 for (i, &yi) in y.iter().enumerate().take(self.dim) {
66 let row = &self.matrix[i * self.dim..(i + 1) * self.dim];
67 for j in 0..self.dim {
68 result[j] += row[j] * yi;
69 }
70 }
71 y.copy_from_slice(&result);
72 }
73}
74
75#[inline]
77fn dot(a: &[f32], b: &[f32]) -> f32 {
78 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
79}
80
81fn generate_orthogonal(dim: usize, seed: u64) -> Vec<f32> {
87 let mut rng = StdRng::seed_from_u64(seed);
88 let normal = StandardNormal;
89
90 let mut q = vec![0.0f32; dim * dim];
92 for val in &mut q {
93 *val = normal.sample(&mut rng);
94 }
95
96 for i in 0..dim {
98 let row_start = i * dim;
99
100 for j in 0..i {
102 let prev_start = j * dim;
103 let mut proj = 0.0f32;
104 for k in 0..dim {
105 proj += q[row_start + k] * q[prev_start + k];
106 }
107 for k in 0..dim {
108 q[row_start + k] -= proj * q[prev_start + k];
109 }
110 }
111
112 let mut norm = 0.0f32;
114 for k in 0..dim {
115 norm += q[row_start + k] * q[row_start + k];
116 }
117 let norm = norm.sqrt();
118 if norm > 0.0 {
119 for k in 0..dim {
120 q[row_start + k] /= norm;
121 }
122 }
123 }
124
125 q
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn orthogonality() {
134 let dim = 32;
135 let rot = Rotation::new(dim, 42);
136
137 for i in 0..dim {
139 for j in 0..dim {
140 let row_i = &rot.matrix[i * dim..(i + 1) * dim];
141 let row_j = &rot.matrix[j * dim..(j + 1) * dim];
142 let d = dot(row_i, row_j);
143 if i == j {
144 assert!(
145 (d - 1.0).abs() < 1e-4,
146 "diagonal [{i},{j}] = {d}, expected 1.0"
147 );
148 } else {
149 assert!(d.abs() < 1e-4, "off-diagonal [{i},{j}] = {d}, expected 0");
150 }
151 }
152 }
153 }
154
155 #[test]
156 fn deterministic_with_seed() {
157 let r1 = Rotation::new(16, 123);
158 let r2 = Rotation::new(16, 123);
159 assert_eq!(r1.matrix, r2.matrix);
160 }
161
162 #[test]
163 fn different_seeds_differ() {
164 let r1 = Rotation::new(16, 1);
165 let r2 = Rotation::new(16, 2);
166 assert_ne!(r1.matrix, r2.matrix);
167 }
168
169 #[test]
170 fn forward_inverse_roundtrip() {
171 let dim = 64;
172 let rot = Rotation::new(dim, 99);
173
174 let original: Vec<f32> = (0..dim).map(|i| (i as f32 + 0.5) / dim as f32).collect();
175 let mut v = original.clone();
176
177 rot.forward(&mut v);
178 assert!(
180 v.iter()
181 .zip(original.iter())
182 .any(|(a, b)| (a - b).abs() > 1e-4),
183 "rotation had no effect"
184 );
185
186 rot.inverse(&mut v);
187 for (a, b) in v.iter().zip(original.iter()) {
189 assert!(
190 (a - b).abs() < 1e-3,
191 "roundtrip failed: {a} vs {b} (diff={})",
192 (a - b).abs()
193 );
194 }
195 }
196
197 #[test]
198 fn preserves_norm() {
199 let dim = 64;
200 let rot = Rotation::new(dim, 7);
201
202 let mut v: Vec<f32> = (0..dim).map(|i| (i as f32).sin()).collect();
203 let norm_before: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
204
205 rot.forward(&mut v);
206 let norm_after: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
207
208 assert!(
209 (norm_before - norm_after).abs() < 1e-3,
210 "norm changed: {norm_before} → {norm_after}"
211 );
212 }
213}