rabitq_rs/
rotation.rs

1use rand::prelude::*;
2use rand_distr::{Distribution, Normal, Uniform};
3
4use crate::math::{dot, normalize};
5use crate::RabitqError;
6
7/// Type of rotator to use for data transformation.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9#[repr(u8)]
10pub enum RotatorType {
11    /// Matrix-based rotator using Gram-Schmidt orthogonalization
12    MatrixRotator = 0,
13    /// Fast Hadamard Transform (FHT) with Kac Walk rotator
14    FhtKacRotator = 1,
15}
16
17impl RotatorType {
18    pub fn from_u8(value: u8) -> Option<Self> {
19        match value {
20            0 => Some(RotatorType::MatrixRotator),
21            1 => Some(RotatorType::FhtKacRotator),
22            _ => None,
23        }
24    }
25
26    /// Get padding requirement for the rotator type.
27    pub fn padding_requirement(self, dim: usize) -> usize {
28        match self {
29            RotatorType::MatrixRotator => dim,
30            RotatorType::FhtKacRotator => round_up_to_multiple(dim, 64),
31        }
32    }
33}
34
35fn round_up_to_multiple(value: usize, multiple: usize) -> usize {
36    ((value + multiple - 1) / multiple) * multiple
37}
38
39/// Trait for vector rotation operations.
40pub trait Rotator: Send + Sync {
41    /// Get the original dimension
42    fn dim(&self) -> usize;
43
44    /// Get the padded dimension after rotation
45    fn padded_dim(&self) -> usize;
46
47    /// Apply rotation to input vector, returning rotated output
48    fn rotate(&self, input: &[f32]) -> Vec<f32>;
49
50    /// Apply rotation into an existing buffer
51    fn rotate_into(&self, input: &[f32], output: &mut [f32]);
52
53    /// Get rotator type
54    fn rotator_type(&self) -> RotatorType;
55
56    /// Serialize rotator state for persistence
57    fn serialize(&self) -> Vec<u8>;
58
59    /// Deserialize rotator state from bytes
60    fn deserialize(dim: usize, padded_dim: usize, data: &[u8]) -> Result<Self, RabitqError>
61    where
62        Self: Sized;
63}
64
65/// Matrix-based rotator using Gram-Schmidt orthogonalization.
66#[derive(Debug, Clone)]
67pub struct MatrixRotator {
68    dim: usize,
69    padded_dim: usize,
70    matrix: Vec<f32>, // Row-major storage: padded_dim x padded_dim
71}
72
73impl MatrixRotator {
74    /// Create a new matrix rotator with the provided seed.
75    pub fn new(dim: usize, seed: u64) -> Self {
76        let padded_dim = RotatorType::MatrixRotator.padding_requirement(dim);
77        let mut rng = StdRng::seed_from_u64(seed);
78        Self::with_rng(dim, padded_dim, &mut rng)
79    }
80
81    fn with_rng(dim: usize, padded_dim: usize, rng: &mut StdRng) -> Self {
82        let normal = Normal::new(0.0, 1.0).expect("failed to create normal distribution");
83        let mut basis: Vec<Vec<f32>> = Vec::with_capacity(padded_dim);
84
85        for _ in 0..padded_dim {
86            let mut vec: Vec<f32> = (0..padded_dim).map(|_| normal.sample(rng) as f32).collect();
87
88            // Orthogonalize against previous basis vectors
89            for prev in &basis {
90                let proj = dot(&vec, prev);
91                for (v, p) in vec.iter_mut().zip(prev.iter()) {
92                    *v -= proj * *p;
93                }
94            }
95
96            let mut attempts = 0;
97            loop {
98                let norm = normalize(&mut vec);
99                if norm > f32::EPSILON {
100                    break;
101                }
102                attempts += 1;
103                if attempts > 8 {
104                    // Fallback to a canonical basis vector
105                    vec.fill(0.0);
106                    vec[attempts % padded_dim] = 1.0;
107                    break;
108                }
109                for value in vec.iter_mut() {
110                    *value = normal.sample(rng) as f32;
111                }
112                for prev in &basis {
113                    let proj = dot(&vec, prev);
114                    for (v, p) in vec.iter_mut().zip(prev.iter()) {
115                        *v -= proj * *p;
116                    }
117                }
118            }
119
120            basis.push(vec);
121        }
122
123        let mut matrix = Vec::with_capacity(padded_dim * padded_dim);
124        for row in basis {
125            matrix.extend_from_slice(&row);
126        }
127
128        Self {
129            dim,
130            padded_dim,
131            matrix,
132        }
133    }
134}
135
136impl Rotator for MatrixRotator {
137    fn dim(&self) -> usize {
138        self.dim
139    }
140
141    fn padded_dim(&self) -> usize {
142        self.padded_dim
143    }
144
145    fn rotate(&self, input: &[f32]) -> Vec<f32> {
146        assert_eq!(input.len(), self.dim);
147        let mut output = vec![0.0f32; self.padded_dim];
148        self.rotate_into(input, &mut output);
149        output
150    }
151
152    fn rotate_into(&self, input: &[f32], output: &mut [f32]) {
153        assert_eq!(input.len(), self.dim);
154        assert_eq!(output.len(), self.padded_dim);
155
156        // Pad input with zeros
157        let mut padded_input = vec![0.0f32; self.padded_dim];
158        padded_input[..self.dim].copy_from_slice(input);
159
160        for (row_idx, chunk) in self.matrix.chunks(self.padded_dim).enumerate() {
161            let mut acc = 0.0f32;
162            for (value, &weight) in padded_input.iter().zip(chunk.iter()) {
163                acc += value * weight;
164            }
165            output[row_idx] = acc;
166        }
167    }
168
169    fn rotator_type(&self) -> RotatorType {
170        RotatorType::MatrixRotator
171    }
172
173    fn serialize(&self) -> Vec<u8> {
174        let mut bytes = Vec::new();
175        for &value in &self.matrix {
176            bytes.extend_from_slice(&value.to_le_bytes());
177        }
178        bytes
179    }
180
181    fn deserialize(dim: usize, padded_dim: usize, data: &[u8]) -> Result<Self, RabitqError> {
182        let expected_len = padded_dim * padded_dim * 4; // 4 bytes per f32
183        if data.len() != expected_len {
184            return Err(RabitqError::InvalidPersistence(
185                "rotator matrix length mismatch",
186            ));
187        }
188
189        let mut matrix = Vec::with_capacity(padded_dim * padded_dim);
190        for chunk in data.chunks_exact(4) {
191            let bytes: [u8; 4] = chunk.try_into().unwrap();
192            matrix.push(f32::from_le_bytes(bytes));
193        }
194
195        Ok(Self {
196            dim,
197            padded_dim,
198            matrix,
199        })
200    }
201}
202
203/// Fast Hadamard Transform (FHT) rotator with Kac Walk.
204/// This matches the C++ FhtKacRotator implementation for performance.
205#[derive(Debug, Clone)]
206pub struct FhtKacRotator {
207    dim: usize,
208    padded_dim: usize,
209    flip: Vec<u8>, // 4 * padded_dim / 8 bytes of random flip bits
210    trunc_dim: usize,
211    fac: f32,
212}
213
214impl FhtKacRotator {
215    /// Create a new FHT rotator with the provided seed.
216    pub fn new(dim: usize, seed: u64) -> Self {
217        let padded_dim = RotatorType::FhtKacRotator.padding_requirement(dim);
218        assert_eq!(
219            padded_dim % 64,
220            0,
221            "FHT rotator requires dimension to be multiple of 64"
222        );
223
224        let mut rng = StdRng::seed_from_u64(seed);
225        let uniform = Uniform::new_inclusive(0u8, 255u8);
226
227        // Generate 4 sets of random flip bits (4 rounds of FHT)
228        let flip_bytes = 4 * padded_dim / 8;
229        let flip: Vec<u8> = (0..flip_bytes).map(|_| uniform.sample(&mut rng)).collect();
230
231        // Compute truncated dimension (largest power of 2 <= dim)
232        let bottom_log_dim = floor_log2(dim);
233        let trunc_dim = 1 << bottom_log_dim;
234        let fac = 1.0 / (trunc_dim as f32).sqrt();
235
236        Self {
237            dim,
238            padded_dim,
239            flip,
240            trunc_dim,
241            fac,
242        }
243    }
244
245    /// Apply sign flip based on bit mask
246    fn flip_sign(data: &mut [f32], flip_bits: &[u8]) {
247        for (i, value) in data.iter_mut().enumerate() {
248            let byte_idx = i / 8;
249            let bit_idx = i % 8;
250            if byte_idx < flip_bits.len() {
251                let bit = (flip_bits[byte_idx] >> bit_idx) & 1;
252                if bit == 1 {
253                    *value = -*value;
254                }
255            }
256        }
257    }
258
259    /// Fast Hadamard Transform (FHT) for power-of-2 dimensions
260    fn fht(data: &mut [f32]) {
261        let n = data.len();
262        assert!(
263            n.is_power_of_two(),
264            "FHT requires power-of-2 dimension, got {}",
265            n
266        );
267
268        let mut h = 1;
269        while h < n {
270            for i in (0..n).step_by(h * 2) {
271                for j in i..(i + h) {
272                    let x = data[j];
273                    let y = data[j + h];
274                    data[j] = x + y;
275                    data[j + h] = x - y;
276                }
277            }
278            h *= 2;
279        }
280    }
281
282    /// Kac's walk: Hadamard-like operation
283    fn kacs_walk(data: &mut [f32]) {
284        let len = data.len();
285        let half = len / 2;
286        for i in 0..half {
287            let x = data[i];
288            let y = data[i + half];
289            data[i] = x + y;
290            data[i + half] = x - y;
291        }
292    }
293
294    /// Rescale vector by constant factor
295    fn rescale(data: &mut [f32], factor: f32) {
296        for value in data.iter_mut() {
297            *value *= factor;
298        }
299    }
300}
301
302impl Rotator for FhtKacRotator {
303    fn dim(&self) -> usize {
304        self.dim
305    }
306
307    fn padded_dim(&self) -> usize {
308        self.padded_dim
309    }
310
311    fn rotate(&self, input: &[f32]) -> Vec<f32> {
312        assert_eq!(input.len(), self.dim);
313        let mut output = vec![0.0f32; self.padded_dim];
314        self.rotate_into(input, &mut output);
315        output
316    }
317
318    fn rotate_into(&self, input: &[f32], output: &mut [f32]) {
319        assert_eq!(input.len(), self.dim);
320        assert_eq!(output.len(), self.padded_dim);
321
322        // Copy input and pad with zeros
323        output[..self.dim].copy_from_slice(input);
324        output[self.dim..].fill(0.0);
325
326        let flip_offset = self.padded_dim / 8;
327
328        if self.trunc_dim == self.padded_dim {
329            // Case 1: trunc_dim == padded_dim (dimension is power of 2)
330            // Apply 4 rounds of: flip_sign -> FHT -> rescale
331            for round in 0..4 {
332                let flip_start = round * flip_offset;
333                let flip_end = flip_start + flip_offset;
334                Self::flip_sign(output, &self.flip[flip_start..flip_end]);
335                Self::fht(output);
336                Self::rescale(output, self.fac);
337            }
338        } else {
339            // Case 2: trunc_dim < padded_dim (dimension is not power of 2)
340            let start = self.padded_dim - self.trunc_dim;
341
342            // Round 1
343            Self::flip_sign(output, &self.flip[0..flip_offset]);
344            Self::fht(&mut output[..self.trunc_dim]);
345            Self::rescale(&mut output[..self.trunc_dim], self.fac);
346            Self::kacs_walk(output);
347
348            // Round 2
349            Self::flip_sign(output, &self.flip[flip_offset..2 * flip_offset]);
350            Self::fht(&mut output[start..]);
351            Self::rescale(&mut output[start..], self.fac);
352            Self::kacs_walk(output);
353
354            // Round 3
355            Self::flip_sign(output, &self.flip[2 * flip_offset..3 * flip_offset]);
356            Self::fht(&mut output[..self.trunc_dim]);
357            Self::rescale(&mut output[..self.trunc_dim], self.fac);
358            Self::kacs_walk(output);
359
360            // Round 4
361            Self::flip_sign(output, &self.flip[3 * flip_offset..4 * flip_offset]);
362            Self::fht(&mut output[start..]);
363            Self::rescale(&mut output[start..], self.fac);
364            Self::kacs_walk(output);
365
366            // Final rescale to match C++ normalization
367            Self::rescale(output, 0.25);
368        }
369    }
370
371    fn rotator_type(&self) -> RotatorType {
372        RotatorType::FhtKacRotator
373    }
374
375    fn serialize(&self) -> Vec<u8> {
376        // Only store the flip bits (much smaller than full matrix)
377        self.flip.clone()
378    }
379
380    fn deserialize(dim: usize, padded_dim: usize, data: &[u8]) -> Result<Self, RabitqError> {
381        let expected_len = 4 * padded_dim / 8;
382        if data.len() != expected_len {
383            return Err(RabitqError::InvalidPersistence(
384                "FHT rotator flip bits length mismatch",
385            ));
386        }
387
388        let bottom_log_dim = floor_log2(dim);
389        let trunc_dim = 1 << bottom_log_dim;
390        let fac = 1.0 / (trunc_dim as f32).sqrt();
391
392        Ok(Self {
393            dim,
394            padded_dim,
395            flip: data.to_vec(),
396            trunc_dim,
397            fac,
398        })
399    }
400}
401
402/// Compute floor(log2(x)) for positive integers
403fn floor_log2(x: usize) -> usize {
404    assert!(x > 0, "floor_log2 requires positive input");
405    (usize::BITS - 1 - x.leading_zeros()) as usize
406}
407
408/// Dynamic rotator that can be either Matrix or FHT based
409#[derive(Debug, Clone)]
410pub enum DynamicRotator {
411    Matrix(MatrixRotator),
412    Fht(FhtKacRotator),
413}
414
415impl DynamicRotator {
416    /// Create a new rotator of the specified type
417    pub fn new(dim: usize, rotator_type: RotatorType, seed: u64) -> Self {
418        match rotator_type {
419            RotatorType::MatrixRotator => DynamicRotator::Matrix(MatrixRotator::new(dim, seed)),
420            RotatorType::FhtKacRotator => DynamicRotator::Fht(FhtKacRotator::new(dim, seed)),
421        }
422    }
423
424    pub fn dim(&self) -> usize {
425        match self {
426            DynamicRotator::Matrix(r) => r.dim(),
427            DynamicRotator::Fht(r) => r.dim(),
428        }
429    }
430
431    pub fn padded_dim(&self) -> usize {
432        match self {
433            DynamicRotator::Matrix(r) => r.padded_dim(),
434            DynamicRotator::Fht(r) => r.padded_dim(),
435        }
436    }
437
438    pub fn rotate(&self, input: &[f32]) -> Vec<f32> {
439        match self {
440            DynamicRotator::Matrix(r) => r.rotate(input),
441            DynamicRotator::Fht(r) => r.rotate(input),
442        }
443    }
444
445    pub fn rotate_into(&self, input: &[f32], output: &mut [f32]) {
446        match self {
447            DynamicRotator::Matrix(r) => r.rotate_into(input, output),
448            DynamicRotator::Fht(r) => r.rotate_into(input, output),
449        }
450    }
451
452    pub fn rotator_type(&self) -> RotatorType {
453        match self {
454            DynamicRotator::Matrix(r) => r.rotator_type(),
455            DynamicRotator::Fht(r) => r.rotator_type(),
456        }
457    }
458
459    pub fn serialize(&self) -> Vec<u8> {
460        match self {
461            DynamicRotator::Matrix(r) => r.serialize(),
462            DynamicRotator::Fht(r) => r.serialize(),
463        }
464    }
465
466    pub fn deserialize(
467        dim: usize,
468        padded_dim: usize,
469        rotator_type: RotatorType,
470        data: &[u8],
471    ) -> Result<Self, RabitqError> {
472        match rotator_type {
473            RotatorType::MatrixRotator => Ok(DynamicRotator::Matrix(MatrixRotator::deserialize(
474                dim, padded_dim, data,
475            )?)),
476            RotatorType::FhtKacRotator => Ok(DynamicRotator::Fht(FhtKacRotator::deserialize(
477                dim, padded_dim, data,
478            )?)),
479        }
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn test_floor_log2() {
489        assert_eq!(floor_log2(1), 0);
490        assert_eq!(floor_log2(2), 1);
491        assert_eq!(floor_log2(3), 1);
492        assert_eq!(floor_log2(4), 2);
493        assert_eq!(floor_log2(7), 2);
494        assert_eq!(floor_log2(8), 3);
495        assert_eq!(floor_log2(960), 9);
496    }
497
498    #[test]
499    fn test_fht_basic() {
500        let mut data = vec![1.0, 2.0, 3.0, 4.0];
501        FhtKacRotator::fht(&mut data);
502        // FHT is self-inverse (up to scaling)
503        FhtKacRotator::fht(&mut data);
504        for (i, &val) in data.iter().enumerate() {
505            assert!((val - (i + 1) as f32 * 4.0).abs() < 1e-5);
506        }
507    }
508
509    #[test]
510    fn test_matrix_rotator_orthogonality() {
511        let dim = 16;
512        let rotator = MatrixRotator::new(dim, 12345);
513
514        let input = vec![1.0; dim];
515        let output = rotator.rotate(&input);
516
517        assert_eq!(output.len(), dim);
518
519        // Rotation should preserve norm (approximately)
520        let input_norm: f32 = input.iter().map(|x| x * x).sum::<f32>().sqrt();
521        let output_norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
522        assert!((input_norm - output_norm).abs() < 1e-4);
523    }
524
525    #[test]
526    fn test_fht_rotator_basic() {
527        let dim = 64;
528        let rotator = FhtKacRotator::new(dim, 54321);
529
530        assert_eq!(rotator.dim(), dim);
531        assert_eq!(rotator.padded_dim(), 64);
532
533        let input = vec![1.0; dim];
534        let output = rotator.rotate(&input);
535
536        assert_eq!(output.len(), 64);
537    }
538
539    #[test]
540    fn test_fht_rotator_non_power_of_2() {
541        let dim = 960; // GIST dataset dimension
542        let rotator = FhtKacRotator::new(dim, 98765);
543
544        assert_eq!(rotator.dim(), dim);
545        assert_eq!(rotator.padded_dim(), 960); // Already multiple of 64
546
547        let input = vec![1.0; dim];
548        let output = rotator.rotate(&input);
549
550        assert_eq!(output.len(), 960);
551    }
552
553    #[test]
554    fn test_rotator_serialization() {
555        let dim = 128;
556
557        // Test MatrixRotator
558        let matrix_rot = MatrixRotator::new(dim, 11111);
559        let matrix_bytes = matrix_rot.serialize();
560        let matrix_rot2 = MatrixRotator::deserialize(dim, dim, &matrix_bytes).unwrap();
561
562        let input = [1.0, 2.0, 3.0]
563            .iter()
564            .cycle()
565            .take(dim)
566            .copied()
567            .collect::<Vec<_>>();
568        let out1 = matrix_rot.rotate(&input);
569        let out2 = matrix_rot2.rotate(&input);
570
571        for (a, b) in out1.iter().zip(out2.iter()) {
572            assert!((a - b).abs() < 1e-6);
573        }
574
575        // Test FhtKacRotator
576        let fht_rot = FhtKacRotator::new(dim, 22222);
577        let fht_bytes = fht_rot.serialize();
578        let fht_rot2 = FhtKacRotator::deserialize(dim, 128, &fht_bytes).unwrap();
579
580        let out3 = fht_rot.rotate(&input);
581        let out4 = fht_rot2.rotate(&input);
582
583        for (a, b) in out3.iter().zip(out4.iter()) {
584            assert!((a - b).abs() < 1e-6);
585        }
586    }
587}