cubecl_linalg/matmul/components/global/
quantization.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::matmul::components::{InputIdent, MatmulPrecision};
5
6/// Store the quantization meta-parameters.
7/// For now, we only support symmetric quantization,
8/// thus we only store the scaling.
9#[derive(CubeType, Clone, Copy)]
10pub struct Quantization<MP: MatmulPrecision> {
11    // I use MP instead of simply ES to be future proof.
12    pub scaling_lhs: MP::ES,
13    pub scaling_rhs: MP::ES,
14}
15
16#[cube]
17impl<MP: MatmulPrecision> Quantization<MP> {
18    pub fn dequantize(&self, line: Line<MP::EI>, #[comptime] ident: InputIdent) -> Line<MP::ES> {
19        match ident {
20            InputIdent::Lhs => Line::<MP::ES>::new(self.scaling_lhs) * Line::cast_from(line),
21            InputIdent::Rhs => Line::<MP::ES>::new(self.scaling_rhs) * Line::cast_from(line),
22        }
23    }
24}