1use rand::prelude::*;
2use rand_distr::{Distribution, Normal, Uniform};
3
4use crate::math::{dot, normalize};
5use crate::RabitqError;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9#[repr(u8)]
10pub enum RotatorType {
11 MatrixRotator = 0,
13 FhtKacRotator = 1,
15}
16
17impl RotatorType {
18 pub fn from_u8(value: u8) -> Option<Self> {
19 match value {
20 0 => Some(RotatorType::MatrixRotator),
21 1 => Some(RotatorType::FhtKacRotator),
22 _ => None,
23 }
24 }
25
26 pub fn padding_requirement(self, dim: usize) -> usize {
28 match self {
29 RotatorType::MatrixRotator => dim,
30 RotatorType::FhtKacRotator => round_up_to_multiple(dim, 64),
31 }
32 }
33}
34
35fn round_up_to_multiple(value: usize, multiple: usize) -> usize {
36 ((value + multiple - 1) / multiple) * multiple
37}
38
39pub trait Rotator: Send + Sync {
41 fn dim(&self) -> usize;
43
44 fn padded_dim(&self) -> usize;
46
47 fn rotate(&self, input: &[f32]) -> Vec<f32>;
49
50 fn rotate_into(&self, input: &[f32], output: &mut [f32]);
52
53 fn rotator_type(&self) -> RotatorType;
55
56 fn serialize(&self) -> Vec<u8>;
58
59 fn deserialize(dim: usize, padded_dim: usize, data: &[u8]) -> Result<Self, RabitqError>
61 where
62 Self: Sized;
63}
64
65#[derive(Debug, Clone)]
67pub struct MatrixRotator {
68 dim: usize,
69 padded_dim: usize,
70 matrix: Vec<f32>, }
72
73impl MatrixRotator {
74 pub fn new(dim: usize, seed: u64) -> Self {
76 let padded_dim = RotatorType::MatrixRotator.padding_requirement(dim);
77 let mut rng = StdRng::seed_from_u64(seed);
78 Self::with_rng(dim, padded_dim, &mut rng)
79 }
80
81 fn with_rng(dim: usize, padded_dim: usize, rng: &mut StdRng) -> Self {
82 let normal = Normal::new(0.0, 1.0).expect("failed to create normal distribution");
83 let mut basis: Vec<Vec<f32>> = Vec::with_capacity(padded_dim);
84
85 for _ in 0..padded_dim {
86 let mut vec: Vec<f32> = (0..padded_dim).map(|_| normal.sample(rng) as f32).collect();
87
88 for prev in &basis {
90 let proj = dot(&vec, prev);
91 for (v, p) in vec.iter_mut().zip(prev.iter()) {
92 *v -= proj * *p;
93 }
94 }
95
96 let mut attempts = 0;
97 loop {
98 let norm = normalize(&mut vec);
99 if norm > f32::EPSILON {
100 break;
101 }
102 attempts += 1;
103 if attempts > 8 {
104 vec.fill(0.0);
106 vec[attempts % padded_dim] = 1.0;
107 break;
108 }
109 for value in vec.iter_mut() {
110 *value = normal.sample(rng) as f32;
111 }
112 for prev in &basis {
113 let proj = dot(&vec, prev);
114 for (v, p) in vec.iter_mut().zip(prev.iter()) {
115 *v -= proj * *p;
116 }
117 }
118 }
119
120 basis.push(vec);
121 }
122
123 let mut matrix = Vec::with_capacity(padded_dim * padded_dim);
124 for row in basis {
125 matrix.extend_from_slice(&row);
126 }
127
128 Self {
129 dim,
130 padded_dim,
131 matrix,
132 }
133 }
134}
135
136impl Rotator for MatrixRotator {
137 fn dim(&self) -> usize {
138 self.dim
139 }
140
141 fn padded_dim(&self) -> usize {
142 self.padded_dim
143 }
144
145 fn rotate(&self, input: &[f32]) -> Vec<f32> {
146 assert_eq!(input.len(), self.dim);
147 let mut output = vec![0.0f32; self.padded_dim];
148 self.rotate_into(input, &mut output);
149 output
150 }
151
152 fn rotate_into(&self, input: &[f32], output: &mut [f32]) {
153 assert_eq!(input.len(), self.dim);
154 assert_eq!(output.len(), self.padded_dim);
155
156 let mut padded_input = vec![0.0f32; self.padded_dim];
158 padded_input[..self.dim].copy_from_slice(input);
159
160 for (row_idx, chunk) in self.matrix.chunks(self.padded_dim).enumerate() {
161 let mut acc = 0.0f32;
162 for (value, &weight) in padded_input.iter().zip(chunk.iter()) {
163 acc += value * weight;
164 }
165 output[row_idx] = acc;
166 }
167 }
168
169 fn rotator_type(&self) -> RotatorType {
170 RotatorType::MatrixRotator
171 }
172
173 fn serialize(&self) -> Vec<u8> {
174 let mut bytes = Vec::new();
175 for &value in &self.matrix {
176 bytes.extend_from_slice(&value.to_le_bytes());
177 }
178 bytes
179 }
180
181 fn deserialize(dim: usize, padded_dim: usize, data: &[u8]) -> Result<Self, RabitqError> {
182 let expected_len = padded_dim * padded_dim * 4; if data.len() != expected_len {
184 return Err(RabitqError::InvalidPersistence(
185 "rotator matrix length mismatch",
186 ));
187 }
188
189 let mut matrix = Vec::with_capacity(padded_dim * padded_dim);
190 for chunk in data.chunks_exact(4) {
191 let bytes: [u8; 4] = chunk.try_into().unwrap();
192 matrix.push(f32::from_le_bytes(bytes));
193 }
194
195 Ok(Self {
196 dim,
197 padded_dim,
198 matrix,
199 })
200 }
201}
202
203#[derive(Debug, Clone)]
206pub struct FhtKacRotator {
207 dim: usize,
208 padded_dim: usize,
209 flip: Vec<u8>, trunc_dim: usize,
211 fac: f32,
212}
213
214impl FhtKacRotator {
215 pub fn new(dim: usize, seed: u64) -> Self {
217 let padded_dim = RotatorType::FhtKacRotator.padding_requirement(dim);
218 assert_eq!(
219 padded_dim % 64,
220 0,
221 "FHT rotator requires dimension to be multiple of 64"
222 );
223
224 let mut rng = StdRng::seed_from_u64(seed);
225 let uniform = Uniform::new_inclusive(0u8, 255u8);
226
227 let flip_bytes = 4 * padded_dim / 8;
229 let flip: Vec<u8> = (0..flip_bytes).map(|_| uniform.sample(&mut rng)).collect();
230
231 let bottom_log_dim = floor_log2(dim);
233 let trunc_dim = 1 << bottom_log_dim;
234 let fac = 1.0 / (trunc_dim as f32).sqrt();
235
236 Self {
237 dim,
238 padded_dim,
239 flip,
240 trunc_dim,
241 fac,
242 }
243 }
244
245 fn flip_sign(data: &mut [f32], flip_bits: &[u8]) {
247 for (i, value) in data.iter_mut().enumerate() {
248 let byte_idx = i / 8;
249 let bit_idx = i % 8;
250 if byte_idx < flip_bits.len() {
251 let bit = (flip_bits[byte_idx] >> bit_idx) & 1;
252 if bit == 1 {
253 *value = -*value;
254 }
255 }
256 }
257 }
258
259 fn fht(data: &mut [f32]) {
261 let n = data.len();
262 assert!(
263 n.is_power_of_two(),
264 "FHT requires power-of-2 dimension, got {}",
265 n
266 );
267
268 let mut h = 1;
269 while h < n {
270 for i in (0..n).step_by(h * 2) {
271 for j in i..(i + h) {
272 let x = data[j];
273 let y = data[j + h];
274 data[j] = x + y;
275 data[j + h] = x - y;
276 }
277 }
278 h *= 2;
279 }
280 }
281
282 fn kacs_walk(data: &mut [f32]) {
284 let len = data.len();
285 let half = len / 2;
286 for i in 0..half {
287 let x = data[i];
288 let y = data[i + half];
289 data[i] = x + y;
290 data[i + half] = x - y;
291 }
292 }
293
294 fn rescale(data: &mut [f32], factor: f32) {
296 for value in data.iter_mut() {
297 *value *= factor;
298 }
299 }
300}
301
302impl Rotator for FhtKacRotator {
303 fn dim(&self) -> usize {
304 self.dim
305 }
306
307 fn padded_dim(&self) -> usize {
308 self.padded_dim
309 }
310
311 fn rotate(&self, input: &[f32]) -> Vec<f32> {
312 assert_eq!(input.len(), self.dim);
313 let mut output = vec![0.0f32; self.padded_dim];
314 self.rotate_into(input, &mut output);
315 output
316 }
317
318 fn rotate_into(&self, input: &[f32], output: &mut [f32]) {
319 assert_eq!(input.len(), self.dim);
320 assert_eq!(output.len(), self.padded_dim);
321
322 output[..self.dim].copy_from_slice(input);
324 output[self.dim..].fill(0.0);
325
326 let flip_offset = self.padded_dim / 8;
327
328 if self.trunc_dim == self.padded_dim {
329 for round in 0..4 {
332 let flip_start = round * flip_offset;
333 let flip_end = flip_start + flip_offset;
334 Self::flip_sign(output, &self.flip[flip_start..flip_end]);
335 Self::fht(output);
336 Self::rescale(output, self.fac);
337 }
338 } else {
339 let start = self.padded_dim - self.trunc_dim;
341
342 Self::flip_sign(output, &self.flip[0..flip_offset]);
344 Self::fht(&mut output[..self.trunc_dim]);
345 Self::rescale(&mut output[..self.trunc_dim], self.fac);
346 Self::kacs_walk(output);
347
348 Self::flip_sign(output, &self.flip[flip_offset..2 * flip_offset]);
350 Self::fht(&mut output[start..]);
351 Self::rescale(&mut output[start..], self.fac);
352 Self::kacs_walk(output);
353
354 Self::flip_sign(output, &self.flip[2 * flip_offset..3 * flip_offset]);
356 Self::fht(&mut output[..self.trunc_dim]);
357 Self::rescale(&mut output[..self.trunc_dim], self.fac);
358 Self::kacs_walk(output);
359
360 Self::flip_sign(output, &self.flip[3 * flip_offset..4 * flip_offset]);
362 Self::fht(&mut output[start..]);
363 Self::rescale(&mut output[start..], self.fac);
364 Self::kacs_walk(output);
365
366 Self::rescale(output, 0.25);
368 }
369 }
370
371 fn rotator_type(&self) -> RotatorType {
372 RotatorType::FhtKacRotator
373 }
374
375 fn serialize(&self) -> Vec<u8> {
376 self.flip.clone()
378 }
379
380 fn deserialize(dim: usize, padded_dim: usize, data: &[u8]) -> Result<Self, RabitqError> {
381 let expected_len = 4 * padded_dim / 8;
382 if data.len() != expected_len {
383 return Err(RabitqError::InvalidPersistence(
384 "FHT rotator flip bits length mismatch",
385 ));
386 }
387
388 let bottom_log_dim = floor_log2(dim);
389 let trunc_dim = 1 << bottom_log_dim;
390 let fac = 1.0 / (trunc_dim as f32).sqrt();
391
392 Ok(Self {
393 dim,
394 padded_dim,
395 flip: data.to_vec(),
396 trunc_dim,
397 fac,
398 })
399 }
400}
401
402fn floor_log2(x: usize) -> usize {
404 assert!(x > 0, "floor_log2 requires positive input");
405 (usize::BITS - 1 - x.leading_zeros()) as usize
406}
407
408#[derive(Debug, Clone)]
410pub enum DynamicRotator {
411 Matrix(MatrixRotator),
412 Fht(FhtKacRotator),
413}
414
415impl DynamicRotator {
416 pub fn new(dim: usize, rotator_type: RotatorType, seed: u64) -> Self {
418 match rotator_type {
419 RotatorType::MatrixRotator => DynamicRotator::Matrix(MatrixRotator::new(dim, seed)),
420 RotatorType::FhtKacRotator => DynamicRotator::Fht(FhtKacRotator::new(dim, seed)),
421 }
422 }
423
424 pub fn dim(&self) -> usize {
425 match self {
426 DynamicRotator::Matrix(r) => r.dim(),
427 DynamicRotator::Fht(r) => r.dim(),
428 }
429 }
430
431 pub fn padded_dim(&self) -> usize {
432 match self {
433 DynamicRotator::Matrix(r) => r.padded_dim(),
434 DynamicRotator::Fht(r) => r.padded_dim(),
435 }
436 }
437
438 pub fn rotate(&self, input: &[f32]) -> Vec<f32> {
439 match self {
440 DynamicRotator::Matrix(r) => r.rotate(input),
441 DynamicRotator::Fht(r) => r.rotate(input),
442 }
443 }
444
445 pub fn rotate_into(&self, input: &[f32], output: &mut [f32]) {
446 match self {
447 DynamicRotator::Matrix(r) => r.rotate_into(input, output),
448 DynamicRotator::Fht(r) => r.rotate_into(input, output),
449 }
450 }
451
452 pub fn rotator_type(&self) -> RotatorType {
453 match self {
454 DynamicRotator::Matrix(r) => r.rotator_type(),
455 DynamicRotator::Fht(r) => r.rotator_type(),
456 }
457 }
458
459 pub fn serialize(&self) -> Vec<u8> {
460 match self {
461 DynamicRotator::Matrix(r) => r.serialize(),
462 DynamicRotator::Fht(r) => r.serialize(),
463 }
464 }
465
466 pub fn deserialize(
467 dim: usize,
468 padded_dim: usize,
469 rotator_type: RotatorType,
470 data: &[u8],
471 ) -> Result<Self, RabitqError> {
472 match rotator_type {
473 RotatorType::MatrixRotator => Ok(DynamicRotator::Matrix(MatrixRotator::deserialize(
474 dim, padded_dim, data,
475 )?)),
476 RotatorType::FhtKacRotator => Ok(DynamicRotator::Fht(FhtKacRotator::deserialize(
477 dim, padded_dim, data,
478 )?)),
479 }
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn test_floor_log2() {
489 assert_eq!(floor_log2(1), 0);
490 assert_eq!(floor_log2(2), 1);
491 assert_eq!(floor_log2(3), 1);
492 assert_eq!(floor_log2(4), 2);
493 assert_eq!(floor_log2(7), 2);
494 assert_eq!(floor_log2(8), 3);
495 assert_eq!(floor_log2(960), 9);
496 }
497
498 #[test]
499 fn test_fht_basic() {
500 let mut data = vec![1.0, 2.0, 3.0, 4.0];
501 FhtKacRotator::fht(&mut data);
502 FhtKacRotator::fht(&mut data);
504 for (i, &val) in data.iter().enumerate() {
505 assert!((val - (i + 1) as f32 * 4.0).abs() < 1e-5);
506 }
507 }
508
509 #[test]
510 fn test_matrix_rotator_orthogonality() {
511 let dim = 16;
512 let rotator = MatrixRotator::new(dim, 12345);
513
514 let input = vec![1.0; dim];
515 let output = rotator.rotate(&input);
516
517 assert_eq!(output.len(), dim);
518
519 let input_norm: f32 = input.iter().map(|x| x * x).sum::<f32>().sqrt();
521 let output_norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
522 assert!((input_norm - output_norm).abs() < 1e-4);
523 }
524
525 #[test]
526 fn test_fht_rotator_basic() {
527 let dim = 64;
528 let rotator = FhtKacRotator::new(dim, 54321);
529
530 assert_eq!(rotator.dim(), dim);
531 assert_eq!(rotator.padded_dim(), 64);
532
533 let input = vec![1.0; dim];
534 let output = rotator.rotate(&input);
535
536 assert_eq!(output.len(), 64);
537 }
538
539 #[test]
540 fn test_fht_rotator_non_power_of_2() {
541 let dim = 960; let rotator = FhtKacRotator::new(dim, 98765);
543
544 assert_eq!(rotator.dim(), dim);
545 assert_eq!(rotator.padded_dim(), 960); let input = vec![1.0; dim];
548 let output = rotator.rotate(&input);
549
550 assert_eq!(output.len(), 960);
551 }
552
553 #[test]
554 fn test_rotator_serialization() {
555 let dim = 128;
556
557 let matrix_rot = MatrixRotator::new(dim, 11111);
559 let matrix_bytes = matrix_rot.serialize();
560 let matrix_rot2 = MatrixRotator::deserialize(dim, dim, &matrix_bytes).unwrap();
561
562 let input = [1.0, 2.0, 3.0]
563 .iter()
564 .cycle()
565 .take(dim)
566 .copied()
567 .collect::<Vec<_>>();
568 let out1 = matrix_rot.rotate(&input);
569 let out2 = matrix_rot2.rotate(&input);
570
571 for (a, b) in out1.iter().zip(out2.iter()) {
572 assert!((a - b).abs() < 1e-6);
573 }
574
575 let fht_rot = FhtKacRotator::new(dim, 22222);
577 let fht_bytes = fht_rot.serialize();
578 let fht_rot2 = FhtKacRotator::deserialize(dim, 128, &fht_bytes).unwrap();
579
580 let out3 = fht_rot.rotate(&input);
581 let out4 = fht_rot2.rotate(&input);
582
583 for (a, b) in out3.iter().zip(out4.iter()) {
584 assert!((a - b).abs() < 1e-6);
585 }
586 }
587}