1use crate::error::{CoreError, CoreResult};
38use candle_core::{Device, Tensor};
39use scirs2_core::ndarray::{s, Array2};
40use serde::{Deserialize, Serialize};
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct DataLoaderConfig {
45 pub window_size: usize,
47 pub horizon: usize,
49 pub batch_size: usize,
51 pub shuffle: bool,
53 pub overlap: f32,
55 pub drop_last: bool,
57 pub num_workers: usize,
59}
60
61impl Default for DataLoaderConfig {
62 fn default() -> Self {
63 Self {
64 window_size: 64,
65 horizon: 1,
66 batch_size: 32,
67 shuffle: true,
68 overlap: 0.0,
69 drop_last: false,
70 num_workers: 1,
71 }
72 }
73}
74
75impl DataLoaderConfig {
76 pub fn new() -> Self {
77 Self::default()
78 }
79
80 pub fn with_window_size(mut self, window_size: usize) -> Self {
81 self.window_size = window_size;
82 self
83 }
84
85 pub fn with_horizon(mut self, horizon: usize) -> Self {
86 self.horizon = horizon;
87 self
88 }
89
90 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
91 self.batch_size = batch_size;
92 self
93 }
94
95 pub fn with_shuffle(mut self, shuffle: bool) -> Self {
96 self.shuffle = shuffle;
97 self
98 }
99
100 pub fn with_overlap(mut self, overlap: f32) -> Self {
101 self.overlap = overlap.clamp(0.0, 1.0);
102 self
103 }
104
105 pub fn with_drop_last(mut self, drop_last: bool) -> Self {
106 self.drop_last = drop_last;
107 self
108 }
109}
110
111pub struct TimeSeriesDataLoader {
113 data: Array2<f32>,
114 config: DataLoaderConfig,
115 indices: Vec<usize>,
116 current_epoch: usize,
117}
118
119impl TimeSeriesDataLoader {
120 pub fn new(data: Array2<f32>, config: DataLoaderConfig) -> CoreResult<Self> {
126 if data.nrows() < config.window_size + config.horizon {
127 return Err(CoreError::InvalidConfig(format!(
128 "Data length {} is too short for window_size {} + horizon {}",
129 data.nrows(),
130 config.window_size,
131 config.horizon
132 )));
133 }
134
135 let stride = ((config.window_size as f32) * (1.0 - config.overlap)).max(1.0) as usize;
137
138 let max_start = data.nrows() - config.window_size - config.horizon + 1;
140 let indices: Vec<usize> = (0..max_start).step_by(stride).collect();
141
142 Ok(Self {
143 data,
144 config,
145 indices,
146 current_epoch: 0,
147 })
148 }
149
150 pub fn num_batches(&self) -> usize {
152 let num_samples = self.indices.len();
153 if self.config.drop_last {
154 num_samples / self.config.batch_size
155 } else {
156 num_samples.div_ceil(self.config.batch_size)
157 }
158 }
159
160 pub fn num_samples(&self) -> usize {
162 self.indices.len()
163 }
164
165 pub fn shuffle(&mut self) {
167 if self.config.shuffle {
168 use scirs2_core::convenience::uniform;
169 for i in (1..self.indices.len()).rev() {
171 let j = (uniform() * (i + 1) as f64) as usize;
172 self.indices.swap(i, j);
173 }
174 }
175 }
176
177 fn extract_window(&self, start_idx: usize) -> CoreResult<(Array2<f32>, Array2<f32>)> {
179 let end_input = start_idx + self.config.window_size;
180 let end_target = end_input + self.config.horizon;
181
182 if end_target > self.data.nrows() {
183 return Err(CoreError::Generic(format!(
184 "Window exceeds data bounds: {} > {}",
185 end_target,
186 self.data.nrows()
187 )));
188 }
189
190 let input = self.data.slice(s![start_idx..end_input, ..]).to_owned();
191
192 let target = self.data.slice(s![end_input..end_target, ..]).to_owned();
193
194 Ok((input, target))
195 }
196
197 fn create_batch(&self, batch_indices: &[usize]) -> CoreResult<(Array2<f32>, Array2<f32>)> {
199 let mut inputs = Vec::new();
200 let mut targets = Vec::new();
201
202 for &idx in batch_indices {
203 let start = self.indices[idx];
204 let (input, target) = self.extract_window(start)?;
205 inputs.push(input);
206 targets.push(target);
207 }
208
209 let batch_size = inputs.len();
211 let window_size = self.config.window_size;
212 let horizon = self.config.horizon;
213 let n_features = self.data.ncols();
214
215 let mut batch_input = Array2::zeros((batch_size * window_size, n_features));
216 let mut batch_target = Array2::zeros((batch_size * horizon, n_features));
217
218 for (i, (inp, tgt)) in inputs.iter().zip(targets.iter()).enumerate() {
219 let input_start = i * window_size;
220 let input_end = input_start + window_size;
221 batch_input
222 .slice_mut(s![input_start..input_end, ..])
223 .assign(inp);
224
225 let target_start = i * horizon;
226 let target_end = target_start + horizon;
227 batch_target
228 .slice_mut(s![target_start..target_end, ..])
229 .assign(tgt);
230 }
231
232 Ok((batch_input, batch_target))
233 }
234
235 pub fn iter_batches(&mut self) -> BatchIterator<'_> {
237 if self.current_epoch > 0 {
238 self.shuffle();
239 }
240 self.current_epoch += 1;
241
242 BatchIterator {
243 loader: self,
244 current_batch: 0,
245 }
246 }
247
248 pub fn to_tensors(
250 &self,
251 inputs: &Array2<f32>,
252 targets: &Array2<f32>,
253 device: &Device,
254 ) -> CoreResult<(Tensor, Tensor)> {
255 let batch_size = inputs.nrows() / self.config.window_size;
256 let window_size = self.config.window_size;
257 let horizon = self.config.horizon;
258 let n_features = inputs.ncols();
259
260 let input_vec: Vec<f32> = inputs.iter().copied().collect();
262 let target_vec: Vec<f32> = targets.iter().copied().collect();
263
264 let input_tensor =
266 Tensor::from_vec(input_vec, &[batch_size, window_size, n_features], device)
267 .map_err(|e| CoreError::Generic(format!("Failed to create input tensor: {}", e)))?;
268
269 let target_tensor =
270 Tensor::from_vec(target_vec, &[batch_size, horizon, n_features], device).map_err(
271 |e| CoreError::Generic(format!("Failed to create target tensor: {}", e)),
272 )?;
273
274 Ok((input_tensor, target_tensor))
275 }
276
277 pub fn config(&self) -> &DataLoaderConfig {
279 &self.config
280 }
281}
282
283pub struct BatchIterator<'a> {
285 loader: &'a TimeSeriesDataLoader,
286 current_batch: usize,
287}
288
289impl<'a> Iterator for BatchIterator<'a> {
290 type Item = CoreResult<(Array2<f32>, Array2<f32>)>;
291
292 fn next(&mut self) -> Option<Self::Item> {
293 let num_batches = self.loader.num_batches();
294 if self.current_batch >= num_batches {
295 return None;
296 }
297
298 let start_idx = self.current_batch * self.loader.config.batch_size;
299 let end_idx = (start_idx + self.loader.config.batch_size).min(self.loader.indices.len());
300
301 if self.loader.config.drop_last && end_idx - start_idx < self.loader.config.batch_size {
303 return None;
304 }
305
306 let batch_indices: Vec<usize> = (start_idx..end_idx).collect();
307 self.current_batch += 1;
308
309 Some(self.loader.create_batch(&batch_indices))
310 }
311}
312
313pub struct TimeSeriesAugmentation;
315
316impl TimeSeriesAugmentation {
317 pub fn add_noise(data: &Array2<f32>, std: f32) -> Array2<f32> {
319 use scirs2_core::convenience::uniform;
320 let noise = Array2::from_shape_fn(data.dim(), |_| {
321 let u1 = uniform();
323 let u2 = uniform();
324 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
325 (z0 * std as f64) as f32
326 });
327 data + &noise
328 }
329
330 pub fn scale(data: &Array2<f32>, min_scale: f32, max_scale: f32) -> Array2<f32> {
332 use scirs2_core::convenience::uniform;
333 let scale = uniform() * (max_scale - min_scale) as f64 + min_scale as f64;
334 data * (scale as f32)
335 }
336
337 pub fn time_shift(data: &Array2<f32>, max_shift: usize) -> Array2<f32> {
339 use scirs2_core::convenience::uniform;
340 let shift = (uniform() * max_shift as f64) as usize;
341
342 let mut shifted = data.clone();
343 if shift > 0 {
344 let n = data.nrows();
345 for i in 0..n {
346 let src = (i + shift) % n;
347 shifted.row_mut(i).assign(&data.row(src));
348 }
349 }
350 shifted
351 }
352
353 pub fn mask(data: &Array2<f32>, mask_prob: f32) -> Array2<f32> {
355 use scirs2_core::convenience::uniform;
356 let mut masked = data.clone();
357 for i in 0..masked.nrows() {
358 if uniform() < mask_prob as f64 {
359 masked.row_mut(i).fill(0.0);
360 }
361 }
362 masked
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn test_dataloader_creation() {
372 let data = Array2::<f32>::zeros((1000, 3));
373 let config = DataLoaderConfig::default()
374 .with_window_size(64)
375 .with_batch_size(32);
376
377 let loader = TimeSeriesDataLoader::new(data, config);
378 assert!(loader.is_ok());
379 }
380
381 #[test]
382 fn test_dataloader_insufficient_data() {
383 let data = Array2::<f32>::zeros((50, 3)); let config = DataLoaderConfig::default()
385 .with_window_size(64)
386 .with_horizon(1);
387
388 let loader = TimeSeriesDataLoader::new(data, config);
389 assert!(loader.is_err());
390 }
391
392 #[test]
393 fn test_num_batches() {
394 let data = Array2::<f32>::zeros((1000, 3));
395 let config = DataLoaderConfig::default()
396 .with_window_size(64)
397 .with_batch_size(32)
398 .with_overlap(0.0);
399
400 let loader = TimeSeriesDataLoader::new(data, config).unwrap();
401 assert!(loader.num_batches() > 0);
402 }
403
404 #[test]
405 fn test_batch_iteration() {
406 let data = Array2::<f32>::from_shape_fn((200, 3), |(i, j)| (i + j) as f32);
407 let config = DataLoaderConfig::default()
408 .with_window_size(10)
409 .with_batch_size(4)
410 .with_horizon(1)
411 .with_shuffle(false);
412
413 let mut loader = TimeSeriesDataLoader::new(data, config).unwrap();
414
415 let mut batch_count = 0;
416 for batch in loader.iter_batches() {
417 let (inputs, targets) = batch.unwrap();
418 assert_eq!(inputs.ncols(), 3);
419 assert_eq!(targets.ncols(), 3);
420 batch_count += 1;
421 }
422
423 assert!(batch_count > 0);
424 assert_eq!(batch_count, loader.num_batches());
425 }
426
427 #[test]
428 fn test_tensor_conversion() {
429 let data = Array2::<f32>::from_shape_fn((200, 3), |(i, j)| (i + j) as f32);
430 let config = DataLoaderConfig::default()
431 .with_window_size(10)
432 .with_batch_size(4)
433 .with_horizon(1);
434
435 let mut loader = TimeSeriesDataLoader::new(data, config).unwrap();
436
437 let batch = loader.iter_batches().next().unwrap();
439 let (inputs, targets) = batch.unwrap();
440 let device = Device::Cpu;
441
442 let (input_tensor, target_tensor) = loader.to_tensors(&inputs, &targets, &device).unwrap();
443
444 assert_eq!(input_tensor.dims().len(), 3); assert_eq!(target_tensor.dims().len(), 3);
446 assert_eq!(input_tensor.dims()[2], 3); }
448
449 #[test]
450 fn test_overlap() {
451 let data = Array2::<f32>::zeros((200, 3));
452 let config_no_overlap = DataLoaderConfig::default()
453 .with_window_size(10)
454 .with_overlap(0.0);
455
456 let config_overlap = DataLoaderConfig::default()
457 .with_window_size(10)
458 .with_overlap(0.5);
459
460 let loader_no_overlap = TimeSeriesDataLoader::new(data.clone(), config_no_overlap).unwrap();
461 let loader_overlap = TimeSeriesDataLoader::new(data, config_overlap).unwrap();
462
463 assert!(loader_overlap.num_samples() > loader_no_overlap.num_samples());
465 }
466
467 #[test]
468 fn test_augmentation_noise() {
469 let data = Array2::<f32>::zeros((100, 3));
470 let augmented = TimeSeriesAugmentation::add_noise(&data, 0.1);
471
472 assert_eq!(augmented.dim(), data.dim());
473 assert!(augmented.iter().any(|&x| x != 0.0));
475 }
476
477 #[test]
478 fn test_augmentation_scale() {
479 let data = Array2::<f32>::ones((100, 3));
480 let augmented = TimeSeriesAugmentation::scale(&data, 0.5, 1.5);
481
482 assert_eq!(augmented.dim(), data.dim());
483 let mean = augmented.mean().unwrap();
485 assert!((0.5..=1.5).contains(&mean));
486 }
487
488 #[test]
489 fn test_drop_last() {
490 let data = Array2::<f32>::zeros((100, 3));
491 let config_drop = DataLoaderConfig::default()
492 .with_window_size(10)
493 .with_batch_size(7)
494 .with_drop_last(true);
495
496 let config_no_drop = DataLoaderConfig::default()
497 .with_window_size(10)
498 .with_batch_size(7)
499 .with_drop_last(false);
500
501 let loader_drop = TimeSeriesDataLoader::new(data.clone(), config_drop).unwrap();
502 let loader_no_drop = TimeSeriesDataLoader::new(data, config_no_drop).unwrap();
503
504 assert!(loader_no_drop.num_batches() >= loader_drop.num_batches());
506 }
507}