Skip to main content

axonml_quant/
inference.rs

1//! Quantized Inference — fast inference with quantized weights
2//!
3//! # File
4//! `crates/axonml-quant/src/inference.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 16, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17// =============================================================================
18// Imports
19// =============================================================================
20
21use crate::dequantize::dequantize_tensor;
22use crate::quantize::quantize_tensor;
23use crate::types::{Q4_1Block, Q4Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
24use axonml_tensor::Tensor;
25use half::f16;
26use rayon::prelude::*;
27
28// =============================================================================
29// Quantized Matmul Kernels
30// =============================================================================
31
32/// Compute dot product of f32 activation vector with Q8_0 quantized weight row.
33/// This is the inner kernel — called per output element.
34#[inline]
35fn dot_q8_block(block: &Q8Block, activations: &[f32]) -> f32 {
36    let scale = f32::from(block.scale);
37    let mut sum = 0.0f32;
38    for (d, a) in block.data.iter().zip(activations.iter()) {
39        sum += (*d as f32) * a;
40    }
41    sum * scale
42}
43
44/// Compute dot product of f32 activation vector with Q4_0 quantized weight row.
45#[inline]
46fn dot_q4_block(block: &Q4Block, activations: &[f32]) -> f32 {
47    let scale = f32::from(block.scale);
48    let unpacked = block.unpack();
49    let mut sum = 0.0f32;
50    for i in 0..unpacked.len().min(activations.len()) {
51        sum += (unpacked[i] as f32) * activations[i];
52    }
53    sum * scale
54}
55
56/// Compute dot product of f32 activation vector with Q4_1 quantized weight row.
57#[inline]
58fn dot_q4_1_block(block: &Q4_1Block, activations: &[f32]) -> f32 {
59    let scale = f32::from(block.scale);
60    let min = f32::from(block.min);
61    let unpacked = block.unpack();
62    let mut sum = 0.0f32;
63    for i in 0..unpacked.len().min(activations.len()) {
64        sum += (unpacked[i] as f32 * scale + min) * activations[i];
65    }
66    sum
67}
68
69/// Compute dot product of f32 activation vector with F16 weight block.
70#[inline]
71fn dot_f16_block(data: &[f16], activations: &[f32]) -> f32 {
72    let mut sum = 0.0f32;
73    for i in 0..data.len().min(activations.len()) {
74        sum += f32::from(data[i]) * activations[i];
75    }
76    sum
77}
78
79/// Dot product of an activation vector against a single quantized block.
80#[inline]
81fn dot_block(block: &QuantizedBlock, activations: &[f32]) -> f32 {
82    match block {
83        QuantizedBlock::Q8(b) => dot_q8_block(b, activations),
84        QuantizedBlock::Q4(b) => dot_q4_block(b, activations),
85        QuantizedBlock::Q4_1(b) => dot_q4_1_block(b, activations),
86        QuantizedBlock::F16(data) => dot_f16_block(data, activations),
87        QuantizedBlock::F32(data) => {
88            let mut sum = 0.0f32;
89            for i in 0..data.len().min(activations.len()) {
90                sum += data[i] * activations[i];
91            }
92            sum
93        }
94    }
95}
96
97// =============================================================================
98// QuantizedLinear — drop-in replacement for Linear
99// =============================================================================
100
101/// A linear layer with quantized weights for fast inference.
102///
103/// Stores weights as `QuantizedTensor` (Q8/Q4/F16) and dequantizes on-the-fly
104/// during the matrix multiply. Bias remains in f32.
105///
106/// # Usage
107/// ```ignore
108/// use axonml_quant::inference::QuantizedLinear;
109/// use axonml_quant::QuantType;
110///
111/// let qlinear = QuantizedLinear::from_linear_params(&weights, Some(&bias), 512, 128, QuantType::Q8_0);
112/// let output = qlinear.forward_f32(&input_data, 1);
113/// ```
114#[derive(Debug, Clone)]
115pub struct QuantizedLinear {
116    /// Quantized weight matrix (out_features x in_features, row-major).
117    weight: QuantizedTensor,
118    /// Bias vector (f32, not quantized).
119    bias: Option<Vec<f32>>,
120    /// Input feature dimension.
121    pub in_features: usize,
122    /// Output feature dimension.
123    pub out_features: usize,
124    /// Quantization type used.
125    pub quant_type: QuantType,
126    /// Number of blocks per output row (cached for fast access).
127    blocks_per_row: usize,
128}
129
130impl QuantizedLinear {
131    /// Create a QuantizedLinear from an axonml_nn::Linear layer.
132    pub fn from_linear_params(
133        weight_data: &[f32],
134        bias_data: Option<&[f32]>,
135        in_features: usize,
136        out_features: usize,
137        quant_type: QuantType,
138    ) -> Self {
139        // Weight shape is [out_features, in_features]
140        let weight_tensor = Tensor::from_vec(weight_data.to_vec(), &[out_features, in_features])
141            .expect("Failed to create weight tensor for quantization");
142
143        let weight =
144            quantize_tensor(&weight_tensor, quant_type).expect("Failed to quantize weight tensor");
145
146        let block_size = quant_type.block_size();
147        let blocks_per_row = in_features.div_ceil(block_size);
148
149        QuantizedLinear {
150            weight,
151            bias: bias_data.map(|b| b.to_vec()),
152            in_features,
153            out_features,
154            quant_type,
155            blocks_per_row,
156        }
157    }
158
159    /// Forward pass: f32 input → f32 output.
160    ///
161    /// Input shape: `[batch, in_features]`
162    /// Output shape: `[batch, out_features]`
163    ///
164    /// Performs quantized matrix multiplication: each output element is computed
165    /// by iterating over the weight row's quantized blocks and accumulating
166    /// dot products with the corresponding activation slices.
167    pub fn forward_f32(&self, input: &[f32], batch_size: usize) -> Vec<f32> {
168        let mut output = vec![0.0f32; batch_size * self.out_features];
169
170        // For non-block types (F16, F32), extract flat weight data and do direct matmul
171        if !self.quant_type.is_block_quantized() {
172            let weight_flat = self.extract_flat_weights();
173            output
174                .par_chunks_mut(self.out_features)
175                .enumerate()
176                .for_each(|(b, out_row)| {
177                    let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
178                    for o in 0..self.out_features {
179                        let w_start = o * self.in_features;
180                        let mut sum = 0.0f32;
181                        for k in 0..self.in_features {
182                            sum += weight_flat[w_start + k] * input_row[k];
183                        }
184                        if let Some(ref bias) = self.bias {
185                            sum += bias[o];
186                        }
187                        out_row[o] = sum;
188                    }
189                });
190            return output;
191        }
192
193        // Block-quantized path (Q8, Q4, Q4_1): iterate over blocks per weight row
194        let block_size = self.quant_type.block_size();
195
196        output
197            .par_chunks_mut(self.out_features)
198            .enumerate()
199            .for_each(|(b, out_row)| {
200                let input_row = &input[b * self.in_features..(b + 1) * self.in_features];
201
202                for o in 0..self.out_features {
203                    let row_block_start = o * self.blocks_per_row;
204                    let mut sum = 0.0f32;
205
206                    for blk_idx in 0..self.blocks_per_row {
207                        let act_start = blk_idx * block_size;
208                        let act_end = (act_start + block_size).min(self.in_features);
209                        let act_slice = &input_row[act_start..act_end];
210
211                        let block = &self.weight.blocks[row_block_start + blk_idx];
212                        sum += dot_block(block, act_slice);
213                    }
214
215                    if let Some(ref bias) = self.bias {
216                        sum += bias[o];
217                    }
218
219                    out_row[o] = sum;
220                }
221            });
222
223        output
224    }
225
226    /// Extract flat f32 weights (for F16/F32 non-block types).
227    fn extract_flat_weights(&self) -> Vec<f32> {
228        let mut flat = Vec::with_capacity(self.in_features * self.out_features);
229        for block in &self.weight.blocks {
230            match block {
231                QuantizedBlock::F16(data) => {
232                    flat.extend(data.iter().map(|v| f32::from(*v)));
233                }
234                QuantizedBlock::F32(data) => {
235                    flat.extend_from_slice(data);
236                }
237                _ => {} // block-quantized handled separately
238            }
239        }
240        flat
241    }
242
243    /// Forward pass with Variable input/output (for integration with autograd).
244    ///
245    /// Note: Quantized inference is forward-only (no gradient tracking).
246    /// The output Variable has `requires_grad = false`.
247    pub fn forward_var(&self, input: &axonml_autograd::Variable) -> axonml_autograd::Variable {
248        let shape = input.shape();
249        let batch = if shape.len() > 1 { shape[0] } else { 1 };
250        let input_data = input.data().to_vec();
251
252        let output_data = self.forward_f32(&input_data, batch);
253
254        let output_tensor = Tensor::from_vec(output_data, &[batch, self.out_features])
255            .expect("Failed to create output tensor");
256
257        axonml_autograd::Variable::new(output_tensor, false)
258    }
259
260    /// Memory usage in bytes (weights only, excludes bias).
261    pub fn weight_bytes(&self) -> usize {
262        self.weight.size_bytes()
263    }
264
265    /// Compression ratio vs f32 weights.
266    pub fn compression_ratio(&self) -> f32 {
267        self.weight.compression_ratio()
268    }
269
270    /// Dequantize weights back to f32 (for debugging/validation).
271    pub fn dequantize_weights(&self) -> Tensor<f32> {
272        dequantize_tensor(&self.weight).expect("Failed to dequantize weights")
273    }
274}
275
276// =============================================================================
277// Model Quantization — convert a full model
278// =============================================================================
279
280/// Quantize all parameters of a model, returning a flat list of quantized tensors.
281///
282/// This extracts all parameters from a Module, quantizes each one, and returns
283/// them in order. Use with `QuantizedModel` for full inference.
284pub fn quantize_parameters(
285    params: &[axonml_nn::Parameter],
286    quant_type: QuantType,
287) -> Vec<QuantizedTensor> {
288    params
289        .par_iter()
290        .map(|param| {
291            let tensor = param.data();
292            quantize_tensor(&tensor, quant_type).expect("Failed to quantize parameter")
293        })
294        .collect()
295}
296
297// =============================================================================
298// QuantizedModel — generic quantized inference wrapper
299// =============================================================================
300
301/// A fully quantized model for inference.
302///
303/// Wraps a collection of quantized parameters and provides fast forward pass
304/// by dequantizing weights on-the-fly during computation.
305///
306/// # Usage
307/// ```ignore
308/// use axonml_quant::inference::QuantizedModel;
309/// use axonml_quant::QuantType;
310///
311/// let qmodel = QuantizedModel::from_module(&model, QuantType::Q8_0);
312/// println!("{}", qmodel.summary());
313/// qmodel.load_into_module(&model); // dequant weights back for inference
314/// ```
315pub struct QuantizedModel {
316    /// Quantized weight tensors (in parameter order).
317    pub quantized_params: Vec<QuantizedTensor>,
318    /// Quantization type used.
319    pub quant_type: QuantType,
320    /// Total original parameter count.
321    pub total_params: usize,
322    /// Total quantized size in bytes.
323    pub total_bytes: usize,
324    /// Original f32 size in bytes.
325    pub original_bytes: usize,
326}
327
328impl QuantizedModel {
329    /// Quantize a Module's parameters.
330    pub fn from_module<M: axonml_nn::Module>(module: &M, quant_type: QuantType) -> Self {
331        let params = module.parameters();
332        let total_params: usize = params.iter().map(|p| p.numel()).sum();
333        let original_bytes = total_params * 4;
334
335        let quantized_params = quantize_parameters(&params, quant_type);
336
337        let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
338
339        QuantizedModel {
340            quantized_params,
341            quant_type,
342            total_params,
343            total_bytes,
344            original_bytes,
345        }
346    }
347
348    /// Load quantized weights back into a Module for inference.
349    ///
350    /// Dequantizes all parameters and updates the module's parameters in-place.
351    /// This is the simplest integration path — the model runs at full f32 speed
352    /// but loads from a compressed checkpoint.
353    pub fn load_into_module<M: axonml_nn::Module>(&self, module: &M) {
354        let params = module.parameters();
355        for (param, qparam) in params.iter().zip(self.quantized_params.iter()) {
356            let tensor = dequantize_tensor(qparam).expect("Failed to dequantize parameter");
357            param.update_data(tensor);
358        }
359    }
360
361    /// Compression ratio.
362    pub fn compression_ratio(&self) -> f32 {
363        self.original_bytes as f32 / self.total_bytes as f32
364    }
365
366    /// Print a summary of the quantized model.
367    pub fn summary(&self) -> String {
368        format!(
369            "QuantizedModel(type={}, params={}, f32={:.1}MB, quant={:.1}MB, ratio={:.1}x)",
370            self.quant_type,
371            self.total_params,
372            self.original_bytes as f64 / 1024.0 / 1024.0,
373            self.total_bytes as f64 / 1024.0 / 1024.0,
374            self.compression_ratio(),
375        )
376    }
377}
378
379// =============================================================================
380// Serialization — save/load quantized models
381// =============================================================================
382
383/// Serialize a QuantizedModel to bytes (for .axonml files).
384pub fn serialize_quantized(model: &QuantizedModel) -> Vec<u8> {
385    let mut buf = Vec::new();
386
387    // Magic: "AXQT" (AxonML Quantized)
388    buf.extend_from_slice(b"AXQT");
389    // Version
390    buf.push(1u8);
391    // Quant type
392    buf.push(match model.quant_type {
393        QuantType::Q8_0 => 0,
394        QuantType::Q4_0 => 1,
395        QuantType::Q4_1 => 2,
396        QuantType::Q5_0 => 3,
397        QuantType::Q5_1 => 4,
398        QuantType::F16 => 5,
399        QuantType::F32 => 6,
400    });
401    // Number of tensors
402    buf.extend_from_slice(&(model.quantized_params.len() as u32).to_le_bytes());
403    // Total params
404    buf.extend_from_slice(&(model.total_params as u64).to_le_bytes());
405
406    // Each tensor: shape_len + shape + num_blocks + blocks
407    for qt in &model.quantized_params {
408        // Shape
409        buf.extend_from_slice(&(qt.shape.len() as u32).to_le_bytes());
410        for &dim in &qt.shape {
411            buf.extend_from_slice(&(dim as u32).to_le_bytes());
412        }
413        // Number of blocks
414        buf.extend_from_slice(&(qt.blocks.len() as u32).to_le_bytes());
415        // Block data
416        for block in &qt.blocks {
417            match block {
418                QuantizedBlock::Q8(b) => {
419                    buf.extend_from_slice(&b.to_bytes());
420                }
421                QuantizedBlock::Q4(b) => {
422                    buf.extend_from_slice(&b.to_bytes());
423                }
424                QuantizedBlock::Q4_1(b) => {
425                    buf.extend_from_slice(&b.to_bytes());
426                }
427                QuantizedBlock::F16(data) => {
428                    for &v in data {
429                        buf.extend_from_slice(&v.to_le_bytes());
430                    }
431                }
432                QuantizedBlock::F32(data) => {
433                    for &v in data {
434                        buf.extend_from_slice(&v.to_le_bytes());
435                    }
436                }
437            }
438        }
439    }
440
441    buf
442}
443
444/// Deserialize a QuantizedModel from bytes.
445pub fn deserialize_quantized(data: &[u8]) -> Result<QuantizedModel, String> {
446    if data.len() < 18 || &data[0..4] != b"AXQT" {
447        return Err("Invalid quantized model file (bad magic)".to_string());
448    }
449
450    let version = data[4];
451    if version != 1 {
452        return Err(format!("Unsupported quantized model version: {version}"));
453    }
454
455    let quant_type = match data[5] {
456        0 => QuantType::Q8_0,
457        1 => QuantType::Q4_0,
458        2 => QuantType::Q4_1,
459        3 => QuantType::Q5_0,
460        4 => QuantType::Q5_1,
461        5 => QuantType::F16,
462        6 => QuantType::F32,
463        x => return Err(format!("Unknown quant type byte: {x}")),
464    };
465
466    let num_tensors = u32::from_le_bytes([data[6], data[7], data[8], data[9]]) as usize;
467    let total_params = u64::from_le_bytes([
468        data[10], data[11], data[12], data[13], data[14], data[15], data[16], data[17],
469    ]) as usize;
470
471    let mut offset = 18usize;
472    let mut quantized_params = Vec::with_capacity(num_tensors);
473
474    let block_bytes = quant_type.bytes_per_block();
475
476    for _ in 0..num_tensors {
477        if offset + 4 > data.len() {
478            return Err("Truncated quantized model file".to_string());
479        }
480
481        // Shape
482        let shape_len = u32::from_le_bytes([
483            data[offset],
484            data[offset + 1],
485            data[offset + 2],
486            data[offset + 3],
487        ]) as usize;
488        offset += 4;
489
490        let mut shape = Vec::with_capacity(shape_len);
491        for _ in 0..shape_len {
492            let dim = u32::from_le_bytes([
493                data[offset],
494                data[offset + 1],
495                data[offset + 2],
496                data[offset + 3],
497            ]) as usize;
498            shape.push(dim);
499            offset += 4;
500        }
501
502        // Number of blocks
503        let num_blocks = u32::from_le_bytes([
504            data[offset],
505            data[offset + 1],
506            data[offset + 2],
507            data[offset + 3],
508        ]) as usize;
509        offset += 4;
510
511        // Read blocks
512        let mut blocks = Vec::with_capacity(num_blocks);
513        for _ in 0..num_blocks {
514            if offset + block_bytes > data.len() {
515                return Err("Truncated block data".to_string());
516            }
517
518            let block = match quant_type {
519                QuantType::Q8_0 => {
520                    let b =
521                        Q8Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q8 block")?;
522                    QuantizedBlock::Q8(b)
523                }
524                QuantType::Q4_0 => {
525                    let b =
526                        Q4Block::from_bytes(&data[offset..]).ok_or("Failed to parse Q4 block")?;
527                    QuantizedBlock::Q4(b)
528                }
529                QuantType::Q4_1 => {
530                    let scale = f16::from_le_bytes([data[offset], data[offset + 1]]);
531                    let min = f16::from_le_bytes([data[offset + 2], data[offset + 3]]);
532                    let mut block_data = [0u8; 16];
533                    block_data.copy_from_slice(&data[offset + 4..offset + 20]);
534                    QuantizedBlock::Q4_1(Q4_1Block::new(scale, min, block_data))
535                }
536                QuantType::F16 => {
537                    let v = f16::from_le_bytes([data[offset], data[offset + 1]]);
538                    QuantizedBlock::F16(vec![v])
539                }
540                QuantType::F32 => {
541                    let v = f32::from_le_bytes([
542                        data[offset],
543                        data[offset + 1],
544                        data[offset + 2],
545                        data[offset + 3],
546                    ]);
547                    QuantizedBlock::F32(vec![v])
548                }
549                _ => return Err("Unsupported quant type for deserialization".to_string()),
550            };
551
552            blocks.push(block);
553            offset += block_bytes;
554        }
555
556        quantized_params.push(QuantizedTensor::new(shape, quant_type, blocks));
557    }
558
559    let total_bytes: usize = quantized_params.iter().map(|q| q.size_bytes()).sum();
560    let original_bytes = total_params * 4;
561
562    Ok(QuantizedModel {
563        quantized_params,
564        quant_type,
565        total_params,
566        total_bytes,
567        original_bytes,
568    })
569}
570
571// =============================================================================
572// Tests
573// =============================================================================
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578
579    #[test]
580    fn test_quantized_linear_q8_forward() {
581        let in_f = 64;
582        let out_f = 16;
583        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 5.0).collect();
584        let bias: Vec<f32> = (0..out_f).map(|i| i as f32 * 0.1).collect();
585
586        let ql =
587            QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
588
589        let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
590        let output = ql.forward_f32(&input, 1);
591
592        assert_eq!(output.len(), out_f);
593        // Output should not be all zeros
594        let sum: f32 = output.iter().sum();
595        assert!(sum.abs() > 0.01, "Output should be non-zero, got sum={sum}");
596
597        // Compare with f32 reference
598        let ref_ql =
599            QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::F32);
600        let ref_output = ref_ql.forward_f32(&input, 1);
601
602        // Q8 should be very close to f32
603        let max_err: f32 = output
604            .iter()
605            .zip(ref_output.iter())
606            .map(|(a, b)| (a - b).abs())
607            .fold(0.0f32, f32::max);
608        assert!(max_err < 1.0, "Q8 error too large: {max_err}");
609    }
610
611    #[test]
612    fn test_quantized_linear_q4_forward() {
613        let in_f = 64;
614        let out_f = 8;
615        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.02) - 5.0).collect();
616
617        let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
618
619        let input: Vec<f32> = (0..in_f).map(|i| i as f32 * 0.1).collect();
620        let output = ql.forward_f32(&input, 1);
621
622        assert_eq!(output.len(), out_f);
623        let sum: f32 = output.iter().sum();
624        assert!(sum.abs() > 0.01, "Output should be non-zero");
625    }
626
627    #[test]
628    fn test_quantized_linear_batch() {
629        let in_f = 32;
630        let out_f = 8;
631        let batch = 4;
632        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
633
634        let ql = QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
635
636        let input: Vec<f32> = (0..batch * in_f).map(|i| i as f32 * 0.01).collect();
637        let output = ql.forward_f32(&input, batch);
638
639        assert_eq!(output.len(), batch * out_f);
640    }
641
642    #[test]
643    fn test_quantized_linear_compression() {
644        let in_f = 1024;
645        let out_f = 512;
646        let weight: Vec<f32> = vec![0.1; in_f * out_f];
647
648        let ql_q8 =
649            QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q8_0);
650        let ql_q4 =
651            QuantizedLinear::from_linear_params(&weight, None, in_f, out_f, QuantType::Q4_0);
652
653        assert!(ql_q8.compression_ratio() > 3.5, "Q8 should compress ~4x");
654        assert!(ql_q4.compression_ratio() > 6.0, "Q4 should compress ~7-8x");
655    }
656
657    #[test]
658    fn test_quantized_model_roundtrip() {
659        let in_f = 32;
660        let out_f = 8;
661        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
662        let weight_tensor = Tensor::from_vec(weight.clone(), &[out_f, in_f]).unwrap();
663        let qt = quantize_tensor(&weight_tensor, QuantType::Q8_0).unwrap();
664
665        let model = QuantizedModel {
666            quantized_params: vec![qt],
667            quant_type: QuantType::Q8_0,
668            total_params: in_f * out_f,
669            total_bytes: 0,
670            original_bytes: in_f * out_f * 4,
671        };
672
673        let serialized = serialize_quantized(&model);
674        let deserialized = deserialize_quantized(&serialized).unwrap();
675
676        assert_eq!(deserialized.quant_type, QuantType::Q8_0);
677        assert_eq!(deserialized.quantized_params.len(), 1);
678        assert_eq!(deserialized.quantized_params[0].shape, vec![out_f, in_f]);
679    }
680
681    #[test]
682    fn test_quantized_linear_variable_forward() {
683        let in_f = 32;
684        let out_f = 8;
685        let weight: Vec<f32> = (0..in_f * out_f).map(|i| (i as f32 * 0.01) - 1.0).collect();
686        let bias: Vec<f32> = vec![0.5; out_f];
687
688        let ql =
689            QuantizedLinear::from_linear_params(&weight, Some(&bias), in_f, out_f, QuantType::Q8_0);
690
691        let input_tensor =
692            Tensor::from_vec((0..2 * in_f).map(|i| i as f32 * 0.1).collect(), &[2, in_f]).unwrap();
693        let input_var = axonml_autograd::Variable::new(input_tensor, false);
694
695        let output = ql.forward_var(&input_var);
696
697        assert_eq!(output.shape(), vec![2, out_f]);
698    }
699}