use super::SamplingError;
pub fn softmax_temperature_f32(
logits: &[f32],
temperature: f32,
out: &mut [f32],
) -> Result<(), SamplingError> {
if logits.is_empty() {
return Err(SamplingError::Empty);
}
if out.len() < logits.len() {
return Err(SamplingError::ShapeMismatch);
}
if !temperature.is_finite() || temperature <= 0.0 {
return Err(SamplingError::InvalidParameter);
}
let mut max_v = logits[0] / temperature;
for &v in logits.iter().skip(1) {
let t = v / temperature;
if t > max_v {
max_v = t;
}
}
let mut sum = 0.0f32;
for i in 0..logits.len() {
let p = crate::math::expf((logits[i] / temperature) - max_v);
out[i] = p;
sum += p;
}
if !sum.is_finite() || sum <= 0.0 {
return Err(SamplingError::InvalidParameter);
}
let inv = 1.0 / sum;
for p in out.iter_mut().take(logits.len()) {
*p *= inv;
}
Ok(())
}
pub fn softmax_temperature(
logits: &[f32],
temperature: f32,
out: &mut [f32],
) -> Result<(), SamplingError> {
softmax_temperature_f32(logits, temperature, out)
}
pub fn softmax_temperature_f64(
logits: &[f64],
temperature: f64,
out: &mut [f64],
) -> Result<(), SamplingError> {
if logits.is_empty() {
return Err(SamplingError::Empty);
}
if out.len() < logits.len() {
return Err(SamplingError::ShapeMismatch);
}
if !temperature.is_finite() || temperature <= 0.0 {
return Err(SamplingError::InvalidParameter);
}
let mut max_v = logits[0] / temperature;
for &v in logits.iter().skip(1) {
let t = v / temperature;
if t > max_v {
max_v = t;
}
}
let mut sum = 0.0f64;
for i in 0..logits.len() {
let p = crate::math::expd((logits[i] / temperature) - max_v);
out[i] = p;
sum += p;
}
if !sum.is_finite() || sum <= 0.0 {
return Err(SamplingError::InvalidParameter);
}
let inv = 1.0f64 / sum;
for p in out.iter_mut().take(logits.len()) {
*p *= inv;
}
Ok(())
}