use crate::error::{Result, RuvLLMError};
pub const MAX_LOG_DIM: u32 = 20;
#[cfg(target_arch = "aarch64")]
pub const SIMD_LANES: usize = 4;
#[cfg(target_arch = "x86_64")]
pub const SIMD_LANES: usize = 8;
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
pub const SIMD_LANES: usize = 1;
#[derive(Debug, Clone)]
pub struct HadamardTransform {
log_dim: u32,
dim: usize,
signs: Vec<i8>,
norm_factor: f32,
randomized: bool,
}
impl HadamardTransform {
pub fn new(log_dim: u32, seed: Option<u64>) -> Result<Self> {
if log_dim > MAX_LOG_DIM {
return Err(RuvLLMError::Quantization(format!(
"Hadamard dimension 2^{} exceeds maximum supported 2^{}",
log_dim, MAX_LOG_DIM
)));
}
let dim = 1usize << log_dim;
let norm_factor = 1.0 / (dim as f32).sqrt();
let (signs, randomized) = match seed {
Some(s) => {
let mut rng_state = s;
let signs: Vec<i8> = (0..dim)
.map(|_| {
rng_state = rng_state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
if (rng_state >> 63) & 1 == 0 {
1
} else {
-1
}
})
.collect();
(signs, true)
}
None => {
(vec![1i8; dim], false)
}
};
Ok(Self {
log_dim,
dim,
signs,
norm_factor,
randomized,
})
}
pub fn deterministic(log_dim: u32) -> Result<Self> {
Self::new(log_dim, None)
}
pub fn randomized(log_dim: u32, seed: u64) -> Result<Self> {
Self::new(log_dim, Some(seed))
}
#[inline]
pub fn dim(&self) -> usize {
self.dim
}
#[inline]
pub fn log_dim(&self) -> u32 {
self.log_dim
}
#[inline]
pub fn is_randomized(&self) -> bool {
self.randomized
}
pub fn forward_inplace(&self, data: &mut [f32]) {
assert_eq!(
data.len(),
self.dim,
"Data length {} must match transform dimension {}",
data.len(),
self.dim
);
if self.randomized {
self.apply_signs(data);
}
self.hadamard_butterfly(data);
self.normalize(data);
}
pub fn inverse_inplace(&self, data: &mut [f32]) {
assert_eq!(
data.len(),
self.dim,
"Data length {} must match transform dimension {}",
data.len(),
self.dim
);
self.hadamard_butterfly(data);
self.normalize(data);
if self.randomized {
self.apply_signs(data);
}
}
#[inline]
fn apply_signs(&self, data: &mut [f32]) {
#[cfg(target_arch = "aarch64")]
unsafe {
self.apply_signs_neon(data);
return;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe {
self.apply_signs_avx2(data);
}
return;
}
}
self.apply_signs_scalar(data);
}
#[inline]
fn apply_signs_scalar(&self, data: &mut [f32]) {
for (d, &s) in data.iter_mut().zip(self.signs.iter()) {
*d *= s as f32;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn apply_signs_neon(&self, data: &mut [f32]) {
use std::arch::aarch64::*;
let n = data.len();
let chunks = n / 4;
let remainder = n % 4;
let data_ptr = data.as_mut_ptr();
for i in 0..chunks {
let idx = i * 4;
let v = vld1q_f32(data_ptr.add(idx));
let signs = [
self.signs[idx] as f32,
self.signs[idx + 1] as f32,
self.signs[idx + 2] as f32,
self.signs[idx + 3] as f32,
];
let s = vld1q_f32(signs.as_ptr());
let result = vmulq_f32(v, s);
vst1q_f32(data_ptr.add(idx), result);
}
for i in (chunks * 4)..n {
data[i] *= self.signs[i] as f32;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn apply_signs_avx2(&self, data: &mut [f32]) {
use std::arch::x86_64::*;
let n = data.len();
let chunks = n / 8;
let data_ptr = data.as_mut_ptr();
for i in 0..chunks {
let idx = i * 8;
let v = _mm256_loadu_ps(data_ptr.add(idx));
let signs: [f32; 8] = [
self.signs[idx] as f32,
self.signs[idx + 1] as f32,
self.signs[idx + 2] as f32,
self.signs[idx + 3] as f32,
self.signs[idx + 4] as f32,
self.signs[idx + 5] as f32,
self.signs[idx + 6] as f32,
self.signs[idx + 7] as f32,
];
let s = _mm256_loadu_ps(signs.as_ptr());
let result = _mm256_mul_ps(v, s);
_mm256_storeu_ps(data_ptr.add(idx), result);
}
for i in (chunks * 8)..n {
data[i] *= self.signs[i] as f32;
}
}
fn hadamard_butterfly(&self, data: &mut [f32]) {
let n = self.dim;
#[cfg(target_arch = "aarch64")]
{
if n >= 8 {
unsafe {
self.hadamard_butterfly_neon(data);
}
return;
}
}
#[cfg(target_arch = "x86_64")]
{
if n >= 16 && is_x86_feature_detected!("avx2") {
unsafe {
self.hadamard_butterfly_avx2(data);
}
return;
}
}
self.hadamard_butterfly_scalar(data);
}
fn hadamard_butterfly_scalar(&self, data: &mut [f32]) {
let n = self.dim;
let mut h = 1;
while h < n {
let mut j = 0;
while j < n {
for k in 0..h {
let a = data[j + k];
let b = data[j + k + h];
data[j + k] = a + b;
data[j + k + h] = a - b;
}
j += h * 2;
}
h *= 2;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn hadamard_butterfly_neon(&self, data: &mut [f32]) {
use std::arch::aarch64::*;
let n = self.dim;
let mut h = 1;
while h < n {
if h >= 4 {
let mut j = 0;
while j < n {
let mut k = 0;
while k + 4 <= h {
let ptr_a = data.as_mut_ptr().add(j + k);
let ptr_b = data.as_mut_ptr().add(j + k + h);
let a = vld1q_f32(ptr_a);
let b = vld1q_f32(ptr_b);
let sum = vaddq_f32(a, b);
let diff = vsubq_f32(a, b);
vst1q_f32(ptr_a, sum);
vst1q_f32(ptr_b, diff);
k += 4;
}
while k < h {
let a = data[j + k];
let b = data[j + k + h];
data[j + k] = a + b;
data[j + k + h] = a - b;
k += 1;
}
j += h * 2;
}
} else {
let mut j = 0;
while j < n {
for k in 0..h {
let a = data[j + k];
let b = data[j + k + h];
data[j + k] = a + b;
data[j + k + h] = a - b;
}
j += h * 2;
}
}
h *= 2;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn hadamard_butterfly_avx2(&self, data: &mut [f32]) {
use std::arch::x86_64::*;
let n = self.dim;
let mut h = 1;
while h < n {
if h >= 8 {
let mut j = 0;
while j < n {
let mut k = 0;
while k + 8 <= h {
let ptr_a = data.as_mut_ptr().add(j + k);
let ptr_b = data.as_mut_ptr().add(j + k + h);
let a = _mm256_loadu_ps(ptr_a);
let b = _mm256_loadu_ps(ptr_b);
let sum = _mm256_add_ps(a, b);
let diff = _mm256_sub_ps(a, b);
_mm256_storeu_ps(ptr_a, sum);
_mm256_storeu_ps(ptr_b, diff);
k += 8;
}
while k < h {
let a = data[j + k];
let b = data[j + k + h];
data[j + k] = a + b;
data[j + k + h] = a - b;
k += 1;
}
j += h * 2;
}
} else {
let mut j = 0;
while j < n {
for k in 0..h {
let a = data[j + k];
let b = data[j + k + h];
data[j + k] = a + b;
data[j + k + h] = a - b;
}
j += h * 2;
}
}
h *= 2;
}
}
#[inline]
fn normalize(&self, data: &mut [f32]) {
#[cfg(target_arch = "aarch64")]
unsafe {
self.normalize_neon(data);
return;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe {
self.normalize_avx2(data);
}
return;
}
}
for d in data.iter_mut() {
*d *= self.norm_factor;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn normalize_neon(&self, data: &mut [f32]) {
use std::arch::aarch64::*;
let n = data.len();
let chunks = n / 4;
let norm = vdupq_n_f32(self.norm_factor);
let data_ptr = data.as_mut_ptr();
for i in 0..chunks {
let idx = i * 4;
let v = vld1q_f32(data_ptr.add(idx));
let result = vmulq_f32(v, norm);
vst1q_f32(data_ptr.add(idx), result);
}
for i in (chunks * 4)..n {
data[i] *= self.norm_factor;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn normalize_avx2(&self, data: &mut [f32]) {
use std::arch::x86_64::*;
let n = data.len();
let chunks = n / 8;
let norm = _mm256_set1_ps(self.norm_factor);
let data_ptr = data.as_mut_ptr();
for i in 0..chunks {
let idx = i * 8;
let v = _mm256_loadu_ps(data_ptr.add(idx));
let result = _mm256_mul_ps(v, norm);
_mm256_storeu_ps(data_ptr.add(idx), result);
}
for i in (chunks * 8)..n {
data[i] *= self.norm_factor;
}
}
pub fn verify_orthogonality(&self, tolerance: f32) -> bool {
let mut data: Vec<f32> = (0..self.dim)
.map(|i| (i as f32 + 1.0) / self.dim as f32)
.collect();
let original = data.clone();
self.forward_inplace(&mut data);
self.inverse_inplace(&mut data);
for (a, b) in data.iter().zip(original.iter()) {
if (a - b).abs() > tolerance {
return false;
}
}
true
}
}
pub fn hadamard_batch_transform(
transform: &HadamardTransform,
data: &mut [f32],
batch_size: usize,
) -> Result<()> {
let dim = transform.dim();
if data.len() != batch_size * dim {
return Err(RuvLLMError::Quantization(format!(
"Data length {} doesn't match batch_size {} * dim {}",
data.len(),
batch_size,
dim
)));
}
for i in 0..batch_size {
let start = i * dim;
let end = start + dim;
transform.forward_inplace(&mut data[start..end]);
}
Ok(())
}
pub fn hadamard_batch_inverse(
transform: &HadamardTransform,
data: &mut [f32],
batch_size: usize,
) -> Result<()> {
let dim = transform.dim();
if data.len() != batch_size * dim {
return Err(RuvLLMError::Quantization(format!(
"Data length {} doesn't match batch_size {} * dim {}",
data.len(),
batch_size,
dim
)));
}
for i in 0..batch_size {
let start = i * dim;
let end = start + dim;
transform.inverse_inplace(&mut data[start..end]);
}
Ok(())
}
#[inline]
pub fn next_power_of_2(n: usize) -> usize {
if n == 0 {
return 1;
}
1usize << (usize::BITS - (n - 1).leading_zeros())
}
#[inline]
pub fn log2_exact(n: usize) -> Option<u32> {
if n == 0 || (n & (n - 1)) != 0 {
return None;
}
Some(n.trailing_zeros())
}
pub fn pad_to_power_of_2(data: &[f32]) -> Vec<f32> {
let target_len = next_power_of_2(data.len());
let mut padded = data.to_vec();
padded.resize(target_len, 0.0);
padded
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hadamard_basic() {
let transform = HadamardTransform::deterministic(3).unwrap();
assert_eq!(transform.dim(), 8);
assert!(!transform.is_randomized());
}
#[test]
fn test_hadamard_roundtrip() {
let transform = HadamardTransform::deterministic(4).unwrap();
let original: Vec<f32> = (0..16).map(|i| i as f32).collect();
let mut data = original.clone();
transform.forward_inplace(&mut data);
transform.inverse_inplace(&mut data);
for (a, b) in data.iter().zip(original.iter()) {
assert!((a - b).abs() < 1e-5, "Roundtrip failed: {} vs {}", a, b);
}
}
#[test]
fn test_hadamard_randomized_roundtrip() {
let transform = HadamardTransform::randomized(5, 12345).unwrap();
let original: Vec<f32> = (0..32).map(|i| (i as f32 - 16.0) / 10.0).collect();
let mut data = original.clone();
transform.forward_inplace(&mut data);
transform.inverse_inplace(&mut data);
for (a, b) in data.iter().zip(original.iter()) {
assert!(
(a - b).abs() < 1e-5,
"Randomized roundtrip failed: {} vs {}",
a,
b
);
}
}
#[test]
fn test_orthogonality_property_inv4() {
let transform = HadamardTransform::deterministic(6).unwrap();
assert!(
transform.verify_orthogonality(1e-5),
"Orthogonality property (INV-4) violated"
);
}
#[test]
fn test_orthogonality_randomized() {
let transform = HadamardTransform::randomized(6, 42).unwrap();
assert!(
transform.verify_orthogonality(1e-5),
"Randomized transform orthogonality violated"
);
}
#[test]
fn test_hadamard_known_values() {
let transform = HadamardTransform::deterministic(2).unwrap();
let mut data = vec![1.0, 0.0, 0.0, 0.0];
transform.forward_inplace(&mut data);
for &v in &data {
assert!((v - 0.5).abs() < 1e-5, "Expected 0.5, got {}", v);
}
}
#[test]
fn test_energy_preservation() {
let transform = HadamardTransform::deterministic(4).unwrap();
let original: Vec<f32> = (0..16).map(|i| (i as f32) * 0.1).collect();
let original_norm: f32 = original.iter().map(|x| x * x).sum::<f32>().sqrt();
let mut data = original.clone();
transform.forward_inplace(&mut data);
let transformed_norm: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(original_norm - transformed_norm).abs() < 1e-4,
"Energy not preserved: {} vs {}",
original_norm,
transformed_norm
);
}
#[test]
fn test_batch_transform() {
let transform = HadamardTransform::deterministic(3).unwrap();
let batch_size = 4;
let mut data: Vec<f32> = (0..32).map(|i| i as f32).collect();
let original = data.clone();
hadamard_batch_transform(&transform, &mut data, batch_size).unwrap();
hadamard_batch_inverse(&transform, &mut data, batch_size).unwrap();
for (a, b) in data.iter().zip(original.iter()) {
assert!((a - b).abs() < 1e-5);
}
}
#[test]
fn test_next_power_of_2() {
assert_eq!(next_power_of_2(1), 1);
assert_eq!(next_power_of_2(2), 2);
assert_eq!(next_power_of_2(3), 4);
assert_eq!(next_power_of_2(5), 8);
assert_eq!(next_power_of_2(16), 16);
assert_eq!(next_power_of_2(17), 32);
}
#[test]
fn test_log2_exact() {
assert_eq!(log2_exact(1), Some(0));
assert_eq!(log2_exact(2), Some(1));
assert_eq!(log2_exact(4), Some(2));
assert_eq!(log2_exact(1024), Some(10));
assert_eq!(log2_exact(3), None);
assert_eq!(log2_exact(0), None);
}
#[test]
fn test_large_dimension() {
let transform = HadamardTransform::deterministic(8).unwrap();
let original: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 100.0).collect();
let mut data = original.clone();
transform.forward_inplace(&mut data);
transform.inverse_inplace(&mut data);
let max_error: f32 = data
.iter()
.zip(original.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, |a, b| a.max(b));
assert!(
max_error < 1e-4,
"Large dimension roundtrip error too high: {}",
max_error
);
}
#[test]
fn test_error_on_invalid_dimension() {
let result = HadamardTransform::new(MAX_LOG_DIM + 1, None);
assert!(result.is_err());
}
}