quantrs2_ml/pytorch_api/
data.rs

1//! Data loading utilities for PyTorch-like API
2
3use crate::error::{MLError, Result};
4use crate::scirs2_integration::SciRS2Array;
5
6/// Data loader trait
7pub trait DataLoader {
8    /// Get next batch
9    fn next_batch(&mut self) -> Result<Option<(SciRS2Array, SciRS2Array)>>;
10
11    /// Reset to beginning
12    fn reset(&mut self);
13
14    /// Get batch size
15    fn batch_size(&self) -> usize;
16}
17
18/// Simple in-memory data loader
19pub struct MemoryDataLoader {
20    /// Input data
21    inputs: SciRS2Array,
22    /// Target data
23    targets: SciRS2Array,
24    /// Batch size
25    batch_size_val: usize,
26    /// Current position
27    current_pos: usize,
28    /// Shuffle data
29    shuffle: bool,
30    /// Indices for shuffling
31    indices: Vec<usize>,
32}
33
34impl MemoryDataLoader {
35    /// Create new memory data loader
36    pub fn new(
37        inputs: SciRS2Array,
38        targets: SciRS2Array,
39        batch_size: usize,
40        shuffle: bool,
41    ) -> Result<Self> {
42        let num_samples = inputs.data.shape()[0];
43        if targets.data.shape()[0] != num_samples {
44            return Err(MLError::InvalidConfiguration(
45                "Input and target batch sizes don't match".to_string(),
46            ));
47        }
48
49        let indices: Vec<usize> = (0..num_samples).collect();
50
51        Ok(Self {
52            inputs,
53            targets,
54            batch_size_val: batch_size,
55            current_pos: 0,
56            shuffle,
57            indices,
58        })
59    }
60
61    /// Shuffle indices
62    fn shuffle_indices(&mut self) {
63        if self.shuffle {
64            for i in (1..self.indices.len()).rev() {
65                let j = fastrand::usize(0..=i);
66                self.indices.swap(i, j);
67            }
68        }
69    }
70}
71
72impl DataLoader for MemoryDataLoader {
73    fn next_batch(&mut self) -> Result<Option<(SciRS2Array, SciRS2Array)>> {
74        if self.current_pos >= self.indices.len() {
75            return Ok(None);
76        }
77
78        let end_pos = (self.current_pos + self.batch_size_val).min(self.indices.len());
79        let _batch_indices = &self.indices[self.current_pos..end_pos];
80
81        // Extract batch data (simplified - would use proper indexing)
82        let batch_inputs = self.inputs.clone();
83        let batch_targets = self.targets.clone();
84
85        self.current_pos = end_pos;
86
87        Ok(Some((batch_inputs, batch_targets)))
88    }
89
90    fn reset(&mut self) {
91        self.current_pos = 0;
92        self.shuffle_indices();
93    }
94
95    fn batch_size(&self) -> usize {
96        self.batch_size_val
97    }
98}