Skip to main content

unit_intervals/
random.rs

1//! Random sampling support through [`rand_distr`].
2//!
3//! This module is available with the `rand_distr` crate feature. It implements
4//! [`Distribution`] for interval types where the sampled backing float is
5//! guaranteed to satisfy the interval invariant, and provides adapter
6//! distributions for arbitrary float distributions.
7//!
8//! [`StandardUniform`] samples [`UnitInterval`] values from `[0, 1)` and
9//! [`SignedUnitInterval`] values from `[-1, 1)`.
10//!
11//! ```
12//! use rand::{RngExt, SeedableRng, rngs::StdRng};
13//! use rand_distr::{Distribution, StandardUniform};
14//! use unit_intervals::{SignedUnitInterval, UnitInterval};
15//!
16//! let mut rng = StdRng::seed_from_u64(42);
17//!
18//! let unit: UnitInterval<f32> = StandardUniform.sample(&mut rng);
19//! let signed: SignedUnitInterval<f64> = rng.random();
20//!
21//! assert!(UnitInterval::<f32>::contains(unit.get()));
22//! assert!(SignedUnitInterval::<f64>::contains(signed.get()));
23//! ```
24//!
25//! Bounded `rand_distr` float distributions whose support is inside `[0, 1]`
26//! can also sample [`UnitInterval`] values directly.
27//!
28//! ```
29//! use rand::{SeedableRng, rngs::StdRng};
30//! use rand_distr::{Beta, Distribution, Open01, OpenClosed01};
31//! use unit_intervals::UnitInterval;
32//!
33//! let mut rng = StdRng::seed_from_u64(42);
34//! let beta = Beta::new(2.0_f64, 5.0).unwrap();
35//!
36//! let open: UnitInterval<f32> = Open01.sample(&mut rng);
37//! let open_closed: UnitInterval<f64> = OpenClosed01.sample(&mut rng);
38//! let beta_value: UnitInterval<f64> = beta.sample(&mut rng);
39//!
40//! assert!(UnitInterval::<f32>::contains(open.get()));
41//! assert!(UnitInterval::<f64>::contains(open_closed.get()));
42//! assert!(UnitInterval::<f64>::contains(beta_value.get()));
43//! ```
44//!
45//! Arbitrary float distributions may produce values outside the target
46//! interval. Use checked adapters when out-of-range samples should become
47//! `None`, or saturating adapters when they should be clamped.
48//!
49//! ```
50//! use rand::{SeedableRng, rngs::StdRng};
51//! use rand_distr::{Distribution, Normal};
52//! use unit_intervals::{
53//!     UnitInterval,
54//!     random::{CheckedUnitIntervalDistribution, SaturatingUnitIntervalDistribution},
55//! };
56//!
57//! let normal = Normal::new(0.5_f32, 2.0).unwrap();
58//! let checked = CheckedUnitIntervalDistribution::new(normal);
59//! let saturating = SaturatingUnitIntervalDistribution::new(normal);
60//! let mut rng = StdRng::seed_from_u64(42);
61//!
62//! let checked_sample: Option<UnitInterval<f32>> = checked.sample(&mut rng);
63//! let saturating_sample: UnitInterval<f32> = saturating.sample(&mut rng);
64//!
65//! assert!(checked_sample.is_none_or(|value| UnitInterval::<f32>::contains(value.get())));
66//! assert!(UnitInterval::<f32>::contains(saturating_sample.get()));
67//! ```
68//!
69//! The same adapter pattern is available for [`SignedUnitInterval`].
70//!
71//! ```
72//! use rand::{SeedableRng, rngs::StdRng};
73//! use rand_distr::{Distribution, Normal};
74//! use unit_intervals::{
75//!     SignedUnitInterval,
76//!     random::{
77//!         CheckedSignedUnitIntervalDistribution, SaturatingSignedUnitIntervalDistribution,
78//!     },
79//! };
80//!
81//! let normal = Normal::new(0.0_f64, 2.0).unwrap();
82//! let checked = CheckedSignedUnitIntervalDistribution::new(normal);
83//! let saturating = SaturatingSignedUnitIntervalDistribution::new(normal);
84//! let mut rng = StdRng::seed_from_u64(42);
85//!
86//! let checked_sample: Option<SignedUnitInterval<f64>> = checked.sample(&mut rng);
87//! let saturating_sample: SignedUnitInterval<f64> = saturating.sample(&mut rng);
88//!
89//! assert!(checked_sample.is_none_or(|value| SignedUnitInterval::<f64>::contains(value.get())));
90//! assert!(SignedUnitInterval::<f64>::contains(saturating_sample.get()));
91//! ```
92
93use crate::{SignedUnitInterval, UnitInterval, UnitIntervalFloat};
94use ::rand::Rng;
95use ::rand_distr::{Beta, Distribution, Open01, OpenClosed01, StandardUniform};
96
97/// Adapts an arbitrary distribution over raw floats into checked
98/// [`UnitInterval`] samples.
99#[derive(Debug, Copy, Clone, Eq, PartialEq)]
100pub struct CheckedUnitIntervalDistribution<D> {
101    distribution: D,
102}
103
104impl<D> CheckedUnitIntervalDistribution<D> {
105    /// Creates a checked unit interval distribution adapter.
106    #[inline]
107    pub const fn new(distribution: D) -> Self {
108        Self { distribution }
109    }
110
111    /// Returns a shared reference to the wrapped distribution.
112    #[inline]
113    pub const fn as_inner(&self) -> &D {
114        &self.distribution
115    }
116
117    /// Consumes the adapter and returns the wrapped distribution.
118    #[inline]
119    pub fn into_inner(self) -> D {
120        self.distribution
121    }
122}
123
124/// Adapts an arbitrary distribution over raw floats into saturating
125/// [`UnitInterval`] samples.
126#[derive(Debug, Copy, Clone, Eq, PartialEq)]
127pub struct SaturatingUnitIntervalDistribution<D> {
128    distribution: D,
129}
130
131impl<D> SaturatingUnitIntervalDistribution<D> {
132    /// Creates a saturating unit interval distribution adapter.
133    #[inline]
134    pub const fn new(distribution: D) -> Self {
135        Self { distribution }
136    }
137
138    /// Returns a shared reference to the wrapped distribution.
139    #[inline]
140    pub const fn as_inner(&self) -> &D {
141        &self.distribution
142    }
143
144    /// Consumes the adapter and returns the wrapped distribution.
145    #[inline]
146    pub fn into_inner(self) -> D {
147        self.distribution
148    }
149}
150
151/// Adapts an arbitrary distribution over raw floats into checked
152/// [`SignedUnitInterval`] samples.
153#[derive(Debug, Copy, Clone, Eq, PartialEq)]
154pub struct CheckedSignedUnitIntervalDistribution<D> {
155    distribution: D,
156}
157
158impl<D> CheckedSignedUnitIntervalDistribution<D> {
159    /// Creates a checked signed unit interval distribution adapter.
160    #[inline]
161    pub const fn new(distribution: D) -> Self {
162        Self { distribution }
163    }
164
165    /// Returns a shared reference to the wrapped distribution.
166    #[inline]
167    pub const fn as_inner(&self) -> &D {
168        &self.distribution
169    }
170
171    /// Consumes the adapter and returns the wrapped distribution.
172    #[inline]
173    pub fn into_inner(self) -> D {
174        self.distribution
175    }
176}
177
178/// Adapts an arbitrary distribution over raw floats into saturating
179/// [`SignedUnitInterval`] samples.
180#[derive(Debug, Copy, Clone, Eq, PartialEq)]
181pub struct SaturatingSignedUnitIntervalDistribution<D> {
182    distribution: D,
183}
184
185impl<D> SaturatingSignedUnitIntervalDistribution<D> {
186    /// Creates a saturating signed unit interval distribution adapter.
187    #[inline]
188    pub const fn new(distribution: D) -> Self {
189        Self { distribution }
190    }
191
192    /// Returns a shared reference to the wrapped distribution.
193    #[inline]
194    pub const fn as_inner(&self) -> &D {
195        &self.distribution
196    }
197
198    /// Consumes the adapter and returns the wrapped distribution.
199    #[inline]
200    pub fn into_inner(self) -> D {
201        self.distribution
202    }
203}
204
205impl<T, D> Distribution<Option<UnitInterval<T>>> for CheckedUnitIntervalDistribution<D>
206where
207    T: UnitIntervalFloat,
208    D: Distribution<T>,
209{
210    #[inline]
211    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Option<UnitInterval<T>> {
212        UnitInterval::new(self.distribution.sample(rng))
213    }
214}
215
216impl<T, D> Distribution<UnitInterval<T>> for SaturatingUnitIntervalDistribution<D>
217where
218    T: UnitIntervalFloat,
219    D: Distribution<T>,
220{
221    #[inline]
222    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> UnitInterval<T> {
223        UnitInterval::saturating(self.distribution.sample(rng))
224    }
225}
226
227impl<T, D> Distribution<Option<SignedUnitInterval<T>>> for CheckedSignedUnitIntervalDistribution<D>
228where
229    T: UnitIntervalFloat,
230    D: Distribution<T>,
231{
232    #[inline]
233    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Option<SignedUnitInterval<T>> {
234        SignedUnitInterval::new(self.distribution.sample(rng))
235    }
236}
237
238impl<T, D> Distribution<SignedUnitInterval<T>> for SaturatingSignedUnitIntervalDistribution<D>
239where
240    T: UnitIntervalFloat,
241    D: Distribution<T>,
242{
243    #[inline]
244    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> SignedUnitInterval<T> {
245        SignedUnitInterval::saturating(self.distribution.sample(rng))
246    }
247}
248
249impl<T> Distribution<UnitInterval<T>> for StandardUniform
250where
251    T: UnitIntervalFloat,
252    StandardUniform: Distribution<T>,
253{
254    #[inline]
255    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> UnitInterval<T> {
256        UnitInterval::from_inner(<Self as Distribution<T>>::sample(self, rng))
257    }
258}
259
260impl<T> Distribution<SignedUnitInterval<T>> for StandardUniform
261where
262    T: UnitIntervalFloat,
263    StandardUniform: Distribution<T>,
264{
265    #[inline]
266    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> SignedUnitInterval<T> {
267        let value = <Self as Distribution<T>>::sample(self, rng) * (T::ONE + T::ONE) - T::ONE;
268
269        SignedUnitInterval::from_inner(value)
270    }
271}
272
273macro_rules! impl_unit_interval_distribution {
274    ($distribution:ty, $float:ty) => {
275        impl Distribution<UnitInterval<$float>> for $distribution {
276            #[inline]
277            fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> UnitInterval<$float> {
278                UnitInterval::from_inner(<Self as Distribution<$float>>::sample(self, rng))
279            }
280        }
281    };
282}
283
284impl_unit_interval_distribution!(Open01, f32);
285impl_unit_interval_distribution!(Open01, f64);
286impl_unit_interval_distribution!(OpenClosed01, f32);
287impl_unit_interval_distribution!(OpenClosed01, f64);
288impl_unit_interval_distribution!(Beta<f32>, f32);
289impl_unit_interval_distribution!(Beta<f64>, f64);