burn_dataset/transform/
options.rs1use rand::SeedableRng;
2use rand::prelude::StdRng;
3use rand::rngs::SysRng;
4
5#[derive(Debug, Default, PartialEq, Eq)]
31#[allow(clippy::large_enum_variant)]
32pub enum RngSource {
33 #[default]
35 Default,
36
37 Seed(u64),
39
40 Rng(StdRng),
42}
43
44impl From<RngSource> for StdRng {
45 fn from(source: RngSource) -> Self {
46 match source {
47 RngSource::Default => StdRng::try_from_rng(&mut SysRng).unwrap(),
48 RngSource::Rng(rng) => rng,
49 RngSource::Seed(seed) => StdRng::seed_from_u64(seed),
50 }
51 }
52}
53
54impl From<u64> for RngSource {
55 fn from(seed: u64) -> Self {
56 Self::Seed(seed)
57 }
58}
59
60impl From<StdRng> for RngSource {
61 fn from(rng: StdRng) -> Self {
62 Self::Rng(rng)
63 }
64}
65
66impl From<&mut StdRng> for RngSource {
72 fn from(rng: &mut StdRng) -> Self {
73 Self::Rng(rng.fork())
74 }
75}
76
77#[derive(Debug, Clone, Copy, Default, PartialEq)]
79pub enum SizeConfig {
80 #[default]
82 Default,
83
84 Ratio(f64),
88
89 Fixed(usize),
91}
92
93impl SizeConfig {
94 pub fn source() -> Self {
96 Self::Default
97 }
98
99 pub fn resolve(self, source_size: usize) -> usize {
109 match self {
110 SizeConfig::Default => source_size,
111 SizeConfig::Ratio(ratio) => {
112 assert!(ratio >= 0.0, "Ratio must be positive: {ratio}");
113 ((source_size as f64) * ratio) as usize
114 }
115 SizeConfig::Fixed(size) => size,
116 }
117 }
118}
119
120impl From<usize> for SizeConfig {
121 fn from(size: usize) -> Self {
122 Self::Fixed(size)
123 }
124}
125
126impl From<f64> for SizeConfig {
127 fn from(ratio: f64) -> Self {
128 Self::Ratio(ratio)
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use rand::SeedableRng;
136
137 #[test]
138 fn test_rng_source_default() {
139 let rng_source: RngSource = Default::default();
140 assert_eq!(&rng_source, &RngSource::Default);
141 assert_eq!(&rng_source, &RngSource::default());
142
143 let _rng: StdRng = rng_source.into();
145 }
146
147 #[test]
148 fn test_rng_source_seed() {
149 let rng_source = RngSource::from(42);
150 assert_eq!(&rng_source, &RngSource::Seed(42));
151
152 let rng: StdRng = rng_source.into();
153 let expected = StdRng::seed_from_u64(42);
154
155 assert_eq!(rng, expected);
156 }
157
158 #[test]
159 fn test_rng_source_rng() {
160 {
162 let original = StdRng::seed_from_u64(42);
163
164 let rng_source = RngSource::from(original);
165 let rng: StdRng = rng_source.into();
166 let original = StdRng::seed_from_u64(42);
168 assert_eq!(rng, original);
169 }
170
171 {
173 let mut original = StdRng::seed_from_u64(42);
174 let mut rng = StdRng::seed_from_u64(42);
175 let rng_forked = rng.fork();
176
177 let rng_source = RngSource::from(&mut original);
178
179 assert_eq!(original, rng);
181
182 let rng: StdRng = rng_source.into();
184 assert_eq!(rng, rng_forked);
185 }
186 }
187
188 #[test]
189 fn test_size_config() {
190 assert_eq!(SizeConfig::default(), SizeConfig::Default);
191
192 assert_eq!(SizeConfig::from(42), SizeConfig::Fixed(42));
193
194 assert_eq!(SizeConfig::from(1.5), SizeConfig::Ratio(1.5));
195
196 assert_eq!(SizeConfig::source(), SizeConfig::Default);
197 assert_eq!(SizeConfig::source().resolve(50), 50);
198 }
199}