use super::SamplingError;
pub fn argmax_sample(probabilities: &[f32]) -> Result<usize, SamplingError> {
if probabilities.is_empty() {
return Err(SamplingError::Empty);
}
let mut idx = 0usize;
let mut best = probabilities[0];
for (i, &p) in probabilities.iter().enumerate().skip(1) {
if p > best {
best = p;
idx = i;
}
}
Ok(idx)
}
pub fn sample_from_cumulative(probabilities: &[f32], threshold: f32) -> Result<usize, SamplingError> {
if probabilities.is_empty() {
return Err(SamplingError::Empty);
}
if !threshold.is_finite() || !(0.0..=1.0).contains(&threshold) {
return Err(SamplingError::InvalidParameter);
}
let mut cumulative = 0.0f32;
for (i, &p) in probabilities.iter().enumerate() {
if !p.is_finite() || p < 0.0 {
return Err(SamplingError::InvalidParameter);
}
cumulative += p;
if cumulative >= threshold {
return Ok(i);
}
}
Ok(probabilities.len() - 1)
}