Skip to main content

burn_dataset/transform/
options.rs

1use rand::prelude::StdRng;
2use rand::{RngCore, SeedableRng};
3
4/// Defines a source for a `StdRng`.
5///
6/// # Examples
7///
8/// ```rust,no_run
9/// use rand::rngs::StdRng;
10/// use rand::SeedableRng;
11/// use burn_dataset::transform::RngSource;
12///
13/// // Default via `StdRng::from_os_rng()` (`RngSource::Default`)
14/// let system: RngSource = RngSource::default();
15///
16/// // From a fixed seed (`RngSource::Seed`)
17/// let seeded: RngSource = 42.into();
18///
19/// // From an existing rng (`RngSource::Rng`)
20/// let rng = StdRng::seed_from_u64(123);
21/// let with_rng: RngSource = rng.into();
22///
23/// // From a snapshot of the current state (`RngSource::Rng`)
24/// let rng = StdRng::seed_from_u64(123);
25/// let snapshot: RngSource = (&rng).into();
26///
27/// // Advances the original RNG and then clones its new state
28/// let mut rng = StdRng::seed_from_u64(123);
29/// let stateful: RngSource = (&mut rng).into();
30/// ```
31#[derive(Debug, Clone, Default, PartialEq, Eq)]
32#[allow(clippy::large_enum_variant)]
33pub enum RngSource {
34    /// Build a new rng from the system.
35    #[default]
36    Default,
37
38    /// The rng is passed as a seed.
39    Seed(u64),
40
41    /// The rng is passed as an option.
42    Rng(StdRng),
43}
44
45impl From<RngSource> for StdRng {
46    fn from(source: RngSource) -> Self {
47        match &source {
48            RngSource::Default => StdRng::from_os_rng(),
49            RngSource::Rng(rng) => rng.clone(),
50            RngSource::Seed(seed) => StdRng::seed_from_u64(*seed),
51        }
52    }
53}
54
55impl From<u64> for RngSource {
56    fn from(seed: u64) -> Self {
57        Self::Seed(seed)
58    }
59}
60
61impl From<StdRng> for RngSource {
62    fn from(rng: StdRng) -> Self {
63        Self::Rng(rng)
64    }
65}
66
67impl From<&StdRng> for RngSource {
68    fn from(rng: &StdRng) -> Self {
69        Self::Rng(rng.clone())
70    }
71}
72
73/// Users calling with a mutable rng expect state advancement,
74/// So conversion from `&mut StdRng` advances the rng before cloning.
75impl From<&mut StdRng> for RngSource {
76    fn from(rng: &mut StdRng) -> Self {
77        rng.next_u64();
78        Self::Rng(rng.clone())
79    }
80}
81
82/// Helper option to describe the size of a wrapper, relative to a wrapped object.
83#[derive(Debug, Clone, Copy, Default, PartialEq)]
84pub enum SizeConfig {
85    /// Use the size of the source dataset.
86    #[default]
87    Default,
88
89    /// Use the size as a ratio of the source dataset size.
90    ///
91    /// Must be >= 0.
92    Ratio(f64),
93
94    /// Use a fixed size.
95    Fixed(usize),
96}
97
98impl SizeConfig {
99    /// Construct a source which will have the same size as the source dataset.
100    pub fn source() -> Self {
101        Self::Default
102    }
103
104    /// Resolve the effective size.
105    ///
106    /// ## Arguments
107    ///
108    /// - `source_size`: the size of the source dataset.
109    ///
110    /// ## Returns
111    ///
112    /// The resolved size of the wrapper dataset.
113    pub fn resolve(self, source_size: usize) -> usize {
114        match self {
115            SizeConfig::Default => source_size,
116            SizeConfig::Ratio(ratio) => {
117                assert!(ratio >= 0.0, "Ratio must be positive: {ratio}");
118                ((source_size as f64) * ratio) as usize
119            }
120            SizeConfig::Fixed(size) => size,
121        }
122    }
123}
124
125impl From<usize> for SizeConfig {
126    fn from(size: usize) -> Self {
127        Self::Fixed(size)
128    }
129}
130
131impl From<f64> for SizeConfig {
132    fn from(ratio: f64) -> Self {
133        Self::Ratio(ratio)
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use rand::SeedableRng;
141
142    #[test]
143    fn test_rng_source_default() {
144        let rng_source: RngSource = Default::default();
145        assert_eq!(&rng_source, &RngSource::Default);
146        assert_eq!(&rng_source, &RngSource::default());
147
148        // Exercise the from_os_rng() call; but we don't know its seed;
149        let _rng: StdRng = rng_source.into();
150    }
151
152    #[test]
153    fn test_rng_source_seed() {
154        let rng_source = RngSource::from(42);
155        assert_eq!(&rng_source, &RngSource::Seed(42));
156
157        let rng: StdRng = rng_source.into();
158        let expected = StdRng::seed_from_u64(42);
159
160        assert_eq!(rng, expected);
161    }
162
163    #[test]
164    fn test_rng_source_rng() {
165        let original = StdRng::seed_from_u64(42);
166
167        // From StdRng.
168        {
169            let rng_source = RngSource::from(original.clone());
170            let rng: StdRng = rng_source.into();
171            assert_eq!(rng, original);
172        }
173
174        // From &StdRng.
175        {
176            let rng_source = RngSource::from(&original);
177            let rng: StdRng = rng_source.into();
178            assert_eq!(rng, original);
179        }
180
181        // From &mut StdRng.
182        {
183            let mut stateful = original.clone();
184
185            let rng_source = RngSource::from(&mut stateful);
186            assert_ne!(stateful, original);
187
188            // Advance the rng.
189            let rng: StdRng = rng_source.into();
190            assert_eq!(rng, stateful);
191        }
192    }
193
194    #[test]
195    fn test_size_config() {
196        assert_eq!(SizeConfig::default(), SizeConfig::Default);
197
198        assert_eq!(SizeConfig::from(42), SizeConfig::Fixed(42));
199
200        assert_eq!(SizeConfig::from(1.5), SizeConfig::Ratio(1.5));
201
202        assert_eq!(SizeConfig::source(), SizeConfig::Default);
203        assert_eq!(SizeConfig::source().resolve(50), 50);
204    }
205}