use wide::f32x8;
pub(crate) fn relu(s: &mut [f32]) {
let zero = f32x8::splat(0.0);
let (main, tail) = split8(s);
for chunk in main.chunks_exact_mut(8) {
let v = f32x8::from(<[f32; 8]>::try_from(&*chunk).unwrap());
chunk.copy_from_slice(&<[f32; 8]>::from(v.max(zero)));
}
for x in tail.iter_mut() {
*x = x.max(0.0);
}
}
pub(crate) fn sigmoid(s: &mut [f32]) {
let one = f32x8::splat(1.0);
let lo = f32x8::splat(-88.0);
let hi = f32x8::splat(88.0);
let (main, tail) = split8(s);
for chunk in main.chunks_exact_mut(8) {
let v = f32x8::from(<[f32; 8]>::try_from(&*chunk).unwrap());
let clamped = v.max(lo).min(hi);
let result = one / (one + (-clamped).exp());
chunk.copy_from_slice(&<[f32; 8]>::from(result));
}
for x in tail.iter_mut() {
*x = 1.0 / (1.0 + (-x.clamp(-88.0, 88.0)).exp());
}
}
pub(crate) fn exp(s: &mut [f32]) {
let (main, tail) = split8(s);
for chunk in main.chunks_exact_mut(8) {
let v = f32x8::from(<[f32; 8]>::try_from(&*chunk).unwrap());
chunk.copy_from_slice(&<[f32; 8]>::from(v.exp()));
}
for x in tail.iter_mut() {
*x = x.exp();
}
}
pub(crate) fn sqrt(s: &mut [f32]) {
let (main, tail) = split8(s);
for chunk in main.chunks_exact_mut(8) {
let v = f32x8::from(<[f32; 8]>::try_from(&*chunk).unwrap());
chunk.copy_from_slice(&<[f32; 8]>::from(v.sqrt()));
}
for x in tail.iter_mut() {
*x = x.sqrt();
}
}
#[inline(always)]
fn split8(s: &mut [f32]) -> (&mut [f32], &mut [f32]) {
let n8 = (s.len() / 8) * 8;
s.split_at_mut(n8)
}