Skip to main content

entrenar/prune/data_loader/
loader.rs

1//! Calibration data loader implementation.
2
3use super::config::CalibrationDataConfig;
4use super::iter::CalibrationDataIter;
5use crate::train::Batch;
6use crate::Tensor;
7
8/// Calibration data loader for pruning.
9///
10/// Provides an iterator over calibration batches for collecting
11/// activation statistics during pruning.
12#[derive(Debug, Clone)]
13pub struct CalibrationDataLoader {
14    /// Configuration.
15    config: CalibrationDataConfig,
16    /// Pre-loaded data (if available).
17    data: Option<Vec<Batch>>,
18    /// Current position in iteration.
19    pub(crate) position: usize,
20}
21
22impl CalibrationDataLoader {
23    /// Create a new calibration data loader.
24    pub fn new(config: CalibrationDataConfig) -> Self {
25        Self { config, data: None, position: 0 }
26    }
27
28    /// Create a data loader with pre-loaded synthetic data for testing.
29    pub fn with_synthetic_data(config: CalibrationDataConfig) -> Self {
30        let mut loader = Self::new(config);
31        loader.generate_synthetic_data();
32        loader
33    }
34
35    /// Generate synthetic calibration data for testing.
36    fn generate_synthetic_data(&mut self) {
37        use rand::prelude::*;
38        use rand::rngs::StdRng;
39
40        let mut rng = StdRng::seed_from_u64(self.config.seed());
41        let mut batches = Vec::with_capacity(self.config.num_batches());
42
43        let mut samples_remaining = self.config.num_samples();
44
45        while samples_remaining > 0 {
46            let batch_size = samples_remaining.min(self.config.batch_size());
47            samples_remaining -= batch_size;
48
49            // Generate random input data
50            let input_size = batch_size * self.config.sequence_length();
51            let inputs: Vec<f32> = (0..input_size).map(|_| rng.random::<f32>()).collect();
52            let targets: Vec<f32> = (0..batch_size).map(|_| rng.random::<f32>()).collect();
53
54            batches.push(Batch::new(
55                Tensor::from_vec(inputs, false),
56                Tensor::from_vec(targets, false),
57            ));
58        }
59
60        self.data = Some(batches);
61    }
62
63    /// Load data from the configured source.
64    ///
65    /// This is a placeholder for actual dataset loading.
66    /// In production, this would load from C4, WikiText, etc.
67    pub fn load(&mut self) -> Result<(), String> {
68        if self.data.is_some() {
69            return Ok(());
70        }
71
72        // For now, generate synthetic data
73        // Real implementation would load from dataset
74        self.generate_synthetic_data();
75        Ok(())
76    }
77
78    /// Get the configuration.
79    pub fn config(&self) -> &CalibrationDataConfig {
80        &self.config
81    }
82
83    /// Check if data is loaded.
84    pub fn is_loaded(&self) -> bool {
85        self.data.is_some()
86    }
87
88    /// Get the number of batches.
89    pub fn num_batches(&self) -> usize {
90        self.data.as_ref().map_or(0, Vec::len)
91    }
92
93    /// Reset the iterator position.
94    pub fn reset(&mut self) {
95        self.position = 0;
96    }
97
98    /// Get a batch by index.
99    pub fn get_batch(&self, index: usize) -> Option<&Batch> {
100        self.data.as_ref().and_then(|d| d.get(index))
101    }
102
103    /// Create an iterator over batches.
104    pub fn iter(&self) -> CalibrationDataIter<'_> {
105        CalibrationDataIter::new(self)
106    }
107}
108
109impl<'a> IntoIterator for &'a CalibrationDataLoader {
110    type Item = &'a Batch;
111    type IntoIter = CalibrationDataIter<'a>;
112
113    fn into_iter(self) -> Self::IntoIter {
114        self.iter()
115    }
116}