Skip to main content

quantize_rs/calibration/
mod.rs

1//! Calibration datasets and activation-based range estimation.
2//!
3//! - [`CalibrationDataset`] — load or generate calibration samples
4//! - [`methods::CalibrationMethod`] — range optimization strategies
5//! - [`stats::ActivationStats`] — incremental min/max/histogram tracker
6//! - [`inference::ActivationEstimator`] — run inference to collect activation stats
7
8use crate::errors::{QuantizeError, Result};
9#[cfg(feature = "calibration")]
10use std::path::Path;
11
12pub mod stats;
13pub mod methods;
14#[cfg(feature = "calibration")]
15pub mod inference;
16
17#[cfg(feature = "calibration")]
18pub use inference::ActivationEstimator;
19
20/// A collection of FP32 calibration samples used for range estimation.
21#[derive(Clone)]
22pub struct CalibrationDataset {
23    /// Individual samples, each flattened to match `shape`.
24    pub samples: Vec<Vec<f32>>,
25
26    /// Shape of a single sample (excluding batch dimension).
27    pub shape: Vec<usize>,
28}
29
30impl std::fmt::Debug for CalibrationDataset {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("CalibrationDataset")
33            .field("num_samples", &self.samples.len())
34            .field("shape", &self.shape)
35            .finish()
36    }
37}
38
39impl CalibrationDataset {
40    /// Load calibration samples from a NumPy `.npy` file.
41    ///
42    /// The array must be at least 2-dimensional `[batch, ...]`.
43    ///
44    /// Requires the `calibration` feature (enabled by default).
45    ///
46    /// # Errors
47    ///
48    /// Returns [`QuantizeError::Calibration`] if the file is missing, not `.npy`,
49    /// or has an invalid shape.
50    #[cfg(feature = "calibration")]
51    pub fn from_numpy(path: impl AsRef<Path>) -> Result<Self> {
52        use ndarray::{Array, IxDyn};
53
54        let path = path.as_ref();
55
56        if !path.exists() {
57            return Err(QuantizeError::Calibration { reason: format!("File not found: {}", path.display()) });
58        }
59
60        let array: Array<f32, IxDyn> = if path.extension().and_then(|s| s.to_str()) == Some("npy") {
61            ndarray_npy::read_npy(path)
62                .map_err(|e| QuantizeError::Calibration { reason: format!("Failed to read NPY file '{}': {e}", path.display()) })?
63        } else {
64            return Err(QuantizeError::Calibration { reason: "Only .npy files supported currently".into() });
65        };
66
67        let shape: Vec<usize> = array.shape().to_vec();
68
69        if shape.is_empty() {
70            return Err(QuantizeError::Calibration { reason: "Invalid array shape".into() });
71        }
72
73        if shape.len() < 2 {
74            return Err(QuantizeError::Calibration { reason: format!("Calibration data must be at least 2-dimensional (batch, ...). Got shape {:?}", shape) });
75        }
76
77        let num_samples = shape[0];
78        let sample_size: usize = shape[1..].iter().product();
79
80        let data = array.into_raw_vec();
81        let mut samples = Vec::with_capacity(num_samples);
82
83        for i in 0..num_samples {
84            let start = i * sample_size;
85            let end = start + sample_size;
86            samples.push(data[start..end].to_vec());
87        }
88
89        Ok(Self {
90            samples,
91            shape: shape[1..].to_vec(),
92        })
93    }
94    
95    /// Generate random calibration samples uniformly distributed in `range`.
96    ///
97    /// # Errors
98    ///
99    /// Returns [`QuantizeError::Calibration`] if shape is empty, `num_samples` is 0,
100    /// or the range is invalid.
101    pub fn random(shape: Vec<usize>, num_samples: usize, range: (f32, f32)) -> Result<Self> {
102        if shape.is_empty() || shape.contains(&0) {
103            return Err(QuantizeError::Calibration { reason: format!("Invalid shape: {:?} - all dimensions must be > 0", shape) });
104        }
105        if num_samples == 0 {
106            return Err(QuantizeError::Calibration { reason: "num_samples must be > 0".into() });
107        }
108        if range.0 >= range.1 {
109            return Err(QuantizeError::Calibration { reason: format!("Invalid range: ({}, {}) - min must be less than max", range.0, range.1) });
110        }
111        use rand::Rng;
112        let mut rng = rand::thread_rng();
113
114        let sample_size: usize = shape.iter().product();
115        let mut samples = Vec::with_capacity(num_samples);
116
117        for _ in 0..num_samples {
118            let sample: Vec<f32> = (0..sample_size)
119                .map(|_| rng.gen_range(range.0..range.1))
120                .collect();
121            samples.push(sample);
122        }
123
124        Ok(Self {
125            samples,
126            shape,
127        })
128    }
129    
130    /// Create a dataset from pre-existing sample vectors.
131    ///
132    /// # Errors
133    ///
134    /// Returns [`QuantizeError::Calibration`] if `samples` is empty or any
135    /// sample has the wrong length for the given `shape`.
136    pub fn from_samples(samples: Vec<Vec<f32>>, shape: Vec<usize>) -> Result<Self> {
137        let num_samples = samples.len();
138
139        if num_samples == 0 {
140            return Err(QuantizeError::Calibration { reason: "No samples provided".into() });
141        }
142
143        let expected_size: usize = shape.iter().product();
144
145        for (i, sample) in samples.iter().enumerate() {
146            if sample.len() != expected_size {
147                return Err(QuantizeError::Calibration {
148                    reason: format!(
149                        "Sample {} has size {} but expected {} (shape: {:?})",
150                        i, sample.len(), expected_size, shape
151                    ),
152                });
153            }
154        }
155        
156        Ok(Self {
157            samples,
158            shape,
159        })
160    }
161    
162    /// Shape of a single sample (excluding batch dimension).
163    pub fn sample_shape(&self) -> &[usize] {
164        &self.shape
165    }
166
167    /// Number of samples in the dataset.
168    pub fn len(&self) -> usize {
169        self.samples.len()
170    }
171
172    /// Whether the dataset contains no samples.
173    pub fn is_empty(&self) -> bool {
174        self.samples.is_empty()
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    
182    #[test]
183    fn test_random_dataset() {
184        let dataset = CalibrationDataset::random(vec![3, 224, 224], 10, (-1.0, 1.0)).unwrap();
185
186        assert_eq!(dataset.len(), 10);
187        assert_eq!(dataset.sample_shape(), &[3, 224, 224]);
188        assert_eq!(dataset.samples[0].len(), 3 * 224 * 224);
189    }
190    
191    #[test]
192    fn test_from_samples() {
193        let samples = vec![
194            vec![1.0, 2.0, 3.0],
195            vec![4.0, 5.0, 6.0],
196        ];
197        
198        let dataset = CalibrationDataset::from_samples(samples, vec![3]).unwrap();
199        assert_eq!(dataset.len(), 2);
200    }
201}