1#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum DynamicScaleMode {
16 MaxAbs,
18 Percentile(f32),
21}
22
23#[derive(Debug, Clone, Copy, PartialEq)]
25pub enum DynQuantFormat {
26 Int8PerTensor,
28 Int8PerRow,
30 Int4PerTensor,
32}
33
34#[derive(Debug, Clone)]
36pub struct DynQuantTensor {
37 pub data: Vec<i8>,
39 pub scales: Vec<f32>,
41 pub shape: Vec<usize>,
43 pub format: DynQuantFormat,
45}
46
47impl DynQuantTensor {
48 pub fn dequantize(&self) -> Vec<f32> {
50 match self.format {
51 DynQuantFormat::Int8PerTensor => {
52 let scale = self.scales.first().copied().unwrap_or(0.0);
53 self.data.iter().map(|&q| q as f32 * scale).collect()
54 }
55 DynQuantFormat::Int8PerRow => {
56 if self.scales.is_empty() || self.data.is_empty() {
57 return Vec::new();
58 }
59 let rows = self.scales.len();
60 let cols = self.data.len() / rows.max(1);
61 let mut out = Vec::with_capacity(self.data.len());
62 for (r, &scale) in self.scales.iter().enumerate() {
63 let start = r * cols;
64 let end = (start + cols).min(self.data.len());
65 for &q in &self.data[start..end] {
66 out.push(q as f32 * scale);
67 }
68 }
69 out
70 }
71 DynQuantFormat::Int4PerTensor => {
72 let scale = self.scales.first().copied().unwrap_or(0.0);
73 self.data.iter().map(|&q| q as f32 * scale).collect()
74 }
75 }
76 }
77
78 pub fn memory_bytes(&self) -> usize {
80 self.data.len() + self.scales.len() * core::mem::size_of::<f32>()
81 }
82
83 pub fn compression_ratio(&self) -> f32 {
85 let original_bytes = self.data.len() * core::mem::size_of::<f32>();
86 let quantized_bytes = self.memory_bytes();
87 if quantized_bytes == 0 {
88 return 1.0;
89 }
90 original_bytes as f32 / quantized_bytes as f32
91 }
92
93 pub fn element_count(&self) -> usize {
95 self.data.len()
96 }
97}
98
99pub fn compute_scale(data: &[f32], clip_val: f32, mode: DynamicScaleMode) -> f32 {
106 if data.is_empty() {
107 return 0.0;
108 }
109
110 let abs_max = match mode {
111 DynamicScaleMode::MaxAbs => data.iter().map(|x| x.abs()).fold(0.0_f32, f32::max),
112 DynamicScaleMode::Percentile(p) => {
113 let p_clamped = p.clamp(0.0, 1.0);
114 let mut abs_vals: Vec<f32> = data.iter().map(|x| x.abs()).collect();
115 abs_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
117 let len = abs_vals.len();
118 let idx = ((p_clamped * len as f32).ceil() as usize)
120 .saturating_sub(1)
121 .min(len - 1);
122 abs_vals[idx]
123 }
124 };
125
126 if abs_max == 0.0 {
127 return 0.0;
128 }
129
130 abs_max / clip_val
131}
132
133pub fn dynamic_quantize_int8(data: &[f32], mode: DynamicScaleMode) -> DynQuantTensor {
137 const CLIP_VAL: f32 = 127.0;
138
139 if data.is_empty() {
140 return DynQuantTensor {
141 data: Vec::new(),
142 scales: vec![0.0],
143 shape: vec![0],
144 format: DynQuantFormat::Int8PerTensor,
145 };
146 }
147
148 let scale = compute_scale(data, CLIP_VAL, mode);
149
150 let quantized: Vec<i8> = if scale == 0.0 {
151 vec![0i8; data.len()]
152 } else {
153 data.iter()
154 .map(|&x| (x / scale).round().clamp(-127.0, 127.0) as i8)
155 .collect()
156 };
157
158 DynQuantTensor {
159 data: quantized,
160 scales: vec![scale],
161 shape: vec![data.len()],
162 format: DynQuantFormat::Int8PerTensor,
163 }
164}
165
166pub fn dynamic_quantize_int8_per_row(
172 data: &[f32],
173 rows: usize,
174 cols: usize,
175 mode: DynamicScaleMode,
176) -> DynQuantTensor {
177 const CLIP_VAL: f32 = 127.0;
178
179 if data.is_empty() || rows == 0 || cols == 0 {
180 return DynQuantTensor {
181 data: Vec::new(),
182 scales: Vec::new(),
183 shape: vec![rows, cols],
184 format: DynQuantFormat::Int8PerRow,
185 };
186 }
187
188 let total = rows * cols;
189 let actual_len = data.len().min(total);
190
191 let mut quantized = Vec::with_capacity(actual_len);
192 let mut scales = Vec::with_capacity(rows);
193
194 for r in 0..rows {
195 let start = r * cols;
196 let end = (start + cols).min(data.len());
197 if start >= data.len() {
198 quantized.extend(vec![0i8; cols]);
200 scales.push(0.0_f32);
201 continue;
202 }
203 let row = &data[start..end];
204 let scale = compute_scale(row, CLIP_VAL, mode);
205 scales.push(scale);
206 if scale == 0.0 {
207 quantized.extend(vec![0i8; row.len()]);
208 } else {
209 for &x in row {
210 quantized.push((x / scale).round().clamp(-127.0, 127.0) as i8);
211 }
212 }
213 }
214
215 DynQuantTensor {
216 data: quantized,
217 scales,
218 shape: vec![rows, cols],
219 format: DynQuantFormat::Int8PerRow,
220 }
221}
222
223pub fn dynamic_quantize_int4(data: &[f32], mode: DynamicScaleMode) -> DynQuantTensor {
229 const CLIP_VAL: f32 = 7.0;
230
231 if data.is_empty() {
232 return DynQuantTensor {
233 data: Vec::new(),
234 scales: vec![0.0],
235 shape: vec![0],
236 format: DynQuantFormat::Int4PerTensor,
237 };
238 }
239
240 let scale = compute_scale(data, CLIP_VAL, mode);
241
242 let quantized: Vec<i8> = if scale == 0.0 {
243 vec![0i8; data.len()]
244 } else {
245 data.iter()
246 .map(|&x| (x / scale).round().clamp(-7.0, 7.0) as i8)
247 .collect()
248 };
249
250 DynQuantTensor {
251 data: quantized,
252 scales: vec![scale],
253 shape: vec![data.len()],
254 format: DynQuantFormat::Int4PerTensor,
255 }
256}
257
258pub fn quantization_mae(original: &[f32], quantized: &DynQuantTensor) -> f32 {
262 let reconstructed = quantized.dequantize();
263 let n = original.len().min(reconstructed.len());
264 if n == 0 {
265 return 0.0;
266 }
267 let sum_abs_err: f32 = original[..n]
268 .iter()
269 .zip(reconstructed[..n].iter())
270 .map(|(&o, &r)| (o - r).abs())
271 .sum();
272 sum_abs_err / n as f32
273}
274
275#[derive(Debug, Clone)]
282pub struct SmoothQuantConfig {
283 pub alpha: f32,
285 pub epsilon: f32,
287}
288
289impl SmoothQuantConfig {
290 pub fn new(alpha: f32) -> Self {
292 Self {
293 alpha: alpha.clamp(0.0, 1.0),
294 epsilon: 1e-5,
295 }
296 }
297
298 pub fn default_alpha() -> Self {
300 Self::new(0.5)
301 }
302}
303
304pub fn compute_smooth_factors(
310 activations: &[f32],
311 weights: &[f32],
312 in_features: usize,
313 tokens: usize,
314 out_features: usize,
315 config: &SmoothQuantConfig,
316) -> Vec<f32> {
317 if in_features == 0 {
318 return Vec::new();
319 }
320
321 let alpha = config.alpha.clamp(0.0, 1.0);
322 let epsilon = config.epsilon.max(1e-10);
323
324 let mut act_max = vec![0.0_f32; in_features];
326 for t in 0..tokens {
327 for (j, slot) in act_max.iter_mut().enumerate() {
328 let idx = t * in_features + j;
329 if idx < activations.len() {
330 let v = activations[idx].abs();
331 if v > *slot {
332 *slot = v;
333 }
334 }
335 }
336 }
337
338 let mut w_max = vec![0.0_f32; in_features];
340 for o in 0..out_features {
341 for (j, slot) in w_max.iter_mut().enumerate() {
342 let idx = o * in_features + j;
343 if idx < weights.len() {
344 let v = weights[idx].abs();
345 if v > *slot {
346 *slot = v;
347 }
348 }
349 }
350 }
351
352 (0..in_features)
354 .map(|j| {
355 let a = (act_max[j] + epsilon).powf(alpha);
356 let w = (w_max[j] + epsilon).powf(1.0 - alpha);
357 (a / w).max(epsilon)
358 })
359 .collect()
360}
361
362pub fn smooth_activations(
364 activations: &mut [f32],
365 smooth_factors: &[f32],
366 tokens: usize,
367 in_features: usize,
368) -> Result<(), DynQuantError> {
369 if smooth_factors.len() != in_features {
370 return Err(DynQuantError::FeatureDimMismatch {
371 in_features,
372 sf_len: smooth_factors.len(),
373 });
374 }
375 let expected = tokens * in_features;
376 if activations.len() != expected {
377 return Err(DynQuantError::ShapeMismatch {
378 expected,
379 actual: activations.len(),
380 });
381 }
382 for t in 0..tokens {
383 for (j, &sf) in smooth_factors.iter().enumerate() {
384 let idx = t * in_features + j;
385 activations[idx] /= sf;
386 }
387 }
388 Ok(())
389}
390
391pub fn smooth_weights(
393 weights: &mut [f32],
394 smooth_factors: &[f32],
395 out_features: usize,
396 in_features: usize,
397) -> Result<(), DynQuantError> {
398 if smooth_factors.len() != in_features {
399 return Err(DynQuantError::FeatureDimMismatch {
400 in_features,
401 sf_len: smooth_factors.len(),
402 });
403 }
404 let expected = out_features * in_features;
405 if weights.len() != expected {
406 return Err(DynQuantError::ShapeMismatch {
407 expected,
408 actual: weights.len(),
409 });
410 }
411 for o in 0..out_features {
412 for (j, &sf) in smooth_factors.iter().enumerate() {
413 let idx = o * in_features + j;
414 weights[idx] *= sf;
415 }
416 }
417 Ok(())
418}
419
420pub fn w8a8_matvec(
429 weight_i8: &[i8],
430 weight_scales: &[f32],
431 activation: &[f32],
432 out_size: usize,
433 in_size: usize,
434) -> Result<Vec<f32>, DynQuantError> {
435 if activation.is_empty() {
436 return Err(DynQuantError::EmptyInput);
437 }
438 if activation.len() != in_size {
439 return Err(DynQuantError::ShapeMismatch {
440 expected: in_size,
441 actual: activation.len(),
442 });
443 }
444 let expected_w = out_size * in_size;
445 if weight_i8.len() != expected_w {
446 return Err(DynQuantError::ShapeMismatch {
447 expected: expected_w,
448 actual: weight_i8.len(),
449 });
450 }
451 if weight_scales.len() != out_size {
452 return Err(DynQuantError::ShapeMismatch {
453 expected: out_size,
454 actual: weight_scales.len(),
455 });
456 }
457
458 let act_quant = dynamic_quantize_int8(activation, DynamicScaleMode::MaxAbs);
460 let act_scale = act_quant.scales.first().copied().unwrap_or(0.0);
461 let act_i8 = &act_quant.data;
462
463 let mut output = vec![0.0_f32; out_size];
464
465 for o in 0..out_size {
466 let row_start = o * in_size;
467 let row_end = row_start + in_size;
468 let row = &weight_i8[row_start..row_end];
469
470 let mut acc = 0_i32;
471 for (&w, &a) in row.iter().zip(act_i8.iter()) {
472 acc += w as i32 * a as i32;
473 }
474
475 output[o] = acc as f32 * weight_scales[o] * act_scale;
477 }
478
479 Ok(output)
480}
481
482#[derive(Debug, Clone)]
486pub struct CalibStats {
487 pub min: f32,
489 pub max: f32,
491 pub mean: f32,
493 pub std_dev: f32,
495 pub p99: f32,
497 pub suggested_scale: f32,
499}
500
501impl CalibStats {
502 pub fn collect(batches: &[Vec<f32>]) -> Self {
504 let all_values: Vec<f32> = batches.iter().flat_map(|b| b.iter().copied()).collect();
505
506 if all_values.is_empty() {
507 return Self {
508 min: 0.0,
509 max: 0.0,
510 mean: 0.0,
511 std_dev: 0.0,
512 p99: 0.0,
513 suggested_scale: 0.0,
514 };
515 }
516
517 let n = all_values.len();
518
519 let min_val = all_values.iter().copied().fold(f32::INFINITY, f32::min);
521 let max_val = all_values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
522
523 let sum: f32 = all_values.iter().sum();
525 let mean_val = sum / n as f32;
526
527 let variance: f32 = all_values
529 .iter()
530 .map(|&x| {
531 let d = x - mean_val;
532 d * d
533 })
534 .sum::<f32>()
535 / n as f32;
536 let std_dev_val = variance.sqrt();
537
538 let mut abs_vals: Vec<f32> = all_values.iter().map(|x| x.abs()).collect();
540 abs_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
541 let p99_idx = ((0.99_f32 * n as f32).ceil() as usize)
542 .saturating_sub(1)
543 .min(n - 1);
544 let p99_val = abs_vals[p99_idx];
545
546 let suggested = if p99_val > 0.0 {
547 p99_val / 127.0
548 } else {
549 let max_abs = abs_vals.last().copied().unwrap_or(0.0);
551 if max_abs > 0.0 {
552 max_abs / 127.0
553 } else {
554 1.0 / 127.0
555 }
556 };
557
558 Self {
559 min: min_val,
560 max: max_val,
561 mean: mean_val,
562 std_dev: std_dev_val,
563 p99: p99_val,
564 suggested_scale: suggested,
565 }
566 }
567}
568
569#[derive(Debug, thiserror::Error)]
573pub enum DynQuantError {
574 #[error("shape mismatch: expected {expected}, got {actual}")]
576 ShapeMismatch { expected: usize, actual: usize },
577
578 #[error("empty input")]
580 EmptyInput,
581
582 #[error("invalid alpha {0}: must be in [0, 1]")]
584 InvalidAlpha(f32),
585
586 #[error("dimension mismatch: in_features {in_features}, smooth_factors {sf_len}")]
588 FeatureDimMismatch { in_features: usize, sf_len: usize },
589}
590
591#[cfg(test)]
594mod tests {
595 use super::*;
596
597 #[test]
598 fn test_compute_scale_max_abs_basic() {
599 let data = [1.0_f32, -2.0, 0.5];
600 let scale = compute_scale(&data, 127.0, DynamicScaleMode::MaxAbs);
601 let expected = 2.0 / 127.0;
602 assert!(
603 (scale - expected).abs() < 1e-6,
604 "scale={scale}, expected={expected}"
605 );
606 }
607
608 #[test]
609 fn test_compute_scale_zeros() {
610 let data = [0.0_f32; 8];
611 let scale = compute_scale(&data, 127.0, DynamicScaleMode::MaxAbs);
612 assert_eq!(scale, 0.0);
613 }
614
615 #[test]
616 fn test_dequantize_roundtrip_int8() {
617 let data: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) * 0.1).collect();
618 let qt = dynamic_quantize_int8(&data, DynamicScaleMode::MaxAbs);
619 let recon = qt.dequantize();
620 let mae = quantization_mae(&data, &qt);
621 let max_abs = data.iter().map(|x| x.abs()).fold(0.0_f32, f32::max);
622 assert!(
623 mae < 0.005 * max_abs,
624 "MAE {mae} >= 0.5% of max_abs {max_abs}"
625 );
626 assert_eq!(recon.len(), data.len());
627 }
628
629 #[test]
630 fn test_int4_range() {
631 let data: Vec<f32> = (-50..=50).map(|i| i as f32 * 0.3).collect();
632 let qt = dynamic_quantize_int4(&data, DynamicScaleMode::MaxAbs);
633 for &q in &qt.data {
634 assert!((-7..=7).contains(&q), "INT4 value {q} out of range [-7, 7]");
635 }
636 }
637
638 #[test]
639 fn test_smooth_quant_config_new() {
640 let cfg = SmoothQuantConfig::new(0.7);
641 assert!((cfg.alpha - 0.7).abs() < 1e-6);
642 }
643
644 #[test]
645 fn test_smooth_quant_config_default_alpha() {
646 let cfg = SmoothQuantConfig::default_alpha();
647 assert!((cfg.alpha - 0.5).abs() < 1e-6);
648 }
649
650 #[test]
651 fn test_calib_stats_basic() {
652 let batches = vec![vec![1.0_f32, 2.0, 3.0], vec![-1.0_f32, 0.0, 4.0]];
653 let stats = CalibStats::collect(&batches);
654 assert!(stats.min <= stats.mean);
655 assert!(stats.mean <= stats.max);
656 assert!(stats.suggested_scale > 0.0);
657 }
658}