quantrs2_ml/pytorch_api/
data.rs1use crate::error::{MLError, Result};
4use crate::scirs2_integration::SciRS2Array;
5
6pub trait DataLoader {
8 fn next_batch(&mut self) -> Result<Option<(SciRS2Array, SciRS2Array)>>;
10
11 fn reset(&mut self);
13
14 fn batch_size(&self) -> usize;
16}
17
18pub struct MemoryDataLoader {
20 inputs: SciRS2Array,
22 targets: SciRS2Array,
24 batch_size_val: usize,
26 current_pos: usize,
28 shuffle: bool,
30 indices: Vec<usize>,
32}
33
34impl MemoryDataLoader {
35 pub fn new(
37 inputs: SciRS2Array,
38 targets: SciRS2Array,
39 batch_size: usize,
40 shuffle: bool,
41 ) -> Result<Self> {
42 let num_samples = inputs.data.shape()[0];
43 if targets.data.shape()[0] != num_samples {
44 return Err(MLError::InvalidConfiguration(
45 "Input and target batch sizes don't match".to_string(),
46 ));
47 }
48
49 let indices: Vec<usize> = (0..num_samples).collect();
50
51 Ok(Self {
52 inputs,
53 targets,
54 batch_size_val: batch_size,
55 current_pos: 0,
56 shuffle,
57 indices,
58 })
59 }
60
61 fn shuffle_indices(&mut self) {
63 if self.shuffle {
64 for i in (1..self.indices.len()).rev() {
65 let j = fastrand::usize(0..=i);
66 self.indices.swap(i, j);
67 }
68 }
69 }
70}
71
72impl DataLoader for MemoryDataLoader {
73 fn next_batch(&mut self) -> Result<Option<(SciRS2Array, SciRS2Array)>> {
74 if self.current_pos >= self.indices.len() {
75 return Ok(None);
76 }
77
78 let end_pos = (self.current_pos + self.batch_size_val).min(self.indices.len());
79 let _batch_indices = &self.indices[self.current_pos..end_pos];
80
81 let batch_inputs = self.inputs.clone();
83 let batch_targets = self.targets.clone();
84
85 self.current_pos = end_pos;
86
87 Ok(Some((batch_inputs, batch_targets)))
88 }
89
90 fn reset(&mut self) {
91 self.current_pos = 0;
92 self.shuffle_indices();
93 }
94
95 fn batch_size(&self) -> usize {
96 self.batch_size_val
97 }
98}