1use crate::dequantize::dequantize_tensor;
23use crate::quantize::quantize_tensor;
24use crate::types::{Q4_1Block, Q4Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
25use axonml_tensor::Tensor;
26use half::f16;
27use rayon::prelude::*;
28
29#[inline]
36fn dot_q8_block(block: &Q8Block, activations: &[f32]) -> f32 {
37 let scale = f32::from(block.scale);
38 let mut sum = 0.0f32;
39 for (d, a) in block.data.iter().zip(activations.iter()) {
40 sum += (*d as f32) * a;
41 }
42 sum * scale
43}
44
45#[inline]
47fn dot_q4_block(block: &Q4Block, activations: &[f32]) -> f32 {
48 let scale = f32::from(block.scale);
49 let unpacked = block.unpack();
50 let mut sum = 0.0f32;
51 for i in 0..unpacked.len().min(activations.len()) {
52 sum += (unpacked[i] as f32) * activations[i];
53 }
54 sum * scale
55}
56
57#[inline]
59fn dot_q4_1_block(block: &Q4_1Block, activations: &[f32]) -> f32 {
60 let scale = f32::from(block.scale);
61 let min = f32::from(block.min);
62 let unpacked = block.unpack();
63 let mut sum = 0.0f32;
64 for i in 0..unpacked.len().min(activations.len()) {
65 sum += (unpacked[i] as f32 * scale + min) * activations[i];
66 }
67 sum
68}
69
70#[inline]
72fn dot_f16_block(data: &[f16], activations: &[f32]) -> f32 {
73 let mut sum = 0.0f32;
74 for i in 0..data.len().min(activations.len()) {
75 sum += f32::from(data[i]) * activations[i];
76 }
77 sum
78}
79
80#[inline]
82fn dot_block(block: &QuantizedBlock, activations: &[f32]) -> f32 {
83 match block {
84 QuantizedBlock::Q8(b) => dot_q8_block(b, activations),
85 QuantizedBlock::Q4(b) => dot_q4_block(b, activations),
86 QuantizedBlock::Q4_1(b) => dot_q4_1_block(b, activations),
87 QuantizedBlock::Q5(b) => {
88 let scale = b.scale.to_f32();
89 let values = b.unpack();
90 values
91 .iter()
92 .zip(activations)
93 .map(|(&v, &a)| v as f32 * scale * a)
94 .sum()
95 }
96 QuantizedBlock::Q5_1(b) => {
97 let scale = b.scale.to_f32();
98 let min = b.min.to_f32();
99 let values = b.unpack();
100 values
101 .iter()
102 .zip(activations)
103 .map(|(&v, &a)| (v as f32 * scale + min) * a)
104 .sum()
105 }
106 QuantizedBlock::F16(data) => dot_f16_block(data, activations),
107 QuantizedBlock::F32(data) => {
108 let mut sum = 0.0f32;
109 for i in 0..data.len().min(activations.len()) {
110 sum += data[i] * activations[i];
111 }
112 sum
113 }
114 }
115}
116
117#[derive(Debug, Clone)]
135pub struct QuantizedLinear {
136 weight: QuantizedTensor,
138 bias: Option<Vec<f32>>,
140 pub in_features: usize,
142 pub out_features: usize,
144 pub quant_type: QuantType,
146 blocks_per_row: usize,
148}
149
150impl QuantizedLinear {
151 pub fn from_linear_params(
153 weight_data: &[f32],
154 bias_data: Option<&[f32]>,
155 in_features: usize,
156 out_features: usize,
157 quant_type: QuantType,
158 ) -> Self {
159 let weight_tensor = Tensor::from_vec(weight_data.to_vec(), &[out_features, in_features])
161 .expect("Failed to create weight tensor for quantization");
162
163 let weight =
164 quantize_tensor(&weight_tensor, quant_type).expect("Failed to quantize weight tensor");
165
166 let block_size = quant_type.block_size();
167 let blocks_per_row = in_features.div_ceil(block_size);
168
169 QuantizedLinear {
170 weight,
171 bias: bias_data.map(|b| b.to_vec()),
172 in_features,
173 out_features,
174 quant_type,
175 blocks_per_row,
176 }
177 }
178
179 pub fn forward_f32(&self, input: &[f32], batch_size: usize) -> Vec<f32> {
188 let mut output = vec![0.0f32; batch_size * self.out_features];
189
190 if !self.quant_type.is_block_quantized() {
192 let weight_flat = self.extract_flat_weights();
193 output
194 .par_chunks_mut(self.out_features)
195 .enumerate()
196 .for_each(|(b, out_row)| {
197 let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
198 for o in 0..self.out_features {
199 let w_start = o * self.in_features;
200 let mut sum = 0.0f32;
201 for k in 0..self.in_features {
202 sum += weight_flat[w_start + k] * input_row[k];
203 }
204 if let Some(ref bias) = self.bias {
205 sum += bias[o];
206 }
207 out_row[o] = sum;
208 }
209 });
210 return output;
211 }
212
213 let block_size = self.quant_type.block_size();
215
216 output
217 .par_chunks_mut(self.out_features)
218 .enumerate()
219 .for_each(|(b, out_row)| {
220 let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
221
222 for o in 0..self.out_features {
223 let row_block_start = o * self.blocks_per_row;
224 let mut sum = 0.0f32;
225
226 for blk_idx in 0..self.blocks_per_row {
227 let act_start = blk_idx * block_size;
228 let act_end = (act_start + block_size).min(self.in_features);
229 let act_slice = &input_row[act_start..act_end];
230
231 let block = &self.weight.blocks[row_block_start + blk_idx];
232 sum += dot_block(block, act_slice);
233 }
234
235 if let Some(ref bias) = self.bias {
236 sum += bias[o];
237 }
238
239 out_row[o] = sum;
240 }
241 });
242
243 output
244 }
245
246 fn extract_flat_weights(&self) -> Vec<f32> {
248 let mut flat = Vec::with_capacity(self.in_features * self.out_features);
249 for block in &self.weight.blocks {
250 match block {
251 QuantizedBlock::F16(data) => {
252 flat.extend(data.iter().map(|v| f32::from(*v)));
253 }
254 QuantizedBlock::F32(data) => {
255 flat.extend_from_slice(data);
256 }
257 _ => {} }
259 }
260 flat
261 }
262
263 pub fn forward_var(&self, input: &axonml_autograd::Variable) -> axonml_autograd::Variable {
268 let shape = input.shape();
269 let batch = if shape.len() > 1 { shape[0] } else { 1 };
270 let input_data = input.data().to_vec();
271
272 let output_data = self.forward_f32(&input_data, batch);
273
274 let output_tensor = Tensor::from_vec(output_data, &[batch, self.out_features])
275 .expect("Failed to create output tensor");
276
277 axonml_autograd::Variable::new(output_tensor, false)
278 }
279
280 pub fn weight_bytes(&self) -> usize {
282 self.weight.size_bytes()
283 }
284
285 pub fn compression_ratio(&self) -> f32 {
287 self.weight.compression_ratio()
288 }
289
290 pub fn dequantize_weights(&self) -> Tensor<f32> {
292 dequantize_tensor(&self.weight).expect("Failed to dequantize weights")
293 }
294}
295
296pub fn quantize_parameters(
305 params: &[axonml_nn::Parameter],
306 quant_type: QuantType,
307) -> Vec<QuantizedTensor> {
308 params
309 .par_iter()
310 .map(|param| {
311 let tensor = param.data();
312 quantize_tensor(&tensor, quant_type).expect("Failed to quantize parameter")
313 })
314 .collect()
315}
316
317pub struct QuantizedModel {
336 pub quantized_params: Vec<QuantizedTensor>,
338 pub quant_type: QuantType,
340 pub total_params: usize,
342 pub total_bytes: usize,
344 pub original_bytes: usize,
346}
347
348impl QuantizedModel {
349 pub fn from_module<M: axonml_nn::Module>(module: &M, quant_type: QuantType) -> Self {
351 let params = module.parameters();
352 let total_params: usize = params.iter().map(|p| p.numel()).sum();
353 let original_bytes = total_params * 4;
354
355 let quantized_params = quantize_parameters(¶ms, quant_type);
356
357 let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
358
359 QuantizedModel {
360 quantized_params,
361 quant_type,
362 total_params,
363 total_bytes,
364 original_bytes,
365 }
366 }
367
368 pub fn load_into_module<M: axonml_nn::Module>(&self, module: &M) {
374 let params = module.parameters();
375 for (param, qparam) in params.iter().zip(self.quantized_params.iter()) {
376 let tensor = dequantize_tensor(qparam).expect("Failed to dequantize parameter");
377 param.update_data(tensor);
378 }
379 }
380
381 pub fn compression_ratio(&self) -> f32 {
383 self.original_bytes as f32 / self.total_bytes as f32
384 }
385
386 pub fn summary(&self) -> String {
388 format!(
389 "QuantizedModel(type={}, params={}, f32={:.1}MB, quant={:.1}MB, ratio={:.1}x)",
390 self.quant_type,
391 self.total_params,
392 self.original_bytes as f64 / 1024.0 / 1024.0,
393 self.total_bytes as f64 / 1024.0 / 1024.0,
394 self.compression_ratio(),
395 )
396 }
397}
398
399pub fn serialize_quantized(model: &QuantizedModel) -> Vec<u8> {
405 let mut buf = Vec::new();
406
407 buf.extend_from_slice(b"AXQT");
409 buf.push(1u8);
411 buf.push(match model.quant_type {
413 QuantType::Q8_0 => 0,
414 QuantType::Q4_0 => 1,
415 QuantType::Q4_1 => 2,
416 QuantType::Q5_0 => 3,
417 QuantType::Q5_1 => 4,
418 QuantType::F16 => 5,
419 QuantType::F32 => 6,
420 });
421 buf.extend_from_slice(&(model.quantized_params.len() as u32).to_le_bytes());
423 buf.extend_from_slice(&(model.total_params as u64).to_le_bytes());
425
426 for qt in &model.quantized_params {
428 buf.extend_from_slice(&(qt.shape.len() as u32).to_le_bytes());
430 for &dim in &qt.shape {
431 buf.extend_from_slice(&(dim as u32).to_le_bytes());
432 }
433 buf.extend_from_slice(&(qt.blocks.len() as u32).to_le_bytes());
435 for block in &qt.blocks {
437 match block {
438 QuantizedBlock::Q8(b) => {
439 buf.extend_from_slice(&b.to_bytes());
440 }
441 QuantizedBlock::Q4(b) => {
442 buf.extend_from_slice(&b.to_bytes());
443 }
444 QuantizedBlock::Q4_1(b) => {
445 buf.extend_from_slice(&b.to_bytes());
446 }
447 QuantizedBlock::Q5(b) => {
448 buf.extend_from_slice(&b.to_bytes());
449 }
450 QuantizedBlock::Q5_1(b) => {
451 buf.extend_from_slice(&b.to_bytes());
452 }
453 QuantizedBlock::F16(data) => {
454 for &v in data {
455 buf.extend_from_slice(&v.to_le_bytes());
456 }
457 }
458 QuantizedBlock::F32(data) => {
459 for &v in data {
460 buf.extend_from_slice(&v.to_le_bytes());
461 }
462 }
463 }
464 }
465 }
466
467 buf
468}
469
470pub fn deserialize_quantized(data: &[u8]) -> Result<QuantizedModel, String> {
472 if data.len() < 18 || &data[0..4] != b"AXQT" {
473 return Err("Invalid quantized model file (bad magic)".to_string());
474 }
475
476 let version = data[4];
477 if version != 1 {
478 return Err(format!("Unsupported quantized model version: {version}"));
479 }
480
481 let quant_type = match data[5] {
482 0 => QuantType::Q8_0,
483 1 => QuantType::Q4_0,
484 2 => QuantType::Q4_1,
485 3 => QuantType::Q5_0,
486 4 => QuantType::Q5_1,
487 5 => QuantType::F16,
488 6 => QuantType::F32,
489 x => return Err(format!("Unknown quant type byte: {x}")),
490 };
491
492 let num_tensors = u32::from_le_bytes([data[6], data[7], data[8], data[9]]) as usize;
493 let total_params = u64::from_le_bytes([
494 data[10], data[11], data[12], data[13], data[14], data[15], data[16], data[17],
495 ]) as usize;
496
497 let mut offset = 18usize;
498 let mut quantized_params = Vec::with_capacity(num_tensors);
499
500 let block_bytes = quant_type.bytes_per_block();
501
502 for _ in 0..num_tensors {
503 if offset + 4 > data.len() {
504 return Err("Truncated quantized model file".to_string());
505 }
506
507 let shape_len = u32::from_le_bytes([
509 data[offset],
510 data[offset + 1],
511 data[offset + 2],
512 data[offset + 3],
513 ]) as usize;
514 offset += 4;
515
516 let mut shape = Vec::with_capacity(shape_len);
517 for _ in 0..shape_len {
518 let dim = u32::from_le_bytes([
519 data[offset],
520 data[offset + 1],
521 data[offset + 2],
522 data[offset + 3],
523 ]) as usize;
524 shape.push(dim);
525 offset += 4;
526 }
527
528 let num_blocks = u32::from_le_bytes([
530 data[offset],
531 data[offset + 1],
532 data[offset + 2],
533 data[offset + 3],
534 ]) as usize;
535 offset += 4;
536
537 let mut blocks = Vec::with_capacity(num_blocks);
539 for _ in 0..num_blocks {
540 if offset + block_bytes > data.len() {
541 return Err("Truncated block data".to_string());
542 }
543
544 let block = match quant_type {
545 QuantType::Q8_0 => {
546 let b =
547 Q8Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q8 block")?;
548 QuantizedBlock::Q8(b)
549 }
550 QuantType::Q4_0 => {
551 let b =
552 Q4Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q4 block")?;
553 QuantizedBlock::Q4(b)
554 }
555 QuantType::Q4_1 => {
556 let scale = f16::from_le_bytes([data[offset], data[offset + 1]]);
557 let min = f16::from_le_bytes([data[offset + 2], data[offset + 3]]);
558 let mut block_data = [0u8; 16];
559 block_data.copy_from_slice(&data[offset + 4..offset + 20]);
560 QuantizedBlock::Q4_1(Q4_1Block::new(scale, min, block_data))
561 }
562 QuantType::F16 => {
563 let v = f16::from_le_bytes([data[offset], data[offset + 1]]);
564 QuantizedBlock::F16(vec![v])
565 }
566 QuantType::F32 => {
567 let v = f32::from_le_bytes([
568 data[offset],
569 data[offset + 1],
570 data[offset + 2],
571 data[offset + 3],
572 ]);
573 QuantizedBlock::F32(vec![v])
574 }
575 _ => return Err("Unsupported quant type for deserialization".to_string()),
576 };
577
578 blocks.push(block);
579 offset += block_bytes;
580 }
581
582 quantized_params.push(QuantizedTensor::new(shape, quant_type, blocks));
583 }
584
585 let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
586 let original_bytes = total_params * 4;
587
588 Ok(QuantizedModel {
589 quantized_params,
590 quant_type,
591 total_params,
592 total_bytes,
593 original_bytes,
594 })
595}
596
597#[cfg(test)]
602mod tests {
603 use super::*;
604
605 #[test]
606 fn test_quantized_linear_q8_forward() {
607 let in_f = 64;
608 let out_f = 16;
609 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 5.0).collect();
610 let bias: Vec<f32> = (0..out_f).map(|i| i as f32 * 0.1).collect();
611
612 let ql =
613 QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
614
615 let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
616 let output = ql.forward_f32(&input, 1);
617
618 assert_eq!(output.len(), out_f);
619 let sum: f32 = output.iter().sum();
621 assert!(sum.abs() > 0.01, "Output should be non-zero, got sum={sum}");
622
623 let ref_ql =
625 QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::F32);
626 let ref_output = ref_ql.forward_f32(&input, 1);
627
628 let max_err: f32 = output
630 .iter()
631 .zip(ref_output.iter())
632 .map(|(a, b)| (a - b).abs())
633 .fold(0.0f32, f32::max);
634 assert!(max_err < 1.0, "Q8 error too large: {max_err}");
635 }
636
637 #[test]
638 fn test_quantized_linear_q4_forward() {
639 let in_f = 64;
640 let out_f = 8;
641 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.02) - 5.0).collect();
642
643 let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
644
645 let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
646 let output = ql.forward_f32(&input, 1);
647
648 assert_eq!(output.len(), out_f);
649 let sum: f32 = output.iter().sum();
650 assert!(sum.abs() > 0.01, "Output should be non-zero");
651 }
652
653 #[test]
654 fn test_quantized_linear_batch() {
655 let in_f = 32;
656 let out_f = 8;
657 let batch = 4;
658 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
659
660 let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
661
662 let input: Vec<f32> = (0..batch * in_f).map(|i| i as f32 * 0.01).collect();
663 let output = ql.forward_f32(&input, batch);
664
665 assert_eq!(output.len(), batch * out_f);
666 }
667
668 #[test]
669 fn test_quantized_linear_compression() {
670 let in_f = 1024;
671 let out_f = 512;
672 let weight: Vec<f32> = vec![0.1; in_f * out_f];
673
674 let ql_q8 =
675 QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
676 let ql_q4 =
677 QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
678
679 assert!(ql_q8.compression_ratio() > 3.5, "Q8 should compress ~4x");
680 assert!(ql_q4.compression_ratio() > 6.0, "Q4 should compress ~7-8x");
681 }
682
683 #[test]
684 fn test_quantized_model_roundtrip() {
685 let in_f = 32;
686 let out_f = 8;
687 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
688 let weight_tensor = Tensor::from_vec(weight.clone(), &[out_f, in_f]).unwrap();
689 let qt = quantize_tensor(&weight_tensor, QuantType::Q8_0).unwrap();
690
691 let model = QuantizedModel {
692 quantized_params: vec![qt],
693 quant_type: QuantType::Q8_0,
694 total_params: in_f * out_f,
695 total_bytes: 0,
696 original_bytes: in_f * out_f * 4,
697 };
698
699 let serialized = serialize_quantized(&model);
700 let deserialized = deserialize_quantized(&serialized).unwrap();
701
702 assert_eq!(deserialized.quant_type, QuantType::Q8_0);
703 assert_eq!(deserialized.quantized_params.len(), 1);
704 assert_eq!(deserialized.quantized_params[0].shape, vec![out_f, in_f]);
705 }
706
707 #[test]
708 fn test_quantized_linear_variable_forward() {
709 let in_f = 32;
710 let out_f = 8;
711 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
712 let bias: Vec<f32> = vec![0.5; out_f];
713
714 let ql =
715 QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
716
717 let input_tensor =
718 Tensor::from_vec((0..2 * in_f).map(|i| i as f32 * 0.1).collect(), &[2, in_f]).unwrap();
719 let input_var = axonml_autograd::Variable::new(input_tensor, false);
720
721 let output = ql.forward_var(&input_var);
722
723 assert_eq!(output.shape(), vec![2, out_f]);
724 }
725}