burn_dataset/transform/
options.rs1use rand::prelude::StdRng;
2use rand::{RngCore, SeedableRng};
3
4#[derive(Debug, Clone, Default, PartialEq, Eq)]
32#[allow(clippy::large_enum_variant)]
33pub enum RngSource {
34 #[default]
36 Default,
37
38 Seed(u64),
40
41 Rng(StdRng),
43}
44
45impl From<RngSource> for StdRng {
46 fn from(source: RngSource) -> Self {
47 match &source {
48 RngSource::Default => StdRng::from_os_rng(),
49 RngSource::Rng(rng) => rng.clone(),
50 RngSource::Seed(seed) => StdRng::seed_from_u64(*seed),
51 }
52 }
53}
54
55impl From<u64> for RngSource {
56 fn from(seed: u64) -> Self {
57 Self::Seed(seed)
58 }
59}
60
61impl From<StdRng> for RngSource {
62 fn from(rng: StdRng) -> Self {
63 Self::Rng(rng)
64 }
65}
66
67impl From<&StdRng> for RngSource {
68 fn from(rng: &StdRng) -> Self {
69 Self::Rng(rng.clone())
70 }
71}
72
73impl From<&mut StdRng> for RngSource {
76 fn from(rng: &mut StdRng) -> Self {
77 rng.next_u64();
78 Self::Rng(rng.clone())
79 }
80}
81
82#[derive(Debug, Clone, Copy, Default, PartialEq)]
84pub enum SizeConfig {
85 #[default]
87 Default,
88
89 Ratio(f64),
93
94 Fixed(usize),
96}
97
98impl SizeConfig {
99 pub fn source() -> Self {
101 Self::Default
102 }
103
104 pub fn resolve(self, source_size: usize) -> usize {
114 match self {
115 SizeConfig::Default => source_size,
116 SizeConfig::Ratio(ratio) => {
117 assert!(ratio >= 0.0, "Ratio must be positive: {ratio}");
118 ((source_size as f64) * ratio) as usize
119 }
120 SizeConfig::Fixed(size) => size,
121 }
122 }
123}
124
125impl From<usize> for SizeConfig {
126 fn from(size: usize) -> Self {
127 Self::Fixed(size)
128 }
129}
130
131impl From<f64> for SizeConfig {
132 fn from(ratio: f64) -> Self {
133 Self::Ratio(ratio)
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use rand::SeedableRng;
141
142 #[test]
143 fn test_rng_source_default() {
144 let rng_source: RngSource = Default::default();
145 assert_eq!(&rng_source, &RngSource::Default);
146 assert_eq!(&rng_source, &RngSource::default());
147
148 let _rng: StdRng = rng_source.into();
150 }
151
152 #[test]
153 fn test_rng_source_seed() {
154 let rng_source = RngSource::from(42);
155 assert_eq!(&rng_source, &RngSource::Seed(42));
156
157 let rng: StdRng = rng_source.into();
158 let expected = StdRng::seed_from_u64(42);
159
160 assert_eq!(rng, expected);
161 }
162
163 #[test]
164 fn test_rng_source_rng() {
165 let original = StdRng::seed_from_u64(42);
166
167 {
169 let rng_source = RngSource::from(original.clone());
170 let rng: StdRng = rng_source.into();
171 assert_eq!(rng, original);
172 }
173
174 {
176 let rng_source = RngSource::from(&original);
177 let rng: StdRng = rng_source.into();
178 assert_eq!(rng, original);
179 }
180
181 {
183 let mut stateful = original.clone();
184
185 let rng_source = RngSource::from(&mut stateful);
186 assert_ne!(stateful, original);
187
188 let rng: StdRng = rng_source.into();
190 assert_eq!(rng, stateful);
191 }
192 }
193
194 #[test]
195 fn test_size_config() {
196 assert_eq!(SizeConfig::default(), SizeConfig::Default);
197
198 assert_eq!(SizeConfig::from(42), SizeConfig::Fixed(42));
199
200 assert_eq!(SizeConfig::from(1.5), SizeConfig::Ratio(1.5));
201
202 assert_eq!(SizeConfig::source(), SizeConfig::Default);
203 assert_eq!(SizeConfig::source().resolve(50), 50);
204 }
205}