Skip to main content

native_neural_network/sampling/
pick.rs

1use super::SamplingError;
2
3pub fn argmax_sample(probabilities: &[f32]) -> Result<usize, SamplingError> {
4    if probabilities.is_empty() {
5        return Err(SamplingError::Empty);
6    }
7    let mut idx = 0usize;
8    let mut best = probabilities[0];
9    for (i, &p) in probabilities.iter().enumerate().skip(1) {
10        if p > best {
11            best = p;
12            idx = i;
13        }
14    }
15    Ok(idx)
16}
17
18pub fn sample_from_cumulative(probabilities: &[f32], threshold: f32) -> Result<usize, SamplingError> {
19    if probabilities.is_empty() {
20        return Err(SamplingError::Empty);
21    }
22    if !threshold.is_finite() || threshold < 0.0 || threshold > 1.0 {
23        return Err(SamplingError::InvalidParameter);
24    }
25
26    let mut cumulative = 0.0f32;
27    for (i, &p) in probabilities.iter().enumerate() {
28        if !p.is_finite() || p < 0.0 {
29            return Err(SamplingError::InvalidParameter);
30        }
31        cumulative += p;
32        if cumulative >= threshold {
33            return Ok(i);
34        }
35    }
36    Ok(probabilities.len() - 1)
37}