cubecl_matmul/components/global/
quantization.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::components::{InputIdent, MatmulPrecision};
5
6/// Store the quantization meta-parameters.
7/// For now, we only support symmetric quantization,
8/// thus we only store the scaling.
9#[derive(CubeType, Clone, Copy)]
10pub struct Quantization<MP: MatmulPrecision> {
11    pub scaling_lhs: MP::ES,
12    pub scaling_rhs: MP::ES,
13}
14
15#[cube]
16impl<MP: MatmulPrecision> Quantization<MP> {
17    pub fn dequantize(&self, line: Line<MP::EI>, #[comptime] ident: InputIdent) -> Line<MP::ES> {
18        match ident {
19            InputIdent::Lhs => Line::<MP::ES>::new(self.scaling_lhs) * Line::cast_from(line),
20            InputIdent::Rhs => Line::<MP::ES>::new(self.scaling_rhs) * Line::cast_from(line),
21        }
22    }
23}