entrenar/prune/data_loader/
loader.rs1use super::config::CalibrationDataConfig;
4use super::iter::CalibrationDataIter;
5use crate::train::Batch;
6use crate::Tensor;
7
8#[derive(Debug, Clone)]
13pub struct CalibrationDataLoader {
14 config: CalibrationDataConfig,
16 data: Option<Vec<Batch>>,
18 pub(crate) position: usize,
20}
21
22impl CalibrationDataLoader {
23 pub fn new(config: CalibrationDataConfig) -> Self {
25 Self { config, data: None, position: 0 }
26 }
27
28 pub fn with_synthetic_data(config: CalibrationDataConfig) -> Self {
30 let mut loader = Self::new(config);
31 loader.generate_synthetic_data();
32 loader
33 }
34
35 fn generate_synthetic_data(&mut self) {
37 use rand::prelude::*;
38 use rand::rngs::StdRng;
39
40 let mut rng = StdRng::seed_from_u64(self.config.seed());
41 let mut batches = Vec::with_capacity(self.config.num_batches());
42
43 let mut samples_remaining = self.config.num_samples();
44
45 while samples_remaining > 0 {
46 let batch_size = samples_remaining.min(self.config.batch_size());
47 samples_remaining -= batch_size;
48
49 let input_size = batch_size * self.config.sequence_length();
51 let inputs: Vec<f32> = (0..input_size).map(|_| rng.random::<f32>()).collect();
52 let targets: Vec<f32> = (0..batch_size).map(|_| rng.random::<f32>()).collect();
53
54 batches.push(Batch::new(
55 Tensor::from_vec(inputs, false),
56 Tensor::from_vec(targets, false),
57 ));
58 }
59
60 self.data = Some(batches);
61 }
62
63 pub fn load(&mut self) -> Result<(), String> {
68 if self.data.is_some() {
69 return Ok(());
70 }
71
72 self.generate_synthetic_data();
75 Ok(())
76 }
77
78 pub fn config(&self) -> &CalibrationDataConfig {
80 &self.config
81 }
82
83 pub fn is_loaded(&self) -> bool {
85 self.data.is_some()
86 }
87
88 pub fn num_batches(&self) -> usize {
90 self.data.as_ref().map_or(0, Vec::len)
91 }
92
93 pub fn reset(&mut self) {
95 self.position = 0;
96 }
97
98 pub fn get_batch(&self, index: usize) -> Option<&Batch> {
100 self.data.as_ref().and_then(|d| d.get(index))
101 }
102
103 pub fn iter(&self) -> CalibrationDataIter<'_> {
105 CalibrationDataIter::new(self)
106 }
107}
108
109impl<'a> IntoIterator for &'a CalibrationDataLoader {
110 type Item = &'a Batch;
111 type IntoIter = CalibrationDataIter<'a>;
112
113 fn into_iter(self) -> Self::IntoIter {
114 self.iter()
115 }
116}