#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "x86_64")]
pub fn is_avx2_available() -> bool {
is_x86_feature_detected!("avx2")
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn snap_batch_avx2(
valid_states: &[[f32; 2]],
vectors: &[[f32; 2]],
results: &mut [([f32; 2], f32)],
) {
let vec_count = vectors.len();
let state_count = valid_states.len();
let chunks = vec_count / 8;
let remainder = vec_count % 8;
for chunk_idx in 0..chunks {
let base = chunk_idx * 8;
let mut vx_arr = [0.0f32; 8];
let mut vy_arr = [0.0f32; 8];
for i in 0..8 {
let vec = vectors[base + i];
let norm = (vec[0] * vec[0] + vec[1] * vec[1]).sqrt().max(1e-10);
vx_arr[i] = vec[0] / norm;
vy_arr[i] = vec[1] / norm;
}
let vx = _mm256_loadu_ps(vx_arr.as_ptr());
let vy = _mm256_loadu_ps(vy_arr.as_ptr());
let mut best_idx_f32 = _mm256_setzero_ps();
let mut max_res = _mm256_set1_ps(f32::MIN);
for (state_idx, state) in valid_states.iter().enumerate().take(state_count) {
let idx_as_f32 = state_idx as f32;
let sx = _mm256_set1_ps(state[0]);
let sy = _mm256_set1_ps(state[1]);
let idx_v = _mm256_set1_ps(idx_as_f32);
let rx = _mm256_mul_ps(sx, vx);
let ry = _mm256_mul_ps(sy, vy);
let resonance = _mm256_add_ps(rx, ry);
let cmp = _mm256_cmp_ps(resonance, max_res, _CMP_GT_OS);
let new_max = _mm256_blendv_ps(max_res, resonance, cmp);
let new_idx = _mm256_blendv_ps(best_idx_f32, idx_v, cmp);
max_res = new_max;
best_idx_f32 = new_idx;
}
let mut max_res_arr = [0.0f32; 8];
let mut best_idx_arr = [0.0f32; 8];
_mm256_storeu_ps(max_res_arr.as_mut_ptr(), max_res);
_mm256_storeu_ps(best_idx_arr.as_mut_ptr(), best_idx_f32);
for i in 0..8 {
let state_idx = best_idx_arr[i] as usize;
if state_idx < valid_states.len() {
let snapped = valid_states[state_idx];
let noise = 1.0 - max_res_arr[i];
results[base + i] = (snapped, noise);
}
}
}
let remainder_start = chunks * 8;
for i in 0..remainder {
let idx = remainder_start + i;
let vec = vectors[idx];
let norm = (vec[0] * vec[0] + vec[1] * vec[1]).sqrt();
if norm < 1e-10 {
results[idx] = ([1.0, 0.0], 0.0);
continue;
}
let v_in = [vec[0] / norm, vec[1] / norm];
let mut max_resonance = f32::MIN;
let mut best_idx = 0;
for (j, state) in valid_states.iter().enumerate() {
let resonance = state[0] * v_in[0] + state[1] * v_in[1];
if resonance > max_resonance {
max_resonance = resonance;
best_idx = j;
}
}
let snapped = valid_states[best_idx];
let noise = 1.0 - max_resonance;
results[idx] = (snapped, noise);
}
}
pub fn snap_batch_simd(
valid_states: &[[f32; 2]],
vectors: &[[f32; 2]],
results: &mut [([f32; 2], f32)],
) {
assert_eq!(
vectors.len(),
results.len(),
"Input and output buffers must have same length"
);
#[cfg(target_arch = "x86_64")]
{
if is_avx2_available() {
unsafe {
snap_batch_avx2(valid_states, vectors, results);
}
return;
}
}
for (i, &vec) in vectors.iter().enumerate() {
let norm = (vec[0] * vec[0] + vec[1] * vec[1]).sqrt();
if norm < 1e-10 {
results[i] = ([1.0, 0.0], 0.0);
continue;
}
let v_in = [vec[0] / norm, vec[1] / norm];
let mut max_resonance = f32::MIN;
let mut best_idx = 0;
for (j, state) in valid_states.iter().enumerate() {
let resonance = state[0] * v_in[0] + state[1] * v_in[1];
if resonance > max_resonance {
max_resonance = resonance;
best_idx = j;
}
}
let snapped = valid_states[best_idx];
let noise = 1.0 - max_resonance;
results[i] = (snapped, noise);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_vs_scalar() {
let states: Vec<[f32; 2]> = vec![[1.0, 0.0], [0.0, 1.0], [0.6, 0.8], [0.8, 0.6]];
let vectors: Vec<[f32; 2]> = vec![
[0.59, 0.81],
[0.01, 0.99],
[0.99, 0.01],
[0.61, 0.79],
[0.7, 0.7],
[0.5, 0.9],
[0.9, 0.5],
[0.3, 0.95],
];
let mut results_simd = vec![([0.0, 0.0], 0.0f32); vectors.len()];
let mut results_scalar = vec![([0.0, 0.0], 0.0f32); vectors.len()];
snap_batch_simd(&states, &vectors, &mut results_simd);
for (i, &vec) in vectors.iter().enumerate() {
let norm = (vec[0] * vec[0] + vec[1] * vec[1]).sqrt().max(1e-10);
let v_in = [vec[0] / norm, vec[1] / norm];
let mut max_r = f32::MIN;
let mut best = [1.0, 0.0];
for &state in &states {
let r = state[0] * v_in[0] + state[1] * v_in[1];
if r > max_r {
max_r = r;
best = state;
}
}
results_scalar[i] = (best, 1.0 - max_r);
}
for i in 0..vectors.len() {
assert!(
(results_simd[i].0[0] - results_scalar[i].0[0]).abs() < 0.01,
"X mismatch at {}: simd={:?} scalar={:?}",
i,
results_simd[i].0,
results_scalar[i].0
);
assert!(
(results_simd[i].0[1] - results_scalar[i].0[1]).abs() < 0.01,
"Y mismatch at {}: simd={:?} scalar={:?}",
i,
results_simd[i].0,
results_scalar[i].0
);
assert!(
(results_simd[i].1 - results_scalar[i].1).abs() < 0.01,
"Noise mismatch at {}: simd={:?} scalar={:?}",
i,
results_simd[i].1,
results_scalar[i].1
);
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_avx2_available() {
let has_avx2 = is_avx2_available();
println!("AVX2 available: {}", has_avx2);
}
}