1use rand::{self, Rng, SeedableRng};
2use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
3use std::{convert::TryFrom, fmt, iter, ops, vec};
4
5#[derive(Clone, Debug, Serialize, Deserialize)]
7pub struct SampleTotal {
8 total: usize,
9 samples: usize,
10 picked: usize,
11 i: usize,
12}
13impl SampleTotal {
14 pub fn new(total: usize, samples: usize) -> Self {
16 assert!(total >= samples);
17 Self {
18 total,
19 samples,
20 picked: 0,
21 i: 0,
22 }
23 }
24
25 pub fn sample<R: Rng>(&mut self, rng: &mut R) -> bool {
27 let sample = rng.gen_range(0, self.total - self.i) < (self.samples - self.picked);
28 self.i += 1;
29 if sample {
30 self.picked += 1;
31 }
32 sample
33 }
34}
35impl Drop for SampleTotal {
36 fn drop(&mut self) {
37 assert_eq!(self.picked, self.samples);
38 }
39}
40
41#[derive(Clone)]
42struct FixedCapVec<T>(Vec<T>);
43impl<T> FixedCapVec<T> {
44 fn new(cap: usize) -> Self {
45 let self_ = Self(Vec::with_capacity(cap));
46 assert_eq!(self_.capacity(), cap);
47 self_
48 }
49 fn len(&self) -> usize {
50 self.0.len()
51 }
52 fn capacity(&self) -> usize {
53 self.0.capacity()
54 }
55 fn push(&mut self, t: T) {
56 assert!(self.len() < self.capacity());
57 let cap = self.capacity();
58 self.0.push(t);
59 assert_eq!(self.capacity(), cap);
60 }
61 fn pop(&mut self) -> Option<T> {
62 let cap = self.capacity();
63 let ret = self.0.pop();
64 assert_eq!(self.capacity(), cap);
65 ret
66 }
67 fn into_iter(self) -> std::vec::IntoIter<T> {
68 self.0.into_iter()
69 }
70}
71impl<T, Idx> std::ops::Index<Idx> for FixedCapVec<T>
72where
73 Idx: std::slice::SliceIndex<[T]>,
74{
75 type Output = <Vec<T> as std::ops::Index<Idx>>::Output;
76 fn index(&self, index: Idx) -> &Self::Output {
77 std::ops::Index::index(&self.0, index)
78 }
79}
80impl<T, Idx> std::ops::IndexMut<Idx> for FixedCapVec<T>
81where
82 Idx: std::slice::SliceIndex<[T]>,
83{
84 fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
85 std::ops::IndexMut::index_mut(&mut self.0, index)
86 }
87}
88impl<T> fmt::Debug for FixedCapVec<T>
89where
90 T: fmt::Debug,
91{
92 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
93 self.0.fmt(f)
94 }
95}
96impl<T> Serialize for FixedCapVec<T>
97where
98 T: Serialize,
99{
100 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
101 where
102 S: Serializer,
103 {
104 <(usize, &Vec<T>)>::serialize(&(self.0.capacity(), &self.0), serializer)
105 }
106}
107impl<'de, T> Deserialize<'de> for FixedCapVec<T>
108where
109 T: Deserialize<'de>,
110{
111 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
112 where
113 D: Deserializer<'de>,
114 {
115 <(usize, Vec<T>)>::deserialize(deserializer).map(|(cap, mut vec)| {
116 vec.reserve_exact(cap - vec.len());
117 assert_eq!(vec.capacity(), cap);
118 Self(vec)
119 })
120 }
121}
122
123#[derive(Clone, Debug, Serialize, Deserialize)]
125pub struct SampleUnstable<T> {
126 reservoir: FixedCapVec<T>,
127 i: usize,
128}
129impl<T> SampleUnstable<T> {
130 pub fn new(samples: usize) -> Self {
132 Self {
133 reservoir: FixedCapVec::new(samples),
134 i: 0,
135 }
136 }
137
138 pub fn push<R: Rng>(&mut self, t: T, rng: &mut R) {
140 if self.reservoir.len() < self.reservoir.capacity() {
142 self.reservoir.push(t);
143 } else {
144 let idx = rng.gen_range(0, self.i);
145 if idx < self.reservoir.capacity() {
146 self.reservoir[idx] = t;
147 }
148 }
149 self.i += 1;
150 }
151}
152impl<T> IntoIterator for SampleUnstable<T> {
153 type Item = T;
154 type IntoIter = vec::IntoIter<T>;
155
156 fn into_iter(self) -> vec::IntoIter<T> {
157 self.reservoir.into_iter()
158 }
159}
160impl<T> iter::Sum for SampleUnstable<T> {
161 fn sum<I>(iter: I) -> Self
162 where
163 I: Iterator<Item = Self>,
164 {
165 let mut total = Self::new(0); for sample in iter {
167 total += sample;
168 }
169 total
170 }
171}
172impl<T> ops::Add for SampleUnstable<T> {
173 type Output = Self;
174
175 fn add(mut self, other: Self) -> Self {
176 self += other;
177 self
178 }
179}
180impl<T> ops::AddAssign for SampleUnstable<T> {
181 fn add_assign(&mut self, mut other: Self) {
182 if self.reservoir.capacity() > 0 {
183 assert_eq!(self.reservoir.capacity(), other.reservoir.capacity());
185 let mut new = FixedCapVec::new(self.reservoir.capacity());
186 let (m, n) = (self.i, other.i);
187 let mut rng = rand::rngs::SmallRng::from_seed([
188 u8::try_from(m & 0xff).unwrap(),
189 u8::try_from(n & 0xff).unwrap(),
190 u8::try_from(self.reservoir.capacity() & 0xff).unwrap(),
191 3,
192 4,
193 5,
194 6,
195 7,
196 8,
197 9,
198 10,
199 11,
200 12,
201 13,
202 14,
203 15,
204 ]); for _ in 0..new.capacity() {
206 if rng.gen_range(0, m + n) < m {
207 new.push(self.reservoir.pop().unwrap());
208 } else {
209 new.push(other.reservoir.pop().unwrap());
210 }
211 }
212 self.reservoir = new;
213 self.i += other.i;
214 } else {
215 *self = other;
216 }
217 }
218}
219
220#[cfg(test)]
221mod test {
222 use super::*;
223 use std::collections::HashMap;
224
225 #[test]
226 fn sample_without_replacement() {
227 let total = 6;
228 let samples = 2;
229
230 let mut hash = HashMap::new();
231 for _ in 0..1_000_000 {
232 let mut res = Vec::with_capacity(samples);
233 let mut x = SampleTotal::new(total, samples);
234 for i in 0..total {
235 if x.sample(&mut rand::thread_rng()) {
236 res.push(i);
237 }
238 }
239 *hash.entry(res).or_insert(0) += 1;
240 }
241 println!("{:#?}", hash);
242 }
243
244 #[test]
245 fn sample_unstable() {
246 let total = 6;
247 let samples = 2;
248
249 let mut hash = HashMap::new();
250 for _ in 0..1_000_000 {
251 let mut x = SampleUnstable::new(samples);
252 for i in 0..total {
253 x.push(i, &mut rand::thread_rng());
254 }
255 *hash.entry(x.into_iter().collect::<Vec<_>>()).or_insert(0) += 1;
256 }
257 println!("{:#?}", hash);
258 }
259}