Skip to main content

axonml_quant/
inference.rs

1//! Quantized Inference — fast inference with quantized weights
2//!
3//! # File
4//! `crates/axonml-quant/src/inference.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18// =============================================================================
19// Imports
20// =============================================================================
21
22use crate::dequantize::dequantize_tensor;
23use crate::quantize::quantize_tensor;
24use crate::types::{Q4_1Block, Q4Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
25use axonml_tensor::Tensor;
26use half::f16;
27use rayon::prelude::*;
28
29// =============================================================================
30// Quantized Matmul Kernels
31// =============================================================================
32
33/// Compute dot product of f32 activation vector with Q8_0 quantized weight row.
34/// This is the inner kernel — called per output element.
35#[inline]
36fn dot_q8_block(block: &Q8Block, activations: &[f32]) -> f32 {
37    let scale = f32::from(block.scale);
38    let mut sum = 0.0f32;
39    for (d, a) in block.data.iter().zip(activations.iter()) {
40        sum += (*d as f32) * a;
41    }
42    sum * scale
43}
44
45/// Compute dot product of f32 activation vector with Q4_0 quantized weight row.
46#[inline]
47fn dot_q4_block(block: &Q4Block, activations: &[f32]) -> f32 {
48    let scale = f32::from(block.scale);
49    let unpacked = block.unpack();
50    let mut sum = 0.0f32;
51    for i in 0..unpacked.len().min(activations.len()) {
52        sum += (unpacked[i] as f32) * activations[i];
53    }
54    sum * scale
55}
56
57/// Compute dot product of f32 activation vector with Q4_1 quantized weight row.
58#[inline]
59fn dot_q4_1_block(block: &Q4_1Block, activations: &[f32]) -> f32 {
60    let scale = f32::from(block.scale);
61    let min = f32::from(block.min);
62    let unpacked = block.unpack();
63    let mut sum = 0.0f32;
64    for i in 0..unpacked.len().min(activations.len()) {
65        sum += (unpacked[i] as f32 * scale + min) * activations[i];
66    }
67    sum
68}
69
70/// Compute dot product of f32 activation vector with F16 weight block.
71#[inline]
72fn dot_f16_block(data: &[f16], activations: &[f32]) -> f32 {
73    let mut sum = 0.0f32;
74    for i in 0..data.len().min(activations.len()) {
75        sum += f32::from(data[i]) * activations[i];
76    }
77    sum
78}
79
80/// Dot product of an activation vector against a single quantized block.
81#[inline]
82fn dot_block(block: &QuantizedBlock, activations: &[f32]) -> f32 {
83    match block {
84        QuantizedBlock::Q8(b) => dot_q8_block(b, activations),
85        QuantizedBlock::Q4(b) => dot_q4_block(b, activations),
86        QuantizedBlock::Q4_1(b) => dot_q4_1_block(b, activations),
87        QuantizedBlock::Q5(b) => {
88            let scale = b.scale.to_f32();
89            let values = b.unpack();
90            values
91                .iter()
92                .zip(activations)
93                .map(|(&v, &a)| v as f32 * scale * a)
94                .sum()
95        }
96        QuantizedBlock::Q5_1(b) => {
97            let scale = b.scale.to_f32();
98            let min = b.min.to_f32();
99            let values = b.unpack();
100            values
101                .iter()
102                .zip(activations)
103                .map(|(&v, &a)| (v as f32 * scale + min) * a)
104                .sum()
105        }
106        QuantizedBlock::F16(data) => dot_f16_block(data, activations),
107        QuantizedBlock::F32(data) => {
108            let mut sum = 0.0f32;
109            for i in 0..data.len().min(activations.len()) {
110                sum += data[i] * activations[i];
111            }
112            sum
113        }
114    }
115}
116
117// =============================================================================
118// QuantizedLinear — drop-in replacement for Linear
119// =============================================================================
120
121/// A linear layer with quantized weights for fast inference.
122///
123/// Stores weights as `QuantizedTensor` (Q8/Q4/F16) and dequantizes on-the-fly
124/// during the matrix multiply. Bias remains in f32.
125///
126/// # Usage
127/// ```ignore
128/// use axonml_quant::inference::QuantizedLinear;
129/// use axonml_quant::QuantType;
130///
131/// let qlinear = QuantizedLinear::from_linear_params(&weights, Some(&bias), 512, 128, QuantType::Q8_0);
132/// let output = qlinear.forward_f32(&input_data, 1);
133/// ```
134#[derive(Debug, Clone)]
135pub struct QuantizedLinear {
136    /// Quantized weight matrix (out_features x in_features, row-major).
137    weight: QuantizedTensor,
138    /// Bias vector (f32, not quantized).
139    bias: Option<Vec<f32>>,
140    /// Input feature dimension.
141    pub in_features: usize,
142    /// Output feature dimension.
143    pub out_features: usize,
144    /// Quantization type used.
145    pub quant_type: QuantType,
146    /// Number of blocks per output row (cached for fast access).
147    blocks_per_row: usize,
148}
149
150impl QuantizedLinear {
151    /// Create a QuantizedLinear from an axonml_nn::Linear layer.
152    pub fn from_linear_params(
153        weight_data: &[f32],
154        bias_data: Option<&[f32]>,
155        in_features: usize,
156        out_features: usize,
157        quant_type: QuantType,
158    ) -> Self {
159        // Weight shape is [out_features, in_features]
160        let weight_tensor = Tensor::from_vec(weight_data.to_vec(), &[out_features, in_features])
161            .expect("Failed to create weight tensor for quantization");
162
163        let weight =
164            quantize_tensor(&weight_tensor, quant_type).expect("Failed to quantize weight tensor");
165
166        let block_size = quant_type.block_size();
167        let blocks_per_row = in_features.div_ceil(block_size);
168
169        QuantizedLinear {
170            weight,
171            bias: bias_data.map(|b| b.to_vec()),
172            in_features,
173            out_features,
174            quant_type,
175            blocks_per_row,
176        }
177    }
178
179    /// Forward pass: f32 input → f32 output.
180    ///
181    /// Input shape: `[batch, in_features]`
182    /// Output shape: `[batch, out_features]`
183    ///
184    /// Performs quantized matrix multiplication: each output element is computed
185    /// by iterating over the weight row's quantized blocks and accumulating
186    /// dot products with the corresponding activation slices.
187    pub fn forward_f32(&self, input: &[f32], batch_size: usize) -> Vec<f32> {
188        let mut output = vec![0.0f32; batch_size * self.out_features];
189
190        // For non-block types (F16, F32), extract flat weight data and do direct matmul
191        if !self.quant_type.is_block_quantized() {
192            let weight_flat = self.extract_flat_weights();
193            output
194                .par_chunks_mut(self.out_features)
195                .enumerate()
196                .for_each(|(b, out_row)| {
197                    let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
198                    for o in 0..self.out_features {
199                        let w_start = o * self.in_features;
200                        let mut sum = 0.0f32;
201                        for k in 0..self.in_features {
202                            sum += weight_flat[w_start + k] * input_row[k];
203                        }
204                        if let Some(ref bias) = self.bias {
205                            sum += bias[o];
206                        }
207                        out_row[o] = sum;
208                    }
209                });
210            return output;
211        }
212
213        // Block-quantized path (Q8, Q4, Q4_1): iterate over blocks per weight row
214        let block_size = self.quant_type.block_size();
215
216        output
217            .par_chunks_mut(self.out_features)
218            .enumerate()
219            .for_each(|(b, out_row)| {
220                let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
221
222                for o in 0..self.out_features {
223                    let row_block_start = o * self.blocks_per_row;
224                    let mut sum = 0.0f32;
225
226                    for blk_idx in 0..self.blocks_per_row {
227                        let act_start = blk_idx * block_size;
228                        let act_end = (act_start + block_size).min(self.in_features);
229                        let act_slice = &input_row[act_start..act_end];
230
231                        let block = &self.weight.blocks[row_block_start + blk_idx];
232                        sum += dot_block(block, act_slice);
233                    }
234
235                    if let Some(ref bias) = self.bias {
236                        sum += bias[o];
237                    }
238
239                    out_row[o] = sum;
240                }
241            });
242
243        output
244    }
245
246    /// Extract flat f32 weights (for F16/F32 non-block types).
247    fn extract_flat_weights(&self) -> Vec<f32> {
248        let mut flat = Vec::with_capacity(self.in_features * self.out_features);
249        for block in &self.weight.blocks {
250            match block {
251                QuantizedBlock::F16(data) => {
252                    flat.extend(data.iter().map(|v| f32::from(*v)));
253                }
254                QuantizedBlock::F32(data) => {
255                    flat.extend_from_slice(data);
256                }
257                _ => {} // block-quantized handled separately
258            }
259        }
260        flat
261    }
262
263    /// Forward pass with Variable input/output (for integration with autograd).
264    ///
265    /// Note: Quantized inference is forward-only (no gradient tracking).
266    /// The output Variable has `requires_grad = false`.
267    pub fn forward_var(&self, input: &axonml_autograd::Variable) -> axonml_autograd::Variable {
268        let shape = input.shape();
269        let batch = if shape.len() > 1 { shape[0] } else { 1 };
270        let input_data = input.data().to_vec();
271
272        let output_data = self.forward_f32(&input_data, batch);
273
274        let output_tensor = Tensor::from_vec(output_data, &[batch, self.out_features])
275            .expect("Failed to create output tensor");
276
277        axonml_autograd::Variable::new(output_tensor, false)
278    }
279
280    /// Memory usage in bytes (weights only, excludes bias).
281    pub fn weight_bytes(&self) -> usize {
282        self.weight.size_bytes()
283    }
284
285    /// Compression ratio vs f32 weights.
286    pub fn compression_ratio(&self) -> f32 {
287        self.weight.compression_ratio()
288    }
289
290    /// Dequantize weights back to f32 (for debugging/validation).
291    pub fn dequantize_weights(&self) -> Tensor<f32> {
292        dequantize_tensor(&self.weight).expect("Failed to dequantize weights")
293    }
294}
295
296// =============================================================================
297// Model Quantization — convert a full model
298// =============================================================================
299
300/// Quantize all parameters of a model, returning a flat list of quantized tensors.
301///
302/// This extracts all parameters from a Module, quantizes each one, and returns
303/// them in order. Use with `QuantizedModel` for full inference.
304pub fn quantize_parameters(
305    params: &[axonml_nn::Parameter],
306    quant_type: QuantType,
307) -> Vec<QuantizedTensor> {
308    params
309        .par_iter()
310        .map(|param| {
311            let tensor = param.data();
312            quantize_tensor(&tensor, quant_type).expect("Failed to quantize parameter")
313        })
314        .collect()
315}
316
317// =============================================================================
318// QuantizedModel — generic quantized inference wrapper
319// =============================================================================
320
321/// A fully quantized model for inference.
322///
323/// Wraps a collection of quantized parameters and provides fast forward pass
324/// by dequantizing weights on-the-fly during computation.
325///
326/// # Usage
327/// ```ignore
328/// use axonml_quant::inference::QuantizedModel;
329/// use axonml_quant::QuantType;
330///
331/// let qmodel = QuantizedModel::from_module(&model, QuantType::Q8_0);
332/// println!("{}", qmodel.summary());
333/// qmodel.load_into_module(&model); // dequant weights back for inference
334/// ```
335pub struct QuantizedModel {
336    /// Quantized weight tensors (in parameter order).
337    pub quantized_params: Vec<QuantizedTensor>,
338    /// Quantization type used.
339    pub quant_type: QuantType,
340    /// Total original parameter count.
341    pub total_params: usize,
342    /// Total quantized size in bytes.
343    pub total_bytes: usize,
344    /// Original f32 size in bytes.
345    pub original_bytes: usize,
346}
347
348impl QuantizedModel {
349    /// Quantize a Module's parameters.
350    pub fn from_module<M: axonml_nn::Module>(module: &M, quant_type: QuantType) -> Self {
351        let params = module.parameters();
352        let total_params: usize = params.iter().map(|p| p.numel()).sum();
353        let original_bytes = total_params * 4;
354
355        let quantized_params = quantize_parameters(&params, quant_type);
356
357        let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
358
359        QuantizedModel {
360            quantized_params,
361            quant_type,
362            total_params,
363            total_bytes,
364            original_bytes,
365        }
366    }
367
368    /// Load quantized weights back into a Module for inference.
369    ///
370    /// Dequantizes all parameters and updates the module's parameters in-place.
371    /// This is the simplest integration path — the model runs at full f32 speed
372    /// but loads from a compressed checkpoint.
373    pub fn load_into_module<M: axonml_nn::Module>(&self, module: &M) {
374        let params = module.parameters();
375        for (param, qparam) in params.iter().zip(self.quantized_params.iter()) {
376            let tensor = dequantize_tensor(qparam).expect("Failed to dequantize parameter");
377            param.update_data(tensor);
378        }
379    }
380
381    /// Compression ratio.
382    pub fn compression_ratio(&self) -> f32 {
383        self.original_bytes as f32 / self.total_bytes as f32
384    }
385
386    /// Print a summary of the quantized model.
387    pub fn summary(&self) -> String {
388        format!(
389            "QuantizedModel(type={}, params={}, f32={:.1}MB, quant={:.1}MB, ratio={:.1}x)",
390            self.quant_type,
391            self.total_params,
392            self.original_bytes as f64 / 1024.0 / 1024.0,
393            self.total_bytes as f64 / 1024.0 / 1024.0,
394            self.compression_ratio(),
395        )
396    }
397}
398
399// =============================================================================
400// Serialization — save/load quantized models
401// =============================================================================
402
403/// Serialize a QuantizedModel to bytes (for .axonml files).
404pub fn serialize_quantized(model: &QuantizedModel) -> Vec<u8> {
405    let mut buf = Vec::new();
406
407    // Magic: "AXQT" (AxonML Quantized)
408    buf.extend_from_slice(b"AXQT");
409    // Version
410    buf.push(1u8);
411    // Quant type
412    buf.push(match model.quant_type {
413        QuantType::Q8_0 => 0,
414        QuantType::Q4_0 => 1,
415        QuantType::Q4_1 => 2,
416        QuantType::Q5_0 => 3,
417        QuantType::Q5_1 => 4,
418        QuantType::F16 => 5,
419        QuantType::F32 => 6,
420    });
421    // Number of tensors
422    buf.extend_from_slice(&(model.quantized_params.len() as u32).to_le_bytes());
423    // Total params
424    buf.extend_from_slice(&(model.total_params as u64).to_le_bytes());
425
426    // Each tensor: shape_len + shape + num_blocks + blocks
427    for qt in &model.quantized_params {
428        // Shape
429        buf.extend_from_slice(&(qt.shape.len() as u32).to_le_bytes());
430        for &dim in &qt.shape {
431            buf.extend_from_slice(&(dim as u32).to_le_bytes());
432        }
433        // Number of blocks
434        buf.extend_from_slice(&(qt.blocks.len() as u32).to_le_bytes());
435        // Block data
436        for block in &qt.blocks {
437            match block {
438                QuantizedBlock::Q8(b) => {
439                    buf.extend_from_slice(&b.to_bytes());
440                }
441                QuantizedBlock::Q4(b) => {
442                    buf.extend_from_slice(&b.to_bytes());
443                }
444                QuantizedBlock::Q4_1(b) => {
445                    buf.extend_from_slice(&b.to_bytes());
446                }
447                QuantizedBlock::Q5(b) => {
448                    buf.extend_from_slice(&b.to_bytes());
449                }
450                QuantizedBlock::Q5_1(b) => {
451                    buf.extend_from_slice(&b.to_bytes());
452                }
453                QuantizedBlock::F16(data) => {
454                    for &v in data {
455                        buf.extend_from_slice(&v.to_le_bytes());
456                    }
457                }
458                QuantizedBlock::F32(data) => {
459                    for &v in data {
460                        buf.extend_from_slice(&v.to_le_bytes());
461                    }
462                }
463            }
464        }
465    }
466
467    buf
468}
469
470/// Deserialize a QuantizedModel from bytes.
471pub fn deserialize_quantized(data: &[u8]) -> Result<QuantizedModel, String> {
472    if data.len() < 18 || &data[0..4] != b"AXQT" {
473        return Err("Invalid quantized model file (bad magic)".to_string());
474    }
475
476    let version = data[4];
477    if version != 1 {
478        return Err(format!("Unsupported quantized model version: {version}"));
479    }
480
481    let quant_type = match data[5] {
482        0 => QuantType::Q8_0,
483        1 => QuantType::Q4_0,
484        2 => QuantType::Q4_1,
485        3 => QuantType::Q5_0,
486        4 => QuantType::Q5_1,
487        5 => QuantType::F16,
488        6 => QuantType::F32,
489        x => return Err(format!("Unknown quant type byte: {x}")),
490    };
491
492    let num_tensors = u32::from_le_bytes([data[6], data[7], data[8], data[9]]) as usize;
493    let total_params = u64::from_le_bytes([
494        data[10], data[11], data[12], data[13], data[14], data[15], data[16], data[17],
495    ]) as usize;
496
497    let mut offset = 18usize;
498    let mut quantized_params = Vec::with_capacity(num_tensors);
499
500    let block_bytes = quant_type.bytes_per_block();
501
502    for _ in 0..num_tensors {
503        if offset + 4 > data.len() {
504            return Err("Truncated quantized model file".to_string());
505        }
506
507        // Shape
508        let shape_len = u32::from_le_bytes([
509            data[offset],
510            data[offset + 1],
511            data[offset + 2],
512            data[offset + 3],
513        ]) as usize;
514        offset += 4;
515
516        let mut shape = Vec::with_capacity(shape_len);
517        for _ in 0..shape_len {
518            let dim = u32::from_le_bytes([
519                data[offset],
520                data[offset + 1],
521                data[offset + 2],
522                data[offset + 3],
523            ]) as usize;
524            shape.push(dim);
525            offset += 4;
526        }
527
528        // Number of blocks
529        let num_blocks = u32::from_le_bytes([
530            data[offset],
531            data[offset + 1],
532            data[offset + 2],
533            data[offset + 3],
534        ]) as usize;
535        offset += 4;
536
537        // Read blocks
538        let mut blocks = Vec::with_capacity(num_blocks);
539        for _ in 0..num_blocks {
540            if offset + block_bytes > data.len() {
541                return Err("Truncated block data".to_string());
542            }
543
544            let block = match quant_type {
545                QuantType::Q8_0 => {
546                    let b =
547                        Q8Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q8 block")?;
548                    QuantizedBlock::Q8(b)
549                }
550                QuantType::Q4_0 => {
551                    let b =
552                        Q4Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q4 block")?;
553                    QuantizedBlock::Q4(b)
554                }
555                QuantType::Q4_1 => {
556                    let scale = f16::from_le_bytes([data[offset], data[offset + 1]]);
557                    let min = f16::from_le_bytes([data[offset + 2], data[offset + 3]]);
558                    let mut block_data = [0u8; 16];
559                    block_data.copy_from_slice(&data[offset + 4..offset + 20]);
560                    QuantizedBlock::Q4_1(Q4_1Block::new(scale, min, block_data))
561                }
562                QuantType::F16 => {
563                    let v = f16::from_le_bytes([data[offset], data[offset + 1]]);
564                    QuantizedBlock::F16(vec![v])
565                }
566                QuantType::F32 => {
567                    let v = f32::from_le_bytes([
568                        data[offset],
569                        data[offset + 1],
570                        data[offset + 2],
571                        data[offset + 3],
572                    ]);
573                    QuantizedBlock::F32(vec![v])
574                }
575                _ => return Err("Unsupported quant type for deserialization".to_string()),
576            };
577
578            blocks.push(block);
579            offset += block_bytes;
580        }
581
582        quantized_params.push(QuantizedTensor::new(shape, quant_type, blocks));
583    }
584
585    let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
586    let original_bytes = total_params * 4;
587
588    Ok(QuantizedModel {
589        quantized_params,
590        quant_type,
591        total_params,
592        total_bytes,
593        original_bytes,
594    })
595}
596
597// =============================================================================
598// Tests
599// =============================================================================
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604
605    #[test]
606    fn test_quantized_linear_q8_forward() {
607        let in_f = 64;
608        let out_f = 16;
609        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 5.0).collect();
610        let bias: Vec<f32> = (0..out_f).map(|i| i as f32 * 0.1).collect();
611
612        let ql =
613            QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
614
615        let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
616        let output = ql.forward_f32(&input, 1);
617
618        assert_eq!(output.len(), out_f);
619        // Output should not be all zeros
620        let sum: f32 = output.iter().sum();
621        assert!(sum.abs() > 0.01, "Output should be non-zero, got sum={sum}");
622
623        // Compare with f32 reference
624        let ref_ql =
625            QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::F32);
626        let ref_output = ref_ql.forward_f32(&input, 1);
627
628        // Q8 should be very close to f32
629        let max_err: f32 = output
630            .iter()
631            .zip(ref_output.iter())
632            .map(|(a, b)| (a - b).abs())
633            .fold(0.0f32, f32::max);
634        assert!(max_err < 1.0, "Q8 error too large: {max_err}");
635    }
636
637    #[test]
638    fn test_quantized_linear_q4_forward() {
639        let in_f = 64;
640        let out_f = 8;
641        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.02) - 5.0).collect();
642
643        let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
644
645        let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
646        let output = ql.forward_f32(&input, 1);
647
648        assert_eq!(output.len(), out_f);
649        let sum: f32 = output.iter().sum();
650        assert!(sum.abs() > 0.01, "Output should be non-zero");
651    }
652
653    #[test]
654    fn test_quantized_linear_batch() {
655        let in_f = 32;
656        let out_f = 8;
657        let batch = 4;
658        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
659
660        let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
661
662        let input: Vec<f32> = (0..batch * in_f).map(|i| i as f32 * 0.01).collect();
663        let output = ql.forward_f32(&input, batch);
664
665        assert_eq!(output.len(), batch * out_f);
666    }
667
668    #[test]
669    fn test_quantized_linear_compression() {
670        let in_f = 1024;
671        let out_f = 512;
672        let weight: Vec<f32> = vec![0.1; in_f * out_f];
673
674        let ql_q8 =
675            QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
676        let ql_q4 =
677            QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
678
679        assert!(ql_q8.compression_ratio() > 3.5, "Q8 should compress ~4x");
680        assert!(ql_q4.compression_ratio() > 6.0, "Q4 should compress ~7-8x");
681    }
682
683    #[test]
684    fn test_quantized_model_roundtrip() {
685        let in_f = 32;
686        let out_f = 8;
687        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
688        let weight_tensor = Tensor::from_vec(weight.clone(), &[out_f, in_f]).unwrap();
689        let qt = quantize_tensor(&weight_tensor, QuantType::Q8_0).unwrap();
690
691        let model = QuantizedModel {
692            quantized_params: vec![qt],
693            quant_type: QuantType::Q8_0,
694            total_params: in_f * out_f,
695            total_bytes: 0,
696            original_bytes: in_f * out_f * 4,
697        };
698
699        let serialized = serialize_quantized(&model);
700        let deserialized = deserialize_quantized(&serialized).unwrap();
701
702        assert_eq!(deserialized.quant_type, QuantType::Q8_0);
703        assert_eq!(deserialized.quantized_params.len(), 1);
704        assert_eq!(deserialized.quantized_params[0].shape, vec![out_f, in_f]);
705    }
706
707    #[test]
708    fn test_quantized_linear_variable_forward() {
709        let in_f = 32;
710        let out_f = 8;
711        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
712        let bias: Vec<f32> = vec![0.5; out_f];
713
714        let ql =
715            QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
716
717        let input_tensor =
718            Tensor::from_vec((0..2 * in_f).map(|i| i as f32 * 0.1).collect(), &[2, in_f]).unwrap();
719        let input_var = axonml_autograd::Variable::new(input_tensor, false);
720
721        let output = ql.forward_var(&input_var);
722
723        assert_eq!(output.shape(), vec![2, out_f]);
724    }
725}