Skip to main content

mnemonist_quant/
rotation.rs

1//! Random orthogonal rotation matrix for TurboQuant.
2//!
3//! After applying a random rotation Π to a unit vector x ∈ S^{d-1},
4//! each coordinate of Π·x follows a Beta(d/2, d/2) distribution (Lemma 1),
5//! and distinct coordinates become nearly independent in high dimensions.
6//! This enables per-coordinate scalar quantization.
7//!
8//! We generate Π via QR decomposition of a random Gaussian matrix.
9//! The rotation is seeded for reproducibility — the same seed must be used
10//! for quantization and dequantization.
11
12use rand::SeedableRng;
13use rand::rngs::StdRng;
14use rand_distr::{Distribution, StandardNormal};
15
16/// A random orthogonal rotation matrix.
17///
18/// Stores the full d×d matrix in row-major order.
19/// For typical embedding dimensions (384, 768, 1536), this is 0.6–9 MB.
20#[derive(Debug, Clone)]
21pub struct Rotation {
22    dim: usize,
23    /// Row-major d×d orthogonal matrix.
24    matrix: Vec<f32>,
25    seed: u64,
26}
27
28impl Rotation {
29    /// Generate a random orthogonal rotation matrix of the given dimension.
30    ///
31    /// Uses QR decomposition of a random Gaussian matrix with the given seed.
32    pub fn new(dim: usize, seed: u64) -> Self {
33        let matrix = generate_orthogonal(dim, seed);
34        Self { dim, matrix, seed }
35    }
36
37    /// The dimension of the rotation.
38    pub fn dimension(&self) -> usize {
39        self.dim
40    }
41
42    /// The seed used to generate this rotation.
43    pub fn seed(&self) -> u64 {
44        self.seed
45    }
46
47    /// Apply the forward rotation: y = Π · x (in-place).
48    pub fn forward(&self, x: &mut [f32]) {
49        debug_assert_eq!(x.len(), self.dim);
50        let mut result = vec![0.0f32; self.dim];
51        for (i, item) in result.iter_mut().enumerate().take(self.dim) {
52            let row = &self.matrix[i * self.dim..(i + 1) * self.dim];
53            *item = dot(row, x);
54        }
55        x.copy_from_slice(&result);
56    }
57
58    /// Apply the inverse rotation: x = Π^T · y (in-place).
59    ///
60    /// Since Π is orthogonal, Π^{-1} = Π^T.
61    pub fn inverse(&self, y: &mut [f32]) {
62        debug_assert_eq!(y.len(), self.dim);
63        let mut result = vec![0.0f32; self.dim];
64        // Π^T multiplication: result[j] = sum_i matrix[i][j] * y[i]
65        for (i, &yi) in y.iter().enumerate().take(self.dim) {
66            let row = &self.matrix[i * self.dim..(i + 1) * self.dim];
67            for j in 0..self.dim {
68                result[j] += row[j] * yi;
69            }
70        }
71        y.copy_from_slice(&result);
72    }
73}
74
75/// Dot product of two slices.
76#[inline]
77fn dot(a: &[f32], b: &[f32]) -> f32 {
78    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
79}
80
81/// Generate a random orthogonal matrix via QR decomposition of a Gaussian matrix.
82///
83/// Uses Gram-Schmidt orthogonalization (numerically sufficient for our use case
84/// since we only need the rotation to randomize coordinate distributions,
85/// not for high-precision numerical linear algebra).
86fn generate_orthogonal(dim: usize, seed: u64) -> Vec<f32> {
87    let mut rng = StdRng::seed_from_u64(seed);
88    let normal = StandardNormal;
89
90    // Generate random Gaussian matrix
91    let mut q = vec![0.0f32; dim * dim];
92    for val in &mut q {
93        *val = normal.sample(&mut rng);
94    }
95
96    // Modified Gram-Schmidt orthogonalization (row-wise)
97    for i in 0..dim {
98        let row_start = i * dim;
99
100        // Subtract projections of all previous rows
101        for j in 0..i {
102            let prev_start = j * dim;
103            let mut proj = 0.0f32;
104            for k in 0..dim {
105                proj += q[row_start + k] * q[prev_start + k];
106            }
107            for k in 0..dim {
108                q[row_start + k] -= proj * q[prev_start + k];
109            }
110        }
111
112        // Normalize
113        let mut norm = 0.0f32;
114        for k in 0..dim {
115            norm += q[row_start + k] * q[row_start + k];
116        }
117        let norm = norm.sqrt();
118        if norm > 0.0 {
119            for k in 0..dim {
120                q[row_start + k] /= norm;
121            }
122        }
123    }
124
125    q
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn orthogonality() {
134        let dim = 32;
135        let rot = Rotation::new(dim, 42);
136
137        // Check Q^T · Q ≈ I
138        for i in 0..dim {
139            for j in 0..dim {
140                let row_i = &rot.matrix[i * dim..(i + 1) * dim];
141                let row_j = &rot.matrix[j * dim..(j + 1) * dim];
142                let d = dot(row_i, row_j);
143                if i == j {
144                    assert!(
145                        (d - 1.0).abs() < 1e-4,
146                        "diagonal [{i},{j}] = {d}, expected 1.0"
147                    );
148                } else {
149                    assert!(d.abs() < 1e-4, "off-diagonal [{i},{j}] = {d}, expected 0");
150                }
151            }
152        }
153    }
154
155    #[test]
156    fn deterministic_with_seed() {
157        let r1 = Rotation::new(16, 123);
158        let r2 = Rotation::new(16, 123);
159        assert_eq!(r1.matrix, r2.matrix);
160    }
161
162    #[test]
163    fn different_seeds_differ() {
164        let r1 = Rotation::new(16, 1);
165        let r2 = Rotation::new(16, 2);
166        assert_ne!(r1.matrix, r2.matrix);
167    }
168
169    #[test]
170    fn forward_inverse_roundtrip() {
171        let dim = 64;
172        let rot = Rotation::new(dim, 99);
173
174        let original: Vec<f32> = (0..dim).map(|i| (i as f32 + 0.5) / dim as f32).collect();
175        let mut v = original.clone();
176
177        rot.forward(&mut v);
178        // Rotated vector should differ from original
179        assert!(
180            v.iter()
181                .zip(original.iter())
182                .any(|(a, b)| (a - b).abs() > 1e-4),
183            "rotation had no effect"
184        );
185
186        rot.inverse(&mut v);
187        // Should be back to original
188        for (a, b) in v.iter().zip(original.iter()) {
189            assert!(
190                (a - b).abs() < 1e-3,
191                "roundtrip failed: {a} vs {b} (diff={})",
192                (a - b).abs()
193            );
194        }
195    }
196
197    #[test]
198    fn preserves_norm() {
199        let dim = 64;
200        let rot = Rotation::new(dim, 7);
201
202        let mut v: Vec<f32> = (0..dim).map(|i| (i as f32).sin()).collect();
203        let norm_before: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
204
205        rot.forward(&mut v);
206        let norm_after: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
207
208        assert!(
209            (norm_before - norm_after).abs() < 1e-3,
210            "norm changed: {norm_before} → {norm_after}"
211        );
212    }
213}