use rand::Rng;
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256PlusPlus;
pub struct Rotator {
dim: u32,
padded_dim: u32,
signs: Vec<f32>,
}
impl Rotator {
const ROTATE_SEED: u64 = 0xDEAD_BEEF_CAFE_1234;
pub fn new(dim: u32) -> Self {
let padded_dim = Self::next_power_of_two(dim);
let mut rng = Xoshiro256PlusPlus::seed_from_u64(Self::ROTATE_SEED);
let signs: Vec<f32> = (0..padded_dim)
.map(|_| if rng.r#gen::<bool>() { 1.0 } else { -1.0 })
.collect();
Self {
dim,
padded_dim,
signs,
}
}
fn next_power_of_two(n: u32) -> u32 {
let mut p = 1u32;
while p < n {
p *= 2;
}
p
}
pub fn rotate(&self, x: &[f32]) -> Vec<f32> {
assert!(x.len() <= self.padded_dim as usize);
let mut v = vec![0.0f32; self.padded_dim as usize];
for (i, &val) in x.iter().enumerate() {
v[i] = val * self.signs[i];
}
self.hadamard_transform(&mut v);
v.truncate(self.dim as usize);
let norm_factor = 1.0 / (self.padded_dim as f32).sqrt();
for val in &mut v {
*val *= norm_factor;
}
v
}
pub fn rotate_inverse(&self, x: &[f32]) -> Vec<f32> {
let mut v = vec![0.0f32; self.padded_dim as usize];
for (i, &val) in x.iter().enumerate() {
v[i] = val;
}
self.hadamard_transform(&mut v);
for i in 0..self.dim as usize {
v[i] *= self.signs[i];
}
let norm_factor = 1.0 / (self.padded_dim as f32).sqrt();
for val in &mut v {
*val *= norm_factor;
}
v.truncate(self.dim as usize);
v
}
fn hadamard_transform(&self, v: &mut [f32]) {
let n = v.len();
assert!(n.is_power_of_two());
let mut h = 1;
while h < n {
for i in (0..n).step_by(h * 2) {
for j in i..(i + h) {
let x = v[j];
let y = v[j + h];
v[j] = x + y;
v[j + h] = x - y;
}
}
h *= 2;
}
}
pub fn dim(&self) -> u32 {
self.dim
}
pub fn padded_dim(&self) -> u32 {
self.padded_dim
}
}
pub struct BlockRotator {
dim: u32,
block_size: u32,
num_blocks: u32,
signs: Vec<f32>,
}
impl BlockRotator {
const BLOCK_ROT_SEED: u64 = 0xB10C_B0A7_CAFE_5678;
pub fn new(dim: u32, block_size: u32) -> Self {
assert!(block_size.is_power_of_two());
let num_blocks = (dim + block_size - 1) / block_size;
let mut rng = Xoshiro256PlusPlus::seed_from_u64(Self::BLOCK_ROT_SEED);
let total_size = num_blocks * block_size;
let signs: Vec<f32> = (0..total_size)
.map(|_| if rng.r#gen::<bool>() { 1.0 } else { -1.0 })
.collect();
Self {
dim,
block_size,
num_blocks,
signs,
}
}
pub fn rotate(&self, x: &[f32]) -> Vec<f32> {
let mut result = vec![0.0f32; self.dim as usize];
for block_idx in 0..self.num_blocks as usize {
let start = block_idx * self.block_size as usize;
let end = (start + self.block_size as usize).min(self.dim as usize);
let mut block = vec![0.0f32; self.block_size as usize];
for (i, idx) in (start..end).enumerate() {
if idx < x.len() {
block[i] = x[idx] * self.signs[start + i];
}
}
Self::hadamard_transform_block(&mut block);
let norm = 1.0 / (self.block_size as f32).sqrt();
for (i, idx) in (start..end).enumerate() {
result[idx] = block[i] * norm;
}
}
result
}
fn hadamard_transform_block(v: &mut [f32]) {
let n = v.len();
let mut h = 1;
while h < n {
for i in (0..n).step_by(h * 2) {
for j in i..(i + h) {
let x = v[j];
let y = v[j + h];
v[j] = x + y;
v[j + h] = x - y;
}
}
h *= 2;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rotator_preserves_norm() {
let rotator = Rotator::new(64);
let x: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
let norm_before: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
let y = rotator.rotate(&x);
let norm_after: f32 = y.iter().map(|v| v * v).sum::<f32>().sqrt();
assert!(
(norm_before - norm_after).abs() < 0.01,
"Norms differ: {} vs {}",
norm_before,
norm_after
);
}
#[test]
fn test_rotation_roundtrip() {
let rotator = Rotator::new(64);
let x: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
let y = rotator.rotate(&x);
let z = rotator.rotate_inverse(&y);
for (a, b) in x.iter().zip(z.iter()) {
assert!((a - b).abs() < 0.01, "Mismatch: {} vs {}", a, b);
}
}
#[test]
fn test_hadamard_basic() {
let rotator = Rotator::new(4);
let x = vec![1.0, 0.0, 0.0, 0.0];
let y = rotator.rotate(&x);
assert!(y.iter().all(|&v| (v.abs() - y[0].abs()).abs() < 0.01));
}
#[test]
fn test_block_rotator() {
let rotator = BlockRotator::new(768, 64);
let x: Vec<f32> = (0..768).map(|i| (i as f32 - 384.0) * 0.01).collect();
let norm_before: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
let y = rotator.rotate(&x);
let norm_after: f32 = y.iter().map(|v| v * v).sum::<f32>().sqrt();
assert!((norm_before - norm_after).abs() < 0.1);
}
}