rand/distr/distribution.rs
1// Copyright 2018 Developers of the Rand project.
2// Copyright 2013-2017 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! Distribution trait and associates
11
12use crate::Rng;
13#[cfg(feature = "alloc")]
14use alloc::string::String;
15use core::iter;
16
17#[cfg(doc)]
18use crate::RngExt;
19
20/// Types (distributions) that can be used to create a random instance of `T`.
21///
22/// It is possible to sample from a distribution through both the
23/// `Distribution` and [`RngExt`] traits, via `distr.sample(&mut rng)` and
24/// `rng.sample(distr)`. They also both offer the [`sample_iter`] method, which
25/// produces an iterator that samples from the distribution.
26///
27/// All implementations are expected to be immutable; this has the significant
28/// advantage of not needing to consider thread safety, and for most
29/// distributions efficient state-less sampling algorithms are available.
30///
31/// Implementations are typically expected to be portable with reproducible
32/// results when used with a PRNG with fixed seed; see the
33/// [portability chapter](https://rust-random.github.io/book/portability.html)
34/// of The Rust Rand Book. In some cases this does not apply, e.g. the `usize`
35/// type requires different sampling on 32-bit and 64-bit machines.
36///
37/// [`sample_iter`]: Distribution::sample_iter
38pub trait Distribution<T> {
39 /// Generate a random value of `T`, using `rng` as the source of randomness.
40 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T;
41
42 /// Create an iterator that generates random values of `T`, using `rng` as
43 /// the source of randomness.
44 ///
45 /// Note that this function takes `self` by value. This works since
46 /// `Distribution<T>` is impl'd for `&D` where `D: Distribution<T>`,
47 /// however borrowing is not automatic hence `distr.sample_iter(...)` may
48 /// need to be replaced with `(&distr).sample_iter(...)` to borrow or
49 /// `(&*distr).sample_iter(...)` to reborrow an existing reference.
50 ///
51 /// # Example
52 ///
53 /// ```
54 /// use rand::distr::{Distribution, Alphanumeric, Uniform, StandardUniform};
55 ///
56 /// let mut rng = rand::rng();
57 ///
58 /// // Vec of 16 x f32:
59 /// let v: Vec<f32> = StandardUniform.sample_iter(&mut rng).take(16).collect();
60 ///
61 /// // String:
62 /// let s: String = Alphanumeric
63 /// .sample_iter(&mut rng)
64 /// .take(7)
65 /// .map(char::from)
66 /// .collect();
67 ///
68 /// // Dice-rolling:
69 /// let die_range = Uniform::new_inclusive(1, 6).unwrap();
70 /// let mut roll_die = die_range.sample_iter(&mut rng);
71 /// while roll_die.next().unwrap() != 6 {
72 /// println!("Not a 6; rolling again!");
73 /// }
74 /// ```
75 fn sample_iter<R>(self, rng: R) -> Iter<Self, R, T>
76 where
77 R: Rng,
78 Self: Sized,
79 {
80 Iter {
81 distr: self,
82 rng,
83 phantom: core::marker::PhantomData,
84 }
85 }
86
87 /// Map sampled values to type `S`
88 ///
89 /// # Example
90 ///
91 /// ```
92 /// use rand::distr::{Distribution, Uniform};
93 ///
94 /// let die = Uniform::new_inclusive(1, 6).unwrap();
95 /// let even_number = die.map(|num| num % 2 == 0);
96 /// while !even_number.sample(&mut rand::rng()) {
97 /// println!("Still odd; rolling again!");
98 /// }
99 /// ```
100 fn map<F, S>(self, func: F) -> Map<Self, F, T, S>
101 where
102 F: Fn(T) -> S,
103 Self: Sized,
104 {
105 Map {
106 distr: self,
107 func,
108 phantom: core::marker::PhantomData,
109 }
110 }
111}
112
113impl<T, D: Distribution<T> + ?Sized> Distribution<T> for &D {
114 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
115 (*self).sample(rng)
116 }
117}
118
119/// An iterator over a [`Distribution`]
120///
121/// This iterator yields random values of type `T` with distribution `D`
122/// from a random generator of type `R`.
123///
124/// Construct this `struct` using [`Distribution::sample_iter`] or
125/// [`RngExt::sample_iter`]. It is also used by [`RngExt::random_iter`] and
126/// [`crate::random_iter`].
127#[derive(Debug)]
128pub struct Iter<D, R, T> {
129 distr: D,
130 rng: R,
131 phantom: core::marker::PhantomData<T>,
132}
133
134impl<D, R, T> Iterator for Iter<D, R, T>
135where
136 D: Distribution<T>,
137 R: Rng,
138{
139 type Item = T;
140
141 #[inline(always)]
142 fn next(&mut self) -> Option<T> {
143 // Here, self.rng may be a reference, but we must take &mut anyway.
144 // Even if sample could take an R: Rng by value, we would need to do this
145 // since Rng is not copyable and we cannot enforce that this is "reborrowable".
146 Some(self.distr.sample(&mut self.rng))
147 }
148
149 fn size_hint(&self) -> (usize, Option<usize>) {
150 (usize::MAX, None)
151 }
152}
153
154impl<D, R, T> iter::FusedIterator for Iter<D, R, T>
155where
156 D: Distribution<T>,
157 R: Rng,
158{
159}
160
161/// A [`Distribution`] which maps sampled values to type `S`
162///
163/// This `struct` is created by the [`Distribution::map`] method.
164/// See its documentation for more.
165#[derive(Debug)]
166pub struct Map<D, F, T, S> {
167 distr: D,
168 func: F,
169 phantom: core::marker::PhantomData<fn(T) -> S>,
170}
171
172impl<D, F, T, S> Distribution<S> for Map<D, F, T, S>
173where
174 D: Distribution<T>,
175 F: Fn(T) -> S,
176{
177 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> S {
178 (self.func)(self.distr.sample(rng))
179 }
180}
181
182/// Sample or extend a [`String`]
183///
184/// Helper methods to extend a [`String`] or sample a new [`String`].
185#[cfg(feature = "alloc")]
186pub trait SampleString {
187 /// Append `len` random chars to `string`
188 ///
189 /// Note: implementations may leave `string` with excess capacity. If this
190 /// is undesirable, consider calling [`String::shrink_to_fit`] after this
191 /// method.
192 fn append_string<R: Rng + ?Sized>(&self, rng: &mut R, string: &mut String, len: usize);
193
194 /// Generate a [`String`] of `len` random chars
195 ///
196 /// Note: implementations may leave the string with excess capacity. If this
197 /// is undesirable, consider calling [`String::shrink_to_fit`] after this
198 /// method.
199 #[inline]
200 fn sample_string<R: Rng + ?Sized>(&self, rng: &mut R, len: usize) -> String {
201 let mut s = String::new();
202 self.append_string(rng, &mut s, len);
203 s
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use crate::Rng;
210 use crate::distr::{Distribution, Uniform};
211
212 #[test]
213 fn test_distributions_iter() {
214 use crate::distr::Open01;
215 let mut rng = crate::test::rng(210);
216 let distr = Open01;
217 let mut iter = Distribution::<f32>::sample_iter(distr, &mut rng);
218 let mut sum: f32 = 0.;
219 for _ in 0..100 {
220 sum += iter.next().unwrap();
221 }
222 assert!(0. < sum && sum < 100.);
223 }
224
225 #[test]
226 fn test_distributions_map() {
227 let dist = Uniform::new_inclusive(0, 5).unwrap().map(|val| val + 15);
228
229 let mut rng = crate::test::rng(212);
230 let val = dist.sample(&mut rng);
231 assert!((15..=20).contains(&val));
232 }
233
234 #[test]
235 fn test_make_an_iter() {
236 fn ten_dice_rolls_other_than_five<R: Rng>(rng: &mut R) -> impl Iterator<Item = i32> + '_ {
237 Uniform::new_inclusive(1, 6)
238 .unwrap()
239 .sample_iter(rng)
240 .filter(|x| *x != 5)
241 .take(10)
242 }
243
244 let mut rng = crate::test::rng(211);
245 let mut count = 0;
246 for val in ten_dice_rolls_other_than_five(&mut rng) {
247 assert!((1..=6).contains(&val) && val != 5);
248 count += 1;
249 }
250 assert_eq!(count, 10);
251 }
252
253 #[test]
254 #[cfg(feature = "alloc")]
255 fn test_dist_string() {
256 use crate::distr::{Alphabetic, Alphanumeric, SampleString, StandardUniform};
257 use core::str;
258 let mut rng = crate::test::rng(213);
259
260 let s1 = Alphanumeric.sample_string(&mut rng, 20);
261 assert_eq!(s1.len(), 20);
262 assert_eq!(str::from_utf8(s1.as_bytes()), Ok(s1.as_str()));
263
264 let s2 = StandardUniform.sample_string(&mut rng, 20);
265 assert_eq!(s2.chars().count(), 20);
266 assert_eq!(str::from_utf8(s2.as_bytes()), Ok(s2.as_str()));
267
268 let s3 = Alphabetic.sample_string(&mut rng, 20);
269 assert_eq!(s3.len(), 20);
270 assert_eq!(str::from_utf8(s3.as_bytes()), Ok(s3.as_str()));
271 }
272}