streaming_algorithms/
sample.rs

1use rand::{self, Rng, SeedableRng};
2use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
3use std::{convert::TryFrom, fmt, iter, ops, vec};
4
5/// Given population and sample sizes, returns true if this element is in the sample. Without replacement.
6#[derive(Clone, Debug, Serialize, Deserialize)]
7pub struct SampleTotal {
8	total: usize,
9	samples: usize,
10	picked: usize,
11	i: usize,
12}
13impl SampleTotal {
14	/// Create a `SampleTotal` that will provide a sample of size `samples` of a population of size `total`.
15	pub fn new(total: usize, samples: usize) -> Self {
16		assert!(total >= samples);
17		Self {
18			total,
19			samples,
20			picked: 0,
21			i: 0,
22		}
23	}
24
25	/// Returns whether or not to this value is in the sample
26	pub fn sample<R: Rng>(&mut self, rng: &mut R) -> bool {
27		let sample = rng.gen_range(0, self.total - self.i) < (self.samples - self.picked);
28		self.i += 1;
29		if sample {
30			self.picked += 1;
31		}
32		sample
33	}
34}
35impl Drop for SampleTotal {
36	fn drop(&mut self) {
37		assert_eq!(self.picked, self.samples);
38	}
39}
40
41#[derive(Clone)]
42struct FixedCapVec<T>(Vec<T>);
43impl<T> FixedCapVec<T> {
44	fn new(cap: usize) -> Self {
45		let self_ = Self(Vec::with_capacity(cap));
46		assert_eq!(self_.capacity(), cap);
47		self_
48	}
49	fn len(&self) -> usize {
50		self.0.len()
51	}
52	fn capacity(&self) -> usize {
53		self.0.capacity()
54	}
55	fn push(&mut self, t: T) {
56		assert!(self.len() < self.capacity());
57		let cap = self.capacity();
58		self.0.push(t);
59		assert_eq!(self.capacity(), cap);
60	}
61	fn pop(&mut self) -> Option<T> {
62		let cap = self.capacity();
63		let ret = self.0.pop();
64		assert_eq!(self.capacity(), cap);
65		ret
66	}
67	fn into_iter(self) -> std::vec::IntoIter<T> {
68		self.0.into_iter()
69	}
70}
71impl<T, Idx> std::ops::Index<Idx> for FixedCapVec<T>
72where
73	Idx: std::slice::SliceIndex<[T]>,
74{
75	type Output = <Vec<T> as std::ops::Index<Idx>>::Output;
76	fn index(&self, index: Idx) -> &Self::Output {
77		std::ops::Index::index(&self.0, index)
78	}
79}
80impl<T, Idx> std::ops::IndexMut<Idx> for FixedCapVec<T>
81where
82	Idx: std::slice::SliceIndex<[T]>,
83{
84	fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
85		std::ops::IndexMut::index_mut(&mut self.0, index)
86	}
87}
88impl<T> fmt::Debug for FixedCapVec<T>
89where
90	T: fmt::Debug,
91{
92	fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
93		self.0.fmt(f)
94	}
95}
96impl<T> Serialize for FixedCapVec<T>
97where
98	T: Serialize,
99{
100	fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
101	where
102		S: Serializer,
103	{
104		<(usize, &Vec<T>)>::serialize(&(self.0.capacity(), &self.0), serializer)
105	}
106}
107impl<'de, T> Deserialize<'de> for FixedCapVec<T>
108where
109	T: Deserialize<'de>,
110{
111	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
112	where
113		D: Deserializer<'de>,
114	{
115		<(usize, Vec<T>)>::deserialize(deserializer).map(|(cap, mut vec)| {
116			vec.reserve_exact(cap - vec.len());
117			assert_eq!(vec.capacity(), cap);
118			Self(vec)
119		})
120	}
121}
122
123/// [Reservoir sampling](https://en.wikipedia.org/wiki/Reservoir_sampling). Without replacement, and the returned order is unstable.
124#[derive(Clone, Debug, Serialize, Deserialize)]
125pub struct SampleUnstable<T> {
126	reservoir: FixedCapVec<T>,
127	i: usize,
128}
129impl<T> SampleUnstable<T> {
130	/// Create a `SampleUnstable` that will provide a sample of size `samples`.
131	pub fn new(samples: usize) -> Self {
132		Self {
133			reservoir: FixedCapVec::new(samples),
134			i: 0,
135		}
136	}
137
138	/// "Visit" this element
139	pub fn push<R: Rng>(&mut self, t: T, rng: &mut R) {
140		// TODO: https://dl.acm.org/citation.cfm?id=198435
141		if self.reservoir.len() < self.reservoir.capacity() {
142			self.reservoir.push(t);
143		} else {
144			let idx = rng.gen_range(0, self.i);
145			if idx < self.reservoir.capacity() {
146				self.reservoir[idx] = t;
147			}
148		}
149		self.i += 1;
150	}
151}
152impl<T> IntoIterator for SampleUnstable<T> {
153	type Item = T;
154	type IntoIter = vec::IntoIter<T>;
155
156	fn into_iter(self) -> vec::IntoIter<T> {
157		self.reservoir.into_iter()
158	}
159}
160impl<T> iter::Sum for SampleUnstable<T> {
161	fn sum<I>(iter: I) -> Self
162	where
163		I: Iterator<Item = Self>,
164	{
165		let mut total = Self::new(0); // TODO
166		for sample in iter {
167			total += sample;
168		}
169		total
170	}
171}
172impl<T> ops::Add for SampleUnstable<T> {
173	type Output = Self;
174
175	fn add(mut self, other: Self) -> Self {
176		self += other;
177		self
178	}
179}
180impl<T> ops::AddAssign for SampleUnstable<T> {
181	fn add_assign(&mut self, mut other: Self) {
182		if self.reservoir.capacity() > 0 {
183			// TODO
184			assert_eq!(self.reservoir.capacity(), other.reservoir.capacity());
185			let mut new = FixedCapVec::new(self.reservoir.capacity());
186			let (m, n) = (self.i, other.i);
187			let mut rng = rand::rngs::SmallRng::from_seed([
188				u8::try_from(m & 0xff).unwrap(),
189				u8::try_from(n & 0xff).unwrap(),
190				u8::try_from(self.reservoir.capacity() & 0xff).unwrap(),
191				3,
192				4,
193				5,
194				6,
195				7,
196				8,
197				9,
198				10,
199				11,
200				12,
201				13,
202				14,
203				15,
204			]); // TODO
205			for _ in 0..new.capacity() {
206				if rng.gen_range(0, m + n) < m {
207					new.push(self.reservoir.pop().unwrap());
208				} else {
209					new.push(other.reservoir.pop().unwrap());
210				}
211			}
212			self.reservoir = new;
213			self.i += other.i;
214		} else {
215			*self = other;
216		}
217	}
218}
219
220#[cfg(test)]
221mod test {
222	use super::*;
223	use std::collections::HashMap;
224
225	#[test]
226	fn sample_without_replacement() {
227		let total = 6;
228		let samples = 2;
229
230		let mut hash = HashMap::new();
231		for _ in 0..1_000_000 {
232			let mut res = Vec::with_capacity(samples);
233			let mut x = SampleTotal::new(total, samples);
234			for i in 0..total {
235				if x.sample(&mut rand::thread_rng()) {
236					res.push(i);
237				}
238			}
239			*hash.entry(res).or_insert(0) += 1;
240		}
241		println!("{:#?}", hash);
242	}
243
244	#[test]
245	fn sample_unstable() {
246		let total = 6;
247		let samples = 2;
248
249		let mut hash = HashMap::new();
250		for _ in 0..1_000_000 {
251			let mut x = SampleUnstable::new(samples);
252			for i in 0..total {
253				x.push(i, &mut rand::thread_rng());
254			}
255			*hash.entry(x.into_iter().collect::<Vec<_>>()).or_insert(0) += 1;
256		}
257		println!("{:#?}", hash);
258	}
259}