use super::{softmax_stable, InferenceError};
pub fn normalize_logits_in_place(logits: &mut [f32]) -> Result<(), InferenceError> {
if logits.is_empty() {
return Err(InferenceError::ShapeMismatch);
}
let mut tmp = [0.0f32; 4096];
if logits.len() > tmp.len() {
return Err(InferenceError::ShapeMismatch);
}
softmax_stable(logits, &mut tmp[..logits.len()])?;
logits.copy_from_slice(&tmp[..logits.len()]);
Ok(())
}
pub fn argmax_index(logits: &[f32]) -> Option<usize> {
if logits.is_empty() {
return None;
}
let mut best_idx = 0usize;
let mut best = logits[0];
for (i, &v) in logits.iter().enumerate().skip(1) {
if v > best {
best = v;
best_idx = i;
}
}
Some(best_idx)
}