diskann_quantization/
random.rs1use std::hash::{Hash, Hasher};
7
8use rand::{rngs::StdRng, Rng, RngCore, SeedableRng};
9
10pub trait RngBuilder<T> {
12 type Rng: Rng + 'static;
13
14 fn build_rng(&self, mixin: T) -> Self::Rng;
16}
17
18pub struct StdRngBuilder {
21 hasher: std::hash::DefaultHasher,
22}
23
24impl StdRngBuilder {
25 pub fn new(seed: u64) -> Self {
27 let mut hasher = std::hash::DefaultHasher::new();
28 seed.hash(&mut hasher);
29 Self { hasher }
30 }
31}
32
33impl<T> RngBuilder<T> for StdRngBuilder
34where
35 T: std::hash::Hash,
36{
37 type Rng = StdRng;
38
39 fn build_rng(&self, mixin: T) -> Self::Rng {
40 let mut hasher = self.hasher.clone();
41 mixin.hash(&mut hasher);
42 StdRng::seed_from_u64(hasher.finish())
43 }
44}
45
46pub trait BoxedRngBuilder<T> {
48 fn build_boxed_rng(&self, mixin: T) -> Box<dyn RngCore>;
49}
50
51impl<T, M> BoxedRngBuilder<M> for T
52where
53 T: RngBuilder<M>,
54{
55 fn build_boxed_rng(&self, mixin: M) -> Box<dyn RngCore> {
56 Box::new(self.build_rng(mixin))
57 }
58}
59
60#[cfg(test)]
65mod tests {
66 use rand::distr::{Distribution, StandardUniform};
67
68 use super::*;
69
70 fn test_builder<T>(builder: T, last: Option<[u32; 2]>) -> [u32; 2]
71 where
72 T: RngBuilder<u64>,
73 {
74 let standard = StandardUniform {};
75
76 let seed_0: u64 = 0xd60f0bc189624369;
77 let seed_1: u64 = 0x7478ac104ed40abb;
78
79 let mut rng0 = builder.build_rng(seed_0);
82 let mut rng1 = builder.build_rng(seed_0);
83
84 let v0: u32 = standard.sample(&mut rng0);
85 let v1: u32 = standard.sample(&mut rng1);
86 assert_eq!(v0, v1);
87
88 let mut rng1 = builder.build_rng(seed_1);
90 let v1: u32 = standard.sample(&mut rng1);
91 assert_ne!(v0, v1);
92
93 let v = [v0, v1];
94 if let Some(last) = last {
95 assert_ne!(v, last);
96 }
97 v
98 }
99
100 #[test]
101 fn test_stdrng_builder() {
102 let builder = StdRngBuilder::new(0x376226f7d2d5a16b);
103 let v = test_builder(builder, None);
104
105 let builder = StdRngBuilder::new(0x1f197993987ed14f);
107 let _ = test_builder(builder, Some(v));
108 }
109
110 #[test]
112 #[should_panic]
113 fn test_stdrng_builder_test_panics() {
114 let builder = StdRngBuilder::new(0x49a85a468d6865e6);
115 let v = test_builder(builder, None);
116
117 let builder = StdRngBuilder::new(0x49a85a468d6865e6);
118 let _ = test_builder(builder, Some(v));
120 }
121}