1use crate::errors::{QuantizeError, Result};
9#[cfg(feature = "calibration")]
10use std::path::Path;
11
12#[cfg(feature = "calibration")]
13pub mod inference;
14pub mod methods;
15pub mod stats;
16
17#[cfg(feature = "calibration")]
18pub use inference::ActivationEstimator;
19
20#[derive(Clone)]
22pub struct CalibrationDataset {
23 pub samples: Vec<Vec<f32>>,
25
26 pub shape: Vec<usize>,
28}
29
30impl std::fmt::Debug for CalibrationDataset {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("CalibrationDataset")
33 .field("num_samples", &self.samples.len())
34 .field("shape", &self.shape)
35 .finish()
36 }
37}
38
39impl CalibrationDataset {
40 #[cfg(feature = "calibration")]
51 pub fn from_numpy(path: impl AsRef<Path>) -> Result<Self> {
52 use ndarray::{Array, IxDyn};
53
54 let path = path.as_ref();
55
56 if !path.exists() {
57 return Err(QuantizeError::Calibration {
58 reason: format!("File not found: {}", path.display()),
59 });
60 }
61
62 let array: Array<f32, IxDyn> = if path.extension().and_then(|s| s.to_str()) == Some("npy") {
63 ndarray_npy::read_npy(path).map_err(|e| QuantizeError::Calibration {
64 reason: format!("Failed to read NPY file '{}': {e}", path.display()),
65 })?
66 } else {
67 return Err(QuantizeError::Calibration {
68 reason: "Only .npy files supported currently".into(),
69 });
70 };
71
72 let shape: Vec<usize> = array.shape().to_vec();
73
74 if shape.is_empty() {
75 return Err(QuantizeError::Calibration {
76 reason: "Invalid array shape".into(),
77 });
78 }
79
80 if shape.len() < 2 {
81 return Err(QuantizeError::Calibration {
82 reason: format!(
83 "Calibration data must be at least 2-dimensional (batch, ...). Got shape {:?}",
84 shape
85 ),
86 });
87 }
88
89 let num_samples = shape[0];
90 let sample_size: usize = shape[1..].iter().product();
91
92 let data = array.into_raw_vec();
93 let mut samples = Vec::with_capacity(num_samples);
94
95 for i in 0..num_samples {
96 let start = i * sample_size;
97 let end = start + sample_size;
98 samples.push(data[start..end].to_vec());
99 }
100
101 Ok(Self {
102 samples,
103 shape: shape[1..].to_vec(),
104 })
105 }
106
107 #[cfg(feature = "safetensors-input")]
116 pub fn from_safetensors(path: impl AsRef<Path>) -> Result<Self> {
117 let path = path.as_ref();
118 let buffer = std::fs::read(path).map_err(|e| QuantizeError::Calibration {
119 reason: format!("Failed to read safetensors file '{}': {e}", path.display()),
120 })?;
121 let tensors = safetensors::SafeTensors::deserialize(&buffer).map_err(|e| {
122 QuantizeError::Calibration {
123 reason: format!("Failed to parse safetensors file: {e}"),
124 }
125 })?;
126 let names: Vec<String> = tensors.names().into_iter().map(|s| s.to_string()).collect();
127 if names.is_empty() {
128 return Err(QuantizeError::Calibration {
129 reason: "safetensors file contains no tensors".into(),
130 });
131 }
132 if names.len() > 1 {
133 return Err(QuantizeError::Calibration {
134 reason: format!(
135 "safetensors file contains {} tensors; pass one explicitly via \
136 from_safetensors_named(). Available tensors: {}",
137 names.len(),
138 names.join(", ")
139 ),
140 });
141 }
142 Self::from_safetensors_view(&tensors, &names[0])
143 }
144
145 #[cfg(feature = "safetensors-input")]
150 pub fn from_safetensors_named(path: impl AsRef<Path>, tensor_name: &str) -> Result<Self> {
151 let path = path.as_ref();
152 let buffer = std::fs::read(path).map_err(|e| QuantizeError::Calibration {
153 reason: format!("Failed to read safetensors file '{}': {e}", path.display()),
154 })?;
155 let tensors = safetensors::SafeTensors::deserialize(&buffer).map_err(|e| {
156 QuantizeError::Calibration {
157 reason: format!("Failed to parse safetensors file: {e}"),
158 }
159 })?;
160 Self::from_safetensors_view(&tensors, tensor_name)
161 }
162
163 #[cfg(feature = "safetensors-input")]
164 fn from_safetensors_view(
165 tensors: &safetensors::SafeTensors<'_>,
166 tensor_name: &str,
167 ) -> Result<Self> {
168 use safetensors::Dtype;
169
170 let view = tensors
171 .tensor(tensor_name)
172 .map_err(|e| QuantizeError::Calibration {
173 reason: format!(
174 "Tensor '{}' not found in safetensors file: {e}",
175 tensor_name
176 ),
177 })?;
178
179 if view.dtype() != Dtype::F32 {
180 return Err(QuantizeError::Calibration {
181 reason: format!(
182 "Tensor '{}' has dtype {:?}; only F32 is supported for calibration input",
183 tensor_name,
184 view.dtype()
185 ),
186 });
187 }
188
189 let shape: Vec<usize> = view.shape().to_vec();
190 if shape.len() < 2 {
191 return Err(QuantizeError::Calibration {
192 reason: format!(
193 "Calibration tensor must be at least 2-dimensional (batch, ...). \
194 Got shape {:?}",
195 shape
196 ),
197 });
198 }
199 let expected_bytes: usize = shape.iter().product::<usize>() * std::mem::size_of::<f32>();
200 let raw = view.data();
201 if raw.len() != expected_bytes {
202 return Err(QuantizeError::Calibration {
203 reason: format!(
204 "Tensor '{}' data size {} bytes does not match shape {:?} \
205 × 4 = {} bytes",
206 tensor_name,
207 raw.len(),
208 shape,
209 expected_bytes
210 ),
211 });
212 }
213
214 let data: Vec<f32> = raw
218 .chunks_exact(4)
219 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
220 .collect();
221
222 let num_samples = shape[0];
223 let sample_size: usize = shape[1..].iter().product();
224 let mut samples = Vec::with_capacity(num_samples);
225 for i in 0..num_samples {
226 let start = i * sample_size;
227 let end = start + sample_size;
228 samples.push(data[start..end].to_vec());
229 }
230
231 Ok(Self {
232 samples,
233 shape: shape[1..].to_vec(),
234 })
235 }
236
237 pub fn random(shape: Vec<usize>, num_samples: usize, range: (f32, f32)) -> Result<Self> {
244 if shape.is_empty() || shape.contains(&0) {
245 return Err(QuantizeError::Calibration {
246 reason: format!("Invalid shape: {:?} - all dimensions must be > 0", shape),
247 });
248 }
249 if num_samples == 0 {
250 return Err(QuantizeError::Calibration {
251 reason: "num_samples must be > 0".into(),
252 });
253 }
254 if range.0 >= range.1 {
255 return Err(QuantizeError::Calibration {
256 reason: format!(
257 "Invalid range: ({}, {}) - min must be less than max",
258 range.0, range.1
259 ),
260 });
261 }
262 use rand::Rng;
263 let mut rng = rand::thread_rng();
264
265 let sample_size: usize = shape.iter().product();
266 let mut samples = Vec::with_capacity(num_samples);
267
268 for _ in 0..num_samples {
269 let sample: Vec<f32> = (0..sample_size)
270 .map(|_| rng.gen_range(range.0..range.1))
271 .collect();
272 samples.push(sample);
273 }
274
275 Ok(Self { samples, shape })
276 }
277
278 pub fn from_samples(samples: Vec<Vec<f32>>, shape: Vec<usize>) -> Result<Self> {
285 let num_samples = samples.len();
286
287 if num_samples == 0 {
288 return Err(QuantizeError::Calibration {
289 reason: "No samples provided".into(),
290 });
291 }
292
293 let expected_size: usize = shape.iter().product();
294
295 for (i, sample) in samples.iter().enumerate() {
296 if sample.len() != expected_size {
297 return Err(QuantizeError::Calibration {
298 reason: format!(
299 "Sample {} has size {} but expected {} (shape: {:?})",
300 i,
301 sample.len(),
302 expected_size,
303 shape
304 ),
305 });
306 }
307 }
308
309 Ok(Self { samples, shape })
310 }
311
312 pub fn sample_shape(&self) -> &[usize] {
314 &self.shape
315 }
316
317 pub fn len(&self) -> usize {
319 self.samples.len()
320 }
321
322 pub fn is_empty(&self) -> bool {
324 self.samples.is_empty()
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[test]
333 fn test_random_dataset() {
334 let dataset = CalibrationDataset::random(vec![3, 224, 224], 10, (-1.0, 1.0)).unwrap();
335
336 assert_eq!(dataset.len(), 10);
337 assert_eq!(dataset.sample_shape(), &[3, 224, 224]);
338 assert_eq!(dataset.samples[0].len(), 3 * 224 * 224);
339 }
340
341 #[test]
342 fn test_from_samples() {
343 let samples = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
344
345 let dataset = CalibrationDataset::from_samples(samples, vec![3]).unwrap();
346 assert_eq!(dataset.len(), 2);
347 }
348
349 #[cfg(feature = "safetensors-input")]
350 #[test]
351 fn test_from_safetensors_round_trip() {
352 use safetensors::{serialize, tensor::TensorView, Dtype};
353 use std::collections::HashMap;
354
355 let data: Vec<f32> = (0..24).map(|i| i as f32 * 0.1).collect();
357 let raw: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
358 let view = TensorView::new(Dtype::F32, vec![3, 2, 4], &raw).unwrap();
359 let mut tensors = HashMap::new();
360 tensors.insert("input".to_string(), view);
361 let bytes = serialize(&tensors, &None).unwrap();
362
363 let tmp = tempfile::NamedTempFile::with_suffix(".safetensors").unwrap();
364 std::fs::write(tmp.path(), &bytes).unwrap();
365
366 let dataset = CalibrationDataset::from_safetensors(tmp.path()).unwrap();
367 assert_eq!(dataset.len(), 3);
368 assert_eq!(dataset.sample_shape(), &[2, 4]);
369 assert_eq!(dataset.samples[0].len(), 8);
371 assert!((dataset.samples[0][0] - 0.0).abs() < 1e-6);
373 assert!((dataset.samples[1][0] - 0.8).abs() < 1e-6);
374 }
375
376 #[cfg(feature = "safetensors-input")]
377 #[test]
378 fn test_from_safetensors_multi_tensor_errors_without_name() {
379 use safetensors::{serialize, tensor::TensorView, Dtype};
380 use std::collections::HashMap;
381
382 let data: Vec<f32> = (0..8).map(|i| i as f32).collect();
383 let raw: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
384 let v1 = TensorView::new(Dtype::F32, vec![2, 4], &raw).unwrap();
385 let v2 = TensorView::new(Dtype::F32, vec![2, 4], &raw).unwrap();
386 let mut tensors = HashMap::new();
387 tensors.insert("a".to_string(), v1);
388 tensors.insert("b".to_string(), v2);
389 let bytes = serialize(&tensors, &None).unwrap();
390
391 let tmp = tempfile::NamedTempFile::with_suffix(".safetensors").unwrap();
392 std::fs::write(tmp.path(), &bytes).unwrap();
393
394 let err = CalibrationDataset::from_safetensors(tmp.path()).unwrap_err();
395 assert!(err.to_string().contains("contains 2 tensors"));
396
397 let dataset = CalibrationDataset::from_safetensors_named(tmp.path(), "a").unwrap();
399 assert_eq!(dataset.len(), 2);
400 assert_eq!(dataset.sample_shape(), &[4]);
401 }
402}