1use crate::dequantize::dequantize_tensor;
22use crate::quantize::quantize_tensor;
23use crate::types::{Q4_1Block, Q4Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
24use axonml_tensor::Tensor;
25use half::f16;
26use rayon::prelude::*;
27
28#[inline]
35fn dot_q8_block(block: &Q8Block, activations: &[f32]) -> f32 {
36 let scale = f32::from(block.scale);
37 let mut sum = 0.0f32;
38 for (d, a) in block.data.iter().zip(activations.iter()) {
39 sum += (*d as f32) * a;
40 }
41 sum * scale
42}
43
44#[inline]
46fn dot_q4_block(block: &Q4Block, activations: &[f32]) -> f32 {
47 let scale = f32::from(block.scale);
48 let unpacked = block.unpack();
49 let mut sum = 0.0f32;
50 for i in 0..unpacked.len().min(activations.len()) {
51 sum += (unpacked[i] as f32) * activations[i];
52 }
53 sum * scale
54}
55
56#[inline]
58fn dot_q4_1_block(block: &Q4_1Block, activations: &[f32]) -> f32 {
59 let scale = f32::from(block.scale);
60 let min = f32::from(block.min);
61 let unpacked = block.unpack();
62 let mut sum = 0.0f32;
63 for i in 0..unpacked.len().min(activations.len()) {
64 sum += (unpacked[i] as f32 * scale + min) * activations[i];
65 }
66 sum
67}
68
69#[inline]
71fn dot_f16_block(data: &[f16], activations: &[f32]) -> f32 {
72 let mut sum = 0.0f32;
73 for i in 0..data.len().min(activations.len()) {
74 sum += f32::from(data[i]) * activations[i];
75 }
76 sum
77}
78
79#[inline]
81fn dot_block(block: &QuantizedBlock, activations: &[f32]) -> f32 {
82 match block {
83 QuantizedBlock::Q8(b) => dot_q8_block(b, activations),
84 QuantizedBlock::Q4(b) => dot_q4_block(b, activations),
85 QuantizedBlock::Q4_1(b) => dot_q4_1_block(b, activations),
86 QuantizedBlock::Q5(b) => {
87 let scale = b.scale.to_f32();
88 let values = b.unpack();
89 values
90 .iter()
91 .zip(activations)
92 .map(|(&v, &a)| v as f32 * scale * a)
93 .sum()
94 }
95 QuantizedBlock::Q5_1(b) => {
96 let scale = b.scale.to_f32();
97 let min = b.min.to_f32();
98 let values = b.unpack();
99 values
100 .iter()
101 .zip(activations)
102 .map(|(&v, &a)| (v as f32 * scale + min) * a)
103 .sum()
104 }
105 QuantizedBlock::F16(data) => dot_f16_block(data, activations),
106 QuantizedBlock::F32(data) => {
107 let mut sum = 0.0f32;
108 for i in 0..data.len().min(activations.len()) {
109 sum += data[i] * activations[i];
110 }
111 sum
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
134pub struct QuantizedLinear {
135 weight: QuantizedTensor,
137 bias: Option<Vec<f32>>,
139 pub in_features: usize,
141 pub out_features: usize,
143 pub quant_type: QuantType,
145 blocks_per_row: usize,
147}
148
149impl QuantizedLinear {
150 pub fn from_linear_params(
152 weight_data: &[f32],
153 bias_data: Option<&[f32]>,
154 in_features: usize,
155 out_features: usize,
156 quant_type: QuantType,
157 ) -> Self {
158 let weight_tensor = Tensor::from_vec(weight_data.to_vec(), &[out_features, in_features])
160 .expect("Failed to create weight tensor for quantization");
161
162 let weight =
163 quantize_tensor(&weight_tensor, quant_type).expect("Failed to quantize weight tensor");
164
165 let block_size = quant_type.block_size();
166 let blocks_per_row = in_features.div_ceil(block_size);
167
168 QuantizedLinear {
169 weight,
170 bias: bias_data.map(|b| b.to_vec()),
171 in_features,
172 out_features,
173 quant_type,
174 blocks_per_row,
175 }
176 }
177
178 pub fn forward_f32(&self, input: &[f32], batch_size: usize) -> Vec<f32> {
187 let mut output = vec![0.0f32; batch_size * self.out_features];
188
189 if !self.quant_type.is_block_quantized() {
191 let weight_flat = self.extract_flat_weights();
192 output
193 .par_chunks_mut(self.out_features)
194 .enumerate()
195 .for_each(|(b, out_row)| {
196 let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
197 for o in 0..self.out_features {
198 let w_start = o * self.in_features;
199 let mut sum = 0.0f32;
200 for k in 0..self.in_features {
201 sum += weight_flat[w_start + k] * input_row[k];
202 }
203 if let Some(ref bias) = self.bias {
204 sum += bias[o];
205 }
206 out_row[o] = sum;
207 }
208 });
209 return output;
210 }
211
212 let block_size = self.quant_type.block_size();
214
215 output
216 .par_chunks_mut(self.out_features)
217 .enumerate()
218 .for_each(|(b, out_row)| {
219 let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
220
221 for o in 0..self.out_features {
222 let row_block_start = o * self.blocks_per_row;
223 let mut sum = 0.0f32;
224
225 for blk_idx in 0..self.blocks_per_row {
226 let act_start = blk_idx * block_size;
227 let act_end = (act_start + block_size).min(self.in_features);
228 let act_slice = &input_row[act_start..act_end];
229
230 let block = &self.weight.blocks[row_block_start + blk_idx];
231 sum += dot_block(block, act_slice);
232 }
233
234 if let Some(ref bias) = self.bias {
235 sum += bias[o];
236 }
237
238 out_row[o] = sum;
239 }
240 });
241
242 output
243 }
244
245 fn extract_flat_weights(&self) -> Vec<f32> {
247 let mut flat = Vec::with_capacity(self.in_features * self.out_features);
248 for block in &self.weight.blocks {
249 match block {
250 QuantizedBlock::F16(data) => {
251 flat.extend(data.iter().map(|v| f32::from(*v)));
252 }
253 QuantizedBlock::F32(data) => {
254 flat.extend_from_slice(data);
255 }
256 _ => {} }
258 }
259 flat
260 }
261
262 pub fn forward_var(&self, input: &axonml_autograd::Variable) -> axonml_autograd::Variable {
267 let shape = input.shape();
268 let batch = if shape.len() > 1 { shape[0] } else { 1 };
269 let input_data = input.data().to_vec();
270
271 let output_data = self.forward_f32(&input_data, batch);
272
273 let output_tensor = Tensor::from_vec(output_data, &[batch, self.out_features])
274 .expect("Failed to create output tensor");
275
276 axonml_autograd::Variable::new(output_tensor, false)
277 }
278
279 pub fn weight_bytes(&self) -> usize {
281 self.weight.size_bytes()
282 }
283
284 pub fn compression_ratio(&self) -> f32 {
286 self.weight.compression_ratio()
287 }
288
289 pub fn dequantize_weights(&self) -> Tensor<f32> {
291 dequantize_tensor(&self.weight).expect("Failed to dequantize weights")
292 }
293}
294
295pub fn quantize_parameters(
304 params: &[axonml_nn::Parameter],
305 quant_type: QuantType,
306) -> Vec<QuantizedTensor> {
307 params
308 .par_iter()
309 .map(|param| {
310 let tensor = param.data();
311 quantize_tensor(&tensor, quant_type).expect("Failed to quantize parameter")
312 })
313 .collect()
314}
315
316pub struct QuantizedModel {
335 pub quantized_params: Vec<QuantizedTensor>,
337 pub quant_type: QuantType,
339 pub total_params: usize,
341 pub total_bytes: usize,
343 pub original_bytes: usize,
345}
346
347impl QuantizedModel {
348 pub fn from_module<M: axonml_nn::Module>(module: &M, quant_type: QuantType) -> Self {
350 let params = module.parameters();
351 let total_params: usize = params.iter().map(|p| p.numel()).sum();
352 let original_bytes = total_params * 4;
353
354 let quantized_params = quantize_parameters(¶ms, quant_type);
355
356 let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
357
358 QuantizedModel {
359 quantized_params,
360 quant_type,
361 total_params,
362 total_bytes,
363 original_bytes,
364 }
365 }
366
367 pub fn load_into_module<M: axonml_nn::Module>(&self, module: &M) {
373 let params = module.parameters();
374 for (param, qparam) in params.iter().zip(self.quantized_params.iter()) {
375 let tensor = dequantize_tensor(qparam).expect("Failed to dequantize parameter");
376 param.update_data(tensor);
377 }
378 }
379
380 pub fn compression_ratio(&self) -> f32 {
382 self.original_bytes as f32 / self.total_bytes as f32
383 }
384
385 pub fn summary(&self) -> String {
387 format!(
388 "QuantizedModel(type={}, params={}, f32={:.1}MB, quant={:.1}MB, ratio={:.1}x)",
389 self.quant_type,
390 self.total_params,
391 self.original_bytes as f64 / 1024.0 / 1024.0,
392 self.total_bytes as f64 / 1024.0 / 1024.0,
393 self.compression_ratio(),
394 )
395 }
396}
397
398pub fn serialize_quantized(model: &QuantizedModel) -> Vec<u8> {
404 let mut buf = Vec::new();
405
406 buf.extend_from_slice(b"AXQT");
408 buf.push(1u8);
410 buf.push(match model.quant_type {
412 QuantType::Q8_0 => 0,
413 QuantType::Q4_0 => 1,
414 QuantType::Q4_1 => 2,
415 QuantType::Q5_0 => 3,
416 QuantType::Q5_1 => 4,
417 QuantType::F16 => 5,
418 QuantType::F32 => 6,
419 });
420 buf.extend_from_slice(&(model.quantized_params.len() as u32).to_le_bytes());
422 buf.extend_from_slice(&(model.total_params as u64).to_le_bytes());
424
425 for qt in &model.quantized_params {
427 buf.extend_from_slice(&(qt.shape.len() as u32).to_le_bytes());
429 for &dim in &qt.shape {
430 buf.extend_from_slice(&(dim as u32).to_le_bytes());
431 }
432 buf.extend_from_slice(&(qt.blocks.len() as u32).to_le_bytes());
434 for block in &qt.blocks {
436 match block {
437 QuantizedBlock::Q8(b) => {
438 buf.extend_from_slice(&b.to_bytes());
439 }
440 QuantizedBlock::Q4(b) => {
441 buf.extend_from_slice(&b.to_bytes());
442 }
443 QuantizedBlock::Q4_1(b) => {
444 buf.extend_from_slice(&b.to_bytes());
445 }
446 QuantizedBlock::Q5(b) => {
447 buf.extend_from_slice(&b.to_bytes());
448 }
449 QuantizedBlock::Q5_1(b) => {
450 buf.extend_from_slice(&b.to_bytes());
451 }
452 QuantizedBlock::F16(data) => {
453 for &v in data {
454 buf.extend_from_slice(&v.to_le_bytes());
455 }
456 }
457 QuantizedBlock::F32(data) => {
458 for &v in data {
459 buf.extend_from_slice(&v.to_le_bytes());
460 }
461 }
462 }
463 }
464 }
465
466 buf
467}
468
469pub fn deserialize_quantized(data: &[u8]) -> Result<QuantizedModel, String> {
471 if data.len() < 18 || &data[0..4] != b"AXQT" {
472 return Err("Invalid quantized model file (bad magic)".to_string());
473 }
474
475 let version = data[4];
476 if version != 1 {
477 return Err(format!("Unsupported quantized model version: {version}"));
478 }
479
480 let quant_type = match data[5] {
481 0 => QuantType::Q8_0,
482 1 => QuantType::Q4_0,
483 2 => QuantType::Q4_1,
484 3 => QuantType::Q5_0,
485 4 => QuantType::Q5_1,
486 5 => QuantType::F16,
487 6 => QuantType::F32,
488 x => return Err(format!("Unknown quant type byte: {x}")),
489 };
490
491 let num_tensors = u32::from_le_bytes([data[6], data[7], data[8], data[9]]) as usize;
492 let total_params = u64::from_le_bytes([
493 data[10], data[11], data[12], data[13], data[14], data[15], data[16], data[17],
494 ]) as usize;
495
496 let mut offset = 18usize;
497 let mut quantized_params = Vec::with_capacity(num_tensors);
498
499 let block_bytes = quant_type.bytes_per_block();
500
501 for _ in 0..num_tensors {
502 if offset + 4 > data.len() {
503 return Err("Truncated quantized model file".to_string());
504 }
505
506 let shape_len = u32::from_le_bytes([
508 data[offset],
509 data[offset + 1],
510 data[offset + 2],
511 data[offset + 3],
512 ]) as usize;
513 offset += 4;
514
515 let mut shape = Vec::with_capacity(shape_len);
516 for _ in 0..shape_len {
517 let dim = u32::from_le_bytes([
518 data[offset],
519 data[offset + 1],
520 data[offset + 2],
521 data[offset + 3],
522 ]) as usize;
523 shape.push(dim);
524 offset += 4;
525 }
526
527 let num_blocks = u32::from_le_bytes([
529 data[offset],
530 data[offset + 1],
531 data[offset + 2],
532 data[offset + 3],
533 ]) as usize;
534 offset += 4;
535
536 let mut blocks = Vec::with_capacity(num_blocks);
538 for _ in 0..num_blocks {
539 if offset + block_bytes > data.len() {
540 return Err("Truncated block data".to_string());
541 }
542
543 let block = match quant_type {
544 QuantType::Q8_0 => {
545 let b =
546 Q8Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q8 block")?;
547 QuantizedBlock::Q8(b)
548 }
549 QuantType::Q4_0 => {
550 let b =
551 Q4Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q4 block")?;
552 QuantizedBlock::Q4(b)
553 }
554 QuantType::Q4_1 => {
555 let scale = f16::from_le_bytes([data[offset], data[offset + 1]]);
556 let min = f16::from_le_bytes([data[offset + 2], data[offset + 3]]);
557 let mut block_data = [0u8; 16];
558 block_data.copy_from_slice(&data[offset + 4..offset + 20]);
559 QuantizedBlock::Q4_1(Q4_1Block::new(scale, min, block_data))
560 }
561 QuantType::F16 => {
562 let v = f16::from_le_bytes([data[offset], data[offset + 1]]);
563 QuantizedBlock::F16(vec![v])
564 }
565 QuantType::F32 => {
566 let v = f32::from_le_bytes([
567 data[offset],
568 data[offset + 1],
569 data[offset + 2],
570 data[offset + 3],
571 ]);
572 QuantizedBlock::F32(vec![v])
573 }
574 _ => return Err("Unsupported quant type for deserialization".to_string()),
575 };
576
577 blocks.push(block);
578 offset += block_bytes;
579 }
580
581 quantized_params.push(QuantizedTensor::new(shape, quant_type, blocks));
582 }
583
584 let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
585 let original_bytes = total_params * 4;
586
587 Ok(QuantizedModel {
588 quantized_params,
589 quant_type,
590 total_params,
591 total_bytes,
592 original_bytes,
593 })
594}
595
596#[cfg(test)]
601mod tests {
602 use super::*;
603
604 #[test]
605 fn test_quantized_linear_q8_forward() {
606 let in_f = 64;
607 let out_f = 16;
608 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 5.0).collect();
609 let bias: Vec<f32> = (0..out_f).map(|i| i as f32 * 0.1).collect();
610
611 let ql =
612 QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
613
614 let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
615 let output = ql.forward_f32(&input, 1);
616
617 assert_eq!(output.len(), out_f);
618 let sum: f32 = output.iter().sum();
620 assert!(sum.abs() > 0.01, "Output should be non-zero, got sum={sum}");
621
622 let ref_ql =
624 QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::F32);
625 let ref_output = ref_ql.forward_f32(&input, 1);
626
627 let max_err: f32 = output
629 .iter()
630 .zip(ref_output.iter())
631 .map(|(a, b)| (a - b).abs())
632 .fold(0.0f32, f32::max);
633 assert!(max_err < 1.0, "Q8 error too large: {max_err}");
634 }
635
636 #[test]
637 fn test_quantized_linear_q4_forward() {
638 let in_f = 64;
639 let out_f = 8;
640 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.02) - 5.0).collect();
641
642 let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
643
644 let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
645 let output = ql.forward_f32(&input, 1);
646
647 assert_eq!(output.len(), out_f);
648 let sum: f32 = output.iter().sum();
649 assert!(sum.abs() > 0.01, "Output should be non-zero");
650 }
651
652 #[test]
653 fn test_quantized_linear_batch() {
654 let in_f = 32;
655 let out_f = 8;
656 let batch = 4;
657 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
658
659 let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
660
661 let input: Vec<f32> = (0..batch * in_f).map(|i| i as f32 * 0.01).collect();
662 let output = ql.forward_f32(&input, batch);
663
664 assert_eq!(output.len(), batch * out_f);
665 }
666
667 #[test]
668 fn test_quantized_linear_compression() {
669 let in_f = 1024;
670 let out_f = 512;
671 let weight: Vec<f32> = vec![0.1; in_f * out_f];
672
673 let ql_q8 =
674 QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
675 let ql_q4 =
676 QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
677
678 assert!(ql_q8.compression_ratio() > 3.5, "Q8 should compress ~4x");
679 assert!(ql_q4.compression_ratio() > 6.0, "Q4 should compress ~7-8x");
680 }
681
682 #[test]
683 fn test_quantized_model_roundtrip() {
684 let in_f = 32;
685 let out_f = 8;
686 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
687 let weight_tensor = Tensor::from_vec(weight.clone(), &[out_f, in_f]).unwrap();
688 let qt = quantize_tensor(&weight_tensor, QuantType::Q8_0).unwrap();
689
690 let model = QuantizedModel {
691 quantized_params: vec![qt],
692 quant_type: QuantType::Q8_0,
693 total_params: in_f * out_f,
694 total_bytes: 0,
695 original_bytes: in_f * out_f * 4,
696 };
697
698 let serialized = serialize_quantized(&model);
699 let deserialized = deserialize_quantized(&serialized).unwrap();
700
701 assert_eq!(deserialized.quant_type, QuantType::Q8_0);
702 assert_eq!(deserialized.quantized_params.len(), 1);
703 assert_eq!(deserialized.quantized_params[0].shape, vec![out_f, in_f]);
704 }
705
706 #[test]
707 fn test_quantized_linear_variable_forward() {
708 let in_f = 32;
709 let out_f = 8;
710 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
711 let bias: Vec<f32> = vec![0.5; out_f];
712
713 let ql =
714 QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
715
716 let input_tensor =
717 Tensor::from_vec((0..2 * in_f).map(|i| i as f32 * 0.1).collect(), &[2, in_f]).unwrap();
718 let input_var = axonml_autograd::Variable::new(input_tensor, false);
719
720 let output = ql.forward_var(&input_var);
721
722 assert_eq!(output.shape(), vec![2, out_f]);
723 }
724}