Skip to main content

quantize_rs/calibration/
mod.rs

1//! Calibration datasets and activation-based range estimation.
2//!
3//! - [`CalibrationDataset`] — load or generate calibration samples
4//! - [`methods::CalibrationMethod`] — range optimization strategies
5//! - [`stats::ActivationStats`] — incremental min/max/histogram tracker
6//! - [`inference::ActivationEstimator`] — run inference to collect activation stats
7
8use crate::errors::{QuantizeError, Result};
9#[cfg(feature = "calibration")]
10use std::path::Path;
11
12#[cfg(feature = "calibration")]
13pub mod inference;
14pub mod methods;
15pub mod stats;
16
17#[cfg(feature = "calibration")]
18pub use inference::ActivationEstimator;
19
20/// A collection of FP32 calibration samples used for range estimation.
21#[derive(Clone)]
22pub struct CalibrationDataset {
23    /// Individual samples, each flattened to match `shape`.
24    pub samples: Vec<Vec<f32>>,
25
26    /// Shape of a single sample (excluding batch dimension).
27    pub shape: Vec<usize>,
28}
29
30impl std::fmt::Debug for CalibrationDataset {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("CalibrationDataset")
33            .field("num_samples", &self.samples.len())
34            .field("shape", &self.shape)
35            .finish()
36    }
37}
38
39impl CalibrationDataset {
40    /// Load calibration samples from a NumPy `.npy` file.
41    ///
42    /// The array must be at least 2-dimensional `[batch, ...]`.
43    ///
44    /// Requires the `calibration` feature (enabled by default).
45    ///
46    /// # Errors
47    ///
48    /// Returns [`QuantizeError::Calibration`] if the file is missing, not `.npy`,
49    /// or has an invalid shape.
50    #[cfg(feature = "calibration")]
51    pub fn from_numpy(path: impl AsRef<Path>) -> Result<Self> {
52        use ndarray::{Array, IxDyn};
53
54        let path = path.as_ref();
55
56        if !path.exists() {
57            return Err(QuantizeError::Calibration {
58                reason: format!("File not found: {}", path.display()),
59            });
60        }
61
62        let array: Array<f32, IxDyn> = if path.extension().and_then(|s| s.to_str()) == Some("npy") {
63            ndarray_npy::read_npy(path).map_err(|e| QuantizeError::Calibration {
64                reason: format!("Failed to read NPY file '{}': {e}", path.display()),
65            })?
66        } else {
67            return Err(QuantizeError::Calibration {
68                reason: "Only .npy files supported currently".into(),
69            });
70        };
71
72        let shape: Vec<usize> = array.shape().to_vec();
73
74        if shape.is_empty() {
75            return Err(QuantizeError::Calibration {
76                reason: "Invalid array shape".into(),
77            });
78        }
79
80        if shape.len() < 2 {
81            return Err(QuantizeError::Calibration {
82                reason: format!(
83                    "Calibration data must be at least 2-dimensional (batch, ...). Got shape {:?}",
84                    shape
85                ),
86            });
87        }
88
89        let num_samples = shape[0];
90        let sample_size: usize = shape[1..].iter().product();
91
92        let data = array.into_raw_vec();
93        let mut samples = Vec::with_capacity(num_samples);
94
95        for i in 0..num_samples {
96            let start = i * sample_size;
97            let end = start + sample_size;
98            samples.push(data[start..end].to_vec());
99        }
100
101        Ok(Self {
102            samples,
103            shape: shape[1..].to_vec(),
104        })
105    }
106
107    /// Generate random calibration samples uniformly distributed in `range`.
108    ///
109    /// # Errors
110    ///
111    /// Returns [`QuantizeError::Calibration`] if shape is empty, `num_samples` is 0,
112    /// or the range is invalid.
113    pub fn random(shape: Vec<usize>, num_samples: usize, range: (f32, f32)) -> Result<Self> {
114        if shape.is_empty() || shape.contains(&0) {
115            return Err(QuantizeError::Calibration {
116                reason: format!("Invalid shape: {:?} - all dimensions must be > 0", shape),
117            });
118        }
119        if num_samples == 0 {
120            return Err(QuantizeError::Calibration {
121                reason: "num_samples must be > 0".into(),
122            });
123        }
124        if range.0 >= range.1 {
125            return Err(QuantizeError::Calibration {
126                reason: format!(
127                    "Invalid range: ({}, {}) - min must be less than max",
128                    range.0, range.1
129                ),
130            });
131        }
132        use rand::Rng;
133        let mut rng = rand::thread_rng();
134
135        let sample_size: usize = shape.iter().product();
136        let mut samples = Vec::with_capacity(num_samples);
137
138        for _ in 0..num_samples {
139            let sample: Vec<f32> = (0..sample_size)
140                .map(|_| rng.gen_range(range.0..range.1))
141                .collect();
142            samples.push(sample);
143        }
144
145        Ok(Self { samples, shape })
146    }
147
148    /// Create a dataset from pre-existing sample vectors.
149    ///
150    /// # Errors
151    ///
152    /// Returns [`QuantizeError::Calibration`] if `samples` is empty or any
153    /// sample has the wrong length for the given `shape`.
154    pub fn from_samples(samples: Vec<Vec<f32>>, shape: Vec<usize>) -> Result<Self> {
155        let num_samples = samples.len();
156
157        if num_samples == 0 {
158            return Err(QuantizeError::Calibration {
159                reason: "No samples provided".into(),
160            });
161        }
162
163        let expected_size: usize = shape.iter().product();
164
165        for (i, sample) in samples.iter().enumerate() {
166            if sample.len() != expected_size {
167                return Err(QuantizeError::Calibration {
168                    reason: format!(
169                        "Sample {} has size {} but expected {} (shape: {:?})",
170                        i,
171                        sample.len(),
172                        expected_size,
173                        shape
174                    ),
175                });
176            }
177        }
178
179        Ok(Self { samples, shape })
180    }
181
182    /// Shape of a single sample (excluding batch dimension).
183    pub fn sample_shape(&self) -> &[usize] {
184        &self.shape
185    }
186
187    /// Number of samples in the dataset.
188    pub fn len(&self) -> usize {
189        self.samples.len()
190    }
191
192    /// Whether the dataset contains no samples.
193    pub fn is_empty(&self) -> bool {
194        self.samples.is_empty()
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_random_dataset() {
204        let dataset = CalibrationDataset::random(vec![3, 224, 224], 10, (-1.0, 1.0)).unwrap();
205
206        assert_eq!(dataset.len(), 10);
207        assert_eq!(dataset.sample_shape(), &[3, 224, 224]);
208        assert_eq!(dataset.samples[0].len(), 3 * 224 * 224);
209    }
210
211    #[test]
212    fn test_from_samples() {
213        let samples = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
214
215        let dataset = CalibrationDataset::from_samples(samples, vec![3]).unwrap();
216        assert_eq!(dataset.len(), 2);
217    }
218}