quantize_rs/calibration/
mod.rs1use crate::errors::{QuantizeError, Result};
9#[cfg(feature = "calibration")]
10use std::path::Path;
11
12pub mod stats;
13pub mod methods;
14#[cfg(feature = "calibration")]
15pub mod inference;
16
17#[cfg(feature = "calibration")]
18pub use inference::ActivationEstimator;
19
20#[derive(Clone)]
22pub struct CalibrationDataset {
23 pub samples: Vec<Vec<f32>>,
25
26 pub shape: Vec<usize>,
28}
29
30impl std::fmt::Debug for CalibrationDataset {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("CalibrationDataset")
33 .field("num_samples", &self.samples.len())
34 .field("shape", &self.shape)
35 .finish()
36 }
37}
38
39impl CalibrationDataset {
40 #[cfg(feature = "calibration")]
51 pub fn from_numpy(path: impl AsRef<Path>) -> Result<Self> {
52 use ndarray::{Array, IxDyn};
53
54 let path = path.as_ref();
55
56 if !path.exists() {
57 return Err(QuantizeError::Calibration { reason: format!("File not found: {}", path.display()) });
58 }
59
60 let array: Array<f32, IxDyn> = if path.extension().and_then(|s| s.to_str()) == Some("npy") {
61 ndarray_npy::read_npy(path)
62 .map_err(|e| QuantizeError::Calibration { reason: format!("Failed to read NPY file '{}': {e}", path.display()) })?
63 } else {
64 return Err(QuantizeError::Calibration { reason: "Only .npy files supported currently".into() });
65 };
66
67 let shape: Vec<usize> = array.shape().to_vec();
68
69 if shape.is_empty() {
70 return Err(QuantizeError::Calibration { reason: "Invalid array shape".into() });
71 }
72
73 if shape.len() < 2 {
74 return Err(QuantizeError::Calibration { reason: format!("Calibration data must be at least 2-dimensional (batch, ...). Got shape {:?}", shape) });
75 }
76
77 let num_samples = shape[0];
78 let sample_size: usize = shape[1..].iter().product();
79
80 let data = array.into_raw_vec();
81 let mut samples = Vec::with_capacity(num_samples);
82
83 for i in 0..num_samples {
84 let start = i * sample_size;
85 let end = start + sample_size;
86 samples.push(data[start..end].to_vec());
87 }
88
89 Ok(Self {
90 samples,
91 shape: shape[1..].to_vec(),
92 })
93 }
94
95 pub fn random(shape: Vec<usize>, num_samples: usize, range: (f32, f32)) -> Result<Self> {
102 if shape.is_empty() || shape.contains(&0) {
103 return Err(QuantizeError::Calibration { reason: format!("Invalid shape: {:?} - all dimensions must be > 0", shape) });
104 }
105 if num_samples == 0 {
106 return Err(QuantizeError::Calibration { reason: "num_samples must be > 0".into() });
107 }
108 if range.0 >= range.1 {
109 return Err(QuantizeError::Calibration { reason: format!("Invalid range: ({}, {}) - min must be less than max", range.0, range.1) });
110 }
111 use rand::Rng;
112 let mut rng = rand::thread_rng();
113
114 let sample_size: usize = shape.iter().product();
115 let mut samples = Vec::with_capacity(num_samples);
116
117 for _ in 0..num_samples {
118 let sample: Vec<f32> = (0..sample_size)
119 .map(|_| rng.gen_range(range.0..range.1))
120 .collect();
121 samples.push(sample);
122 }
123
124 Ok(Self {
125 samples,
126 shape,
127 })
128 }
129
130 pub fn from_samples(samples: Vec<Vec<f32>>, shape: Vec<usize>) -> Result<Self> {
137 let num_samples = samples.len();
138
139 if num_samples == 0 {
140 return Err(QuantizeError::Calibration { reason: "No samples provided".into() });
141 }
142
143 let expected_size: usize = shape.iter().product();
144
145 for (i, sample) in samples.iter().enumerate() {
146 if sample.len() != expected_size {
147 return Err(QuantizeError::Calibration {
148 reason: format!(
149 "Sample {} has size {} but expected {} (shape: {:?})",
150 i, sample.len(), expected_size, shape
151 ),
152 });
153 }
154 }
155
156 Ok(Self {
157 samples,
158 shape,
159 })
160 }
161
162 pub fn sample_shape(&self) -> &[usize] {
164 &self.shape
165 }
166
167 pub fn len(&self) -> usize {
169 self.samples.len()
170 }
171
172 pub fn is_empty(&self) -> bool {
174 self.samples.is_empty()
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn test_random_dataset() {
184 let dataset = CalibrationDataset::random(vec![3, 224, 224], 10, (-1.0, 1.0)).unwrap();
185
186 assert_eq!(dataset.len(), 10);
187 assert_eq!(dataset.sample_shape(), &[3, 224, 224]);
188 assert_eq!(dataset.samples[0].len(), 3 * 224 * 224);
189 }
190
191 #[test]
192 fn test_from_samples() {
193 let samples = vec![
194 vec![1.0, 2.0, 3.0],
195 vec![4.0, 5.0, 6.0],
196 ];
197
198 let dataset = CalibrationDataset::from_samples(samples, vec![3]).unwrap();
199 assert_eq!(dataset.len(), 2);
200 }
201}