Skip to main content

burn_dataset/transform/
options.rs

1use rand::SeedableRng;
2use rand::prelude::StdRng;
3use rand::rngs::SysRng;
4
5/// Defines a source for a `StdRng`.
6///
7/// # Examples
8///
9/// ```rust,no_run
10/// use rand::rngs::StdRng;
11/// use rand::SeedableRng;
12/// use burn_dataset::transform::RngSource;
13///
14/// // Default via `StdRng::from_os_rng()` (`RngSource::Default`)
15/// let system: RngSource = RngSource::default();
16///
17/// // From a fixed seed (`RngSource::Seed`)
18/// let seeded: RngSource = 42.into();
19///
20/// // From an existing rng (`RngSource::Rng`)
21/// let rng = StdRng::seed_from_u64(123);
22/// let with_rng: RngSource = rng.into();
23///
24/// // Forks the parent RNG to derive an independent, deterministic child RNG.
25/// // The original `rng` is modified, and the resulting `RngSource` contains
26/// // a new RNG starting from a unique state.
27/// let mut rng = StdRng::seed_from_u64(123);
28/// let forked: RngSource = (&mut rng).into();
29/// ```
30#[derive(Debug, Default, PartialEq, Eq)]
31#[allow(clippy::large_enum_variant)]
32pub enum RngSource {
33    /// Build a new rng from the system.
34    #[default]
35    Default,
36
37    /// The rng is passed as a seed.
38    Seed(u64),
39
40    /// The rng is passed as an option.
41    Rng(StdRng),
42}
43
44impl From<RngSource> for StdRng {
45    fn from(source: RngSource) -> Self {
46        match source {
47            RngSource::Default => StdRng::try_from_rng(&mut SysRng).unwrap(),
48            RngSource::Rng(rng) => rng,
49            RngSource::Seed(seed) => StdRng::seed_from_u64(seed),
50        }
51    }
52}
53
54impl From<u64> for RngSource {
55    fn from(seed: u64) -> Self {
56        Self::Seed(seed)
57    }
58}
59
60impl From<StdRng> for RngSource {
61    fn from(rng: StdRng) -> Self {
62        Self::Rng(rng)
63    }
64}
65
66/// Derive an independent RNG from a mutable parent RNG.
67///
68/// This advances the parent RNG and creates a new RNG seeded from its output.
69/// The derived RNG is *not* a clone of the parent's state, but an independent
70/// stream (equivalent to `SeedableRng::fork`).
71impl From<&mut StdRng> for RngSource {
72    fn from(rng: &mut StdRng) -> Self {
73        Self::Rng(rng.fork())
74    }
75}
76
77/// Helper option to describe the size of a wrapper, relative to a wrapped object.
78#[derive(Debug, Clone, Copy, Default, PartialEq)]
79pub enum SizeConfig {
80    /// Use the size of the source dataset.
81    #[default]
82    Default,
83
84    /// Use the size as a ratio of the source dataset size.
85    ///
86    /// Must be >= 0.
87    Ratio(f64),
88
89    /// Use a fixed size.
90    Fixed(usize),
91}
92
93impl SizeConfig {
94    /// Construct a source which will have the same size as the source dataset.
95    pub fn source() -> Self {
96        Self::Default
97    }
98
99    /// Resolve the effective size.
100    ///
101    /// ## Arguments
102    ///
103    /// - `source_size`: the size of the source dataset.
104    ///
105    /// ## Returns
106    ///
107    /// The resolved size of the wrapper dataset.
108    pub fn resolve(self, source_size: usize) -> usize {
109        match self {
110            SizeConfig::Default => source_size,
111            SizeConfig::Ratio(ratio) => {
112                assert!(ratio >= 0.0, "Ratio must be positive: {ratio}");
113                ((source_size as f64) * ratio) as usize
114            }
115            SizeConfig::Fixed(size) => size,
116        }
117    }
118}
119
120impl From<usize> for SizeConfig {
121    fn from(size: usize) -> Self {
122        Self::Fixed(size)
123    }
124}
125
126impl From<f64> for SizeConfig {
127    fn from(ratio: f64) -> Self {
128        Self::Ratio(ratio)
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use rand::SeedableRng;
136
137    #[test]
138    fn test_rng_source_default() {
139        let rng_source: RngSource = Default::default();
140        assert_eq!(&rng_source, &RngSource::Default);
141        assert_eq!(&rng_source, &RngSource::default());
142
143        // Exercise the from_os_rng() call; but we don't know its seed;
144        let _rng: StdRng = rng_source.into();
145    }
146
147    #[test]
148    fn test_rng_source_seed() {
149        let rng_source = RngSource::from(42);
150        assert_eq!(&rng_source, &RngSource::Seed(42));
151
152        let rng: StdRng = rng_source.into();
153        let expected = StdRng::seed_from_u64(42);
154
155        assert_eq!(rng, expected);
156    }
157
158    #[test]
159    fn test_rng_source_rng() {
160        // From StdRng (owned).
161        {
162            let original = StdRng::seed_from_u64(42);
163
164            let rng_source = RngSource::from(original);
165            let rng: StdRng = rng_source.into();
166            // No longer clone, but from <> into should not have advanced the state
167            let original = StdRng::seed_from_u64(42);
168            assert_eq!(rng, original);
169        }
170
171        // From &mut StdRng (forks parent)
172        {
173            let mut original = StdRng::seed_from_u64(42);
174            let mut rng = StdRng::seed_from_u64(42);
175            let rng_forked = rng.fork();
176
177            let rng_source = RngSource::from(&mut original);
178
179            // Ensure the original was advanced
180            assert_eq!(original, rng);
181
182            // Ensure the sourced RNG matches the fork
183            let rng: StdRng = rng_source.into();
184            assert_eq!(rng, rng_forked);
185        }
186    }
187
188    #[test]
189    fn test_size_config() {
190        assert_eq!(SizeConfig::default(), SizeConfig::Default);
191
192        assert_eq!(SizeConfig::from(42), SizeConfig::Fixed(42));
193
194        assert_eq!(SizeConfig::from(1.5), SizeConfig::Ratio(1.5));
195
196        assert_eq!(SizeConfig::source(), SizeConfig::Default);
197        assert_eq!(SizeConfig::source().resolve(50), 50);
198    }
199}