gumbel_top_bucket/
lib.rs

1use rand::distributions::{Distribution, Uniform};
2
3/// A GumbelTopBucket is a bucket that can be used to draw from a discrete
4/// distribution, similar to a softmax. The difference is that the GumbelTopBucket
5/// uses a Gumbel distribution to add noise to the scores, and then draws from
6/// the noisy scores. This is useful for performant sampling, as it does not
7/// require the re-calculation of the softmax for each draw. The particular
8/// feature of this bucket is that *it will never draw the same index twice*,
9/// even if the scores are the same. This is useful for sampling without
10/// replacement. It is important to note that this comes at a memory cost,
11/// as we have to store a whole vector of noisy scores, on top of the original
12/// scores.
13#[derive(Debug, Clone)]
14pub struct GumbelTopBucket {
15    scores_len: usize,
16    noisy_scores: Vec<(usize, f64)>,
17}
18
19/// This trait is needed for the GumbelTopBucket to work with multiple score types.
20/// It is implemented for f32 and f64, but can be implemented for other types as well, as
21/// long as they have a way to add a f64 to themselves. It is reccomended to implement
22/// the `float_add` function using the `#[inline]` attribute, as it is called for each
23/// score in the bucket.
24pub trait F64Add {
25    fn float_add(self, other: f64) -> f64;
26}
27
28impl F64Add for f32 {
29    #[inline]
30    fn float_add(self, other: f64) -> f64 {
31        (self as f64) + other
32    }
33}
34
35impl F64Add for f64 {
36    #[inline]
37    fn float_add(self, other: f64) -> f64 {
38        self + other
39    }
40}
41
42impl GumbelTopBucket {
43    /// Create a new GumbelTopBucket from a slice of scores and a temperature. Typically,
44    /// scores should be in the range [0, 1], and the temperature should be > 0. It is
45    /// possible to use scores outside of this range, but the results may be unexpected;
46    /// the temperature can be utilized to adjust the range of the scores. A temperature
47    /// of 1.0 is recommended for most use cases.
48    pub fn new<T>(scores: &[T], temperature: f64) -> GumbelTopBucket
49    where
50        T: F64Add + Copy,
51    {
52        let scores_len = scores.len();
53        let noises = GumbelTopBucket::gumbel_noise(scores_len, temperature);
54        let mut noisy_scores: Vec<(usize, f64)> = scores
55            .iter()
56            .enumerate()
57            .map(|(i, &score)| (i, score.float_add(noises[i])))
58            .collect();
59        noisy_scores
60            .sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
61
62        GumbelTopBucket {
63            scores_len,
64            noisy_scores,
65        }
66    }
67
68    /// Generate a vector of Gumbel noise. This is used internally to generate the
69    /// noisy scores. It is exposed as a public function in case you want to use
70    /// the Gumbel noise for something else.
71    pub fn gumbel_noise(size: usize, temperature: f64) -> Vec<f64> {
72        let mut rng = rand::thread_rng();
73        let between = Uniform::from(1e-10f64..(1.0 - 1e-10f64));
74        let u: Vec<f64> = between.sample_iter(&mut rng).take(size).collect();
75        u.iter()
76            .map(|&x| -((-(x.ln())).ln()) * temperature)
77            .collect()
78    }
79
80    /// Draw a score from the bucket. This returns the index of the score in the original list,
81    /// as well as the *noisy* score. The score index will be removed from the list and never
82    /// sampled again. The method will return None if the bucket is empty.
83    pub fn draw_with_score(&mut self) -> Option<(usize, f64)> {
84        if self.scores_len == 0 {
85            return None;
86        }
87        let (idx_max, noisy_score) = self.noisy_scores.remove(0);
88        self.scores_len -= 1;
89        Some((idx_max, noisy_score))
90    }
91
92    /// Draws a score from the bucket. This returns the index of the score in the original list.
93    /// The score index will be removed from the list and never sampled again. The method will
94    /// return None if the bucket is empty.
95    pub fn draw(&mut self) -> Option<usize> {
96        let (idx_max, _) = self.draw_with_score()?;
97        Some(idx_max)
98    }
99}