1use crate::dequantize::dequantize_tensor;
22use crate::quantize::quantize_tensor;
23use crate::types::{Q4_1Block, Q4Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
24use axonml_tensor::Tensor;
25use half::f16;
26use rayon::prelude::*;
27
28#[inline]
35fn dot_q8_block(block: &Q8Block, activations: &[f32]) -> f32 {
36 let scale = f32::from(block.scale);
37 let mut sum = 0.0f32;
38 for (d, a) in block.data.iter().zip(activations.iter()) {
39 sum += (*d as f32) * a;
40 }
41 sum * scale
42}
43
44#[inline]
46fn dot_q4_block(block: &Q4Block, activations: &[f32]) -> f32 {
47 let scale = f32::from(block.scale);
48 let unpacked = block.unpack();
49 let mut sum = 0.0f32;
50 for i in 0..unpacked.len().min(activations.len()) {
51 sum += (unpacked[i] as f32) * activations[i];
52 }
53 sum * scale
54}
55
56#[inline]
58fn dot_q4_1_block(block: &Q4_1Block, activations: &[f32]) -> f32 {
59 let scale = f32::from(block.scale);
60 let min = f32::from(block.min);
61 let unpacked = block.unpack();
62 let mut sum = 0.0f32;
63 for i in 0..unpacked.len().min(activations.len()) {
64 sum += (unpacked[i] as f32 * scale + min) * activations[i];
65 }
66 sum
67}
68
69#[inline]
71fn dot_f16_block(data: &[f16], activations: &[f32]) -> f32 {
72 let mut sum = 0.0f32;
73 for i in 0..data.len().min(activations.len()) {
74 sum += f32::from(data[i]) * activations[i];
75 }
76 sum
77}
78
79#[inline]
81fn dot_block(block: &QuantizedBlock, activations: &[f32]) -> f32 {
82 match block {
83 QuantizedBlock::Q8(b) => dot_q8_block(b, activations),
84 QuantizedBlock::Q4(b) => dot_q4_block(b, activations),
85 QuantizedBlock::Q4_1(b) => dot_q4_1_block(b, activations),
86 QuantizedBlock::F16(data) => dot_f16_block(data, activations),
87 QuantizedBlock::F32(data) => {
88 let mut sum = 0.0f32;
89 for i in 0..data.len().min(activations.len()) {
90 sum += data[i] * activations[i];
91 }
92 sum
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
115pub struct QuantizedLinear {
116 weight: QuantizedTensor,
118 bias: Option<Vec<f32>>,
120 pub in_features: usize,
122 pub out_features: usize,
124 pub quant_type: QuantType,
126 blocks_per_row: usize,
128}
129
130impl QuantizedLinear {
131 pub fn from_linear_params(
133 weight_data: &[f32],
134 bias_data: Option<&[f32]>,
135 in_features: usize,
136 out_features: usize,
137 quant_type: QuantType,
138 ) -> Self {
139 let weight_tensor = Tensor::from_vec(weight_data.to_vec(), &[out_features, in_features])
141 .expect("Failed to create weight tensor for quantization");
142
143 let weight =
144 quantize_tensor(&weight_tensor, quant_type).expect("Failed to quantize weight tensor");
145
146 let block_size = quant_type.block_size();
147 let blocks_per_row = in_features.div_ceil(block_size);
148
149 QuantizedLinear {
150 weight,
151 bias: bias_data.map(|b| b.to_vec()),
152 in_features,
153 out_features,
154 quant_type,
155 blocks_per_row,
156 }
157 }
158
159 pub fn forward_f32(&self, input: &[f32], batch_size: usize) -> Vec<f32> {
168 let mut output = vec![0.0f32; batch_size * self.out_features];
169
170 if !self.quant_type.is_block_quantized() {
172 let weight_flat = self.extract_flat_weights();
173 output
174 .par_chunks_mut(self.out_features)
175 .enumerate()
176 .for_each(|(b, out_row)| {
177 let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
178 for o in 0..self.out_features {
179 let w_start = o * self.in_features;
180 let mut sum = 0.0f32;
181 for k in 0..self.in_features {
182 sum += weight_flat[w_start + k] * input_row[k];
183 }
184 if let Some(ref bias) = self.bias {
185 sum += bias[o];
186 }
187 out_row[o] = sum;
188 }
189 });
190 return output;
191 }
192
193 let block_size = self.quant_type.block_size();
195
196 output
197 .par_chunks_mut(self.out_features)
198 .enumerate()
199 .for_each(|(b, out_row)| {
200 let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
201
202 for o in 0..self.out_features {
203 let row_block_start = o * self.blocks_per_row;
204 let mut sum = 0.0f32;
205
206 for blk_idx in 0..self.blocks_per_row {
207 let act_start = blk_idx * block_size;
208 let act_end = (act_start + block_size).min(self.in_features);
209 let act_slice = &input_row[act_start..act_end];
210
211 let block = &self.weight.blocks[row_block_start + blk_idx];
212 sum += dot_block(block, act_slice);
213 }
214
215 if let Some(ref bias) = self.bias {
216 sum += bias[o];
217 }
218
219 out_row[o] = sum;
220 }
221 });
222
223 output
224 }
225
226 fn extract_flat_weights(&self) -> Vec<f32> {
228 let mut flat = Vec::with_capacity(self.in_features * self.out_features);
229 for block in &self.weight.blocks {
230 match block {
231 QuantizedBlock::F16(data) => {
232 flat.extend(data.iter().map(|v| f32::from(*v)));
233 }
234 QuantizedBlock::F32(data) => {
235 flat.extend_from_slice(data);
236 }
237 _ => {} }
239 }
240 flat
241 }
242
243 pub fn forward_var(&self, input: &axonml_autograd::Variable) -> axonml_autograd::Variable {
248 let shape = input.shape();
249 let batch = if shape.len() > 1 { shape[0] } else { 1 };
250 let input_data = input.data().to_vec();
251
252 let output_data = self.forward_f32(&input_data, batch);
253
254 let output_tensor = Tensor::from_vec(output_data, &[batch, self.out_features])
255 .expect("Failed to create output tensor");
256
257 axonml_autograd::Variable::new(output_tensor, false)
258 }
259
260 pub fn weight_bytes(&self) -> usize {
262 self.weight.size_bytes()
263 }
264
265 pub fn compression_ratio(&self) -> f32 {
267 self.weight.compression_ratio()
268 }
269
270 pub fn dequantize_weights(&self) -> Tensor<f32> {
272 dequantize_tensor(&self.weight).expect("Failed to dequantize weights")
273 }
274}
275
276pub fn quantize_parameters(
285 params: &[axonml_nn::Parameter],
286 quant_type: QuantType,
287) -> Vec<QuantizedTensor> {
288 params
289 .par_iter()
290 .map(|param| {
291 let tensor = param.data();
292 quantize_tensor(&tensor, quant_type).expect("Failed to quantize parameter")
293 })
294 .collect()
295}
296
297pub struct QuantizedModel {
316 pub quantized_params: Vec<QuantizedTensor>,
318 pub quant_type: QuantType,
320 pub total_params: usize,
322 pub total_bytes: usize,
324 pub original_bytes: usize,
326}
327
328impl QuantizedModel {
329 pub fn from_module<M: axonml_nn::Module>(module: &M, quant_type: QuantType) -> Self {
331 let params = module.parameters();
332 let total_params: usize = params.iter().map(|p| p.numel()).sum();
333 let original_bytes = total_params * 4;
334
335 let quantized_params = quantize_parameters(¶ms, quant_type);
336
337 let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
338
339 QuantizedModel {
340 quantized_params,
341 quant_type,
342 total_params,
343 total_bytes,
344 original_bytes,
345 }
346 }
347
348 pub fn load_into_module<M: axonml_nn::Module>(&self, module: &M) {
354 let params = module.parameters();
355 for (param, qparam) in params.iter().zip(self.quantized_params.iter()) {
356 let tensor = dequantize_tensor(qparam).expect("Failed to dequantize parameter");
357 param.update_data(tensor);
358 }
359 }
360
361 pub fn compression_ratio(&self) -> f32 {
363 self.original_bytes as f32 / self.total_bytes as f32
364 }
365
366 pub fn summary(&self) -> String {
368 format!(
369 "QuantizedModel(type={}, params={}, f32={:.1}MB, quant={:.1}MB, ratio={:.1}x)",
370 self.quant_type,
371 self.total_params,
372 self.original_bytes as f64 / 1024.0 / 1024.0,
373 self.total_bytes as f64 / 1024.0 / 1024.0,
374 self.compression_ratio(),
375 )
376 }
377}
378
379pub fn serialize_quantized(model: &QuantizedModel) -> Vec<u8> {
385 let mut buf = Vec::new();
386
387 buf.extend_from_slice(b"AXQT");
389 buf.push(1u8);
391 buf.push(match model.quant_type {
393 QuantType::Q8_0 => 0,
394 QuantType::Q4_0 => 1,
395 QuantType::Q4_1 => 2,
396 QuantType::Q5_0 => 3,
397 QuantType::Q5_1 => 4,
398 QuantType::F16 => 5,
399 QuantType::F32 => 6,
400 });
401 buf.extend_from_slice(&(model.quantized_params.len() as u32).to_le_bytes());
403 buf.extend_from_slice(&(model.total_params as u64).to_le_bytes());
405
406 for qt in &model.quantized_params {
408 buf.extend_from_slice(&(qt.shape.len() as u32).to_le_bytes());
410 for &dim in &qt.shape {
411 buf.extend_from_slice(&(dim as u32).to_le_bytes());
412 }
413 buf.extend_from_slice(&(qt.blocks.len() as u32).to_le_bytes());
415 for block in &qt.blocks {
417 match block {
418 QuantizedBlock::Q8(b) => {
419 buf.extend_from_slice(&b.to_bytes());
420 }
421 QuantizedBlock::Q4(b) => {
422 buf.extend_from_slice(&b.to_bytes());
423 }
424 QuantizedBlock::Q4_1(b) => {
425 buf.extend_from_slice(&b.to_bytes());
426 }
427 QuantizedBlock::F16(data) => {
428 for &v in data {
429 buf.extend_from_slice(&v.to_le_bytes());
430 }
431 }
432 QuantizedBlock::F32(data) => {
433 for &v in data {
434 buf.extend_from_slice(&v.to_le_bytes());
435 }
436 }
437 }
438 }
439 }
440
441 buf
442}
443
444pub fn deserialize_quantized(data: &[u8]) -> Result<QuantizedModel, String> {
446 if data.len() < 18 || &data[0..4] != b"AXQT" {
447 return Err("Invalid quantized model file (bad magic)".to_string());
448 }
449
450 let version = data[4];
451 if version != 1 {
452 return Err(format!("Unsupported quantized model version: {version}"));
453 }
454
455 let quant_type = match data[5] {
456 0 => QuantType::Q8_0,
457 1 => QuantType::Q4_0,
458 2 => QuantType::Q4_1,
459 3 => QuantType::Q5_0,
460 4 => QuantType::Q5_1,
461 5 => QuantType::F16,
462 6 => QuantType::F32,
463 x => return Err(format!("Unknown quant type byte: {x}")),
464 };
465
466 let num_tensors = u32::from_le_bytes([data[6], data[7], data[8], data[9]]) as usize;
467 let total_params = u64::from_le_bytes([
468 data[10], data[11], data[12], data[13], data[14], data[15], data[16], data[17],
469 ]) as usize;
470
471 let mut offset = 18usize;
472 let mut quantized_params = Vec::with_capacity(num_tensors);
473
474 let block_bytes = quant_type.bytes_per_block();
475
476 for _ in 0..num_tensors {
477 if offset + 4 > data.len() {
478 return Err("Truncated quantized model file".to_string());
479 }
480
481 let shape_len = u32::from_le_bytes([
483 data[offset],
484 data[offset + 1],
485 data[offset + 2],
486 data[offset + 3],
487 ]) as usize;
488 offset += 4;
489
490 let mut shape = Vec::with_capacity(shape_len);
491 for _ in 0..shape_len {
492 let dim = u32::from_le_bytes([
493 data[offset],
494 data[offset + 1],
495 data[offset + 2],
496 data[offset + 3],
497 ]) as usize;
498 shape.push(dim);
499 offset += 4;
500 }
501
502 let num_blocks = u32::from_le_bytes([
504 data[offset],
505 data[offset + 1],
506 data[offset + 2],
507 data[offset + 3],
508 ]) as usize;
509 offset += 4;
510
511 let mut blocks = Vec::with_capacity(num_blocks);
513 for _ in 0..num_blocks {
514 if offset + block_bytes > data.len() {
515 return Err("Truncated block data".to_string());
516 }
517
518 let block = match quant_type {
519 QuantType::Q8_0 => {
520 let b =
521 Q8Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q8 block")?;
522 QuantizedBlock::Q8(b)
523 }
524 QuantType::Q4_0 => {
525 let b =
526 Q4Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q4 block")?;
527 QuantizedBlock::Q4(b)
528 }
529 QuantType::Q4_1 => {
530 let scale = f16::from_le_bytes([data[offset], data[offset + 1]]);
531 let min = f16::from_le_bytes([data[offset + 2], data[offset + 3]]);
532 let mut block_data = [0u8; 16];
533 block_data.copy_from_slice(&data[offset + 4..offset + 20]);
534 QuantizedBlock::Q4_1(Q4_1Block::new(scale, min, block_data))
535 }
536 QuantType::F16 => {
537 let v = f16::from_le_bytes([data[offset], data[offset + 1]]);
538 QuantizedBlock::F16(vec![v])
539 }
540 QuantType::F32 => {
541 let v = f32::from_le_bytes([
542 data[offset],
543 data[offset + 1],
544 data[offset + 2],
545 data[offset + 3],
546 ]);
547 QuantizedBlock::F32(vec![v])
548 }
549 _ => return Err("Unsupported quant type for deserialization".to_string()),
550 };
551
552 blocks.push(block);
553 offset += block_bytes;
554 }
555
556 quantized_params.push(QuantizedTensor::new(shape, quant_type, blocks));
557 }
558
559 let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
560 let original_bytes = total_params * 4;
561
562 Ok(QuantizedModel {
563 quantized_params,
564 quant_type,
565 total_params,
566 total_bytes,
567 original_bytes,
568 })
569}
570
571#[cfg(test)]
576mod tests {
577 use super::*;
578
579 #[test]
580 fn test_quantized_linear_q8_forward() {
581 let in_f = 64;
582 let out_f = 16;
583 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 5.0).collect();
584 let bias: Vec<f32> = (0..out_f).map(|i| i as f32 * 0.1).collect();
585
586 let ql =
587 QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
588
589 let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
590 let output = ql.forward_f32(&input, 1);
591
592 assert_eq!(output.len(), out_f);
593 let sum: f32 = output.iter().sum();
595 assert!(sum.abs() > 0.01, "Output should be non-zero, got sum={sum}");
596
597 let ref_ql =
599 QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::F32);
600 let ref_output = ref_ql.forward_f32(&input, 1);
601
602 let max_err: f32 = output
604 .iter()
605 .zip(ref_output.iter())
606 .map(|(a, b)| (a - b).abs())
607 .fold(0.0f32, f32::max);
608 assert!(max_err < 1.0, "Q8 error too large: {max_err}");
609 }
610
611 #[test]
612 fn test_quantized_linear_q4_forward() {
613 let in_f = 64;
614 let out_f = 8;
615 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.02) - 5.0).collect();
616
617 let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
618
619 let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
620 let output = ql.forward_f32(&input, 1);
621
622 assert_eq!(output.len(), out_f);
623 let sum: f32 = output.iter().sum();
624 assert!(sum.abs() > 0.01, "Output should be non-zero");
625 }
626
627 #[test]
628 fn test_quantized_linear_batch() {
629 let in_f = 32;
630 let out_f = 8;
631 let batch = 4;
632 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
633
634 let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
635
636 let input: Vec<f32> = (0..batch * in_f).map(|i| i as f32 * 0.01).collect();
637 let output = ql.forward_f32(&input, batch);
638
639 assert_eq!(output.len(), batch * out_f);
640 }
641
642 #[test]
643 fn test_quantized_linear_compression() {
644 let in_f = 1024;
645 let out_f = 512;
646 let weight: Vec<f32> = vec![0.1; in_f * out_f];
647
648 let ql_q8 =
649 QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
650 let ql_q4 =
651 QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
652
653 assert!(ql_q8.compression_ratio() > 3.5, "Q8 should compress ~4x");
654 assert!(ql_q4.compression_ratio() > 6.0, "Q4 should compress ~7-8x");
655 }
656
657 #[test]
658 fn test_quantized_model_roundtrip() {
659 let in_f = 32;
660 let out_f = 8;
661 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
662 let weight_tensor = Tensor::from_vec(weight.clone(), &[out_f, in_f]).unwrap();
663 let qt = quantize_tensor(&weight_tensor, QuantType::Q8_0).unwrap();
664
665 let model = QuantizedModel {
666 quantized_params: vec![qt],
667 quant_type: QuantType::Q8_0,
668 total_params: in_f * out_f,
669 total_bytes: 0,
670 original_bytes: in_f * out_f * 4,
671 };
672
673 let serialized = serialize_quantized(&model);
674 let deserialized = deserialize_quantized(&serialized).unwrap();
675
676 assert_eq!(deserialized.quant_type, QuantType::Q8_0);
677 assert_eq!(deserialized.quantized_params.len(), 1);
678 assert_eq!(deserialized.quantized_params[0].shape, vec![out_f, in_f]);
679 }
680
681 #[test]
682 fn test_quantized_linear_variable_forward() {
683 let in_f = 32;
684 let out_f = 8;
685 let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
686 let bias: Vec<f32> = vec![0.5; out_f];
687
688 let ql =
689 QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
690
691 let input_tensor =
692 Tensor::from_vec((0..2 * in_f).map(|i| i as f32 * 0.1).collect(), &[2, in_f]).unwrap();
693 let input_var = axonml_autograd::Variable::new(input_tensor, false);
694
695 let output = ql.forward_var(&input_var);
696
697 assert_eq!(output.shape(), vec![2, out_f]);
698 }
699}