Skip to main content

diskann_quantization/
random.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::hash::{Hash, Hasher};
7
8use rand::{Rng, RngCore, SeedableRng, rngs::StdRng};
9
10/// Creation of random number generator in potentially parallelized applications.
11pub trait RngBuilder<T> {
12    type Rng: Rng + 'static;
13
14    // Construct an Rng with the provided value for mixing randomness.
15    fn build_rng(&self, mixin: T) -> Self::Rng;
16}
17
18/// A `RngBuilder` that returns a `rand::rngs::StdRng` and uses the default hasher to seed
19/// the `mixin` value.
20pub struct StdRngBuilder {
21    hasher: std::hash::DefaultHasher,
22}
23
24impl StdRngBuilder {
25    /// Construct a new `StdRngBuilder` using the given seed.
26    pub fn new(seed: u64) -> Self {
27        let mut hasher = std::hash::DefaultHasher::new();
28        seed.hash(&mut hasher);
29        Self { hasher }
30    }
31}
32
33impl<T> RngBuilder<T> for StdRngBuilder
34where
35    T: std::hash::Hash,
36{
37    type Rng = StdRng;
38
39    fn build_rng(&self, mixin: T) -> Self::Rng {
40        let mut hasher = self.hasher.clone();
41        mixin.hash(&mut hasher);
42        StdRng::seed_from_u64(hasher.finish())
43    }
44}
45
46/// An object-safe version of `RngBuilder`.
47pub trait BoxedRngBuilder<T> {
48    fn build_boxed_rng(&self, mixin: T) -> Box<dyn RngCore>;
49}
50
51impl<T, M> BoxedRngBuilder<M> for T
52where
53    T: RngBuilder<M>,
54{
55    fn build_boxed_rng(&self, mixin: M) -> Box<dyn RngCore> {
56        Box::new(self.build_rng(mixin))
57    }
58}
59
60///////////
61// Tests //
62///////////
63
64#[cfg(test)]
65mod tests {
66    use rand::distr::{Distribution, StandardUniform};
67
68    use super::*;
69
70    fn test_builder<T>(builder: T, last: Option<[u32; 2]>) -> [u32; 2]
71    where
72        T: RngBuilder<u64>,
73    {
74        let standard = StandardUniform {};
75
76        let seed_0: u64 = 0xd60f0bc189624369;
77        let seed_1: u64 = 0x7478ac104ed40abb;
78
79        // Make sure that the random number generator returned is the same when the seed
80        // is the same.
81        let mut rng0 = builder.build_rng(seed_0);
82        let mut rng1 = builder.build_rng(seed_0);
83
84        let v0: u32 = standard.sample(&mut rng0);
85        let v1: u32 = standard.sample(&mut rng1);
86        assert_eq!(v0, v1);
87
88        // Changing the seed should change the random number generator.
89        let mut rng1 = builder.build_rng(seed_1);
90        let v1: u32 = standard.sample(&mut rng1);
91        assert_ne!(v0, v1);
92
93        let v = [v0, v1];
94        if let Some(last) = last {
95            assert_ne!(v, last);
96        }
97        v
98    }
99
100    #[test]
101    fn test_stdrng_builder() {
102        let builder = StdRngBuilder::new(0x376226f7d2d5a16b);
103        let v = test_builder(builder, None);
104
105        // If we change the seed for the builder - the returned values should be different.
106        let builder = StdRngBuilder::new(0x1f197993987ed14f);
107        let _ = test_builder(builder, Some(v));
108    }
109
110    // Make sure the test actually panics if the results are the same.
111    #[test]
112    #[should_panic]
113    fn test_stdrng_builder_test_panics() {
114        let builder = StdRngBuilder::new(0x49a85a468d6865e6);
115        let v = test_builder(builder, None);
116
117        let builder = StdRngBuilder::new(0x49a85a468d6865e6);
118        // Panics
119        let _ = test_builder(builder, Some(v));
120    }
121}