Skip to main content

scirs2_neural/data/
dataloader.rs

1//! DataLoader implementation for efficient batch loading
2
3use crate::data::Dataset;
4use crate::error::Result;
5use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
6use scirs2_core::num_integer::div_ceil;
7use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
8use scirs2_core::random::rngs::SmallRng;
9use scirs2_core::random::seq::SliceRandom;
10use scirs2_core::random::{thread_rng, SeedableRng};
11use std::fmt::Debug;
12use std::marker::PhantomData;
13
14/// Type alias for batch result
15type BatchResult<F> = Result<(Array<F, IxDyn>, Array<F, IxDyn>)>;
16
17/// Data loader for efficient batch processing
18pub struct DataLoader<
19    F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
20    D: Dataset<F> + Send + Sync,
21> {
22    /// Dataset to load from
23    pub dataset: D,
24    /// Batch size
25    pub batch_size: usize,
26    /// Whether to shuffle the data
27    pub shuffle: bool,
28    /// Whether to drop the last batch if it's smaller than batch_size
29    pub drop_last: bool,
30    /// Current indices for iteration
31    indices: Vec<usize>,
32    /// Current position in indices
33    position: usize,
34    /// Phantom data for float type
35    _phantom: PhantomData<F>,
36}
37
38impl<
39        F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
40        D: Dataset<F> + Send + Sync,
41    > DataLoader<F, D>
42{
43    /// Create a new data loader
44    ///
45    /// # Arguments
46    /// * `dataset` - Dataset to load from
47    /// * `batch_size` - Number of samples per batch
48    /// * `shuffle` - Whether to shuffle the data
49    /// * `drop_last` - Whether to drop the last batch if it's smaller than batch_size
50    pub fn new(dataset: D, batch_size: usize, shuffle: bool, drop_last: bool) -> Self {
51        let indices: Vec<usize> = (0..dataset.len()).collect();
52        Self {
53            dataset,
54            batch_size,
55            shuffle,
56            drop_last,
57            indices,
58            position: 0,
59            _phantom: PhantomData,
60        }
61    }
62
63    /// Reset the data loader state
64    pub fn reset(&mut self) {
65        if self.shuffle {
66            let mut rng = SmallRng::from_rng(&mut thread_rng());
67            self.indices.shuffle(&mut rng);
68        }
69        self.position = 0;
70    }
71
72    /// Get the number of batches in the dataset
73    pub fn num_batches(&self) -> usize {
74        let num = div_ceil(self.dataset.len(), self.batch_size);
75        if self.drop_last && num > 0 && self.dataset.len() % self.batch_size != 0 {
76            num - 1
77        } else {
78            num
79        }
80    }
81
82    /// Get the dataset len
83    pub fn len(&self) -> usize {
84        self.dataset.len()
85    }
86
87    /// Check if the dataloader is empty
88    pub fn is_empty(&self) -> bool {
89        self.len() == 0
90    }
91
92    /// Get the next batch from the dataset
93    pub fn next_batch(&mut self) -> Option<BatchResult<F>> {
94        if self.position >= self.dataset.len() {
95            return None;
96        }
97
98        let remaining = self.dataset.len() - self.position;
99        let batch_size = if remaining < self.batch_size {
100            if self.drop_last {
101                return None;
102            }
103            remaining
104        } else {
105            self.batch_size
106        };
107
108        // Collect batch indices
109        let batch_indices: Vec<usize> =
110            self.indices[self.position..self.position + batch_size].to_vec();
111        self.position += batch_size;
112
113        // Load data
114        let result = self.load_batch(&batch_indices);
115        Some(result)
116    }
117
118    /// Load a batch of data using the given indices
119    fn load_batch(&self, indices: &[usize]) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
120        // Load first sample to determine shapes
121        let (first_x, first_y) = self.dataset.get(indices[0])?;
122
123        // Create batch arrays
124        let batch_x_shape = [indices.len()]
125            .iter()
126            .chain(first_x.shape())
127            .cloned()
128            .collect::<Vec<_>>();
129        let batch_y_shape = [indices.len()]
130            .iter()
131            .chain(first_y.shape())
132            .cloned()
133            .collect::<Vec<_>>();
134
135        let mut batch_x = Array::zeros(IxDyn(&batch_x_shape));
136        let mut batch_y = Array::zeros(IxDyn(&batch_y_shape));
137
138        // Fill batch arrays
139        for (i, &idx) in indices.iter().enumerate() {
140            let (x, y) = self.dataset.get(idx)?;
141
142            // Copy data into batch arrays
143            let mut batch_x_slice = batch_x.slice_mut(scirs2_core::ndarray::s![i, ..]);
144            batch_x_slice.assign(&x);
145
146            let mut batch_y_slice = batch_y.slice_mut(scirs2_core::ndarray::s![i, ..]);
147            batch_y_slice.assign(&y);
148        }
149
150        Ok((batch_x, batch_y))
151    }
152}
153
154impl<
155        F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
156        D: Dataset<F> + Send + Sync,
157    > Iterator for DataLoader<F, D>
158{
159    type Item = Result<(Array<F, IxDyn>, Array<F, IxDyn>)>;
160
161    fn next(&mut self) -> Option<Self::Item> {
162        self.next_batch()
163    }
164}
165
166/// Helper function to create an iterator over the dataset in batches
167#[allow(dead_code)]
168pub fn iter_batches<
169    F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
170    D: Dataset<F> + Send + Sync,
171>(
172    dataset: D,
173    batch_size: usize,
174    shuffle: bool,
175    drop_last: bool,
176) -> DataLoader<F, D> {
177    DataLoader::new(dataset, batch_size, shuffle, drop_last)
178}