kizzasi_core/
dataloader.rs

1//! DataLoader for time-series training
2//!
3//! Provides efficient data loading, batching, and preprocessing for time-series
4//! signal prediction tasks.
5//!
6//! # Features
7//!
8//! - **Windowing**: Sliding window extraction from continuous signals
9//! - **Batching**: Efficient mini-batch creation with shuffling
10//! - **Prefetching**: Async data loading for GPU transfer
11//! - **Augmentation**: Time-series specific augmentations
12//! - **Multi-GPU**: Distributed data loading support
13//!
14//! # Examples
15//!
16//! ```rust
17//! use kizzasi_core::dataloader::{TimeSeriesDataLoader, DataLoaderConfig};
18//! use scirs2_core::ndarray::Array2;
19//!
20//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
21//! let data = Array2::<f32>::zeros((1000, 3));  // 1000 timesteps, 3 features
22//! let config = DataLoaderConfig::default()
23//!     .with_window_size(64)
24//!     .with_batch_size(32)
25//!     .with_shuffle(true);
26//!
27//! let mut loader = TimeSeriesDataLoader::new(data, config)?;
28//!
29//! for batch in loader.iter_batches() {
30//!     let (inputs, targets) = batch?;
31//!     // Train on batch
32//! }
33//! # Ok(())
34//! # }
35//! ```
36
37use crate::error::{CoreError, CoreResult};
38use candle_core::{Device, Tensor};
39use scirs2_core::ndarray::{s, Array2};
40use serde::{Deserialize, Serialize};
41
42/// Configuration for time-series data loader
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct DataLoaderConfig {
45    /// Window size for sliding window extraction
46    pub window_size: usize,
47    /// Prediction horizon (number of steps ahead to predict)
48    pub horizon: usize,
49    /// Batch size for training
50    pub batch_size: usize,
51    /// Whether to shuffle data
52    pub shuffle: bool,
53    /// Overlap between consecutive windows (0.0 = no overlap, 0.5 = 50% overlap)
54    pub overlap: f32,
55    /// Whether to drop last incomplete batch
56    pub drop_last: bool,
57    /// Number of workers for parallel loading (future)
58    pub num_workers: usize,
59}
60
61impl Default for DataLoaderConfig {
62    fn default() -> Self {
63        Self {
64            window_size: 64,
65            horizon: 1,
66            batch_size: 32,
67            shuffle: true,
68            overlap: 0.0,
69            drop_last: false,
70            num_workers: 1,
71        }
72    }
73}
74
75impl DataLoaderConfig {
76    pub fn new() -> Self {
77        Self::default()
78    }
79
80    pub fn with_window_size(mut self, window_size: usize) -> Self {
81        self.window_size = window_size;
82        self
83    }
84
85    pub fn with_horizon(mut self, horizon: usize) -> Self {
86        self.horizon = horizon;
87        self
88    }
89
90    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
91        self.batch_size = batch_size;
92        self
93    }
94
95    pub fn with_shuffle(mut self, shuffle: bool) -> Self {
96        self.shuffle = shuffle;
97        self
98    }
99
100    pub fn with_overlap(mut self, overlap: f32) -> Self {
101        self.overlap = overlap.clamp(0.0, 1.0);
102        self
103    }
104
105    pub fn with_drop_last(mut self, drop_last: bool) -> Self {
106        self.drop_last = drop_last;
107        self
108    }
109}
110
111/// Time-series data loader
112pub struct TimeSeriesDataLoader {
113    data: Array2<f32>,
114    config: DataLoaderConfig,
115    indices: Vec<usize>,
116    current_epoch: usize,
117}
118
119impl TimeSeriesDataLoader {
120    /// Create a new data loader
121    ///
122    /// # Arguments
123    /// * `data` - Time-series data of shape [timesteps, features]
124    /// * `config` - DataLoader configuration
125    pub fn new(data: Array2<f32>, config: DataLoaderConfig) -> CoreResult<Self> {
126        if data.nrows() < config.window_size + config.horizon {
127            return Err(CoreError::InvalidConfig(format!(
128                "Data length {} is too short for window_size {} + horizon {}",
129                data.nrows(),
130                config.window_size,
131                config.horizon
132            )));
133        }
134
135        // Calculate stride based on overlap
136        let stride = ((config.window_size as f32) * (1.0 - config.overlap)).max(1.0) as usize;
137
138        // Generate window start indices
139        let max_start = data.nrows() - config.window_size - config.horizon + 1;
140        let indices: Vec<usize> = (0..max_start).step_by(stride).collect();
141
142        Ok(Self {
143            data,
144            config,
145            indices,
146            current_epoch: 0,
147        })
148    }
149
150    /// Get number of batches per epoch
151    pub fn num_batches(&self) -> usize {
152        let num_samples = self.indices.len();
153        if self.config.drop_last {
154            num_samples / self.config.batch_size
155        } else {
156            num_samples.div_ceil(self.config.batch_size)
157        }
158    }
159
160    /// Get total number of samples
161    pub fn num_samples(&self) -> usize {
162        self.indices.len()
163    }
164
165    /// Shuffle indices for new epoch
166    pub fn shuffle(&mut self) {
167        if self.config.shuffle {
168            use scirs2_core::convenience::uniform;
169            // Fisher-Yates shuffle
170            for i in (1..self.indices.len()).rev() {
171                let j = (uniform() * (i + 1) as f64) as usize;
172                self.indices.swap(i, j);
173            }
174        }
175    }
176
177    /// Extract a single window
178    fn extract_window(&self, start_idx: usize) -> CoreResult<(Array2<f32>, Array2<f32>)> {
179        let end_input = start_idx + self.config.window_size;
180        let end_target = end_input + self.config.horizon;
181
182        if end_target > self.data.nrows() {
183            return Err(CoreError::Generic(format!(
184                "Window exceeds data bounds: {} > {}",
185                end_target,
186                self.data.nrows()
187            )));
188        }
189
190        let input = self.data.slice(s![start_idx..end_input, ..]).to_owned();
191
192        let target = self.data.slice(s![end_input..end_target, ..]).to_owned();
193
194        Ok((input, target))
195    }
196
197    /// Create a batch of windows
198    fn create_batch(&self, batch_indices: &[usize]) -> CoreResult<(Array2<f32>, Array2<f32>)> {
199        let mut inputs = Vec::new();
200        let mut targets = Vec::new();
201
202        for &idx in batch_indices {
203            let start = self.indices[idx];
204            let (input, target) = self.extract_window(start)?;
205            inputs.push(input);
206            targets.push(target);
207        }
208
209        // Stack into batch: [batch, time, features]
210        let batch_size = inputs.len();
211        let window_size = self.config.window_size;
212        let horizon = self.config.horizon;
213        let n_features = self.data.ncols();
214
215        let mut batch_input = Array2::zeros((batch_size * window_size, n_features));
216        let mut batch_target = Array2::zeros((batch_size * horizon, n_features));
217
218        for (i, (inp, tgt)) in inputs.iter().zip(targets.iter()).enumerate() {
219            let input_start = i * window_size;
220            let input_end = input_start + window_size;
221            batch_input
222                .slice_mut(s![input_start..input_end, ..])
223                .assign(inp);
224
225            let target_start = i * horizon;
226            let target_end = target_start + horizon;
227            batch_target
228                .slice_mut(s![target_start..target_end, ..])
229                .assign(tgt);
230        }
231
232        Ok((batch_input, batch_target))
233    }
234
235    /// Iterate over batches
236    pub fn iter_batches(&mut self) -> BatchIterator<'_> {
237        if self.current_epoch > 0 {
238            self.shuffle();
239        }
240        self.current_epoch += 1;
241
242        BatchIterator {
243            loader: self,
244            current_batch: 0,
245        }
246    }
247
248    /// Convert batch to candle tensors
249    pub fn to_tensors(
250        &self,
251        inputs: &Array2<f32>,
252        targets: &Array2<f32>,
253        device: &Device,
254    ) -> CoreResult<(Tensor, Tensor)> {
255        let batch_size = inputs.nrows() / self.config.window_size;
256        let window_size = self.config.window_size;
257        let horizon = self.config.horizon;
258        let n_features = inputs.ncols();
259
260        // Flatten to Vec<f32>
261        let input_vec: Vec<f32> = inputs.iter().copied().collect();
262        let target_vec: Vec<f32> = targets.iter().copied().collect();
263
264        // Create tensors with shape [batch, seq, features]
265        let input_tensor =
266            Tensor::from_vec(input_vec, &[batch_size, window_size, n_features], device)
267                .map_err(|e| CoreError::Generic(format!("Failed to create input tensor: {}", e)))?;
268
269        let target_tensor =
270            Tensor::from_vec(target_vec, &[batch_size, horizon, n_features], device).map_err(
271                |e| CoreError::Generic(format!("Failed to create target tensor: {}", e)),
272            )?;
273
274        Ok((input_tensor, target_tensor))
275    }
276
277    /// Get configuration
278    pub fn config(&self) -> &DataLoaderConfig {
279        &self.config
280    }
281}
282
283/// Iterator over batches
284pub struct BatchIterator<'a> {
285    loader: &'a TimeSeriesDataLoader,
286    current_batch: usize,
287}
288
289impl<'a> Iterator for BatchIterator<'a> {
290    type Item = CoreResult<(Array2<f32>, Array2<f32>)>;
291
292    fn next(&mut self) -> Option<Self::Item> {
293        let num_batches = self.loader.num_batches();
294        if self.current_batch >= num_batches {
295            return None;
296        }
297
298        let start_idx = self.current_batch * self.loader.config.batch_size;
299        let end_idx = (start_idx + self.loader.config.batch_size).min(self.loader.indices.len());
300
301        // Check if we should drop last incomplete batch
302        if self.loader.config.drop_last && end_idx - start_idx < self.loader.config.batch_size {
303            return None;
304        }
305
306        let batch_indices: Vec<usize> = (start_idx..end_idx).collect();
307        self.current_batch += 1;
308
309        Some(self.loader.create_batch(&batch_indices))
310    }
311}
312
313/// Data augmentation for time-series
314pub struct TimeSeriesAugmentation;
315
316impl TimeSeriesAugmentation {
317    /// Add Gaussian noise
318    pub fn add_noise(data: &Array2<f32>, std: f32) -> Array2<f32> {
319        use scirs2_core::convenience::uniform;
320        let noise = Array2::from_shape_fn(data.dim(), |_| {
321            // Box-Muller transform for Gaussian
322            let u1 = uniform();
323            let u2 = uniform();
324            let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
325            (z0 * std as f64) as f32
326        });
327        data + &noise
328    }
329
330    /// Scale by random factor
331    pub fn scale(data: &Array2<f32>, min_scale: f32, max_scale: f32) -> Array2<f32> {
332        use scirs2_core::convenience::uniform;
333        let scale = uniform() * (max_scale - min_scale) as f64 + min_scale as f64;
334        data * (scale as f32)
335    }
336
337    /// Time shift (circular shift along time axis)
338    pub fn time_shift(data: &Array2<f32>, max_shift: usize) -> Array2<f32> {
339        use scirs2_core::convenience::uniform;
340        let shift = (uniform() * max_shift as f64) as usize;
341
342        let mut shifted = data.clone();
343        if shift > 0 {
344            let n = data.nrows();
345            for i in 0..n {
346                let src = (i + shift) % n;
347                shifted.row_mut(i).assign(&data.row(src));
348            }
349        }
350        shifted
351    }
352
353    /// Apply random masking (set random timesteps to zero)
354    pub fn mask(data: &Array2<f32>, mask_prob: f32) -> Array2<f32> {
355        use scirs2_core::convenience::uniform;
356        let mut masked = data.clone();
357        for i in 0..masked.nrows() {
358            if uniform() < mask_prob as f64 {
359                masked.row_mut(i).fill(0.0);
360            }
361        }
362        masked
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_dataloader_creation() {
372        let data = Array2::<f32>::zeros((1000, 3));
373        let config = DataLoaderConfig::default()
374            .with_window_size(64)
375            .with_batch_size(32);
376
377        let loader = TimeSeriesDataLoader::new(data, config);
378        assert!(loader.is_ok());
379    }
380
381    #[test]
382    fn test_dataloader_insufficient_data() {
383        let data = Array2::<f32>::zeros((50, 3)); // Too short
384        let config = DataLoaderConfig::default()
385            .with_window_size(64)
386            .with_horizon(1);
387
388        let loader = TimeSeriesDataLoader::new(data, config);
389        assert!(loader.is_err());
390    }
391
392    #[test]
393    fn test_num_batches() {
394        let data = Array2::<f32>::zeros((1000, 3));
395        let config = DataLoaderConfig::default()
396            .with_window_size(64)
397            .with_batch_size(32)
398            .with_overlap(0.0);
399
400        let loader = TimeSeriesDataLoader::new(data, config).unwrap();
401        assert!(loader.num_batches() > 0);
402    }
403
404    #[test]
405    fn test_batch_iteration() {
406        let data = Array2::<f32>::from_shape_fn((200, 3), |(i, j)| (i + j) as f32);
407        let config = DataLoaderConfig::default()
408            .with_window_size(10)
409            .with_batch_size(4)
410            .with_horizon(1)
411            .with_shuffle(false);
412
413        let mut loader = TimeSeriesDataLoader::new(data, config).unwrap();
414
415        let mut batch_count = 0;
416        for batch in loader.iter_batches() {
417            let (inputs, targets) = batch.unwrap();
418            assert_eq!(inputs.ncols(), 3);
419            assert_eq!(targets.ncols(), 3);
420            batch_count += 1;
421        }
422
423        assert!(batch_count > 0);
424        assert_eq!(batch_count, loader.num_batches());
425    }
426
427    #[test]
428    fn test_tensor_conversion() {
429        let data = Array2::<f32>::from_shape_fn((200, 3), |(i, j)| (i + j) as f32);
430        let config = DataLoaderConfig::default()
431            .with_window_size(10)
432            .with_batch_size(4)
433            .with_horizon(1);
434
435        let mut loader = TimeSeriesDataLoader::new(data, config).unwrap();
436
437        // Test just one batch
438        let batch = loader.iter_batches().next().unwrap();
439        let (inputs, targets) = batch.unwrap();
440        let device = Device::Cpu;
441
442        let (input_tensor, target_tensor) = loader.to_tensors(&inputs, &targets, &device).unwrap();
443
444        assert_eq!(input_tensor.dims().len(), 3); // [batch, seq, features]
445        assert_eq!(target_tensor.dims().len(), 3);
446        assert_eq!(input_tensor.dims()[2], 3); // 3 features
447    }
448
449    #[test]
450    fn test_overlap() {
451        let data = Array2::<f32>::zeros((200, 3));
452        let config_no_overlap = DataLoaderConfig::default()
453            .with_window_size(10)
454            .with_overlap(0.0);
455
456        let config_overlap = DataLoaderConfig::default()
457            .with_window_size(10)
458            .with_overlap(0.5);
459
460        let loader_no_overlap = TimeSeriesDataLoader::new(data.clone(), config_no_overlap).unwrap();
461        let loader_overlap = TimeSeriesDataLoader::new(data, config_overlap).unwrap();
462
463        // With overlap, we should have more samples
464        assert!(loader_overlap.num_samples() > loader_no_overlap.num_samples());
465    }
466
467    #[test]
468    fn test_augmentation_noise() {
469        let data = Array2::<f32>::zeros((100, 3));
470        let augmented = TimeSeriesAugmentation::add_noise(&data, 0.1);
471
472        assert_eq!(augmented.dim(), data.dim());
473        // With noise, not all values should be exactly zero
474        assert!(augmented.iter().any(|&x| x != 0.0));
475    }
476
477    #[test]
478    fn test_augmentation_scale() {
479        let data = Array2::<f32>::ones((100, 3));
480        let augmented = TimeSeriesAugmentation::scale(&data, 0.5, 1.5);
481
482        assert_eq!(augmented.dim(), data.dim());
483        // Values should be scaled
484        let mean = augmented.mean().unwrap();
485        assert!((0.5..=1.5).contains(&mean));
486    }
487
488    #[test]
489    fn test_drop_last() {
490        let data = Array2::<f32>::zeros((100, 3));
491        let config_drop = DataLoaderConfig::default()
492            .with_window_size(10)
493            .with_batch_size(7)
494            .with_drop_last(true);
495
496        let config_no_drop = DataLoaderConfig::default()
497            .with_window_size(10)
498            .with_batch_size(7)
499            .with_drop_last(false);
500
501        let loader_drop = TimeSeriesDataLoader::new(data.clone(), config_drop).unwrap();
502        let loader_no_drop = TimeSeriesDataLoader::new(data, config_no_drop).unwrap();
503
504        // Without drop_last, we might have more batches
505        assert!(loader_no_drop.num_batches() >= loader_drop.num_batches());
506    }
507}