quantize_rs/calibration/
mod.rs1use crate::errors::{QuantizeError, Result};
9#[cfg(feature = "calibration")]
10use std::path::Path;
11
12#[cfg(feature = "calibration")]
13pub mod inference;
14pub mod methods;
15pub mod stats;
16
17#[cfg(feature = "calibration")]
18pub use inference::ActivationEstimator;
19
20#[derive(Clone)]
22pub struct CalibrationDataset {
23 pub samples: Vec<Vec<f32>>,
25
26 pub shape: Vec<usize>,
28}
29
30impl std::fmt::Debug for CalibrationDataset {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("CalibrationDataset")
33 .field("num_samples", &self.samples.len())
34 .field("shape", &self.shape)
35 .finish()
36 }
37}
38
39impl CalibrationDataset {
40 #[cfg(feature = "calibration")]
51 pub fn from_numpy(path: impl AsRef<Path>) -> Result<Self> {
52 use ndarray::{Array, IxDyn};
53
54 let path = path.as_ref();
55
56 if !path.exists() {
57 return Err(QuantizeError::Calibration {
58 reason: format!("File not found: {}", path.display()),
59 });
60 }
61
62 let array: Array<f32, IxDyn> = if path.extension().and_then(|s| s.to_str()) == Some("npy") {
63 ndarray_npy::read_npy(path).map_err(|e| QuantizeError::Calibration {
64 reason: format!("Failed to read NPY file '{}': {e}", path.display()),
65 })?
66 } else {
67 return Err(QuantizeError::Calibration {
68 reason: "Only .npy files supported currently".into(),
69 });
70 };
71
72 let shape: Vec<usize> = array.shape().to_vec();
73
74 if shape.is_empty() {
75 return Err(QuantizeError::Calibration {
76 reason: "Invalid array shape".into(),
77 });
78 }
79
80 if shape.len() < 2 {
81 return Err(QuantizeError::Calibration {
82 reason: format!(
83 "Calibration data must be at least 2-dimensional (batch, ...). Got shape {:?}",
84 shape
85 ),
86 });
87 }
88
89 let num_samples = shape[0];
90 let sample_size: usize = shape[1..].iter().product();
91
92 let data = array.into_raw_vec();
93 let mut samples = Vec::with_capacity(num_samples);
94
95 for i in 0..num_samples {
96 let start = i * sample_size;
97 let end = start + sample_size;
98 samples.push(data[start..end].to_vec());
99 }
100
101 Ok(Self {
102 samples,
103 shape: shape[1..].to_vec(),
104 })
105 }
106
107 pub fn random(shape: Vec<usize>, num_samples: usize, range: (f32, f32)) -> Result<Self> {
114 if shape.is_empty() || shape.contains(&0) {
115 return Err(QuantizeError::Calibration {
116 reason: format!("Invalid shape: {:?} - all dimensions must be > 0", shape),
117 });
118 }
119 if num_samples == 0 {
120 return Err(QuantizeError::Calibration {
121 reason: "num_samples must be > 0".into(),
122 });
123 }
124 if range.0 >= range.1 {
125 return Err(QuantizeError::Calibration {
126 reason: format!(
127 "Invalid range: ({}, {}) - min must be less than max",
128 range.0, range.1
129 ),
130 });
131 }
132 use rand::Rng;
133 let mut rng = rand::thread_rng();
134
135 let sample_size: usize = shape.iter().product();
136 let mut samples = Vec::with_capacity(num_samples);
137
138 for _ in 0..num_samples {
139 let sample: Vec<f32> = (0..sample_size)
140 .map(|_| rng.gen_range(range.0..range.1))
141 .collect();
142 samples.push(sample);
143 }
144
145 Ok(Self { samples, shape })
146 }
147
148 pub fn from_samples(samples: Vec<Vec<f32>>, shape: Vec<usize>) -> Result<Self> {
155 let num_samples = samples.len();
156
157 if num_samples == 0 {
158 return Err(QuantizeError::Calibration {
159 reason: "No samples provided".into(),
160 });
161 }
162
163 let expected_size: usize = shape.iter().product();
164
165 for (i, sample) in samples.iter().enumerate() {
166 if sample.len() != expected_size {
167 return Err(QuantizeError::Calibration {
168 reason: format!(
169 "Sample {} has size {} but expected {} (shape: {:?})",
170 i,
171 sample.len(),
172 expected_size,
173 shape
174 ),
175 });
176 }
177 }
178
179 Ok(Self { samples, shape })
180 }
181
182 pub fn sample_shape(&self) -> &[usize] {
184 &self.shape
185 }
186
187 pub fn len(&self) -> usize {
189 self.samples.len()
190 }
191
192 pub fn is_empty(&self) -> bool {
194 self.samples.is_empty()
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn test_random_dataset() {
204 let dataset = CalibrationDataset::random(vec![3, 224, 224], 10, (-1.0, 1.0)).unwrap();
205
206 assert_eq!(dataset.len(), 10);
207 assert_eq!(dataset.sample_shape(), &[3, 224, 224]);
208 assert_eq!(dataset.samples[0].len(), 3 * 224 * 224);
209 }
210
211 #[test]
212 fn test_from_samples() {
213 let samples = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
214
215 let dataset = CalibrationDataset::from_samples(samples, vec![3]).unwrap();
216 assert_eq!(dataset.len(), 2);
217 }
218}