pub fn stable_softmax_row_f32(values: &mut [f32]) -> Option<()> {
if values.is_empty() {
return None;
}
let mut max_v = values[0];
for &v in values.iter().skip(1) {
if v > max_v {
max_v = v;
}
}
let mut sum = 0.0f32;
for v in values.iter_mut() {
*v = crate::math::expf(*v - max_v);
sum += *v;
}
if !sum.is_finite() || sum <= 0.0 {
return None;
}
let inv = 1.0f32 / sum;
for v in values.iter_mut() {
*v *= inv;
}
Some(())
}
pub fn stable_softmax_row_f64(values: &mut [f64]) -> Option<()> {
if values.is_empty() {
return None;
}
let mut max_v = values[0];
for &v in values.iter().skip(1) {
if v > max_v {
max_v = v;
}
}
let mut sum = 0.0f64;
for v in values.iter_mut() {
*v = crate::math::expd(*v - max_v);
sum += *v;
}
if !sum.is_finite() || sum <= 0.0 {
return None;
}
let inv = 1.0f64 / sum;
for v in values.iter_mut() {
*v *= inv;
}
Some(())
}