scirs2_datasets/streaming/
dataloader.rs1use crate::error::DatasetsError;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::rngs::StdRng;
12
13#[non_exhaustive]
22#[derive(Debug, Clone)]
23pub enum SamplingStrategy {
24 Sequential,
26
27 RandomShuffle,
29
30 Stratified(Vec<usize>),
37
38 WeightedRandom(Vec<f64>),
44}
45
46#[derive(Debug, Clone)]
52pub struct DataLoaderConfig {
53 pub batch_size: usize,
55 pub shuffle: bool,
58 pub drop_last: bool,
61 pub seed: u64,
63 pub sampling: SamplingStrategy,
66}
67
68impl Default for DataLoaderConfig {
69 fn default() -> Self {
70 Self {
71 batch_size: 32,
72 shuffle: true,
73 drop_last: false,
74 seed: 42,
75 sampling: SamplingStrategy::RandomShuffle,
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
86pub struct Batch {
87 pub features: Array2<f64>,
89 pub labels: Option<Array1<f64>>,
91 pub indices: Vec<usize>,
93}
94
95impl Batch {
96 pub fn batch_size(&self) -> usize {
98 self.features.nrows()
99 }
100
101 pub fn n_features(&self) -> usize {
103 self.features.ncols()
104 }
105}
106
107pub struct DataLoader {
117 features: Array2<f64>,
118 labels: Option<Vec<f64>>,
119 config: DataLoaderConfig,
120 indices: Vec<usize>,
122 current_pos: usize,
123 epoch: usize,
124 rng: StdRng,
125}
126
127impl DataLoader {
128 pub fn new(features: Array2<f64>, labels: Option<Vec<f64>>, config: DataLoaderConfig) -> Self {
133 let n_rows = features.nrows();
134 let mut rng = StdRng::seed_from_u64(config.seed);
135 let indices = Self::build_indices(n_rows, &config, &mut rng);
136 Self {
137 features,
138 labels,
139 config,
140 indices,
141 current_pos: 0,
142 epoch: 0,
143 rng,
144 }
145 }
146
147 pub fn n_batches(&self) -> usize {
152 let n = self.indices.len();
153 let bs = self.config.batch_size.max(1);
154 if self.config.drop_last {
155 n / bs
156 } else {
157 n.div_ceil(bs)
158 }
159 }
160
161 pub fn n_rows(&self) -> usize {
163 self.features.nrows()
164 }
165
166 pub fn n_features(&self) -> usize {
168 self.features.ncols()
169 }
170
171 pub fn epoch(&self) -> usize {
173 self.epoch
174 }
175
176 pub fn reset_epoch(&mut self) {
179 self.epoch += 1;
180 self.current_pos = 0;
181 let n_rows = self.features.nrows();
182 self.indices = Self::build_indices(n_rows, &self.config, &mut self.rng);
183 }
184
185 fn build_indices(n_rows: usize, config: &DataLoaderConfig, rng: &mut StdRng) -> Vec<usize> {
191 if n_rows == 0 {
192 return vec![];
193 }
194
195 if config.shuffle {
197 return Self::fisher_yates(n_rows, rng);
198 }
199
200 match &config.sampling {
201 SamplingStrategy::Sequential => (0..n_rows).collect(),
202
203 SamplingStrategy::RandomShuffle => Self::fisher_yates(n_rows, rng),
204
205 SamplingStrategy::Stratified(class_labels) => {
206 Self::stratified_indices(n_rows, class_labels, rng)
207 }
208
209 SamplingStrategy::WeightedRandom(weights) => {
210 Self::weighted_indices(n_rows, weights, rng)
211 }
212 }
213 }
214
215 fn fisher_yates(n: usize, rng: &mut StdRng) -> Vec<usize> {
217 let mut idx: Vec<usize> = (0..n).collect();
218 for i in (1..n).rev() {
219 let j = (rng.next_u64() as usize) % (i + 1);
220 idx.swap(i, j);
221 }
222 idx
223 }
224
225 fn stratified_indices(n_rows: usize, class_labels: &[usize], rng: &mut StdRng) -> Vec<usize> {
227 let max_class = class_labels.iter().copied().max().unwrap_or(0);
229 let mut buckets: Vec<Vec<usize>> = vec![vec![]; max_class + 1];
230 for (row, &cls) in class_labels.iter().enumerate().take(n_rows) {
231 buckets[cls].push(row);
232 }
233 for bucket in &mut buckets {
235 for i in (1..bucket.len()).rev() {
236 let j = (rng.next_u64() as usize) % (i + 1);
237 bucket.swap(i, j);
238 }
239 }
240 let mut result = Vec::with_capacity(n_rows);
242 let mut cursors = vec![0usize; buckets.len()];
243 let mut any_remaining = true;
244 while any_remaining {
245 any_remaining = false;
246 for (cls, bucket) in buckets.iter().enumerate() {
247 if cursors[cls] < bucket.len() {
248 result.push(bucket[cursors[cls]]);
249 cursors[cls] += 1;
250 any_remaining = true;
251 }
252 }
253 }
254 result
255 }
256
257 fn weighted_indices(n_rows: usize, weights: &[f64], rng: &mut StdRng) -> Vec<usize> {
263 let mut keyed: Vec<(f64, usize)> = (0..n_rows)
264 .map(|i| {
265 let w = if i < weights.len() {
266 weights[i].max(0.0)
267 } else {
268 1.0
269 };
270 let u = (rng.next_u64() as f64 + 1.0) / (u64::MAX as f64 + 1.0);
272 let key = if w > 0.0 { -u.ln() / w } else { f64::INFINITY };
273 (key, i)
274 })
275 .collect();
276 keyed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
277 keyed.into_iter().map(|(_, idx)| idx).collect()
278 }
279
280 fn extract_batch(&self, row_indices: &[usize]) -> Batch {
282 let nf = self.features.ncols();
283 let bs = row_indices.len();
284 let mut feat_flat = Vec::with_capacity(bs * nf);
285 let mut label_vals = Vec::with_capacity(bs);
286
287 for &ri in row_indices {
288 for j in 0..nf {
289 feat_flat.push(self.features[[ri, j]]);
290 }
291 if let Some(lbl_vec) = &self.labels {
292 label_vals.push(if ri < lbl_vec.len() { lbl_vec[ri] } else { 0.0 });
293 }
294 }
295
296 let features = Array2::from_shape_vec((bs, nf), feat_flat)
297 .unwrap_or_else(|_| Array2::zeros((bs, nf.max(1))));
298
299 let labels = if self.labels.is_some() {
300 Some(Array1::from_vec(label_vals))
301 } else {
302 None
303 };
304
305 Batch {
306 features,
307 labels,
308 indices: row_indices.to_vec(),
309 }
310 }
311}
312
313impl Iterator for DataLoader {
314 type Item = Batch;
315
316 fn next(&mut self) -> Option<Self::Item> {
317 let remaining = self.indices.len().saturating_sub(self.current_pos);
318 if remaining == 0 {
319 return None;
320 }
321
322 let bs = self.config.batch_size.max(1);
323 let batch_rows = remaining.min(bs);
324
325 if self.config.drop_last && batch_rows < bs {
327 return None;
328 }
329
330 let start = self.current_pos;
331 let end = start + batch_rows;
332 let row_indices: Vec<usize> = self.indices[start..end].to_vec();
333 self.current_pos = end;
334
335 Some(self.extract_batch(&row_indices))
336 }
337}
338
339#[cfg(test)]
344mod tests {
345 use super::*;
346 use scirs2_core::ndarray::Array2;
347
348 fn make_loader(n: usize, f: usize, bs: usize, shuffle: bool) -> DataLoader {
349 let data: Vec<f64> = (0..n * f).map(|x| x as f64).collect();
350 let features = Array2::from_shape_vec((n, f), data).unwrap();
351 let labels: Vec<f64> = (0..n).map(|i| (i % 3) as f64).collect();
352 let config = DataLoaderConfig {
353 batch_size: bs,
354 shuffle,
355 drop_last: false,
356 seed: 42,
357 sampling: if shuffle {
358 SamplingStrategy::RandomShuffle
359 } else {
360 SamplingStrategy::Sequential
361 },
362 };
363 DataLoader::new(features, Some(labels), config)
364 }
365
366 #[test]
367 fn test_dataloader_basic() {
368 let loader = make_loader(100, 4, 32, false);
370 assert_eq!(loader.n_batches(), 4);
371 let batches: Vec<_> = loader.collect();
372 assert_eq!(batches.len(), 4);
373 let total: usize = batches.iter().map(|b| b.batch_size()).sum();
374 assert_eq!(total, 100);
375 }
376
377 #[test]
378 fn test_dataloader_last_batch() {
379 let data: Vec<f64> = (0..105 * 2).map(|x| x as f64).collect();
381 let features = Array2::from_shape_vec((105, 2), data).unwrap();
382 let config = DataLoaderConfig {
383 batch_size: 32,
384 shuffle: false,
385 drop_last: false,
386 seed: 0,
387 sampling: SamplingStrategy::Sequential,
388 };
389 let loader = DataLoader::new(features, None, config);
390 let batches: Vec<_> = loader.collect();
391 assert_eq!(batches.len(), 4);
392 assert_eq!(batches.last().unwrap().batch_size(), 9);
393 }
394
395 #[test]
396 fn test_dataloader_drop_last() {
397 let data: Vec<f64> = (0..105 * 2).map(|x| x as f64).collect();
399 let features = Array2::from_shape_vec((105, 2), data).unwrap();
400 let config = DataLoaderConfig {
401 batch_size: 32,
402 shuffle: false,
403 drop_last: true,
404 seed: 0,
405 sampling: SamplingStrategy::Sequential,
406 };
407 let loader = DataLoader::new(features, None, config);
408 let batches: Vec<_> = loader.collect();
409 assert_eq!(batches.len(), 3);
410 for b in &batches {
411 assert_eq!(b.batch_size(), 32);
412 }
413 }
414
415 #[test]
416 fn test_dataloader_shuffle() {
417 let data: Vec<f64> = (0..50 * 2).map(|x| x as f64).collect();
419 let features = Array2::from_shape_vec((50, 2), data).unwrap();
420 let config = DataLoaderConfig {
421 batch_size: 50,
422 shuffle: true,
423 drop_last: false,
424 seed: 99,
425 sampling: SamplingStrategy::RandomShuffle,
426 };
427 let mut loader = DataLoader::new(features, None, config);
428
429 let first_batch = loader.next().expect("first epoch batch");
430 loader.reset_epoch();
431 let second_batch = loader.next().expect("second epoch batch");
432
433 assert_ne!(first_batch.indices, second_batch.indices);
435 }
436
437 #[test]
438 fn test_dataloader_stratified() {
439 let n = 30usize;
441 let data: Vec<f64> = (0..n * 2).map(|x| x as f64).collect();
442 let features = Array2::from_shape_vec((n, 2), data).unwrap();
443 let class_labels: Vec<usize> = (0..n).map(|i| i % 3).collect();
444 let label_f64: Vec<f64> = class_labels.iter().map(|&c| c as f64).collect();
445 let config = DataLoaderConfig {
446 batch_size: 6,
447 shuffle: false,
448 drop_last: false,
449 seed: 1,
450 sampling: SamplingStrategy::Stratified(class_labels),
451 };
452 let loader = DataLoader::new(features, Some(label_f64), config);
453 let batches: Vec<_> = loader.collect();
454 assert_eq!(batches.len(), 5);
456 for batch in &batches {
458 if let Some(lbls) = &batch.labels {
459 let unique: std::collections::HashSet<i64> =
460 lbls.iter().map(|&x| x as i64).collect();
461 assert!(
462 unique.len() >= 2,
463 "expected multiple classes per batch, got {unique:?}"
464 );
465 }
466 }
467 }
468
469 #[test]
470 fn test_dataloader_epoch_count() {
471 let mut loader = make_loader(20, 2, 5, true);
472 assert_eq!(loader.epoch(), 0);
473 for _ in loader.by_ref() {}
475 loader.reset_epoch();
476 assert_eq!(loader.epoch(), 1);
477 for _ in loader.by_ref() {}
478 loader.reset_epoch();
479 assert_eq!(loader.epoch(), 2);
480 }
481
482 #[test]
483 fn test_dataloader_empty() {
484 let features = Array2::<f64>::zeros((0, 3));
485 let config = DataLoaderConfig::default();
486 let loader = DataLoader::new(features, None, config);
487 assert_eq!(loader.n_batches(), 0);
488 let batches: Vec<_> = loader.collect();
489 assert!(batches.is_empty());
490 }
491
492 #[test]
493 fn test_dataloader_exact_multiple() {
494 let loader = make_loader(64, 4, 32, false);
496 let batches: Vec<_> = loader.collect();
497 assert_eq!(batches.len(), 2);
498 for b in &batches {
499 assert_eq!(b.batch_size(), 32);
500 }
501 }
502
503 #[test]
504 fn test_dataloader_weighted_random() {
505 let n = 40usize;
506 let data: Vec<f64> = (0..n * 2).map(|x| x as f64).collect();
507 let features = Array2::from_shape_vec((n, 2), data).unwrap();
508 let weights: Vec<f64> = (0..n).map(|i| if i < 10 { 100.0 } else { 1.0 }).collect();
510 let config = DataLoaderConfig {
511 batch_size: n, shuffle: false,
513 drop_last: false,
514 seed: 7,
515 sampling: SamplingStrategy::WeightedRandom(weights),
516 };
517 let mut loader = DataLoader::new(features, None, config);
518 let batch = loader.next().expect("batch");
519 let top10: Vec<usize> = batch.indices[..10].to_vec();
521 let heavy_in_top10 = top10.iter().filter(|&&i| i < 10).count();
522 assert!(
524 heavy_in_top10 >= 5,
525 "expected heavy rows near top, got {heavy_in_top10}"
526 );
527 }
528}