Skip to main content

axonml_quant/
inference.rs

1//! Quantized Inference — fast inference with quantized weights
2//!
3//! # File
4//! `crates/axonml-quant/src/inference.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 16, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17// =============================================================================
18// Imports
19// =============================================================================
20
21use crate::dequantize::dequantize_tensor;
22use crate::quantize::quantize_tensor;
23use crate::types::{Q4_1Block, Q4Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
24use axonml_tensor::Tensor;
25use half::f16;
26use rayon::prelude::*;
27
28// =============================================================================
29// Quantized Matmul Kernels
30// =============================================================================
31
32/// Compute dot product of f32 activation vector with Q8_0 quantized weight row.
33/// This is the inner kernel — called per output element.
34#[inline]
35fn dot_q8_block(block: &Q8Block, activations: &[f32]) -> f32 {
36    let scale = f32::from(block.scale);
37    let mut sum = 0.0f32;
38    for (d, a) in block.data.iter().zip(activations.iter()) {
39        sum += (*d as f32) * a;
40    }
41    sum * scale
42}
43
44/// Compute dot product of f32 activation vector with Q4_0 quantized weight row.
45#[inline]
46fn dot_q4_block(block: &Q4Block, activations: &[f32]) -> f32 {
47    let scale = f32::from(block.scale);
48    let unpacked = block.unpack();
49    let mut sum = 0.0f32;
50    for i in 0..unpacked.len().min(activations.len()) {
51        sum += (unpacked[i] as f32) * activations[i];
52    }
53    sum * scale
54}
55
56/// Compute dot product of f32 activation vector with Q4_1 quantized weight row.
57#[inline]
58fn dot_q4_1_block(block: &Q4_1Block, activations: &[f32]) -> f32 {
59    let scale = f32::from(block.scale);
60    let min = f32::from(block.min);
61    let unpacked = block.unpack();
62    let mut sum = 0.0f32;
63    for i in 0..unpacked.len().min(activations.len()) {
64        sum += (unpacked[i] as f32 * scale + min) * activations[i];
65    }
66    sum
67}
68
69/// Compute dot product of f32 activation vector with F16 weight block.
70#[inline]
71fn dot_f16_block(data: &[f16], activations: &[f32]) -> f32 {
72    let mut sum = 0.0f32;
73    for i in 0..data.len().min(activations.len()) {
74        sum += f32::from(data[i]) * activations[i];
75    }
76    sum
77}
78
79/// Dot product of an activation vector against a single quantized block.
80#[inline]
81fn dot_block(block: &QuantizedBlock, activations: &[f32]) -> f32 {
82    match block {
83        QuantizedBlock::Q8(b) => dot_q8_block(b, activations),
84        QuantizedBlock::Q4(b) => dot_q4_block(b, activations),
85        QuantizedBlock::Q4_1(b) => dot_q4_1_block(b, activations),
86        QuantizedBlock::Q5(b) => {
87            let scale = b.scale.to_f32();
88            let values = b.unpack();
89            values.iter().zip(activations).map(|(&v, &a)| v as f32 * scale * a).sum()
90        }
91        QuantizedBlock::Q5_1(b) => {
92            let scale = b.scale.to_f32();
93            let min = b.min.to_f32();
94            let values = b.unpack();
95            values.iter().zip(activations).map(|(&v, &a)| (v as f32 * scale + min) * a).sum()
96        }
97        QuantizedBlock::F16(data) => dot_f16_block(data, activations),
98        QuantizedBlock::F32(data) => {
99            let mut sum = 0.0f32;
100            for i in 0..data.len().min(activations.len()) {
101                sum += data[i] * activations[i];
102            }
103            sum
104        }
105    }
106}
107
108// =============================================================================
109// QuantizedLinear — drop-in replacement for Linear
110// =============================================================================
111
112/// A linear layer with quantized weights for fast inference.
113///
114/// Stores weights as `QuantizedTensor` (Q8/Q4/F16) and dequantizes on-the-fly
115/// during the matrix multiply. Bias remains in f32.
116///
117/// # Usage
118/// ```ignore
119/// use axonml_quant::inference::QuantizedLinear;
120/// use axonml_quant::QuantType;
121///
122/// let qlinear = QuantizedLinear::from_linear_params(&weights, Some(&bias), 512, 128, QuantType::Q8_0);
123/// let output = qlinear.forward_f32(&input_data, 1);
124/// ```
125#[derive(Debug, Clone)]
126pub struct QuantizedLinear {
127    /// Quantized weight matrix (out_features x in_features, row-major).
128    weight: QuantizedTensor,
129    /// Bias vector (f32, not quantized).
130    bias: Option<Vec<f32>>,
131    /// Input feature dimension.
132    pub in_features: usize,
133    /// Output feature dimension.
134    pub out_features: usize,
135    /// Quantization type used.
136    pub quant_type: QuantType,
137    /// Number of blocks per output row (cached for fast access).
138    blocks_per_row: usize,
139}
140
141impl QuantizedLinear {
142    /// Create a QuantizedLinear from an axonml_nn::Linear layer.
143    pub fn from_linear_params(
144        weight_data: &[f32],
145        bias_data: Option<&[f32]>,
146        in_features: usize,
147        out_features: usize,
148        quant_type: QuantType,
149    ) -> Self {
150        // Weight shape is [out_features, in_features]
151        let weight_tensor = Tensor::from_vec(weight_data.to_vec(), &[out_features, in_features])
152            .expect("Failed to create weight tensor for quantization");
153
154        let weight =
155            quantize_tensor(&weight_tensor, quant_type).expect("Failed to quantize weight tensor");
156
157        let block_size = quant_type.block_size();
158        let blocks_per_row = in_features.div_ceil(block_size);
159
160        QuantizedLinear {
161            weight,
162            bias: bias_data.map(|b| b.to_vec()),
163            in_features,
164            out_features,
165            quant_type,
166            blocks_per_row,
167        }
168    }
169
170    /// Forward pass: f32 input → f32 output.
171    ///
172    /// Input shape: `[batch, in_features]`
173    /// Output shape: `[batch, out_features]`
174    ///
175    /// Performs quantized matrix multiplication: each output element is computed
176    /// by iterating over the weight row's quantized blocks and accumulating
177    /// dot products with the corresponding activation slices.
178    pub fn forward_f32(&self, input: &[f32], batch_size: usize) -> Vec<f32> {
179        let mut output = vec![0.0f32; batch_size * self.out_features];
180
181        // For non-block types (F16, F32), extract flat weight data and do direct matmul
182        if !self.quant_type.is_block_quantized() {
183            let weight_flat = self.extract_flat_weights();
184            output
185                .par_chunks_mut(self.out_features)
186                .enumerate()
187                .for_each(|(b, out_row)| {
188                    let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
189                    for o in 0..self.out_features {
190                        let w_start = o * self.in_features;
191                        let mut sum = 0.0f32;
192                        for k in 0..self.in_features {
193                            sum += weight_flat[w_start + k] * input_row[k];
194                        }
195                        if let Some(ref bias) = self.bias {
196                            sum += bias[o];
197                        }
198                        out_row[o] = sum;
199                    }
200                });
201            return output;
202        }
203
204        // Block-quantized path (Q8, Q4, Q4_1): iterate over blocks per weight row
205        let block_size = self.quant_type.block_size();
206
207        output
208            .par_chunks_mut(self.out_features)
209            .enumerate()
210            .for_each(|(b, out_row)| {
211                let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
212
213                for o in 0..self.out_features {
214                    let row_block_start = o * self.blocks_per_row;
215                    let mut sum = 0.0f32;
216
217                    for blk_idx in 0..self.blocks_per_row {
218                        let act_start = blk_idx * block_size;
219                        let act_end = (act_start + block_size).min(self.in_features);
220                        let act_slice = &input_row[act_start..act_end];
221
222                        let block = &self.weight.blocks[row_block_start + blk_idx];
223                        sum += dot_block(block, act_slice);
224                    }
225
226                    if let Some(ref bias) = self.bias {
227                        sum += bias[o];
228                    }
229
230                    out_row[o] = sum;
231                }
232            });
233
234        output
235    }
236
237    /// Extract flat f32 weights (for F16/F32 non-block types).
238    fn extract_flat_weights(&self) -> Vec<f32> {
239        let mut flat = Vec::with_capacity(self.in_features * self.out_features);
240        for block in &self.weight.blocks {
241            match block {
242                QuantizedBlock::F16(data) => {
243                    flat.extend(data.iter().map(|v| f32::from(*v)));
244                }
245                QuantizedBlock::F32(data) => {
246                    flat.extend_from_slice(data);
247                }
248                _ => {} // block-quantized handled separately
249            }
250        }
251        flat
252    }
253
254    /// Forward pass with Variable input/output (for integration with autograd).
255    ///
256    /// Note: Quantized inference is forward-only (no gradient tracking).
257    /// The output Variable has `requires_grad = false`.
258    pub fn forward_var(&self, input: &axonml_autograd::Variable) -> axonml_autograd::Variable {
259        let shape = input.shape();
260        let batch = if shape.len() > 1 { shape[0] } else { 1 };
261        let input_data = input.data().to_vec();
262
263        let output_data = self.forward_f32(&input_data, batch);
264
265        let output_tensor = Tensor::from_vec(output_data, &[batch, self.out_features])
266            .expect("Failed to create output tensor");
267
268        axonml_autograd::Variable::new(output_tensor, false)
269    }
270
271    /// Memory usage in bytes (weights only, excludes bias).
272    pub fn weight_bytes(&self) -> usize {
273        self.weight.size_bytes()
274    }
275
276    /// Compression ratio vs f32 weights.
277    pub fn compression_ratio(&self) -> f32 {
278        self.weight.compression_ratio()
279    }
280
281    /// Dequantize weights back to f32 (for debugging/validation).
282    pub fn dequantize_weights(&self) -> Tensor<f32> {
283        dequantize_tensor(&self.weight).expect("Failed to dequantize weights")
284    }
285}
286
287// =============================================================================
288// Model Quantization — convert a full model
289// =============================================================================
290
291/// Quantize all parameters of a model, returning a flat list of quantized tensors.
292///
293/// This extracts all parameters from a Module, quantizes each one, and returns
294/// them in order. Use with `QuantizedModel` for full inference.
295pub fn quantize_parameters(
296    params: &[axonml_nn::Parameter],
297    quant_type: QuantType,
298) -> Vec<QuantizedTensor> {
299    params
300        .par_iter()
301        .map(|param| {
302            let tensor = param.data();
303            quantize_tensor(&tensor, quant_type).expect("Failed to quantize parameter")
304        })
305        .collect()
306}
307
308// =============================================================================
309// QuantizedModel — generic quantized inference wrapper
310// =============================================================================
311
312/// A fully quantized model for inference.
313///
314/// Wraps a collection of quantized parameters and provides fast forward pass
315/// by dequantizing weights on-the-fly during computation.
316///
317/// # Usage
318/// ```ignore
319/// use axonml_quant::inference::QuantizedModel;
320/// use axonml_quant::QuantType;
321///
322/// let qmodel = QuantizedModel::from_module(&model, QuantType::Q8_0);
323/// println!("{}", qmodel.summary());
324/// qmodel.load_into_module(&model); // dequant weights back for inference
325/// ```
326pub struct QuantizedModel {
327    /// Quantized weight tensors (in parameter order).
328    pub quantized_params: Vec<QuantizedTensor>,
329    /// Quantization type used.
330    pub quant_type: QuantType,
331    /// Total original parameter count.
332    pub total_params: usize,
333    /// Total quantized size in bytes.
334    pub total_bytes: usize,
335    /// Original f32 size in bytes.
336    pub original_bytes: usize,
337}
338
339impl QuantizedModel {
340    /// Quantize a Module's parameters.
341    pub fn from_module<M: axonml_nn::Module>(module: &M, quant_type: QuantType) -> Self {
342        let params = module.parameters();
343        let total_params: usize = params.iter().map(|p| p.numel()).sum();
344        let original_bytes = total_params * 4;
345
346        let quantized_params = quantize_parameters(&params, quant_type);
347
348        let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
349
350        QuantizedModel {
351            quantized_params,
352            quant_type,
353            total_params,
354            total_bytes,
355            original_bytes,
356        }
357    }
358
359    /// Load quantized weights back into a Module for inference.
360    ///
361    /// Dequantizes all parameters and updates the module's parameters in-place.
362    /// This is the simplest integration path — the model runs at full f32 speed
363    /// but loads from a compressed checkpoint.
364    pub fn load_into_module<M: axonml_nn::Module>(&self, module: &M) {
365        let params = module.parameters();
366        for (param, qparam) in params.iter().zip(self.quantized_params.iter()) {
367            let tensor = dequantize_tensor(qparam).expect("Failed to dequantize parameter");
368            param.update_data(tensor);
369        }
370    }
371
372    /// Compression ratio.
373    pub fn compression_ratio(&self) -> f32 {
374        self.original_bytes as f32 / self.total_bytes as f32
375    }
376
377    /// Print a summary of the quantized model.
378    pub fn summary(&self) -> String {
379        format!(
380            "QuantizedModel(type={}, params={}, f32={:.1}MB, quant={:.1}MB, ratio={:.1}x)",
381            self.quant_type,
382            self.total_params,
383            self.original_bytes as f64 / 1024.0 / 1024.0,
384            self.total_bytes as f64 / 1024.0 / 1024.0,
385            self.compression_ratio(),
386        )
387    }
388}
389
390// =============================================================================
391// Serialization — save/load quantized models
392// =============================================================================
393
394/// Serialize a QuantizedModel to bytes (for .axonml files).
395pub fn serialize_quantized(model: &QuantizedModel) -> Vec<u8> {
396    let mut buf = Vec::new();
397
398    // Magic: "AXQT" (AxonML Quantized)
399    buf.extend_from_slice(b"AXQT");
400    // Version
401    buf.push(1u8);
402    // Quant type
403    buf.push(match model.quant_type {
404        QuantType::Q8_0 => 0,
405        QuantType::Q4_0 => 1,
406        QuantType::Q4_1 => 2,
407        QuantType::Q5_0 => 3,
408        QuantType::Q5_1 => 4,
409        QuantType::F16 => 5,
410        QuantType::F32 => 6,
411    });
412    // Number of tensors
413    buf.extend_from_slice(&(model.quantized_params.len() as u32).to_le_bytes());
414    // Total params
415    buf.extend_from_slice(&(model.total_params as u64).to_le_bytes());
416
417    // Each tensor: shape_len + shape + num_blocks + blocks
418    for qt in &model.quantized_params {
419        // Shape
420        buf.extend_from_slice(&(qt.shape.len() as u32).to_le_bytes());
421        for &dim in &qt.shape {
422            buf.extend_from_slice(&(dim as u32).to_le_bytes());
423        }
424        // Number of blocks
425        buf.extend_from_slice(&(qt.blocks.len() as u32).to_le_bytes());
426        // Block data
427        for block in &qt.blocks {
428            match block {
429                QuantizedBlock::Q8(b) => {
430                    buf.extend_from_slice(&b.to_bytes());
431                }
432                QuantizedBlock::Q4(b) => {
433                    buf.extend_from_slice(&b.to_bytes());
434                }
435                QuantizedBlock::Q4_1(b) => {
436                    buf.extend_from_slice(&b.to_bytes());
437                }
438                QuantizedBlock::Q5(b) => {
439                    buf.extend_from_slice(&b.to_bytes());
440                }
441                QuantizedBlock::Q5_1(b) => {
442                    buf.extend_from_slice(&b.to_bytes());
443                }
444                QuantizedBlock::F16(data) => {
445                    for &v in data {
446                        buf.extend_from_slice(&v.to_le_bytes());
447                    }
448                }
449                QuantizedBlock::F32(data) => {
450                    for &v in data {
451                        buf.extend_from_slice(&v.to_le_bytes());
452                    }
453                }
454            }
455        }
456    }
457
458    buf
459}
460
461/// Deserialize a QuantizedModel from bytes.
462pub fn deserialize_quantized(data: &[u8]) -> Result<QuantizedModel, String> {
463    if data.len() < 18 || &data[0..4] != b"AXQT" {
464        return Err("Invalid quantized model file (bad magic)".to_string());
465    }
466
467    let version = data[4];
468    if version != 1 {
469        return Err(format!("Unsupported quantized model version: {version}"));
470    }
471
472    let quant_type = match data[5] {
473        0 => QuantType::Q8_0,
474        1 => QuantType::Q4_0,
475        2 => QuantType::Q4_1,
476        3 => QuantType::Q5_0,
477        4 => QuantType::Q5_1,
478        5 => QuantType::F16,
479        6 => QuantType::F32,
480        x => return Err(format!("Unknown quant type byte: {x}")),
481    };
482
483    let num_tensors = u32::from_le_bytes([data[6], data[7], data[8], data[9]]) as usize;
484    let total_params = u64::from_le_bytes([
485        data[10], data[11], data[12], data[13], data[14], data[15], data[16], data[17],
486    ]) as usize;
487
488    let mut offset = 18usize;
489    let mut quantized_params = Vec::with_capacity(num_tensors);
490
491    let block_bytes = quant_type.bytes_per_block();
492
493    for _ in 0..num_tensors {
494        if offset + 4 > data.len() {
495            return Err("Truncated quantized model file".to_string());
496        }
497
498        // Shape
499        let shape_len = u32::from_le_bytes([
500            data[offset],
501            data[offset + 1],
502            data[offset + 2],
503            data[offset + 3],
504        ]) as usize;
505        offset += 4;
506
507        let mut shape = Vec::with_capacity(shape_len);
508        for _ in 0..shape_len {
509            let dim = u32::from_le_bytes([
510                data[offset],
511                data[offset + 1],
512                data[offset + 2],
513                data[offset + 3],
514            ]) as usize;
515            shape.push(dim);
516            offset += 4;
517        }
518
519        // Number of blocks
520        let num_blocks = u32::from_le_bytes([
521            data[offset],
522            data[offset + 1],
523            data[offset + 2],
524            data[offset + 3],
525        ]) as usize;
526        offset += 4;
527
528        // Read blocks
529        let mut blocks = Vec::with_capacity(num_blocks);
530        for _ in 0..num_blocks {
531            if offset + block_bytes > data.len() {
532                return Err("Truncated block data".to_string());
533            }
534
535            let block = match quant_type {
536                QuantType::Q8_0 => {
537                    let b =
538                        Q8Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q8 block")?;
539                    QuantizedBlock::Q8(b)
540                }
541                QuantType::Q4_0 => {
542                    let b =
543                        Q4Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q4 block")?;
544                    QuantizedBlock::Q4(b)
545                }
546                QuantType::Q4_1 => {
547                    let scale = f16::from_le_bytes([data[offset], data[offset + 1]]);
548                    let min = f16::from_le_bytes([data[offset + 2], data[offset + 3]]);
549                    let mut block_data = [0u8; 16];
550                    block_data.copy_from_slice(&data[offset + 4..offset + 20]);
551                    QuantizedBlock::Q4_1(Q4_1Block::new(scale, min, block_data))
552                }
553                QuantType::F16 => {
554                    let v = f16::from_le_bytes([data[offset], data[offset + 1]]);
555                    QuantizedBlock::F16(vec![v])
556                }
557                QuantType::F32 => {
558                    let v = f32::from_le_bytes([
559                        data[offset],
560                        data[offset + 1],
561                        data[offset + 2],
562                        data[offset + 3],
563                    ]);
564                    QuantizedBlock::F32(vec![v])
565                }
566                _ => return Err("Unsupported quant type for deserialization".to_string()),
567            };
568
569            blocks.push(block);
570            offset += block_bytes;
571        }
572
573        quantized_params.push(QuantizedTensor::new(shape, quant_type, blocks));
574    }
575
576    let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
577    let original_bytes = total_params * 4;
578
579    Ok(QuantizedModel {
580        quantized_params,
581        quant_type,
582        total_params,
583        total_bytes,
584        original_bytes,
585    })
586}
587
588// =============================================================================
589// Tests
590// =============================================================================
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595
596    #[test]
597    fn test_quantized_linear_q8_forward() {
598        let in_f = 64;
599        let out_f = 16;
600        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 5.0).collect();
601        let bias: Vec<f32> = (0..out_f).map(|i| i as f32 * 0.1).collect();
602
603        let ql =
604            QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
605
606        let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
607        let output = ql.forward_f32(&input, 1);
608
609        assert_eq!(output.len(), out_f);
610        // Output should not be all zeros
611        let sum: f32 = output.iter().sum();
612        assert!(sum.abs() > 0.01, "Output should be non-zero, got sum={sum}");
613
614        // Compare with f32 reference
615        let ref_ql =
616            QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::F32);
617        let ref_output = ref_ql.forward_f32(&input, 1);
618
619        // Q8 should be very close to f32
620        let max_err: f32 = output
621            .iter()
622            .zip(ref_output.iter())
623            .map(|(a, b)| (a - b).abs())
624            .fold(0.0f32, f32::max);
625        assert!(max_err < 1.0, "Q8 error too large: {max_err}");
626    }
627
628    #[test]
629    fn test_quantized_linear_q4_forward() {
630        let in_f = 64;
631        let out_f = 8;
632        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.02) - 5.0).collect();
633
634        let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
635
636        let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
637        let output = ql.forward_f32(&input, 1);
638
639        assert_eq!(output.len(), out_f);
640        let sum: f32 = output.iter().sum();
641        assert!(sum.abs() > 0.01, "Output should be non-zero");
642    }
643
644    #[test]
645    fn test_quantized_linear_batch() {
646        let in_f = 32;
647        let out_f = 8;
648        let batch = 4;
649        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
650
651        let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
652
653        let input: Vec<f32> = (0..batch * in_f).map(|i| i as f32 * 0.01).collect();
654        let output = ql.forward_f32(&input, batch);
655
656        assert_eq!(output.len(), batch * out_f);
657    }
658
659    #[test]
660    fn test_quantized_linear_compression() {
661        let in_f = 1024;
662        let out_f = 512;
663        let weight: Vec<f32> = vec![0.1; in_f * out_f];
664
665        let ql_q8 =
666            QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
667        let ql_q4 =
668            QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
669
670        assert!(ql_q8.compression_ratio() > 3.5, "Q8 should compress ~4x");
671        assert!(ql_q4.compression_ratio() > 6.0, "Q4 should compress ~7-8x");
672    }
673
674    #[test]
675    fn test_quantized_model_roundtrip() {
676        let in_f = 32;
677        let out_f = 8;
678        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
679        let weight_tensor = Tensor::from_vec(weight.clone(), &[out_f, in_f]).unwrap();
680        let qt = quantize_tensor(&weight_tensor, QuantType::Q8_0).unwrap();
681
682        let model = QuantizedModel {
683            quantized_params: vec![qt],
684            quant_type: QuantType::Q8_0,
685            total_params: in_f * out_f,
686            total_bytes: 0,
687            original_bytes: in_f * out_f * 4,
688        };
689
690        let serialized = serialize_quantized(&model);
691        let deserialized = deserialize_quantized(&serialized).unwrap();
692
693        assert_eq!(deserialized.quant_type, QuantType::Q8_0);
694        assert_eq!(deserialized.quantized_params.len(), 1);
695        assert_eq!(deserialized.quantized_params[0].shape, vec![out_f, in_f]);
696    }
697
698    #[test]
699    fn test_quantized_linear_variable_forward() {
700        let in_f = 32;
701        let out_f = 8;
702        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
703        let bias: Vec<f32> = vec![0.5; out_f];
704
705        let ql =
706            QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
707
708        let input_tensor =
709            Tensor::from_vec((0..2 * in_f).map(|i| i as f32 * 0.1).collect(), &[2, in_f]).unwrap();
710        let input_var = axonml_autograd::Variable::new(input_tensor, false);
711
712        let output = ql.forward_var(&input_var);
713
714        assert_eq!(output.shape(), vec![2, out_f]);
715    }
716}