#[cfg(feature = "simd")]
use wide::f32x8;
#[cfg(feature = "simd")]
pub fn simd_exp_f32(input: &[f32], output: &mut [f32]) {
let len = input.len().min(output.len());
let simd_len = len / 8;
let remainder_start = simd_len * 8;
for i in 0..simd_len {
let idx = i * 8;
let input_simd = f32x8::from([
input[idx],
input[idx + 1],
input[idx + 2],
input[idx + 3],
input[idx + 4],
input[idx + 5],
input[idx + 6],
input[idx + 7],
]);
let result_simd = simd_fast_exp_f32x8(input_simd);
let result_array: [f32; 8] = result_simd.into();
output[idx..idx + 8].copy_from_slice(&result_array);
}
for i in remainder_start..len {
output[i] = input[i].exp();
}
}
#[cfg(feature = "simd")]
pub fn simd_log_f32(input: &[f32], output: &mut [f32]) {
let len = input.len().min(output.len());
let simd_len = len / 8;
let remainder_start = simd_len * 8;
for i in 0..simd_len {
let idx = i * 8;
let input_simd = f32x8::from([
input[idx],
input[idx + 1],
input[idx + 2],
input[idx + 3],
input[idx + 4],
input[idx + 5],
input[idx + 6],
input[idx + 7],
]);
let result_simd = simd_fast_log_f32x8(input_simd.max(f32x8::splat(1e-30)));
let result_array: [f32; 8] = result_simd.into();
output[idx..idx + 8].copy_from_slice(&result_array);
}
for i in remainder_start..len {
output[i] = input[i].ln();
}
}
#[cfg(feature = "simd")]
pub fn simd_fast_exp_f32x8(x: f32x8) -> f32x8 {
let one = f32x8::splat(1.0);
let x_sq = x * x;
let x_cube = x_sq * x;
let x_fourth = x_sq * x_sq;
let x_fifth = x_fourth * x;
one + x
+ x_sq * f32x8::splat(0.5)
+ x_cube * f32x8::splat(1.0 / 6.0)
+ x_fourth * f32x8::splat(1.0 / 24.0)
+ x_fifth * f32x8::splat(1.0 / 120.0)
}
#[cfg(feature = "simd")]
pub fn simd_fast_log_f32x8(x: f32x8) -> f32x8 {
let one = f32x8::splat(1.0);
let x_minus_1 = x - one;
let x_minus_1_sq = x_minus_1 * x_minus_1;
let x_minus_1_cube = x_minus_1_sq * x_minus_1;
x_minus_1 - x_minus_1_sq * f32x8::splat(0.5) + x_minus_1_cube * f32x8::splat(1.0 / 3.0)
}
#[cfg(not(feature = "simd"))]
pub fn simd_exp_f32(input: &[f32], output: &mut [f32]) {
let len = input.len().min(output.len());
for i in 0..len {
output[i] = input[i].exp();
}
}
#[cfg(not(feature = "simd"))]
pub fn simd_log_f32(input: &[f32], output: &mut [f32]) {
let len = input.len().min(output.len());
for i in 0..len {
output[i] = input[i].ln();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_exp_f32() {
let input = [0.0, 1.0];
let mut output = [0.0; 2];
simd_exp_f32(&input, &mut output);
assert!((output[0] - 1.0).abs() < 0.1);
assert!((output[1] - 2.718).abs() < 0.5);
}
#[test]
fn test_simd_log_f32() {
let input = [1.0, 2.718];
let mut output = [0.0; 2];
simd_log_f32(&input, &mut output);
assert!(output[0].abs() < 0.1);
assert!((output[1] - 1.0).abs() < 0.5);
}
}