1use crate::dequantize::dequantize_tensor;
22use crate::quantize::quantize_tensor;
23use crate::types::{Q4_1Block, Q4Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
24use axonml_tensor::Tensor;
25use half::f16;
26use rayon::prelude::*;
27
28#[inline]
35fn dot_q8_block(block: &Q8Block, activations: &[f32]) -> f32 {
36 let scale = f32::from(block.scale);
37 let mut sum = 0.0f32;
38 for (d, a) in block.data.iter().zip(activations.iter()) {
39 sum += (*d as f32) * a;
40 }
41 sum * scale
42}
43
44#[inline]
46fn dot_q4_block(block: &Q4Block, activations: &[f32]) -> f32 {
47 let scale = f32::from(block.scale);
48 let unpacked = block.unpack();
49 let mut sum = 0.0f32;
50 for i in 0..unpacked.len().min(activations.len()) {
51 sum += (unpacked[i] as f32) * activations[i];
52 }
53 sum * scale
54}
55
56#[inline]
58fn dot_q4_1_block(block: &Q4_1Block, activations: &[f32]) -> f32 {
59 let scale = f32::from(block.scale);
60 let min = f32::from(block.min);
61 let unpacked = block.unpack();
62 let mut sum = 0.0f32;
63 for i in 0..unpacked.len().min(activations.len()) {
64 sum += (unpacked[i] as f32 * scale + min) * activations[i];
65 }
66 sum
67}
68
69#[inline]
71fn dot_f16_block(data: &[f16], activations: &[f32]) -> f32 {
72 let mut sum = 0.0f32;
73 for i in 0..data.len().min(activations.len()) {
74 sum += f32::from(data[i]) * activations[i];
75 }
76 sum
77}
78
79#[inline]
81fn dot_block(block: &QuantizedBlock, activations: &[f32]) -> f32 {
82 match block {
83 QuantizedBlock::Q8(b) => dot_q8_block(b, activations),
84 QuantizedBlock::Q4(b) => dot_q4_block(b, activations),
85 QuantizedBlock::Q4_1(b) => dot_q4_1_block(b, activations),
86 QuantizedBlock::Q5(b) => {
87 let scale = b.scale.to_f32();
88 let values = b.unpack();
89 values.iter().zip(activations).map(|(&v, &a)| v as f32 * scale * a).sum()
90 }
91 QuantizedBlock::Q5_1(b) => {
92 let scale = b.scale.to_f32();
93 let min = b.min.to_f32();
94 let values = b.unpack();
95 values.iter().zip(activations).map(|(&v, &a)| (v as f32 * scale + min) * a).sum()
96 }
97 QuantizedBlock::F16(data) => dot_f16_block(data, activations),
98 QuantizedBlock::F32(data) => {
99 let mut sum = 0.0f32;
100 for i in 0..data.len().min(activations.len()) {
101 sum += data[i] * activations[i];
102 }
103 sum
104 }
105 }
106}
107
108#[derive(Debug, Clone)]
126pub struct QuantizedLinear {
127 weight: QuantizedTensor,
129 bias: Option<Vec<f32>>,
131 pub in_features: usize,
133 pub out_features: usize,
135 pub quant_type: QuantType,
137 blocks_per_row: usize,
139}
140
141impl QuantizedLinear {
142 pub fn from_linear_params(
144 weight_data: &[f32],
145 bias_data: Option<&[f32]>,
146 in_features: usize,
147 out_features: usize,
148 quant_type: QuantType,
149 ) -> Self {
150 let weight_tensor = Tensor::from_vec(weight_data.to_vec(), &[out_features, in_features])
152 .expect("Failed to create weight tensor for quantization");
153
154 let weight =
155 quantize_tensor(&weight_tensor, quant_type).expect("Failed to quantize weight tensor");
156
157 let block_size = quant_type.block_size();
158 let blocks_per_row = in_features.div_ceil(block_size);
159
160 QuantizedLinear {
161 weight,
162 bias: bias_data.map(|b| b.to_vec()),
163 in_features,
164 out_features,
165 quant_type,
166 blocks_per_row,
167 }
168 }
169
170 pub fn forward_f32(&self, input: &[f32], batch_size: usize) -> Vec<f32> {
179 let mut output = vec![0.0f32; batch_size * self.out_features];
180
181 if !self.quant_type.is_block_quantized() {
183 let weight_flat = self.extract_flat_weights();
184 output
185 .par_chunks_mut(self.out_features)
186 .enumerate()
187 .for_each(|(b, out_row)| {
188 let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
189 for o in 0..self.out_features {
190 let w_start = o * self.in_features;
191 let mut sum = 0.0f32;
192 for k in 0..self.in_features {
193 sum += weight_flat[w_start + k] * input_row[k];
194 }
195 if let Some(ref bias) = self.bias {
196 sum += bias[o];
197 }
198 out_row[o] = sum;
199 }
200 });
201 return output;
202 }
203
204 let block_size = self.quant_type.block_size();
206
207 output
208 .par_chunks_mut(self.out_features)
209 .enumerate()
210 .for_each(|(b, out_row)| {
211 let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
212
213 for o in 0..self.out_features {
214 let row_block_start = o * self.blocks_per_row;
215 let mut sum = 0.0f32;
216
217 for blk_idx in 0..self.blocks_per_row {
218 let act_start = blk_idx * block_size;
219 let act_end = (act_start + block_size).min(self.in_features);
220 let act_slice = &input_row[act_start..act_end];
221
222 let block = &self.weight.blocks[row_block_start + blk_idx];
223 sum += dot_block(block, act_slice);
224 }
225
226 if let Some(ref bias) = self.bias {
227 sum += bias[o];
228 }
229
230 out_row[o] = sum;
231 }
232 });
233
234 output
235 }
236
237 fn extract_flat_weights(&self) -> Vec<f32> {
239 let mut flat = Vec::with_capacity(self.in_features * self.out_features);
240 for block in &self.weight.blocks {
241 match block {
242 QuantizedBlock::F16(data) => {
243 flat.extend(data.iter().map(|v| f32::from(*v)));
244 }
245 QuantizedBlock::F32(data) => {
246 flat.extend_from_slice(data);
247 }
248 _ => {} }
250 }
251 flat
252 }
253
254 pub fn forward_var(&self, input: &axonml_autograd::Variable) -> axonml_autograd::Variable {
259 let shape = input.shape();
260 let batch = if shape.len() > 1 { shape[0] } else { 1 };
261 let input_data = input.data().to_vec();
262
263 let output_data = self.forward_f32(&input_data, batch);
264
265 let output_tensor = Tensor::from_vec(output_data, &[batch, self.out_features])
266 .expect("Failed to create output tensor");
267
268 axonml_autograd::Variable::new(output_tensor, false)
269 }
270
271 pub fn weight_bytes(&self) -> usize {
273 self.weight.size_bytes()
274 }
275
276 pub fn compression_ratio(&self) -> f32 {
278 self.weight.compression_ratio()
279 }
280
281 pub fn dequantize_weights(&self) -> Tensor<f32> {
283 dequantize_tensor(&self.weight).expect("Failed to dequantize weights")
284 }
285}
286
287pub fn quantize_parameters(
296 params: &[axonml_nn::Parameter],
297 quant_type: QuantType,
298) -> Vec<QuantizedTensor> {
299 params
300 .par_iter()
301 .map(|param| {
302 let tensor = param.data();
303 quantize_tensor(&tensor, quant_type).expect("Failed to quantize parameter")
304 })
305 .collect()
306}
307
308pub struct QuantizedModel {
327 pub quantized_params: Vec<QuantizedTensor>,
329 pub quant_type: QuantType,
331 pub total_params: usize,
333 pub total_bytes: usize,
335 pub original_bytes: usize,
337}
338
339impl QuantizedModel {
340 pub fn from_module<M: axonml_nn::Module>(module: &M, quant_type: QuantType) -> Self {
342 let params = module.parameters();
343 let total_params: usize = params.iter().map(|p| p.numel()).sum();
344 let original_bytes = total_params * 4;
345
346 let quantized_params = quantize_parameters(¶ms, quant_type);
347
348 let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
349
350 QuantizedModel {
351 quantized_params,
352 quant_type,
353 total_params,
354 total_bytes,
355 original_bytes,
356 }
357 }
358
359 pub fn load_into_module<M: axonml_nn::Module>(&self, module: &M) {
365 let params = module.parameters();
366 for (param, qparam) in params.iter().zip(self.quantized_params.iter()) {
367 let tensor = dequantize_tensor(qparam).expect("Failed to dequantize parameter");
368 param.update_data(tensor);
369 }
370 }
371
372 pub fn compression_ratio(&self) -> f32 {
374 self.original_bytes as f32 / self.total_bytes as f32
375 }
376
377 pub fn summary(&self) -> String {
379 format!(
380 "QuantizedModel(type={}, params={}, f32={:.1}MB, quant={:.1}MB, ratio={:.1}x)",
381 self.quant_type,
382 self.total_params,
383 self.original_bytes as f64 / 1024.0 / 1024.0,
384 self.total_bytes as f64 / 1024.0 / 1024.0,
385 self.compression_ratio(),
386 )
387 }
388}
389
390pub fn serialize_quantized(model: &QuantizedModel) -> Vec<u8> {
396 let mut buf = Vec::new();
397
398 buf.extend_from_slice(b"AXQT");
400 buf.push(1u8);
402 buf.push(match model.quant_type {
404 QuantType::Q8_0 => 0,
405 QuantType::Q4_0 => 1,
406 QuantType::Q4_1 => 2,
407 QuantType::Q5_0 => 3,
408 QuantType::Q5_1 => 4,
409 QuantType::F16 => 5,
410 QuantType::F32 => 6,
411 });
412 buf.extend_from_slice(&(model.quantized_params.len() as u32).to_le_bytes());
414 buf.extend_from_slice(&(model.total_params as u64).to_le_bytes());
416
417 for qt in &model.quantized_params {
419 buf.extend_from_slice(&(qt.shape.len() as u32).to_le_bytes());
421 for &dim in &qt.shape {
422 buf.extend_from_slice(&(dim as u32).to_le_bytes());
423 }
424 buf.extend_from_slice(&(qt.blocks.len() as u32).to_le_bytes());
426 for block in &qt.blocks {
428 match block {
429 QuantizedBlock::Q8(b) => {
430 buf.extend_from_slice(&b.to_bytes());
431 }
432 QuantizedBlock::Q4(b) => {
433 buf.extend_from_slice(&b.to_bytes());
434 }
435 QuantizedBlock::Q4_1(b) => {
436 buf.extend_from_slice(&b.to_bytes());
437 }
438 QuantizedBlock::Q5(b) => {
439 buf.extend_from_slice(&b.to_bytes());
440 }
441 QuantizedBlock::Q5_1(b) => {
442 buf.extend_from_slice(&b.to_bytes());
443 }
444 QuantizedBlock::F16(data) => {
445 for &v in data {
446 buf.extend_from_slice(&v.to_le_bytes());
447 }
448 }
449 QuantizedBlock::F32(data) => {
450 for &v in data {
451 buf.extend_from_slice(&v.to_le_bytes());
452 }
453 }
454 }
455 }
456 }
457
458 buf
459}
460
461pub fn deserialize_quantized(data: &[u8]) -> Result<QuantizedModel, String> {
463 if data.len() < 18 || &data[0..4] != b"AXQT" {
464 return Err("Invalid quantized model file (bad magic)".to_string());
465 }
466
467 let version = data[4];
468 if version != 1 {
469 return Err(format!("Unsupported quantized model version: {version}"));
470 }
471
472 let quant_type = match data[5] {
473 0 => QuantType::Q8_0,
474 1 => QuantType::Q4_0,
475 2 => QuantType::Q4_1,
476 3 => QuantType::Q5_0,
477 4 => QuantType::Q5_1,
478 5 => QuantType::F16,
479 6 => QuantType::F32,
480 x => return Err(format!("Unknown quant type byte: {x}")),
481 };
482
483 let num_tensors = u32::from_le_bytes([data[6], data[7], data[8], data[9]]) as usize;
484 let total_params = u64::from_le_bytes([
485 data[10], data[11], data[12], data[13], data[14], data[15], data[16], data[17],
486 ]) as usize;
487
488 let mut offset = 18usize;
489 let mut quantized_params = Vec::with_capacity(num_tensors);
490
491 let block_bytes = quant_type.bytes_per_block();
492
493 for _ in 0..num_tensors {
494 if offset + 4 > data.len() {
495 return Err("Truncated quantized model file".to_string());
496 }
497
498 let shape_len = u32::from_le_bytes([
500 data[offset],
501 data[offset + 1],
502 data[offset + 2],
503 data[offset + 3],
504 ]) as usize;
505 offset += 4;
506
507 let mut shape = Vec::with_capacity(shape_len);
508 for _ in 0..shape_len {
509 let dim = u32::from_le_bytes([
510 data[offset],
511 data[offset + 1],
512 data[offset + 2],
513 data[offset + 3],
514 ]) as usize;
515 shape.push(dim);
516 offset += 4;
517 }
518
519 let num_blocks = u32::from_le_bytes([
521 data[offset],
522 data[offset + 1],
523 data[offset + 2],
524 data[offset + 3],
525 ]) as usize;
526 offset += 4;
527
528 let mut blocks = Vec::with_capacity(num_blocks);
530 for _ in 0..num_blocks {
531 if offset + block_bytes > data.len() {
532 return Err("Truncated block data".to_string());
533 }
534
535 let block = match quant_type {
536 QuantType::Q8_0 => {
537 let b =
538 Q8Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q8 block")?;
539 QuantizedBlock::Q8(b)
540 }
541 QuantType::Q4_0 => {
542 let b =
543 Q4Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q4 block")?;
544 QuantizedBlock::Q4(b)
545 }
546 QuantType::Q4_1 => {
547 let scale = f16::from_le_bytes([data[offset], data[offset + 1]]);
548 let min = f16::from_le_bytes([data[offset + 2], data[offset + 3]]);
549 let mut block_data = [0u8; 16];
550 block_data.copy_from_slice(&data[offset + 4..offset + 20]);
551 QuantizedBlock::Q4_1(Q4_1Block::new(scale, min, block_data))
552 }
553 QuantType::F16 => {
554 let v = f16::from_le_bytes([data[offset], data[offset + 1]]);
555 QuantizedBlock::F16(vec![v])
556 }
557 QuantType::F32 => {
558 let v = f32::from_le_bytes([
559 data[offset],
560 data[offset + 1],
561 data[offset + 2],
562 data[offset + 3],
563 ]);
564 QuantizedBlock::F32(vec![v])
565 }
566 _ => return Err("Unsupported quant type for deserialization".to_string()),
567 };
568
569 blocks.push(block);
570 offset += block_bytes;
571 }
572
573 quantized_params.push(QuantizedTensor::new(shape, quant_type, blocks));
574 }
575
576 let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
577 let original_bytes = total_params * 4;
578
579 Ok(QuantizedModel {
580 quantized_params,
581 quant_type,
582 total_params,
583 total_bytes,
584 original_bytes,
585 })
586}
587
588#[cfg(test)]
593mod tests {
594 use super::*;
595
596 #[test]
597 fn test_quantized_linear_q8_forward() {
598 let in_f = 64;
599 let out_f = 16;
600 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 5.0).collect();
601 let bias: Vec<f32> = (0..out_f).map(|i| i as f32 * 0.1).collect();
602
603 let ql =
604 QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
605
606 let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
607 let output = ql.forward_f32(&input, 1);
608
609 assert_eq!(output.len(), out_f);
610 let sum: f32 = output.iter().sum();
612 assert!(sum.abs() > 0.01, "Output should be non-zero, got sum={sum}");
613
614 let ref_ql =
616 QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::F32);
617 let ref_output = ref_ql.forward_f32(&input, 1);
618
619 let max_err: f32 = output
621 .iter()
622 .zip(ref_output.iter())
623 .map(|(a, b)| (a - b).abs())
624 .fold(0.0f32, f32::max);
625 assert!(max_err < 1.0, "Q8 error too large: {max_err}");
626 }
627
628 #[test]
629 fn test_quantized_linear_q4_forward() {
630 let in_f = 64;
631 let out_f = 8;
632 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.02) - 5.0).collect();
633
634 let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
635
636 let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
637 let output = ql.forward_f32(&input, 1);
638
639 assert_eq!(output.len(), out_f);
640 let sum: f32 = output.iter().sum();
641 assert!(sum.abs() > 0.01, "Output should be non-zero");
642 }
643
644 #[test]
645 fn test_quantized_linear_batch() {
646 let in_f = 32;
647 let out_f = 8;
648 let batch = 4;
649 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
650
651 let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
652
653 let input: Vec<f32> = (0..batch * in_f).map(|i| i as f32 * 0.01).collect();
654 let output = ql.forward_f32(&input, batch);
655
656 assert_eq!(output.len(), batch * out_f);
657 }
658
659 #[test]
660 fn test_quantized_linear_compression() {
661 let in_f = 1024;
662 let out_f = 512;
663 let weight: Vec<f32> = vec![0.1; in_f * out_f];
664
665 let ql_q8 =
666 QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
667 let ql_q4 =
668 QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
669
670 assert!(ql_q8.compression_ratio() > 3.5, "Q8 should compress ~4x");
671 assert!(ql_q4.compression_ratio() > 6.0, "Q4 should compress ~7-8x");
672 }
673
674 #[test]
675 fn test_quantized_model_roundtrip() {
676 let in_f = 32;
677 let out_f = 8;
678 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
679 let weight_tensor = Tensor::from_vec(weight.clone(), &[out_f, in_f]).unwrap();
680 let qt = quantize_tensor(&weight_tensor, QuantType::Q8_0).unwrap();
681
682 let model = QuantizedModel {
683 quantized_params: vec![qt],
684 quant_type: QuantType::Q8_0,
685 total_params: in_f * out_f,
686 total_bytes: 0,
687 original_bytes: in_f * out_f * 4,
688 };
689
690 let serialized = serialize_quantized(&model);
691 let deserialized = deserialize_quantized(&serialized).unwrap();
692
693 assert_eq!(deserialized.quant_type, QuantType::Q8_0);
694 assert_eq!(deserialized.quantized_params.len(), 1);
695 assert_eq!(deserialized.quantized_params[0].shape, vec![out_f, in_f]);
696 }
697
698 #[test]
699 fn test_quantized_linear_variable_forward() {
700 let in_f = 32;
701 let out_f = 8;
702 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
703 let bias: Vec<f32> = vec![0.5; out_f];
704
705 let ql =
706 QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
707
708 let input_tensor =
709 Tensor::from_vec((0..2 * in_f).map(|i| i as f32 * 0.1).collect(), &[2, in_f]).unwrap();
710 let input_var = axonml_autograd::Variable::new(input_tensor, false);
711
712 let output = ql.forward_var(&input_var);
713
714 assert_eq!(output.shape(), vec![2, out_f]);
715 }
716}