1use super::SamplingError;
2
3pub fn argmax_sample(probabilities: &[f32]) -> Result<usize, SamplingError> {
4 if probabilities.is_empty() {
5 return Err(SamplingError::Empty);
6 }
7 let mut idx = 0usize;
8 let mut best = probabilities[0];
9 for (i, &p) in probabilities.iter().enumerate().skip(1) {
10 if p > best {
11 best = p;
12 idx = i;
13 }
14 }
15 Ok(idx)
16}
17
18pub fn sample_from_cumulative(probabilities: &[f32], threshold: f32) -> Result<usize, SamplingError> {
19 if probabilities.is_empty() {
20 return Err(SamplingError::Empty);
21 }
22 if !threshold.is_finite() || threshold < 0.0 || threshold > 1.0 {
23 return Err(SamplingError::InvalidParameter);
24 }
25
26 let mut cumulative = 0.0f32;
27 for (i, &p) in probabilities.iter().enumerate() {
28 if !p.is_finite() || p < 0.0 {
29 return Err(SamplingError::InvalidParameter);
30 }
31 cumulative += p;
32 if cumulative >= threshold {
33 return Ok(i);
34 }
35 }
36 Ok(probabilities.len() - 1)
37}