use super::pc_map;
pub fn forward_batch_f32(input: &[f32], output: &mut [u32], bits: u32) {
assert!(
output.len() >= input.len(),
"output buffer too small: {} < {}",
output.len(),
input.len()
);
#[cfg(target_arch = "x86_64")]
{
if bits == 32 {
forward_batch_f32_sse2_full(input, output);
return;
}
}
#[cfg(target_arch = "x86")]
{
if bits == 32 {
forward_batch_f32_sse2_full(input, output);
return;
}
}
for (i, &d) in input.iter().enumerate() {
output[i] = pc_map::forward_f32(d, bits);
}
}
pub fn inverse_batch_f32(input: &[u32], output: &mut [f32], bits: u32) {
assert!(
output.len() >= input.len(),
"output buffer too small: {} < {}",
output.len(),
input.len()
);
#[cfg(target_arch = "x86_64")]
{
if bits == 32 {
inverse_batch_f32_sse2_full(input, output);
return;
}
}
#[cfg(target_arch = "x86")]
{
if bits == 32 {
inverse_batch_f32_sse2_full(input, output);
return;
}
}
for (i, &r) in input.iter().enumerate() {
output[i] = pc_map::inverse_f32(r, bits);
}
}
pub fn forward_batch_f64(input: &[f64], output: &mut [u64], bits: u32) {
assert!(
output.len() >= input.len(),
"output buffer too small: {} < {}",
output.len(),
input.len()
);
for (i, &d) in input.iter().enumerate() {
output[i] = pc_map::forward_f64(d, bits);
}
}
pub fn inverse_batch_f64(input: &[u64], output: &mut [f64], bits: u32) {
assert!(
output.len() >= input.len(),
"output buffer too small: {} < {}",
output.len(),
input.len()
);
for (i, &r) in input.iter().enumerate() {
output[i] = pc_map::inverse_f64(r, bits);
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
fn forward_batch_f32_sse2_full(input: &[f32], output: &mut [u32]) {
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
let n = input.len();
let chunks = n / 4;
let remainder = n % 4;
unsafe {
let ones = _mm_set1_epi32(-1i32);
let zero = _mm_setzero_si128();
for i in 0..chunks {
let offset = i * 4;
let r = _mm_loadu_si128(input[offset..].as_ptr() as *const __m128i);
let r = _mm_xor_si128(r, ones);
let sign = _mm_srli_epi32(r, 31); let neg_sign = _mm_sub_epi32(zero, sign); let xor_mask = _mm_srli_epi32(neg_sign, 1);
let r = _mm_xor_si128(r, xor_mask);
_mm_storeu_si128(output[offset..].as_mut_ptr() as *mut __m128i, r);
}
}
let scalar_start = chunks * 4;
for i in 0..remainder {
output[scalar_start + i] = pc_map::forward_f32(input[scalar_start + i], 32);
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
fn inverse_batch_f32_sse2_full(input: &[u32], output: &mut [f32]) {
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
let n = input.len();
let chunks = n / 4;
let remainder = n % 4;
unsafe {
let ones = _mm_set1_epi32(-1i32);
let zero = _mm_setzero_si128();
for i in 0..chunks {
let offset = i * 4;
let r = _mm_loadu_si128(input[offset..].as_ptr() as *const __m128i);
let sign = _mm_srli_epi32(r, 31);
let neg_sign = _mm_sub_epi32(zero, sign);
let xor_mask = _mm_srli_epi32(neg_sign, 1);
let r = _mm_xor_si128(r, xor_mask);
let r = _mm_xor_si128(r, ones);
_mm_storeu_si128(output[offset..].as_mut_ptr() as *mut __m128i, r);
}
}
let scalar_start = chunks * 4;
for i in 0..remainder {
output[scalar_start + i] = pc_map::inverse_f32(input[scalar_start + i], 32);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn forward_batch_f32_matches_scalar() {
let input: Vec<f32> = (-50..50).map(|i| i as f32 * 0.7).collect();
let mut simd_out = vec![0u32; input.len()];
let mut scalar_out = vec![0u32; input.len()];
forward_batch_f32(&input, &mut simd_out, 32);
for (i, &d) in input.iter().enumerate() {
scalar_out[i] = pc_map::forward_f32(d, 32);
}
assert_eq!(simd_out, scalar_out);
}
#[test]
fn inverse_batch_f32_matches_scalar() {
let input: Vec<f32> = (-50..50).map(|i| i as f32 * 0.7).collect();
let mut mapped = vec![0u32; input.len()];
for (i, &d) in input.iter().enumerate() {
mapped[i] = pc_map::forward_f32(d, 32);
}
let mut simd_out = vec![0.0f32; input.len()];
let mut scalar_out = vec![0.0f32; input.len()];
inverse_batch_f32(&mapped, &mut simd_out, 32);
for (i, &r) in mapped.iter().enumerate() {
scalar_out[i] = pc_map::inverse_f32(r, 32);
}
for i in 0..input.len() {
assert_eq!(
simd_out[i].to_bits(),
scalar_out[i].to_bits(),
"mismatch at index {i}"
);
}
}
#[test]
fn forward_batch_f32_round_trip() {
let input: Vec<f32> = vec![
0.0,
-0.0,
1.0,
-1.0,
f32::INFINITY,
f32::NEG_INFINITY,
f32::NAN,
42.5,
-1e10,
f32::EPSILON,
f32::MAX,
f32::MIN,
];
let mut mapped = vec![0u32; input.len()];
let mut output = vec![0.0f32; input.len()];
forward_batch_f32(&input, &mut mapped, 32);
inverse_batch_f32(&mapped, &mut output, 32);
for i in 0..input.len() {
if input[i].is_nan() {
assert!(output[i].is_nan(), "expected NaN at index {i}");
} else {
assert_eq!(
input[i].to_bits(),
output[i].to_bits(),
"mismatch at index {i}: {} vs {}",
input[i],
output[i]
);
}
}
}
#[test]
fn forward_batch_f64_matches_scalar() {
let input: Vec<f64> = (-50..50).map(|i| i as f64 * 0.7).collect();
let mut batch_out = vec![0u64; input.len()];
let mut scalar_out = vec![0u64; input.len()];
forward_batch_f64(&input, &mut batch_out, 64);
for (i, &d) in input.iter().enumerate() {
scalar_out[i] = pc_map::forward_f64(d, 64);
}
assert_eq!(batch_out, scalar_out);
}
#[test]
fn forward_batch_f32_non_aligned_length() {
for len in [1, 2, 3, 5, 7, 13, 17] {
let input: Vec<f32> = (0..len).map(|i| i as f32).collect();
let mut simd_out = vec![0u32; len];
let mut scalar_out = vec![0u32; len];
forward_batch_f32(&input, &mut simd_out, 32);
for (i, &d) in input.iter().enumerate() {
scalar_out[i] = pc_map::forward_f32(d, 32);
}
assert_eq!(simd_out, scalar_out, "failed for len={len}");
}
}
#[test]
fn forward_batch_f32_reduced_precision() {
let input: Vec<f32> = (-20..20).map(|i| i as f32 * 1.5).collect();
for bits in [8, 16, 24] {
let mut batch_out = vec![0u32; input.len()];
let mut scalar_out = vec![0u32; input.len()];
forward_batch_f32(&input, &mut batch_out, bits);
for (i, &d) in input.iter().enumerate() {
scalar_out[i] = pc_map::forward_f32(d, bits);
}
assert_eq!(batch_out, scalar_out, "failed for bits={bits}");
}
}
}