Skip to main content

axonml_quant/
inference.rs

1//! Quantized Inference — fast inference with quantized weights
2//!
3//! # File
4//! `crates/axonml-quant/src/inference.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 16, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17// =============================================================================
18// Imports
19// =============================================================================
20
21use crate::dequantize::dequantize_tensor;
22use crate::quantize::quantize_tensor;
23use crate::types::{Q4_1Block, Q4Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
24use axonml_tensor::Tensor;
25use half::f16;
26use rayon::prelude::*;
27
28// =============================================================================
29// Quantized Matmul Kernels
30// =============================================================================
31
32/// Compute dot product of f32 activation vector with Q8_0 quantized weight row.
33/// This is the inner kernel — called per output element.
34#[inline]
35fn dot_q8_block(block: &Q8Block, activations: &[f32]) -> f32 {
36    let scale = f32::from(block.scale);
37    let mut sum = 0.0f32;
38    for (d, a) in block.data.iter().zip(activations.iter()) {
39        sum += (*d as f32) * a;
40    }
41    sum * scale
42}
43
44/// Compute dot product of f32 activation vector with Q4_0 quantized weight row.
45#[inline]
46fn dot_q4_block(block: &Q4Block, activations: &[f32]) -> f32 {
47    let scale = f32::from(block.scale);
48    let unpacked = block.unpack();
49    let mut sum = 0.0f32;
50    for i in 0..unpacked.len().min(activations.len()) {
51        sum += (unpacked[i] as f32) * activations[i];
52    }
53    sum * scale
54}
55
56/// Compute dot product of f32 activation vector with Q4_1 quantized weight row.
57#[inline]
58fn dot_q4_1_block(block: &Q4_1Block, activations: &[f32]) -> f32 {
59    let scale = f32::from(block.scale);
60    let min = f32::from(block.min);
61    let unpacked = block.unpack();
62    let mut sum = 0.0f32;
63    for i in 0..unpacked.len().min(activations.len()) {
64        sum += (unpacked[i] as f32 * scale + min) * activations[i];
65    }
66    sum
67}
68
69/// Compute dot product of f32 activation vector with F16 weight block.
70#[inline]
71fn dot_f16_block(data: &[f16], activations: &[f32]) -> f32 {
72    let mut sum = 0.0f32;
73    for i in 0..data.len().min(activations.len()) {
74        sum += f32::from(data[i]) * activations[i];
75    }
76    sum
77}
78
79/// Dot product of an activation vector against a single quantized block.
80#[inline]
81fn dot_block(block: &QuantizedBlock, activations: &[f32]) -> f32 {
82    match block {
83        QuantizedBlock::Q8(b) => dot_q8_block(b, activations),
84        QuantizedBlock::Q4(b) => dot_q4_block(b, activations),
85        QuantizedBlock::Q4_1(b) => dot_q4_1_block(b, activations),
86        QuantizedBlock::Q5(b) => {
87            let scale = b.scale.to_f32();
88            let values = b.unpack();
89            values
90                .iter()
91                .zip(activations)
92                .map(|(&v, &a)| v as f32 * scale * a)
93                .sum()
94        }
95        QuantizedBlock::Q5_1(b) => {
96            let scale = b.scale.to_f32();
97            let min = b.min.to_f32();
98            let values = b.unpack();
99            values
100                .iter()
101                .zip(activations)
102                .map(|(&v, &a)| (v as f32 * scale + min) * a)
103                .sum()
104        }
105        QuantizedBlock::F16(data) => dot_f16_block(data, activations),
106        QuantizedBlock::F32(data) => {
107            let mut sum = 0.0f32;
108            for i in 0..data.len().min(activations.len()) {
109                sum += data[i] * activations[i];
110            }
111            sum
112        }
113    }
114}
115
116// =============================================================================
117// QuantizedLinear — drop-in replacement for Linear
118// =============================================================================
119
120/// A linear layer with quantized weights for fast inference.
121///
122/// Stores weights as `QuantizedTensor` (Q8/Q4/F16) and dequantizes on-the-fly
123/// during the matrix multiply. Bias remains in f32.
124///
125/// # Usage
126/// ```ignore
127/// use axonml_quant::inference::QuantizedLinear;
128/// use axonml_quant::QuantType;
129///
130/// let qlinear = QuantizedLinear::from_linear_params(&weights, Some(&bias), 512, 128, QuantType::Q8_0);
131/// let output = qlinear.forward_f32(&input_data, 1);
132/// ```
133#[derive(Debug, Clone)]
134pub struct QuantizedLinear {
135    /// Quantized weight matrix (out_features x in_features, row-major).
136    weight: QuantizedTensor,
137    /// Bias vector (f32, not quantized).
138    bias: Option<Vec<f32>>,
139    /// Input feature dimension.
140    pub in_features: usize,
141    /// Output feature dimension.
142    pub out_features: usize,
143    /// Quantization type used.
144    pub quant_type: QuantType,
145    /// Number of blocks per output row (cached for fast access).
146    blocks_per_row: usize,
147}
148
149impl QuantizedLinear {
150    /// Create a QuantizedLinear from an axonml_nn::Linear layer.
151    pub fn from_linear_params(
152        weight_data: &[f32],
153        bias_data: Option<&[f32]>,
154        in_features: usize,
155        out_features: usize,
156        quant_type: QuantType,
157    ) -> Self {
158        // Weight shape is [out_features, in_features]
159        let weight_tensor = Tensor::from_vec(weight_data.to_vec(), &[out_features, in_features])
160            .expect("Failed to create weight tensor for quantization");
161
162        let weight =
163            quantize_tensor(&weight_tensor, quant_type).expect("Failed to quantize weight tensor");
164
165        let block_size = quant_type.block_size();
166        let blocks_per_row = in_features.div_ceil(block_size);
167
168        QuantizedLinear {
169            weight,
170            bias: bias_data.map(|b| b.to_vec()),
171            in_features,
172            out_features,
173            quant_type,
174            blocks_per_row,
175        }
176    }
177
178    /// Forward pass: f32 input → f32 output.
179    ///
180    /// Input shape: `[batch, in_features]`
181    /// Output shape: `[batch, out_features]`
182    ///
183    /// Performs quantized matrix multiplication: each output element is computed
184    /// by iterating over the weight row's quantized blocks and accumulating
185    /// dot products with the corresponding activation slices.
186    pub fn forward_f32(&self, input: &[f32], batch_size: usize) -> Vec<f32> {
187        let mut output = vec![0.0f32; batch_size * self.out_features];
188
189        // For non-block types (F16, F32), extract flat weight data and do direct matmul
190        if !self.quant_type.is_block_quantized() {
191            let weight_flat = self.extract_flat_weights();
192            output
193                .par_chunks_mut(self.out_features)
194                .enumerate()
195                .for_each(|(b, out_row)| {
196                    let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
197                    for o in 0..self.out_features {
198                        let w_start = o * self.in_features;
199                        let mut sum = 0.0f32;
200                        for k in 0..self.in_features {
201                            sum += weight_flat[w_start + k] * input_row[k];
202                        }
203                        if let Some(ref bias) = self.bias {
204                            sum += bias[o];
205                        }
206                        out_row[o] = sum;
207                    }
208                });
209            return output;
210        }
211
212        // Block-quantized path (Q8, Q4, Q4_1): iterate over blocks per weight row
213        let block_size = self.quant_type.block_size();
214
215        output
216            .par_chunks_mut(self.out_features)
217            .enumerate()
218            .for_each(|(b, out_row)| {
219                let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
220
221                for o in 0..self.out_features {
222                    let row_block_start = o * self.blocks_per_row;
223                    let mut sum = 0.0f32;
224
225                    for blk_idx in 0..self.blocks_per_row {
226                        let act_start = blk_idx * block_size;
227                        let act_end = (act_start + block_size).min(self.in_features);
228                        let act_slice = &input_row[act_start..act_end];
229
230                        let block = &self.weight.blocks[row_block_start + blk_idx];
231                        sum += dot_block(block, act_slice);
232                    }
233
234                    if let Some(ref bias) = self.bias {
235                        sum += bias[o];
236                    }
237
238                    out_row[o] = sum;
239                }
240            });
241
242        output
243    }
244
245    /// Extract flat f32 weights (for F16/F32 non-block types).
246    fn extract_flat_weights(&self) -> Vec<f32> {
247        let mut flat = Vec::with_capacity(self.in_features * self.out_features);
248        for block in &self.weight.blocks {
249            match block {
250                QuantizedBlock::F16(data) => {
251                    flat.extend(data.iter().map(|v| f32::from(*v)));
252                }
253                QuantizedBlock::F32(data) => {
254                    flat.extend_from_slice(data);
255                }
256                _ => {} // block-quantized handled separately
257            }
258        }
259        flat
260    }
261
262    /// Forward pass with Variable input/output (for integration with autograd).
263    ///
264    /// Note: Quantized inference is forward-only (no gradient tracking).
265    /// The output Variable has `requires_grad = false`.
266    pub fn forward_var(&self, input: &axonml_autograd::Variable) -> axonml_autograd::Variable {
267        let shape = input.shape();
268        let batch = if shape.len() > 1 { shape[0] } else { 1 };
269        let input_data = input.data().to_vec();
270
271        let output_data = self.forward_f32(&input_data, batch);
272
273        let output_tensor = Tensor::from_vec(output_data, &[batch, self.out_features])
274            .expect("Failed to create output tensor");
275
276        axonml_autograd::Variable::new(output_tensor, false)
277    }
278
279    /// Memory usage in bytes (weights only, excludes bias).
280    pub fn weight_bytes(&self) -> usize {
281        self.weight.size_bytes()
282    }
283
284    /// Compression ratio vs f32 weights.
285    pub fn compression_ratio(&self) -> f32 {
286        self.weight.compression_ratio()
287    }
288
289    /// Dequantize weights back to f32 (for debugging/validation).
290    pub fn dequantize_weights(&self) -> Tensor<f32> {
291        dequantize_tensor(&self.weight).expect("Failed to dequantize weights")
292    }
293}
294
295// =============================================================================
296// Model Quantization — convert a full model
297// =============================================================================
298
299/// Quantize all parameters of a model, returning a flat list of quantized tensors.
300///
301/// This extracts all parameters from a Module, quantizes each one, and returns
302/// them in order. Use with `QuantizedModel` for full inference.
303pub fn quantize_parameters(
304    params: &[axonml_nn::Parameter],
305    quant_type: QuantType,
306) -> Vec<QuantizedTensor> {
307    params
308        .par_iter()
309        .map(|param| {
310            let tensor = param.data();
311            quantize_tensor(&tensor, quant_type).expect("Failed to quantize parameter")
312        })
313        .collect()
314}
315
316// =============================================================================
317// QuantizedModel — generic quantized inference wrapper
318// =============================================================================
319
320/// A fully quantized model for inference.
321///
322/// Wraps a collection of quantized parameters and provides fast forward pass
323/// by dequantizing weights on-the-fly during computation.
324///
325/// # Usage
326/// ```ignore
327/// use axonml_quant::inference::QuantizedModel;
328/// use axonml_quant::QuantType;
329///
330/// let qmodel = QuantizedModel::from_module(&model, QuantType::Q8_0);
331/// println!("{}", qmodel.summary());
332/// qmodel.load_into_module(&model); // dequant weights back for inference
333/// ```
334pub struct QuantizedModel {
335    /// Quantized weight tensors (in parameter order).
336    pub quantized_params: Vec<QuantizedTensor>,
337    /// Quantization type used.
338    pub quant_type: QuantType,
339    /// Total original parameter count.
340    pub total_params: usize,
341    /// Total quantized size in bytes.
342    pub total_bytes: usize,
343    /// Original f32 size in bytes.
344    pub original_bytes: usize,
345}
346
347impl QuantizedModel {
348    /// Quantize a Module's parameters.
349    pub fn from_module<M: axonml_nn::Module>(module: &M, quant_type: QuantType) -> Self {
350        let params = module.parameters();
351        let total_params: usize = params.iter().map(|p| p.numel()).sum();
352        let original_bytes = total_params * 4;
353
354        let quantized_params = quantize_parameters(&params, quant_type);
355
356        let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
357
358        QuantizedModel {
359            quantized_params,
360            quant_type,
361            total_params,
362            total_bytes,
363            original_bytes,
364        }
365    }
366
367    /// Load quantized weights back into a Module for inference.
368    ///
369    /// Dequantizes all parameters and updates the module's parameters in-place.
370    /// This is the simplest integration path — the model runs at full f32 speed
371    /// but loads from a compressed checkpoint.
372    pub fn load_into_module<M: axonml_nn::Module>(&self, module: &M) {
373        let params = module.parameters();
374        for (param, qparam) in params.iter().zip(self.quantized_params.iter()) {
375            let tensor = dequantize_tensor(qparam).expect("Failed to dequantize parameter");
376            param.update_data(tensor);
377        }
378    }
379
380    /// Compression ratio.
381    pub fn compression_ratio(&self) -> f32 {
382        self.original_bytes as f32 / self.total_bytes as f32
383    }
384
385    /// Print a summary of the quantized model.
386    pub fn summary(&self) -> String {
387        format!(
388            "QuantizedModel(type={}, params={}, f32={:.1}MB, quant={:.1}MB, ratio={:.1}x)",
389            self.quant_type,
390            self.total_params,
391            self.original_bytes as f64 / 1024.0 / 1024.0,
392            self.total_bytes as f64 / 1024.0 / 1024.0,
393            self.compression_ratio(),
394        )
395    }
396}
397
398// =============================================================================
399// Serialization — save/load quantized models
400// =============================================================================
401
402/// Serialize a QuantizedModel to bytes (for .axonml files).
403pub fn serialize_quantized(model: &QuantizedModel) -> Vec<u8> {
404    let mut buf = Vec::new();
405
406    // Magic: "AXQT" (AxonML Quantized)
407    buf.extend_from_slice(b"AXQT");
408    // Version
409    buf.push(1u8);
410    // Quant type
411    buf.push(match model.quant_type {
412        QuantType::Q8_0 => 0,
413        QuantType::Q4_0 => 1,
414        QuantType::Q4_1 => 2,
415        QuantType::Q5_0 => 3,
416        QuantType::Q5_1 => 4,
417        QuantType::F16 => 5,
418        QuantType::F32 => 6,
419    });
420    // Number of tensors
421    buf.extend_from_slice(&(model.quantized_params.len() as u32).to_le_bytes());
422    // Total params
423    buf.extend_from_slice(&(model.total_params as u64).to_le_bytes());
424
425    // Each tensor: shape_len + shape + num_blocks + blocks
426    for qt in &model.quantized_params {
427        // Shape
428        buf.extend_from_slice(&(qt.shape.len() as u32).to_le_bytes());
429        for &dim in &qt.shape {
430            buf.extend_from_slice(&(dim as u32).to_le_bytes());
431        }
432        // Number of blocks
433        buf.extend_from_slice(&(qt.blocks.len() as u32).to_le_bytes());
434        // Block data
435        for block in &qt.blocks {
436            match block {
437                QuantizedBlock::Q8(b) => {
438                    buf.extend_from_slice(&b.to_bytes());
439                }
440                QuantizedBlock::Q4(b) => {
441                    buf.extend_from_slice(&b.to_bytes());
442                }
443                QuantizedBlock::Q4_1(b) => {
444                    buf.extend_from_slice(&b.to_bytes());
445                }
446                QuantizedBlock::Q5(b) => {
447                    buf.extend_from_slice(&b.to_bytes());
448                }
449                QuantizedBlock::Q5_1(b) => {
450                    buf.extend_from_slice(&b.to_bytes());
451                }
452                QuantizedBlock::F16(data) => {
453                    for &v in data {
454                        buf.extend_from_slice(&v.to_le_bytes());
455                    }
456                }
457                QuantizedBlock::F32(data) => {
458                    for &v in data {
459                        buf.extend_from_slice(&v.to_le_bytes());
460                    }
461                }
462            }
463        }
464    }
465
466    buf
467}
468
469/// Deserialize a QuantizedModel from bytes.
470pub fn deserialize_quantized(data: &[u8]) -> Result<QuantizedModel, String> {
471    if data.len() < 18 || &data[0..4] != b"AXQT" {
472        return Err("Invalid quantized model file (bad magic)".to_string());
473    }
474
475    let version = data[4];
476    if version != 1 {
477        return Err(format!("Unsupported quantized model version: {version}"));
478    }
479
480    let quant_type = match data[5] {
481        0 => QuantType::Q8_0,
482        1 => QuantType::Q4_0,
483        2 => QuantType::Q4_1,
484        3 => QuantType::Q5_0,
485        4 => QuantType::Q5_1,
486        5 => QuantType::F16,
487        6 => QuantType::F32,
488        x => return Err(format!("Unknown quant type byte: {x}")),
489    };
490
491    let num_tensors = u32::from_le_bytes([data[6], data[7], data[8], data[9]]) as usize;
492    let total_params = u64::from_le_bytes([
493        data[10], data[11], data[12], data[13], data[14], data[15], data[16], data[17],
494    ]) as usize;
495
496    let mut offset = 18usize;
497    let mut quantized_params = Vec::with_capacity(num_tensors);
498
499    let block_bytes = quant_type.bytes_per_block();
500
501    for _ in 0..num_tensors {
502        if offset + 4 > data.len() {
503            return Err("Truncated quantized model file".to_string());
504        }
505
506        // Shape
507        let shape_len = u32::from_le_bytes([
508            data[offset],
509            data[offset + 1],
510            data[offset + 2],
511            data[offset + 3],
512        ]) as usize;
513        offset += 4;
514
515        let mut shape = Vec::with_capacity(shape_len);
516        for _ in 0..shape_len {
517            let dim = u32::from_le_bytes([
518                data[offset],
519                data[offset + 1],
520                data[offset + 2],
521                data[offset + 3],
522            ]) as usize;
523            shape.push(dim);
524            offset += 4;
525        }
526
527        // Number of blocks
528        let num_blocks = u32::from_le_bytes([
529            data[offset],
530            data[offset + 1],
531            data[offset + 2],
532            data[offset + 3],
533        ]) as usize;
534        offset += 4;
535
536        // Read blocks
537        let mut blocks = Vec::with_capacity(num_blocks);
538        for _ in 0..num_blocks {
539            if offset + block_bytes > data.len() {
540                return Err("Truncated block data".to_string());
541            }
542
543            let block = match quant_type {
544                QuantType::Q8_0 => {
545                    let b =
546                        Q8Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q8 block")?;
547                    QuantizedBlock::Q8(b)
548                }
549                QuantType::Q4_0 => {
550                    let b =
551                        Q4Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q4 block")?;
552                    QuantizedBlock::Q4(b)
553                }
554                QuantType::Q4_1 => {
555                    let scale = f16::from_le_bytes([data[offset], data[offset + 1]]);
556                    let min = f16::from_le_bytes([data[offset + 2], data[offset + 3]]);
557                    let mut block_data = [0u8; 16];
558                    block_data.copy_from_slice(&data[offset + 4..offset + 20]);
559                    QuantizedBlock::Q4_1(Q4_1Block::new(scale, min, block_data))
560                }
561                QuantType::F16 => {
562                    let v = f16::from_le_bytes([data[offset], data[offset + 1]]);
563                    QuantizedBlock::F16(vec![v])
564                }
565                QuantType::F32 => {
566                    let v = f32::from_le_bytes([
567                        data[offset],
568                        data[offset + 1],
569                        data[offset + 2],
570                        data[offset + 3],
571                    ]);
572                    QuantizedBlock::F32(vec![v])
573                }
574                _ => return Err("Unsupported quant type for deserialization".to_string()),
575            };
576
577            blocks.push(block);
578            offset += block_bytes;
579        }
580
581        quantized_params.push(QuantizedTensor::new(shape, quant_type, blocks));
582    }
583
584    let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
585    let original_bytes = total_params * 4;
586
587    Ok(QuantizedModel {
588        quantized_params,
589        quant_type,
590        total_params,
591        total_bytes,
592        original_bytes,
593    })
594}
595
596// =============================================================================
597// Tests
598// =============================================================================
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603
604    #[test]
605    fn test_quantized_linear_q8_forward() {
606        let in_f = 64;
607        let out_f = 16;
608        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 5.0).collect();
609        let bias: Vec<f32> = (0..out_f).map(|i| i as f32 * 0.1).collect();
610
611        let ql =
612            QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
613
614        let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
615        let output = ql.forward_f32(&input, 1);
616
617        assert_eq!(output.len(), out_f);
618        // Output should not be all zeros
619        let sum: f32 = output.iter().sum();
620        assert!(sum.abs() > 0.01, "Output should be non-zero, got sum={sum}");
621
622        // Compare with f32 reference
623        let ref_ql =
624            QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::F32);
625        let ref_output = ref_ql.forward_f32(&input, 1);
626
627        // Q8 should be very close to f32
628        let max_err: f32 = output
629            .iter()
630            .zip(ref_output.iter())
631            .map(|(a, b)| (a - b).abs())
632            .fold(0.0f32, f32::max);
633        assert!(max_err < 1.0, "Q8 error too large: {max_err}");
634    }
635
636    #[test]
637    fn test_quantized_linear_q4_forward() {
638        let in_f = 64;
639        let out_f = 8;
640        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.02) - 5.0).collect();
641
642        let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
643
644        let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
645        let output = ql.forward_f32(&input, 1);
646
647        assert_eq!(output.len(), out_f);
648        let sum: f32 = output.iter().sum();
649        assert!(sum.abs() > 0.01, "Output should be non-zero");
650    }
651
652    #[test]
653    fn test_quantized_linear_batch() {
654        let in_f = 32;
655        let out_f = 8;
656        let batch = 4;
657        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
658
659        let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
660
661        let input: Vec<f32> = (0..batch * in_f).map(|i| i as f32 * 0.01).collect();
662        let output = ql.forward_f32(&input, batch);
663
664        assert_eq!(output.len(), batch * out_f);
665    }
666
667    #[test]
668    fn test_quantized_linear_compression() {
669        let in_f = 1024;
670        let out_f = 512;
671        let weight: Vec<f32> = vec![0.1; in_f * out_f];
672
673        let ql_q8 =
674            QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
675        let ql_q4 =
676            QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
677
678        assert!(ql_q8.compression_ratio() > 3.5, "Q8 should compress ~4x");
679        assert!(ql_q4.compression_ratio() > 6.0, "Q4 should compress ~7-8x");
680    }
681
682    #[test]
683    fn test_quantized_model_roundtrip() {
684        let in_f = 32;
685        let out_f = 8;
686        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
687        let weight_tensor = Tensor::from_vec(weight.clone(), &[out_f, in_f]).unwrap();
688        let qt = quantize_tensor(&weight_tensor, QuantType::Q8_0).unwrap();
689
690        let model = QuantizedModel {
691            quantized_params: vec![qt],
692            quant_type: QuantType::Q8_0,
693            total_params: in_f * out_f,
694            total_bytes: 0,
695            original_bytes: in_f * out_f * 4,
696        };
697
698        let serialized = serialize_quantized(&model);
699        let deserialized = deserialize_quantized(&serialized).unwrap();
700
701        assert_eq!(deserialized.quant_type, QuantType::Q8_0);
702        assert_eq!(deserialized.quantized_params.len(), 1);
703        assert_eq!(deserialized.quantized_params[0].shape, vec![out_f, in_f]);
704    }
705
706    #[test]
707    fn test_quantized_linear_variable_forward() {
708        let in_f = 32;
709        let out_f = 8;
710        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
711        let bias: Vec<f32> = vec![0.5; out_f];
712
713        let ql =
714            QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
715
716        let input_tensor =
717            Tensor::from_vec((0..2 * in_f).map(|i| i as f32 * 0.1).collect(), &[2, in_f]).unwrap();
718        let input_var = axonml_autograd::Variable::new(input_tensor, false);
719
720        let output = ql.forward_var(&input_var);
721
722        assert_eq!(output.shape(), vec![2, out_f]);
723    }
724}