1use crate::error::{Result, TurboQuantError};
13use nalgebra::DMatrix;
14use rand::{Rng, SeedableRng};
15use rand_chacha::ChaCha8Rng;
16use rand_distr::{Distribution, StandardNormal};
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
22pub enum RotationKind {
23 Auto,
25 FastHadamard,
27 StoredQr,
29}
30
31impl RotationKind {
32 pub fn label(self) -> &'static str {
33 match self {
34 Self::Auto => "auto",
35 Self::FastHadamard => "fast_hadamard",
36 Self::StoredQr => "stored_qr_reference",
37 }
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub enum RotationBackend {
44 FastHadamard(FastHadamardRotation),
45 StoredQr(StoredRotation),
46}
47
48impl RotationBackend {
49 pub fn new(dim: usize, seed: u64, kind: RotationKind) -> Result<Self> {
50 match kind {
51 RotationKind::Auto if dim.is_power_of_two() => {
52 FastHadamardRotation::new(dim, seed).map(Self::FastHadamard)
53 }
54 RotationKind::Auto => StoredRotation::new(dim, seed).map(Self::StoredQr),
55 RotationKind::FastHadamard => {
56 FastHadamardRotation::new(dim, seed).map(Self::FastHadamard)
57 }
58 RotationKind::StoredQr => StoredRotation::new(dim, seed).map(Self::StoredQr),
59 }
60 }
61
62 pub fn kind(&self) -> RotationKind {
63 match self {
64 Self::FastHadamard(_) => RotationKind::FastHadamard,
65 Self::StoredQr(_) => RotationKind::StoredQr,
66 }
67 }
68
69 pub fn kind_label(&self) -> &'static str {
70 self.kind().label()
71 }
72
73 pub fn seed(&self) -> u64 {
74 match self {
75 Self::FastHadamard(rotation) => rotation.seed(),
76 Self::StoredQr(rotation) => rotation.seed(),
77 }
78 }
79}
80
81impl Rotation for RotationBackend {
82 fn dim(&self) -> usize {
83 match self {
84 Self::FastHadamard(rotation) => rotation.dim(),
85 Self::StoredQr(rotation) => rotation.dim(),
86 }
87 }
88
89 fn apply(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
90 match self {
91 Self::FastHadamard(rotation) => rotation.apply(input, output),
92 Self::StoredQr(rotation) => rotation.apply(input, output),
93 }
94 }
95
96 fn apply_inverse(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
97 match self {
98 Self::FastHadamard(rotation) => rotation.apply_inverse(input, output),
99 Self::StoredQr(rotation) => rotation.apply_inverse(input, output),
100 }
101 }
102}
103
104impl RotationBackend {
105 pub fn apply_inverse_batch(&self, inputs: &[&[f32]]) -> Result<Vec<Vec<f32>>> {
111 match self {
112 Self::FastHadamard(rotation) => rotation.apply_inverse_batch(inputs),
113 Self::StoredQr(rotation) => rotation.apply_inverse_batch(inputs),
114 }
115 }
116}
117
118pub trait Rotation: Send + Sync {
120 fn dim(&self) -> usize;
122
123 fn apply(&self, input: &[f32], output: &mut [f32]) -> Result<()>;
127
128 fn apply_inverse(&self, input: &[f32], output: &mut [f32]) -> Result<()>;
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct FastHadamardRotation {
137 dim: usize,
138 seed: u64,
139 signs: Vec<f32>,
140}
141
142impl FastHadamardRotation {
143 pub fn new(dim: usize, seed: u64) -> Result<Self> {
144 if dim == 0 {
145 return Err(TurboQuantError::ZeroDimension);
146 }
147 if !dim.is_power_of_two() {
148 return Err(TurboQuantError::RotationFailed {
149 reason: format!("Hadamard rotation requires a power-of-two dimension, got {dim}"),
150 });
151 }
152 let mut rng = ChaCha8Rng::seed_from_u64(seed.wrapping_add(0xA11C_E55E_D5A5_EED5));
153 let signs = (0..dim)
154 .map(|_| if rng.gen::<bool>() { 1.0 } else { -1.0 })
155 .collect();
156 Ok(Self { dim, seed, signs })
157 }
158
159 pub fn seed(&self) -> u64 {
160 self.seed
161 }
162
163 pub fn apply_inverse_batch(&self, inputs: &[&[f32]]) -> Result<Vec<Vec<f32>>> {
171 if inputs.is_empty() {
172 return Ok(Vec::new());
173 }
174 let dim = self.dim;
175 let signs = &self.signs;
176 let mut outputs: Vec<Vec<f32>> = Vec::with_capacity(inputs.len());
177 for input in inputs {
178 if input.len() != dim {
179 return Err(TurboQuantError::DimensionMismatch {
180 expected: dim,
181 got: input.len(),
182 });
183 }
184 let mut out = vec![0.0f32; dim];
185 out.copy_from_slice(input);
187 fwht_normalized(&mut out);
188 for (out_val, sign) in out.iter_mut().zip(signs.iter()) {
189 *out_val *= *sign;
190 }
191 outputs.push(out);
192 }
193 Ok(outputs)
194 }
195}
196
197impl Rotation for FastHadamardRotation {
198 fn dim(&self) -> usize {
199 self.dim
200 }
201
202 fn apply(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
203 check_dim(input.len(), self.dim)?;
204 check_dim(output.len(), self.dim)?;
205 for ((out, value), sign) in output.iter_mut().zip(input.iter()).zip(self.signs.iter()) {
206 *out = value * sign;
207 }
208 fwht_normalized(output);
209 Ok(())
210 }
211
212 fn apply_inverse(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
213 check_dim(input.len(), self.dim)?;
214 check_dim(output.len(), self.dim)?;
215 output.copy_from_slice(input);
216 fwht_normalized(output);
217 for (out, sign) in output.iter_mut().zip(self.signs.iter()) {
218 *out *= sign;
219 }
220 Ok(())
221 }
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct StoredRotation {
232 dim: usize,
233 seed: u64,
234 #[serde(with = "matrix_serde")]
236 matrix: DMatrix<f32>,
237}
238
239impl StoredRotation {
240 pub fn new(dim: usize, seed: u64) -> Result<Self> {
245 if dim == 0 {
246 return Err(TurboQuantError::ZeroDimension);
247 }
248
249 let matrix = generate_orthogonal(dim, seed)?;
250 Ok(Self { dim, seed, matrix })
251 }
252
253 pub fn seed(&self) -> u64 {
255 self.seed
256 }
257
258 pub fn memory_bytes(&self) -> usize {
260 self.dim * self.dim * std::mem::size_of::<f32>()
261 }
262}
263
264impl Rotation for StoredRotation {
265 fn dim(&self) -> usize {
266 self.dim
267 }
268
269 fn apply(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
270 check_dim(input.len(), self.dim)?;
271 check_dim(output.len(), self.dim)?;
272
273 for (i, out) in output.iter_mut().enumerate() {
275 *out = self
276 .matrix
277 .row(i)
278 .iter()
279 .zip(input)
280 .map(|(r, x)| r * x)
281 .sum();
282 }
283 Ok(())
284 }
285
286 fn apply_inverse(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
287 check_dim(input.len(), self.dim)?;
288 check_dim(output.len(), self.dim)?;
289
290 for (i, out) in output.iter_mut().enumerate() {
292 *out = self
293 .matrix
294 .column(i)
295 .iter()
296 .zip(input)
297 .map(|(r, y)| r * y)
298 .sum();
299 }
300 Ok(())
301 }
302}
303
304impl StoredRotation {
305 pub fn apply_inverse_batch(&self, inputs: &[&[f32]]) -> Result<Vec<Vec<f32>>> {
310 if inputs.is_empty() {
311 return Ok(Vec::new());
312 }
313 let dim = self.dim;
314 let mut outputs: Vec<Vec<f32>> = Vec::with_capacity(inputs.len());
315 for input in inputs {
316 if input.len() != dim {
317 return Err(TurboQuantError::DimensionMismatch {
318 expected: dim,
319 got: input.len(),
320 });
321 }
322 let mut out = vec![0.0f32; dim];
323 for i in 0..dim {
324 out[i] = self
325 .matrix
326 .column(i)
327 .iter()
328 .zip(input.iter())
329 .map(|(r, y)| r * y)
330 .sum();
331 }
332 outputs.push(out);
333 }
334 Ok(outputs)
335 }
336}
337
338fn generate_orthogonal(dim: usize, seed: u64) -> Result<DMatrix<f32>> {
341 let mut rng = ChaCha8Rng::seed_from_u64(seed);
342 let dist = StandardNormal;
343
344 let data: Vec<f32> = (0..dim * dim).map(|_| dist.sample(&mut rng)).collect();
346
347 let m = DMatrix::from_vec(dim, dim, data);
349
350 let qr = m.qr();
351 let q = qr.q();
352
353 let r = qr.r();
356 let signs: Vec<f32> = (0..dim)
357 .map(|i| if r[(i, i)] >= 0.0 { 1.0 } else { -1.0 })
358 .collect();
359
360 let mut corrected = q;
361 for (j, &s) in signs.iter().enumerate() {
362 if s < 0.0 {
363 for i in 0..dim {
364 corrected[(i, j)] *= -1.0;
365 }
366 }
367 }
368
369 Ok(corrected)
370}
371
372fn check_dim(got: usize, expected: usize) -> Result<()> {
373 if got != expected {
374 return Err(TurboQuantError::DimensionMismatch { expected, got });
375 }
376 Ok(())
377}
378
379fn fwht_normalized(values: &mut [f32]) {
380 let n = values.len();
381 let mut step = 1;
382 while step < n {
383 let block = step * 2;
384 for start in (0..n).step_by(block) {
385 for offset in 0..step {
386 let a = values[start + offset];
387 let b = values[start + offset + step];
388 values[start + offset] = a + b;
389 values[start + offset + step] = a - b;
390 }
391 }
392 step = block;
393 }
394 let scale = (n as f32).sqrt().recip();
395 for value in values {
396 *value *= scale;
397 }
398}
399
400mod matrix_serde {
401 use nalgebra::DMatrix;
402 use serde::{Deserialize, Deserializer, Serialize, Serializer};
403
404 #[derive(Serialize, Deserialize)]
405 struct MatrixProxy {
406 rows: usize,
407 cols: usize,
408 data: Vec<f32>,
409 }
410
411 pub fn serialize<S: Serializer>(
412 m: &DMatrix<f32>,
413 s: S,
414 ) -> std::result::Result<S::Ok, S::Error> {
415 MatrixProxy {
416 rows: m.nrows(),
417 cols: m.ncols(),
418 data: m.as_slice().to_vec(),
419 }
420 .serialize(s)
421 }
422
423 pub fn deserialize<'de, D: Deserializer<'de>>(
424 d: D,
425 ) -> std::result::Result<DMatrix<f32>, D::Error> {
426 let p = MatrixProxy::deserialize(d)?;
427 Ok(DMatrix::from_vec(p.rows, p.cols, p.data))
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn rotation_is_deterministic_for_same_seed() {
437 let r1 = StoredRotation::new(8, 42).unwrap();
438 let r2 = StoredRotation::new(8, 42).unwrap();
439 assert_eq!(r1.matrix.as_slice(), r2.matrix.as_slice());
440 }
441
442 #[test]
443 fn rotation_differs_across_seeds() {
444 let r1 = StoredRotation::new(8, 1).unwrap();
445 let r2 = StoredRotation::new(8, 2).unwrap();
446 assert_ne!(r1.matrix.as_slice(), r2.matrix.as_slice());
447 }
448
449 #[test]
450 fn rotation_is_orthogonal_rrt_equals_identity() {
451 let r = StoredRotation::new(16, 7).unwrap();
452 let m = &r.matrix;
453 let product = m.transpose() * m;
454 for i in 0..16 {
455 for j in 0..16 {
456 let expected = if i == j { 1.0f32 } else { 0.0f32 };
457 let got = product[(i, j)];
458 assert!(
459 (got - expected).abs() < 1e-5,
460 "RᵀR[{i},{j}] = {got}, expected {expected}"
461 );
462 }
463 }
464 }
465
466 #[test]
467 fn apply_inverse_recovers_input() {
468 let r = StoredRotation::new(8, 99).unwrap();
469 let x = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
470 let mut y = vec![0.0f32; 8];
471 let mut recovered = vec![0.0f32; 8];
472
473 r.apply(&x, &mut y).unwrap();
474 r.apply_inverse(&y, &mut recovered).unwrap();
475
476 for (orig, rec) in x.iter().zip(recovered.iter()) {
477 assert!((orig - rec).abs() < 1e-5, "orig={orig}, recovered={rec}");
478 }
479 }
480
481 #[test]
482 fn rotation_preserves_inner_products() {
483 let r = StoredRotation::new(8, 13).unwrap();
485 let x = vec![1.0f32, 0.5, -1.0, 2.0, 0.1, -0.3, 1.5, 0.8];
486 let y = vec![0.2f32, -1.0, 0.5, 1.0, -0.5, 0.3, 0.9, -0.7];
487 let mut rx = vec![0.0f32; 8];
488 let mut ry = vec![0.0f32; 8];
489
490 r.apply(&x, &mut rx).unwrap();
491 r.apply(&y, &mut ry).unwrap();
492
493 let ip_original: f32 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
494 let ip_rotated: f32 = rx.iter().zip(ry.iter()).map(|(a, b)| a * b).sum();
495
496 assert!((ip_original - ip_rotated).abs() < 1e-4);
497 }
498
499 #[test]
500 fn zero_dimension_is_rejected() {
501 assert!(StoredRotation::new(0, 0).is_err());
502 }
503
504 #[test]
505 fn serialization_roundtrip() {
506 let r = StoredRotation::new(8, 55).unwrap();
507 let json = serde_json::to_string(&r).unwrap();
508 let restored: StoredRotation = serde_json::from_str(&json).unwrap();
509 assert_eq!(r.matrix.as_slice(), restored.matrix.as_slice());
510 }
511}