pecos_core/sims_rngs/
choices.rs

1// Copyright 2024 The PECOS Developers
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4// in compliance with the License.You may obtain a copy of the License at
5//
6//     https://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software distributed under the License
9// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10// or implied. See the License for the specific language governing permissions and limitations under
11// the License.
12
13use rand::distributions::{Distribution, WeightedIndex};
14use rand::RngCore;
15
16const EPSILON: f64 = 1e-9;
17
18/// Struct to hold choices and pre-validated `WeightedIndex`
19/// The weights need to sum up close to 1.0 and will be re-normalized if they do so.
20#[derive(Debug)]
21pub struct Choices<T> {
22    items: Vec<T>,
23    weighted_index: WeightedIndex<f64>,
24}
25
26impl<T> Choices<T> {
27    /// Validate and normalize weights, then create Choices struct
28    /// # Panics
29    /// This will panic if the number of weights and number of items are not the same.
30    #[inline]
31    #[allow(clippy::float_arithmetic)]
32    #[must_use]
33    pub fn new(items: Vec<T>, weights: &[f64]) -> Self {
34        assert_eq!(
35            items.len(),
36            weights.len(),
37            "Number of items needs to equal number of weights"
38        );
39        assert!(
40            weights.iter().all(|&w| w >= 0.0f64),
41            "All weights must be positive numbers since they represent probabilities."
42        );
43
44        let sum_weights: f64 = weights.iter().sum();
45        assert!(
46            isclose(sum_weights, 1.0, EPSILON),
47            "Weights do not sum to 1 \u{b1} \u{3b5}" // 1 ± ε
48        );
49
50        let normalized_weights: Vec<f64> = weights.iter().map(|&w| w / sum_weights).collect();
51        let weighted_index = WeightedIndex::new(normalized_weights)
52            .expect("Failed to create WeightedIndex due to invalid weights");
53
54        Choices {
55            items,
56            weighted_index,
57        }
58    }
59
60    /// Sample a choice based on the weights
61    #[inline]
62    pub fn sample<R: RngCore>(&self, rng: &mut R) -> &T {
63        let index = self.weighted_index.sample(rng);
64        &self.items[index]
65    }
66}
67
68/// Determine if two floats are close to each other.
69#[inline]
70#[allow(clippy::single_call_fn, clippy::float_arithmetic)]
71fn isclose(a: f64, b: f64, epsilon: f64) -> bool {
72    (a - b).abs() <= epsilon
73}