1use crate::tensor::DenseTensor;
10use crate::tensor::traits::TensorBase;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum QuantDtype {
15 F32,
17 INT8,
19 INT4,
21}
22
23#[derive(Debug, Clone)]
25pub struct QuantizationConfig {
26 pub dtype: QuantDtype,
28 pub symmetric: bool,
30 pub per_channel: bool,
32 pub axis: Option<usize>,
34}
35
36impl QuantizationConfig {
37 pub fn int8() -> Self {
39 Self {
40 dtype: QuantDtype::INT8,
41 symmetric: true,
42 per_channel: false,
43 axis: None,
44 }
45 }
46
47 pub fn int4() -> Self {
49 Self {
50 dtype: QuantDtype::INT4,
51 symmetric: true,
52 per_channel: false,
53 axis: None,
54 }
55 }
56
57 pub fn per_channel_int8(axis: usize) -> Self {
59 Self {
60 dtype: QuantDtype::INT8,
61 symmetric: true,
62 per_channel: true,
63 axis: Some(axis),
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct QuantizedTensor {
71 pub data: Vec<i8>,
73 pub scale: Vec<f64>,
75 pub zero_point: Vec<i8>,
77 pub shape: Vec<usize>,
79 pub config: QuantizationConfig,
81 pub channel_scales: Option<Vec<f64>>,
83 pub channel_zero_points: Option<Vec<i8>>,
85}
86
87impl QuantizedTensor {
88 pub fn from_tensor(tensor: &DenseTensor, config: QuantizationConfig) -> Self {
94 match config.dtype {
95 QuantDtype::INT8 => Self::quantize_int8(tensor, &config),
96 QuantDtype::INT4 => Self::quantize_int4(tensor, &config),
97 QuantDtype::F32 => {
98 let data = tensor.data().iter().map(|&x| x as i8).collect();
100 Self {
101 data,
102 scale: vec![1.0],
103 zero_point: vec![0],
104 shape: tensor.shape().to_vec(),
105 config,
106 channel_scales: None,
107 channel_zero_points: None,
108 }
109 }
110 }
111 }
112
113 fn quantize_int8(tensor: &DenseTensor, config: &QuantizationConfig) -> Self {
115 if config.per_channel {
116 Self::quantize_int8_per_channel(tensor, config.axis.unwrap_or(0))
117 } else {
118 Self::quantize_int8_per_tensor(tensor)
119 }
120 }
121
122 fn quantize_int8_per_tensor(tensor: &DenseTensor) -> Self {
124 let data = tensor.data();
125
126 let max_abs = data.iter().fold(0.0_f64, |max, &x: &f64| max.max(x.abs()));
128
129 let scale = max_abs / 127.0;
131
132 let quantized: Vec<i8> = data
134 .iter()
135 .map(|&x| {
136 let q = (x / scale).round() as i32;
137 q.clamp(-128, 127) as i8
138 })
139 .collect();
140
141 Self {
142 data: quantized,
143 scale: vec![scale],
144 zero_point: vec![0],
145 shape: tensor.shape().to_vec(),
146 config: QuantizationConfig::int8(),
147 channel_scales: None,
148 channel_zero_points: None,
149 }
150 }
151
152 fn quantize_int8_per_channel(tensor: &DenseTensor, axis: usize) -> Self {
154 let data = tensor.data();
155 let shape = tensor.shape();
156
157 if axis >= shape.len() {
158 return Self::quantize_int8_per_tensor(tensor);
159 }
160
161 let channel_dim = shape[axis];
162 let channels_before: usize = shape[..axis].iter().product();
163 let channels_after: usize = shape[axis + 1..].iter().product();
164
165 let mut channel_scales = Vec::with_capacity(channel_dim);
166 let mut channel_zero_points = Vec::with_capacity(channel_dim);
167 let mut quantized = Vec::with_capacity(data.len());
168
169 for c in 0..channel_dim {
170 let mut channel_min = f64::INFINITY;
172 let mut channel_max = f64::NEG_INFINITY;
173
174 for cb in 0..channels_before {
175 for ca in 0..channels_after {
176 let offset = (cb * channel_dim + c) * channels_after + ca;
177 let val = data[offset];
178 channel_min = channel_min.min(val);
179 channel_max = channel_max.max(val);
180 }
181 }
182
183 let scale = (channel_max - channel_min) / 255.0;
185 let zero_point = 0i8;
186
187 channel_scales.push(scale);
188 channel_zero_points.push(zero_point);
189 }
190
191 for (i, &val) in data.iter().enumerate() {
193 let c = (i / channels_after) % channel_dim;
194 let scale = channel_scales[c];
195
196 let q = (val / scale).round() as i32;
197 let q = q.clamp(-128, 127) as i8;
198 quantized.push(q);
199 }
200
201 Self {
202 data: quantized,
203 scale: vec![1.0],
204 zero_point: vec![0],
205 shape: shape.to_vec(),
206 config: QuantizationConfig::per_channel_int8(axis),
207 channel_scales: Some(channel_scales),
208 channel_zero_points: Some(channel_zero_points),
209 }
210 }
211
212 fn quantize_int4(tensor: &DenseTensor, config: &QuantizationConfig) -> Self {
214 let data = tensor.data();
216
217 let (min, max) = data.iter().fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max): (f64, f64), &x| {
219 (min.min(x), max.max(x))
220 });
221
222 let scale = (max - min) / 15.0; let mut packed_data = Vec::with_capacity(data.len().div_ceil(2));
226
227 for i in (0..data.len()).step_by(2) {
228 let q0 = ((data[i] - min) / scale).round() as i32;
229 let q0 = q0.clamp(0, 15) as u8;
230
231 let q1 = if i + 1 < data.len() {
232 ((data[i + 1] - min) / scale).round() as i32
233 } else {
234 0
235 };
236 let q1 = q1.clamp(0, 15) as u8;
237
238 let packed = (q1 << 4) | q0;
240 packed_data.push(packed as i8);
241 }
242
243 Self {
244 data: packed_data,
245 scale: vec![scale],
246 zero_point: vec![0],
247 shape: tensor.shape().to_vec(),
248 config: config.clone(),
249 channel_scales: None,
250 channel_zero_points: None,
251 }
252 }
253
254 pub fn dequantize(&self) -> DenseTensor {
256 match self.config.dtype {
257 QuantDtype::INT8 => self.dequantize_int8(),
258 QuantDtype::INT4 => self.dequantize_int4(),
259 QuantDtype::F32 => {
260 let data = self.data.iter().map(|&x| x as f64).collect();
261 DenseTensor::new(data, self.shape.clone())
262 }
263 }
264 }
265
266 fn dequantize_int8(&self) -> DenseTensor {
268 let data = if let Some(scales) = &self.channel_scales {
269 let shape = &self.shape;
271 let axis = self.config.axis.unwrap_or(0);
272 let channel_dim = shape[axis];
273 let _channels_before: usize = shape[..axis].iter().product();
274 let channels_after: usize = shape[axis + 1..].iter().product();
275
276 self.data
277 .iter()
278 .enumerate()
279 .map(|(i, &q)| {
280 let c = (i / channels_after) % channel_dim;
281 let scale = scales[c];
282 q as f64 * scale
283 })
284 .collect()
285 } else {
286 let scale = self.scale[0];
288
289 self.data
290 .iter()
291 .map(|&q| q as f64 * scale)
292 .collect()
293 };
294
295 DenseTensor::new(data, self.shape.clone())
296 }
297
298 fn dequantize_int4(&self) -> DenseTensor {
300 let scale = self.scale[0];
301 let mut data = Vec::with_capacity(self.shape.iter().product::<usize>());
302
303 for &packed in &self.data {
304 let q0 = (packed as u8) & 0x0F;
305 let q1 = (packed as u8) >> 4;
306
307 data.push(q0 as f64 * scale);
308 data.push(q1 as f64 * scale);
309 }
310
311 let total: usize = self.shape.iter().product();
313 data.truncate(total);
314
315 DenseTensor::new(data, self.shape.clone())
316 }
317
318 pub fn quantized_data(&self) -> &[i8] {
320 &self.data
321 }
322
323 pub fn scale(&self) -> f64 {
325 self.scale[0]
326 }
327
328 pub fn memory_bytes(&self) -> usize {
330 let total_elements = self.shape.iter().product::<usize>();
331 match self.config.dtype {
332 QuantDtype::INT8 => total_elements, QuantDtype::INT4 => total_elements.div_ceil(2), QuantDtype::F32 => total_elements * 4, }
336 }
337
338 pub fn compression_ratio(&self) -> f64 {
340 let original_bytes = self.shape.iter().product::<usize>() * 4; original_bytes as f64 / self.memory_bytes() as f64
342 }
343}
344
345pub struct QuantizedMatMul;
347
348impl QuantizedMatMul {
349 pub fn matmul(a: &QuantizedTensor, b: &QuantizedTensor) -> DenseTensor {
358 Self::gemm_int8(a, b)
360 }
361
362 pub fn matmul_qd(a: &QuantizedTensor, b: &DenseTensor) -> DenseTensor {
371 let b_q = QuantizedTensor::from_tensor(b, QuantizationConfig::int8());
373 Self::gemm_int8(a, &b_q)
374 }
375
376 pub fn matmul_dq(a: &DenseTensor, b: &QuantizedTensor) -> DenseTensor {
385 let a_q = QuantizedTensor::from_tensor(a, QuantizationConfig::int8());
387 Self::gemm_int8(&a_q, b)
388 }
389
390 pub fn gemm_int8(a: &QuantizedTensor, b: &QuantizedTensor) -> DenseTensor {
407 let m = a.shape[0];
408 let k = a.shape[1];
409 let n = b.shape[1];
410
411 assert_eq!(a.shape[1], b.shape[0], "Inner dimensions must match");
412
413 let scale_a = if let Some(ref scales) = a.channel_scales {
415 scales
417 } else {
418 &vec![a.scale[0]; k]
420 };
421
422 let scale_b = if let Some(ref scales) = b.channel_scales {
423 scales
425 } else {
426 &vec![b.scale[0]; k]
428 };
429
430 let output_scales: Vec<f64> = if a.channel_scales.is_some() && b.channel_scales.is_some() {
432 let avg_scale_a = scale_a.iter().sum::<f64>() / scale_a.len() as f64;
435 let avg_scale_b = scale_b.iter().sum::<f64>() / scale_b.len() as f64;
436 vec![avg_scale_a * avg_scale_b; m * n]
437 } else if a.channel_scales.is_some() {
438 let scale_b_val = b.scale[0];
440 scale_a.iter().map(|&s| s * scale_b_val).collect()
441 } else if b.channel_scales.is_some() {
442 let scale_a_val = a.scale[0];
444 scale_b.iter().map(|&s| scale_a_val * s).collect()
445 } else {
446 vec![a.scale[0] * b.scale[0]; m * n]
448 };
449
450 let mut result = Vec::with_capacity(m * n);
452
453 for i in 0..m {
454 for j in 0..n {
455 let mut acc: i32 = 0;
456
457 for p in 0..k {
459 let a_val = a.data[i * k + p];
460 let b_val = b.data[p * n + j];
461 acc += (a_val as i32) * (b_val as i32);
462 }
463
464 let scale = output_scales[i * n + j];
466 result.push(acc as f64 * scale);
467 }
468 }
469
470 DenseTensor::new(result, vec![m, n])
471 }
472
473 pub fn gemm_int8_optimized(a: &QuantizedTensor, b: &QuantizedTensor) -> DenseTensor {
486 let m = a.shape[0];
487 let k = a.shape[1];
488 let n = b.shape[1];
489
490 assert_eq!(a.shape[1], b.shape[0], "Inner dimensions must match");
491
492 let scale = a.scale[0] * b.scale[0];
494
495 let mut result = vec![0.0f64; m * n];
496
497 const BLOCK_SIZE: usize = 32;
499
500 for i_block in (0..m).step_by(BLOCK_SIZE) {
501 for j_block in (0..n).step_by(BLOCK_SIZE) {
502 let i_end = (i_block + BLOCK_SIZE).min(m);
503 let j_end = (j_block + BLOCK_SIZE).min(n);
504
505 for p in 0..k {
506 for i in i_block..i_end {
508 let a_val = a.data[i * k + p] as i32;
509
510 let mut j = j_block;
512 while j + 4 <= j_end {
513 let b0 = b.data[p * n + j] as i32;
514 let b1 = b.data[p * n + j + 1] as i32;
515 let b2 = b.data[p * n + j + 2] as i32;
516 let b3 = b.data[p * n + j + 3] as i32;
517
518 result[i * n + j] += (a_val * b0) as f64;
522 result[i * n + j + 1] += (a_val * b1) as f64;
523 result[i * n + j + 2] += (a_val * b2) as f64;
524 result[i * n + j + 3] += (a_val * b3) as f64;
525
526 j += 4;
527 }
528
529 while j < j_end {
531 let b_val = b.data[p * n + j] as i32;
532 result[i * n + j] += (a_val * b_val) as f64;
533 j += 1;
534 }
535 }
536 }
537 }
538 }
539
540 for val in &mut result {
542 *val *= scale;
543 }
544
545 DenseTensor::new(result, vec![m, n])
546 }
547}
548
549pub mod weight_quantization {
551 use super::*;
552
553 pub fn quantize_weights(weights: &DenseTensor) -> QuantizedTensor {
555 QuantizedTensor::from_tensor(weights, QuantizationConfig::int8())
556 }
557
558 pub fn quantize_weights_per_channel(weights: &DenseTensor, axis: usize) -> QuantizedTensor {
560 QuantizedTensor::from_tensor(weights, QuantizationConfig::per_channel_int8(axis))
561 }
562
563 pub fn quantize_embeddings(embeddings: &DenseTensor) -> QuantizedTensor {
565 QuantizedTensor::from_tensor(embeddings, QuantizationConfig::per_channel_int8(0))
567 }
568
569 pub fn quantize_linear_weights(weights: &DenseTensor) -> QuantizedTensor {
571 QuantizedTensor::from_tensor(weights, QuantizationConfig::per_channel_int8(1))
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579
580 #[test]
581 fn test_int8_quantization() {
582 let tensor = DenseTensor::new(vec![0.0, 0.25, 0.5, 0.75, 1.0], vec![1, 5]);
583 let config = QuantizationConfig::int8();
584
585 let quantized = QuantizedTensor::from_tensor(&tensor, config);
586
587 assert_eq!(quantized.shape, vec![1, 5]);
588 assert_eq!(quantized.data.len(), 5);
589
590 let dequantized = quantized.dequantize();
592 let original = tensor.data();
593 let reconstructed = dequantized.data();
594
595 for (orig, recon) in original.iter().zip(reconstructed.iter()) {
596 assert!((orig - recon).abs() < 0.1, "Quantization error too large: orig={}, recon={}", orig, recon);
598 }
599 }
600
601 #[test]
602 fn test_int8_per_channel_quantization() {
603 let tensor = DenseTensor::new(vec![0.0, 1.0, 2.0, 10.0, 20.0, 30.0], vec![2, 3]);
604 let config = QuantizationConfig::per_channel_int8(1);
605
606 let quantized = QuantizedTensor::from_tensor(&tensor, config);
607
608 assert!(quantized.channel_scales.is_some());
609 assert_eq!(quantized.channel_scales.unwrap().len(), 3);
610 }
611
612 #[test]
613 fn test_int4_quantization() {
614 let tensor = DenseTensor::new(vec![0.0, 0.5, 1.0], vec![1, 3]);
615 let config = QuantizationConfig::int4();
616
617 let quantized = QuantizedTensor::from_tensor(&tensor, config);
618
619 assert_eq!(quantized.data.len(), 2);
621 }
622
623 #[test]
624 fn test_compression_ratio() {
625 let tensor = DenseTensor::new(vec![0.0; 100], vec![10, 10]);
626
627 let int8 = QuantizedTensor::from_tensor(&tensor, QuantizationConfig::int8());
628 assert!((int8.compression_ratio() - 4.0).abs() < 0.1); let int4 = QuantizedTensor::from_tensor(&tensor, QuantizationConfig::int4());
631 assert!((int4.compression_ratio() - 8.0).abs() < 0.1); }
633
634 #[test]
635 fn test_quantized_matmul() {
636 let a = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
637 let b = DenseTensor::new(vec![0.1, 0.2, 0.3, 0.4], vec![2, 2]);
638
639 let a_q = QuantizedTensor::from_tensor(&a, QuantizationConfig::int8());
640 let b_q = QuantizedTensor::from_tensor(&b, QuantizationConfig::int8());
641
642 let result = QuantizedMatMul::matmul(&a_q, &b_q);
643
644 assert_eq!(result.shape(), &[2, 2]);
645 }
646
647 #[test]
648 fn test_weight_quantization() {
649 let weights = DenseTensor::new(vec![-1.0, -0.5, 0.0, 0.5, 1.0], vec![1, 5]);
650
651 let quantized = weight_quantization::quantize_weights(&weights);
652
653 assert_eq!(quantized.config.dtype, QuantDtype::INT8);
654
655 let dequantized = quantized.dequantize();
656 let original = weights.data();
657 let reconstructed = dequantized.data();
658
659 for (orig, recon) in original.iter().zip(reconstructed.iter()) {
660 assert!((orig - recon).abs() < 0.15, "Weight quantization error too large: orig={}, recon={}", orig, recon);
662 }
663 }
664}