#![allow(unsafe_op_in_unsafe_fn)]
use std::sync::OnceLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdCapability {
Scalar,
Sse41,
Avx2,
Avx512,
Neon,
}
impl SimdCapability {
pub fn detect() -> Self {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
return Self::Avx512;
}
if is_x86_feature_detected!("avx2") {
return Self::Avx2;
}
if is_x86_feature_detected!("sse4.1") {
return Self::Sse41;
}
}
#[cfg(target_arch = "aarch64")]
{
Self::Neon
}
#[cfg(not(target_arch = "aarch64"))]
Self::Scalar
}
pub fn width(&self) -> usize {
match self {
Self::Scalar => 1,
Self::Sse41 | Self::Neon => 4,
Self::Avx2 => 8,
Self::Avx512 => 16,
}
}
}
static CAPABILITY: OnceLock<SimdCapability> = OnceLock::new();
pub fn simd_capability() -> SimdCapability {
*CAPABILITY.get_or_init(SimdCapability::detect)
}
#[derive(Debug, Clone, Copy)]
pub struct HadamardKernel {
capability: SimdCapability,
}
impl HadamardKernel {
pub fn detect() -> Self {
Self {
capability: simd_capability(),
}
}
pub fn with_capability(capability: SimdCapability) -> Self {
Self { capability }
}
#[inline]
pub fn transform(&self, data: &mut [f32]) {
let n = data.len();
if n == 0 || !n.is_power_of_two() {
return;
}
match self.capability {
#[cfg(target_arch = "x86_64")]
SimdCapability::Avx512 => unsafe { hadamard_avx512(data) },
#[cfg(target_arch = "x86_64")]
SimdCapability::Avx2 => unsafe { hadamard_avx2(data) },
#[cfg(target_arch = "x86_64")]
SimdCapability::Sse41 => unsafe { hadamard_sse41(data) },
#[cfg(target_arch = "aarch64")]
SimdCapability::Neon => unsafe { hadamard_neon(data) },
_ => hadamard_scalar(data),
}
}
pub fn transform_batch(&self, flat_data: &mut [f32], dim: usize) {
if dim == 0 || !dim.is_power_of_two() {
return;
}
let num_vectors = flat_data.len() / dim;
for i in 0..num_vectors {
let start = i * dim;
let slice = &mut flat_data[start..start + dim];
self.transform(slice);
}
}
pub fn capability(&self) -> SimdCapability {
self.capability
}
}
impl Default for HadamardKernel {
fn default() -> Self {
Self::detect()
}
}
pub fn hadamard_scalar(data: &mut [f32]) {
let n = data.len();
if n == 0 || !n.is_power_of_two() {
return;
}
let mut h = 1;
while h < n {
for i in (0..n).step_by(h * 2) {
for j in i..(i + h) {
let x = data[j];
let y = data[j + h];
data[j] = x + y;
data[j + h] = x - y;
}
}
h *= 2;
}
let scale = 1.0 / (n as f32).sqrt();
for x in data.iter_mut() {
*x *= scale;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn hadamard_avx2(data: &mut [f32]) {
use std::arch::x86_64::*;
unsafe {
let n = data.len();
if n < 8 {
hadamard_scalar(data);
return;
}
let mut h = 1;
while h < 8 && h < n {
for i in (0..n).step_by(h * 2) {
for j in i..(i + h) {
let x = *data.get_unchecked(j);
let y = *data.get_unchecked(j + h);
*data.get_unchecked_mut(j) = x + y;
*data.get_unchecked_mut(j + h) = x - y;
}
}
h *= 2;
}
while h < n {
let blocks = n / (h * 2);
for block in 0..blocks {
let base = block * h * 2;
for j in (0..h).step_by(8) {
let idx_a = base + j;
let idx_b = base + h + j;
let va = _mm256_loadu_ps(data.as_ptr().add(idx_a));
let vb = _mm256_loadu_ps(data.as_ptr().add(idx_b));
let sum = _mm256_add_ps(va, vb);
let diff = _mm256_sub_ps(va, vb);
_mm256_storeu_ps(data.as_mut_ptr().add(idx_a), sum);
_mm256_storeu_ps(data.as_mut_ptr().add(idx_b), diff);
}
let remainder = h % 8;
if remainder > 0 {
let start = h - remainder;
for j in start..h {
let idx_a = base + j;
let idx_b = base + h + j;
let x = *data.get_unchecked(idx_a);
let y = *data.get_unchecked(idx_b);
*data.get_unchecked_mut(idx_a) = x + y;
*data.get_unchecked_mut(idx_b) = x - y;
}
}
}
h *= 2;
}
let scale = 1.0 / (n as f32).sqrt();
let vscale = _mm256_set1_ps(scale);
let chunks = n / 8;
for i in 0..chunks {
let offset = i * 8;
let v = _mm256_loadu_ps(data.as_ptr().add(offset));
let scaled = _mm256_mul_ps(v, vscale);
_mm256_storeu_ps(data.as_mut_ptr().add(offset), scaled);
}
for i in (chunks * 8)..n {
*data.get_unchecked_mut(i) *= scale;
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.1")]
unsafe fn hadamard_sse41(data: &mut [f32]) {
use std::arch::x86_64::*;
unsafe {
let n = data.len();
if n < 4 {
hadamard_scalar(data);
return;
}
let mut h = 1;
while h < 4 && h < n {
for i in (0..n).step_by(h * 2) {
for j in i..(i + h) {
let x = *data.get_unchecked(j);
let y = *data.get_unchecked(j + h);
*data.get_unchecked_mut(j) = x + y;
*data.get_unchecked_mut(j + h) = x - y;
}
}
h *= 2;
}
while h < n {
let blocks = n / (h * 2);
for block in 0..blocks {
let base = block * h * 2;
for j in (0..h).step_by(4) {
let idx_a = base + j;
let idx_b = base + h + j;
let va = _mm_loadu_ps(data.as_ptr().add(idx_a));
let vb = _mm_loadu_ps(data.as_ptr().add(idx_b));
let sum = _mm_add_ps(va, vb);
let diff = _mm_sub_ps(va, vb);
_mm_storeu_ps(data.as_mut_ptr().add(idx_a), sum);
_mm_storeu_ps(data.as_mut_ptr().add(idx_b), diff);
}
let remainder = h % 4;
if remainder > 0 {
let start = h - remainder;
for j in start..h {
let idx_a = base + j;
let idx_b = base + h + j;
let x = *data.get_unchecked(idx_a);
let y = *data.get_unchecked(idx_b);
*data.get_unchecked_mut(idx_a) = x + y;
*data.get_unchecked_mut(idx_b) = x - y;
}
}
}
h *= 2;
}
let scale = 1.0 / (n as f32).sqrt();
let vscale = _mm_set1_ps(scale);
let chunks = n / 4;
for i in 0..chunks {
let offset = i * 4;
let v = _mm_loadu_ps(data.as_ptr().add(offset));
let scaled = _mm_mul_ps(v, vscale);
_mm_storeu_ps(data.as_mut_ptr().add(offset), scaled);
}
for i in (chunks * 4)..n {
*data.get_unchecked_mut(i) *= scale;
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn hadamard_avx512(data: &mut [f32]) {
use std::arch::x86_64::*;
unsafe {
let n = data.len();
if n < 16 {
hadamard_avx2(data);
return;
}
let mut h = 1;
while h < 16 && h < n {
for i in (0..n).step_by(h * 2) {
for j in i..(i + h) {
let x = *data.get_unchecked(j);
let y = *data.get_unchecked(j + h);
*data.get_unchecked_mut(j) = x + y;
*data.get_unchecked_mut(j + h) = x - y;
}
}
h *= 2;
}
while h < n {
let blocks = n / (h * 2);
for block in 0..blocks {
let base = block * h * 2;
for j in (0..h).step_by(16) {
let idx_a = base + j;
let idx_b = base + h + j;
let va = _mm512_loadu_ps(data.as_ptr().add(idx_a));
let vb = _mm512_loadu_ps(data.as_ptr().add(idx_b));
let sum = _mm512_add_ps(va, vb);
let diff = _mm512_sub_ps(va, vb);
_mm512_storeu_ps(data.as_mut_ptr().add(idx_a), sum);
_mm512_storeu_ps(data.as_mut_ptr().add(idx_b), diff);
}
let remainder = h % 16;
if remainder > 0 {
let start = h - remainder;
for j in start..h {
let idx_a = base + j;
let idx_b = base + h + j;
let x = *data.get_unchecked(idx_a);
let y = *data.get_unchecked(idx_b);
*data.get_unchecked_mut(idx_a) = x + y;
*data.get_unchecked_mut(idx_b) = x - y;
}
}
}
h *= 2;
}
let scale = 1.0 / (n as f32).sqrt();
let vscale = _mm512_set1_ps(scale);
let chunks = n / 16;
for i in 0..chunks {
let offset = i * 16;
let v = _mm512_loadu_ps(data.as_ptr().add(offset));
let scaled = _mm512_mul_ps(v, vscale);
_mm512_storeu_ps(data.as_mut_ptr().add(offset), scaled);
}
for i in (chunks * 16)..n {
*data.get_unchecked_mut(i) *= scale;
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn hadamard_neon(data: &mut [f32]) {
use std::arch::aarch64::*;
unsafe {
let n = data.len();
if n < 4 {
hadamard_scalar(data);
return;
}
let mut h = 1;
while h < 4 && h < n {
for i in (0..n).step_by(h * 2) {
for j in i..(i + h) {
let x = *data.get_unchecked(j);
let y = *data.get_unchecked(j + h);
*data.get_unchecked_mut(j) = x + y;
*data.get_unchecked_mut(j + h) = x - y;
}
}
h *= 2;
}
while h < n {
let blocks = n / (h * 2);
for block in 0..blocks {
let base = block * h * 2;
for j in (0..h).step_by(4) {
let idx_a = base + j;
let idx_b = base + h + j;
let va = vld1q_f32(data.as_ptr().add(idx_a));
let vb = vld1q_f32(data.as_ptr().add(idx_b));
let sum = vaddq_f32(va, vb);
let diff = vsubq_f32(va, vb);
vst1q_f32(data.as_mut_ptr().add(idx_a), sum);
vst1q_f32(data.as_mut_ptr().add(idx_b), diff);
}
let remainder = h % 4;
if remainder > 0 {
let start = h - remainder;
for j in start..h {
let idx_a = base + j;
let idx_b = base + h + j;
let x = *data.get_unchecked(idx_a);
let y = *data.get_unchecked(idx_b);
*data.get_unchecked_mut(idx_a) = x + y;
*data.get_unchecked_mut(idx_b) = x - y;
}
}
}
h *= 2;
}
let scale = 1.0 / (n as f32).sqrt();
let vscale = vdupq_n_f32(scale);
let chunks = n / 4;
for i in 0..chunks {
let offset = i * 4;
let v = vld1q_f32(data.as_ptr().add(offset));
let scaled = vmulq_f32(v, vscale);
vst1q_f32(data.as_mut_ptr().add(offset), scaled);
}
for i in (chunks * 4)..n {
*data.get_unchecked_mut(i) *= scale;
}
}
}
#[inline]
pub fn hadamard_transform(data: &mut [f32]) {
HadamardKernel::detect().transform(data);
}
pub fn hadamard_transform_batch(flat_data: &mut [f32], dim: usize) {
HadamardKernel::detect().transform_batch(flat_data, dim);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scalar_basic() {
let mut data = vec![1.0, 0.0, 0.0, 0.0];
hadamard_scalar(&mut data);
for &x in &data {
assert!((x - 0.5).abs() < 0.01, "x = {}", x);
}
}
#[test]
fn test_scalar_identity() {
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut data = original.clone();
hadamard_scalar(&mut data);
hadamard_scalar(&mut data);
for (a, b) in original.iter().zip(data.iter()) {
assert!((a - b).abs() < 0.01, "a = {}, b = {}", a, b);
}
}
#[test]
fn test_kernel_detection() {
let kernel = HadamardKernel::detect();
let cap = kernel.capability();
#[cfg(target_arch = "x86_64")]
assert!(matches!(
cap,
SimdCapability::Scalar
| SimdCapability::Sse41
| SimdCapability::Avx2
| SimdCapability::Avx512
));
#[cfg(target_arch = "aarch64")]
assert_eq!(cap, SimdCapability::Neon);
}
#[test]
fn test_kernel_consistency() {
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut scalar_data = original.clone();
hadamard_scalar(&mut scalar_data);
let mut kernel_data = original.clone();
hadamard_transform(&mut kernel_data);
for (a, b) in scalar_data.iter().zip(kernel_data.iter()) {
assert!(
(a - b).abs() < 1e-5,
"Mismatch: scalar {} vs kernel {}",
a,
b
);
}
}
#[test]
fn test_preserves_norm() {
let mut data: Vec<f32> = (1..=16).map(|i| i as f32).collect();
let original_norm: f32 = data.iter().map(|x| x * x).sum();
hadamard_transform(&mut data);
let new_norm: f32 = data.iter().map(|x| x * x).sum();
assert!(
(original_norm - new_norm).abs() < 0.1,
"Norm changed: {} -> {}",
original_norm,
new_norm
);
}
#[test]
fn test_batch_transform() {
let dim = 8;
let num_vectors = 10;
let mut flat_data: Vec<f32> = (0..(dim * num_vectors)).map(|i| i as f32 / 100.0).collect();
hadamard_transform_batch(&mut flat_data, dim);
for i in 0..num_vectors {
let start = i * dim;
let vec = &flat_data[start..start + dim];
let norm: f32 = vec.iter().map(|x| x * x).sum();
assert!(norm > 0.0, "Vector {} has zero norm", i);
}
}
#[test]
fn test_non_power_of_two() {
let mut data = vec![1.0, 2.0, 3.0]; let original = data.clone();
hadamard_transform(&mut data);
assert_eq!(data, original);
}
#[test]
fn test_large_dimension() {
let dim = 1024; let mut data: Vec<f32> = (0..dim).map(|i| (i as f32).sin()).collect();
let original_norm: f32 = data.iter().map(|x| x * x).sum();
hadamard_transform(&mut data);
let new_norm: f32 = data.iter().map(|x| x * x).sum();
let rel_error = (original_norm - new_norm).abs() / original_norm;
assert!(rel_error < 1e-5, "Norm error too large: {}", rel_error);
}
}