Skip to main content

quantize_rs/calibration/
mod.rs

1//! Calibration datasets and activation-based range estimation.
2//!
3//! - [`CalibrationDataset`] — load or generate calibration samples
4//! - [`methods::CalibrationMethod`] — range optimization strategies
5//! - [`stats::ActivationStats`] — incremental min/max/histogram tracker
6//! - [`inference::ActivationEstimator`] — run inference to collect activation stats
7
8use crate::errors::{QuantizeError, Result};
9#[cfg(feature = "calibration")]
10use std::path::Path;
11
12#[cfg(feature = "calibration")]
13pub mod inference;
14pub mod methods;
15pub mod stats;
16
17#[cfg(feature = "calibration")]
18pub use inference::ActivationEstimator;
19
20/// A collection of FP32 calibration samples used for range estimation.
21#[derive(Clone)]
22pub struct CalibrationDataset {
23    /// Individual samples, each flattened to match `shape`.
24    pub samples: Vec<Vec<f32>>,
25
26    /// Shape of a single sample (excluding batch dimension).
27    pub shape: Vec<usize>,
28}
29
30impl std::fmt::Debug for CalibrationDataset {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("CalibrationDataset")
33            .field("num_samples", &self.samples.len())
34            .field("shape", &self.shape)
35            .finish()
36    }
37}
38
39impl CalibrationDataset {
40    /// Load calibration samples from a NumPy `.npy` file.
41    ///
42    /// The array must be at least 2-dimensional `[batch, ...]`.
43    ///
44    /// Requires the `calibration` feature (enabled by default).
45    ///
46    /// # Errors
47    ///
48    /// Returns [`QuantizeError::Calibration`] if the file is missing, not `.npy`,
49    /// or has an invalid shape.
50    #[cfg(feature = "calibration")]
51    pub fn from_numpy(path: impl AsRef<Path>) -> Result<Self> {
52        use ndarray::{Array, IxDyn};
53
54        let path = path.as_ref();
55
56        if !path.exists() {
57            return Err(QuantizeError::Calibration {
58                reason: format!("File not found: {}", path.display()),
59            });
60        }
61
62        let array: Array<f32, IxDyn> = if path.extension().and_then(|s| s.to_str()) == Some("npy") {
63            ndarray_npy::read_npy(path).map_err(|e| QuantizeError::Calibration {
64                reason: format!("Failed to read NPY file '{}': {e}", path.display()),
65            })?
66        } else {
67            return Err(QuantizeError::Calibration {
68                reason: "Only .npy files supported currently".into(),
69            });
70        };
71
72        let shape: Vec<usize> = array.shape().to_vec();
73
74        if shape.is_empty() {
75            return Err(QuantizeError::Calibration {
76                reason: "Invalid array shape".into(),
77            });
78        }
79
80        if shape.len() < 2 {
81            return Err(QuantizeError::Calibration {
82                reason: format!(
83                    "Calibration data must be at least 2-dimensional (batch, ...). Got shape {:?}",
84                    shape
85                ),
86            });
87        }
88
89        let num_samples = shape[0];
90        let sample_size: usize = shape[1..].iter().product();
91
92        let data = array.into_raw_vec();
93        let mut samples = Vec::with_capacity(num_samples);
94
95        for i in 0..num_samples {
96            let start = i * sample_size;
97            let end = start + sample_size;
98            samples.push(data[start..end].to_vec());
99        }
100
101        Ok(Self {
102            samples,
103            shape: shape[1..].to_vec(),
104        })
105    }
106
107    /// Load calibration samples from a HuggingFace `.safetensors` file that
108    /// contains exactly one tensor.
109    ///
110    /// The tensor must be f32 and at least 2-dimensional `[batch, ...]`.
111    /// Requires the `safetensors-input` feature.
112    ///
113    /// For files with multiple named tensors, use
114    /// [`from_safetensors_named`](Self::from_safetensors_named) instead.
115    #[cfg(feature = "safetensors-input")]
116    pub fn from_safetensors(path: impl AsRef<Path>) -> Result<Self> {
117        let path = path.as_ref();
118        let buffer = std::fs::read(path).map_err(|e| QuantizeError::Calibration {
119            reason: format!("Failed to read safetensors file '{}': {e}", path.display()),
120        })?;
121        let tensors = safetensors::SafeTensors::deserialize(&buffer).map_err(|e| {
122            QuantizeError::Calibration {
123                reason: format!("Failed to parse safetensors file: {e}"),
124            }
125        })?;
126        let names: Vec<String> = tensors.names().into_iter().map(|s| s.to_string()).collect();
127        if names.is_empty() {
128            return Err(QuantizeError::Calibration {
129                reason: "safetensors file contains no tensors".into(),
130            });
131        }
132        if names.len() > 1 {
133            return Err(QuantizeError::Calibration {
134                reason: format!(
135                    "safetensors file contains {} tensors; pass one explicitly via \
136                     from_safetensors_named().  Available tensors: {}",
137                    names.len(),
138                    names.join(", ")
139                ),
140            });
141        }
142        Self::from_safetensors_view(&tensors, &names[0])
143    }
144
145    /// Load calibration samples from a specific named tensor inside a
146    /// `.safetensors` file.
147    ///
148    /// Requires the `safetensors-input` feature.
149    #[cfg(feature = "safetensors-input")]
150    pub fn from_safetensors_named(path: impl AsRef<Path>, tensor_name: &str) -> Result<Self> {
151        let path = path.as_ref();
152        let buffer = std::fs::read(path).map_err(|e| QuantizeError::Calibration {
153            reason: format!("Failed to read safetensors file '{}': {e}", path.display()),
154        })?;
155        let tensors = safetensors::SafeTensors::deserialize(&buffer).map_err(|e| {
156            QuantizeError::Calibration {
157                reason: format!("Failed to parse safetensors file: {e}"),
158            }
159        })?;
160        Self::from_safetensors_view(&tensors, tensor_name)
161    }
162
163    #[cfg(feature = "safetensors-input")]
164    fn from_safetensors_view(
165        tensors: &safetensors::SafeTensors<'_>,
166        tensor_name: &str,
167    ) -> Result<Self> {
168        use safetensors::Dtype;
169
170        let view = tensors
171            .tensor(tensor_name)
172            .map_err(|e| QuantizeError::Calibration {
173                reason: format!(
174                    "Tensor '{}' not found in safetensors file: {e}",
175                    tensor_name
176                ),
177            })?;
178
179        if view.dtype() != Dtype::F32 {
180            return Err(QuantizeError::Calibration {
181                reason: format!(
182                    "Tensor '{}' has dtype {:?}; only F32 is supported for calibration input",
183                    tensor_name,
184                    view.dtype()
185                ),
186            });
187        }
188
189        let shape: Vec<usize> = view.shape().to_vec();
190        if shape.len() < 2 {
191            return Err(QuantizeError::Calibration {
192                reason: format!(
193                    "Calibration tensor must be at least 2-dimensional (batch, ...). \
194                     Got shape {:?}",
195                    shape
196                ),
197            });
198        }
199        let expected_bytes: usize = shape.iter().product::<usize>() * std::mem::size_of::<f32>();
200        let raw = view.data();
201        if raw.len() != expected_bytes {
202            return Err(QuantizeError::Calibration {
203                reason: format!(
204                    "Tensor '{}' data size {} bytes does not match shape {:?} \
205                     × 4 = {} bytes",
206                    tensor_name,
207                    raw.len(),
208                    shape,
209                    expected_bytes
210                ),
211            });
212        }
213
214        // safetensors stores data little-endian, which matches every target
215        // quantize-rs builds on today.  Decode per-f32 explicitly to stay
216        // endian-safe rather than relying on an unchecked cast.
217        let data: Vec<f32> = raw
218            .chunks_exact(4)
219            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
220            .collect();
221
222        let num_samples = shape[0];
223        let sample_size: usize = shape[1..].iter().product();
224        let mut samples = Vec::with_capacity(num_samples);
225        for i in 0..num_samples {
226            let start = i * sample_size;
227            let end = start + sample_size;
228            samples.push(data[start..end].to_vec());
229        }
230
231        Ok(Self {
232            samples,
233            shape: shape[1..].to_vec(),
234        })
235    }
236
237    /// Generate random calibration samples uniformly distributed in `range`.
238    ///
239    /// # Errors
240    ///
241    /// Returns [`QuantizeError::Calibration`] if shape is empty, `num_samples` is 0,
242    /// or the range is invalid.
243    pub fn random(shape: Vec<usize>, num_samples: usize, range: (f32, f32)) -> Result<Self> {
244        if shape.is_empty() || shape.contains(&0) {
245            return Err(QuantizeError::Calibration {
246                reason: format!("Invalid shape: {:?} - all dimensions must be > 0", shape),
247            });
248        }
249        if num_samples == 0 {
250            return Err(QuantizeError::Calibration {
251                reason: "num_samples must be > 0".into(),
252            });
253        }
254        if range.0 >= range.1 {
255            return Err(QuantizeError::Calibration {
256                reason: format!(
257                    "Invalid range: ({}, {}) - min must be less than max",
258                    range.0, range.1
259                ),
260            });
261        }
262        use rand::Rng;
263        let mut rng = rand::thread_rng();
264
265        let sample_size: usize = shape.iter().product();
266        let mut samples = Vec::with_capacity(num_samples);
267
268        for _ in 0..num_samples {
269            let sample: Vec<f32> = (0..sample_size)
270                .map(|_| rng.gen_range(range.0..range.1))
271                .collect();
272            samples.push(sample);
273        }
274
275        Ok(Self { samples, shape })
276    }
277
278    /// Create a dataset from pre-existing sample vectors.
279    ///
280    /// # Errors
281    ///
282    /// Returns [`QuantizeError::Calibration`] if `samples` is empty or any
283    /// sample has the wrong length for the given `shape`.
284    pub fn from_samples(samples: Vec<Vec<f32>>, shape: Vec<usize>) -> Result<Self> {
285        let num_samples = samples.len();
286
287        if num_samples == 0 {
288            return Err(QuantizeError::Calibration {
289                reason: "No samples provided".into(),
290            });
291        }
292
293        let expected_size: usize = shape.iter().product();
294
295        for (i, sample) in samples.iter().enumerate() {
296            if sample.len() != expected_size {
297                return Err(QuantizeError::Calibration {
298                    reason: format!(
299                        "Sample {} has size {} but expected {} (shape: {:?})",
300                        i,
301                        sample.len(),
302                        expected_size,
303                        shape
304                    ),
305                });
306            }
307        }
308
309        Ok(Self { samples, shape })
310    }
311
312    /// Shape of a single sample (excluding batch dimension).
313    pub fn sample_shape(&self) -> &[usize] {
314        &self.shape
315    }
316
317    /// Number of samples in the dataset.
318    pub fn len(&self) -> usize {
319        self.samples.len()
320    }
321
322    /// Whether the dataset contains no samples.
323    pub fn is_empty(&self) -> bool {
324        self.samples.is_empty()
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn test_random_dataset() {
334        let dataset = CalibrationDataset::random(vec![3, 224, 224], 10, (-1.0, 1.0)).unwrap();
335
336        assert_eq!(dataset.len(), 10);
337        assert_eq!(dataset.sample_shape(), &[3, 224, 224]);
338        assert_eq!(dataset.samples[0].len(), 3 * 224 * 224);
339    }
340
341    #[test]
342    fn test_from_samples() {
343        let samples = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
344
345        let dataset = CalibrationDataset::from_samples(samples, vec![3]).unwrap();
346        assert_eq!(dataset.len(), 2);
347    }
348
349    #[cfg(feature = "safetensors-input")]
350    #[test]
351    fn test_from_safetensors_round_trip() {
352        use safetensors::{serialize, tensor::TensorView, Dtype};
353        use std::collections::HashMap;
354
355        // Build 3 samples of shape [2, 4] = 24 floats.
356        let data: Vec<f32> = (0..24).map(|i| i as f32 * 0.1).collect();
357        let raw: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
358        let view = TensorView::new(Dtype::F32, vec![3, 2, 4], &raw).unwrap();
359        let mut tensors = HashMap::new();
360        tensors.insert("input".to_string(), view);
361        let bytes = serialize(&tensors, &None).unwrap();
362
363        let tmp = tempfile::NamedTempFile::with_suffix(".safetensors").unwrap();
364        std::fs::write(tmp.path(), &bytes).unwrap();
365
366        let dataset = CalibrationDataset::from_safetensors(tmp.path()).unwrap();
367        assert_eq!(dataset.len(), 3);
368        assert_eq!(dataset.sample_shape(), &[2, 4]);
369        // Each sample holds 8 floats.
370        assert_eq!(dataset.samples[0].len(), 8);
371        // First float of sample 0 is 0.0, first of sample 1 is 0.8 (index 8 * 0.1).
372        assert!((dataset.samples[0][0] - 0.0).abs() < 1e-6);
373        assert!((dataset.samples[1][0] - 0.8).abs() < 1e-6);
374    }
375
376    #[cfg(feature = "safetensors-input")]
377    #[test]
378    fn test_from_safetensors_multi_tensor_errors_without_name() {
379        use safetensors::{serialize, tensor::TensorView, Dtype};
380        use std::collections::HashMap;
381
382        let data: Vec<f32> = (0..8).map(|i| i as f32).collect();
383        let raw: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
384        let v1 = TensorView::new(Dtype::F32, vec![2, 4], &raw).unwrap();
385        let v2 = TensorView::new(Dtype::F32, vec![2, 4], &raw).unwrap();
386        let mut tensors = HashMap::new();
387        tensors.insert("a".to_string(), v1);
388        tensors.insert("b".to_string(), v2);
389        let bytes = serialize(&tensors, &None).unwrap();
390
391        let tmp = tempfile::NamedTempFile::with_suffix(".safetensors").unwrap();
392        std::fs::write(tmp.path(), &bytes).unwrap();
393
394        let err = CalibrationDataset::from_safetensors(tmp.path()).unwrap_err();
395        assert!(err.to_string().contains("contains 2 tensors"));
396
397        // But named access works.
398        let dataset = CalibrationDataset::from_safetensors_named(tmp.path(), "a").unwrap();
399        assert_eq!(dataset.len(), 2);
400        assert_eq!(dataset.sample_shape(), &[4]);
401    }
402}