scirs2_neural/data/
dataloader.rs1use crate::data::Dataset;
4use crate::error::Result;
5use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
6use scirs2_core::num_integer::div_ceil;
7use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
8use scirs2_core::random::rngs::SmallRng;
9use scirs2_core::random::seq::SliceRandom;
10use scirs2_core::random::{thread_rng, SeedableRng};
11use std::fmt::Debug;
12use std::marker::PhantomData;
13
14type BatchResult<F> = Result<(Array<F, IxDyn>, Array<F, IxDyn>)>;
16
17pub struct DataLoader<
19 F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
20 D: Dataset<F> + Send + Sync,
21> {
22 pub dataset: D,
24 pub batch_size: usize,
26 pub shuffle: bool,
28 pub drop_last: bool,
30 indices: Vec<usize>,
32 position: usize,
34 _phantom: PhantomData<F>,
36}
37
38impl<
39 F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
40 D: Dataset<F> + Send + Sync,
41 > DataLoader<F, D>
42{
43 pub fn new(dataset: D, batch_size: usize, shuffle: bool, drop_last: bool) -> Self {
51 let indices: Vec<usize> = (0..dataset.len()).collect();
52 Self {
53 dataset,
54 batch_size,
55 shuffle,
56 drop_last,
57 indices,
58 position: 0,
59 _phantom: PhantomData,
60 }
61 }
62
63 pub fn reset(&mut self) {
65 if self.shuffle {
66 let mut rng = SmallRng::from_rng(&mut thread_rng());
67 self.indices.shuffle(&mut rng);
68 }
69 self.position = 0;
70 }
71
72 pub fn num_batches(&self) -> usize {
74 let num = div_ceil(self.dataset.len(), self.batch_size);
75 if self.drop_last && num > 0 && self.dataset.len() % self.batch_size != 0 {
76 num - 1
77 } else {
78 num
79 }
80 }
81
82 pub fn len(&self) -> usize {
84 self.dataset.len()
85 }
86
87 pub fn is_empty(&self) -> bool {
89 self.len() == 0
90 }
91
92 pub fn next_batch(&mut self) -> Option<BatchResult<F>> {
94 if self.position >= self.dataset.len() {
95 return None;
96 }
97
98 let remaining = self.dataset.len() - self.position;
99 let batch_size = if remaining < self.batch_size {
100 if self.drop_last {
101 return None;
102 }
103 remaining
104 } else {
105 self.batch_size
106 };
107
108 let batch_indices: Vec<usize> =
110 self.indices[self.position..self.position + batch_size].to_vec();
111 self.position += batch_size;
112
113 let result = self.load_batch(&batch_indices);
115 Some(result)
116 }
117
118 fn load_batch(&self, indices: &[usize]) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
120 let (first_x, first_y) = self.dataset.get(indices[0])?;
122
123 let batch_x_shape = [indices.len()]
125 .iter()
126 .chain(first_x.shape())
127 .cloned()
128 .collect::<Vec<_>>();
129 let batch_y_shape = [indices.len()]
130 .iter()
131 .chain(first_y.shape())
132 .cloned()
133 .collect::<Vec<_>>();
134
135 let mut batch_x = Array::zeros(IxDyn(&batch_x_shape));
136 let mut batch_y = Array::zeros(IxDyn(&batch_y_shape));
137
138 for (i, &idx) in indices.iter().enumerate() {
140 let (x, y) = self.dataset.get(idx)?;
141
142 let mut batch_x_slice = batch_x.slice_mut(scirs2_core::ndarray::s![i, ..]);
144 batch_x_slice.assign(&x);
145
146 let mut batch_y_slice = batch_y.slice_mut(scirs2_core::ndarray::s![i, ..]);
147 batch_y_slice.assign(&y);
148 }
149
150 Ok((batch_x, batch_y))
151 }
152}
153
154impl<
155 F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
156 D: Dataset<F> + Send + Sync,
157 > Iterator for DataLoader<F, D>
158{
159 type Item = Result<(Array<F, IxDyn>, Array<F, IxDyn>)>;
160
161 fn next(&mut self) -> Option<Self::Item> {
162 self.next_batch()
163 }
164}
165
166#[allow(dead_code)]
168pub fn iter_batches<
169 F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
170 D: Dataset<F> + Send + Sync,
171>(
172 dataset: D,
173 batch_size: usize,
174 shuffle: bool,
175 drop_last: bool,
176) -> DataLoader<F, D> {
177 DataLoader::new(dataset, batch_size, shuffle, drop_last)
178}