cubecl_linalg/matmul/components/global/
quantization.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::matmul::components::{InputIdent, MatmulPrecision};
5
6#[derive(CubeType, Clone, Copy)]
10pub struct Quantization<MP: MatmulPrecision> {
11 pub scaling_lhs: MP::ES,
13 pub scaling_rhs: MP::ES,
14}
15
16#[cube]
17impl<MP: MatmulPrecision> Quantization<MP> {
18 pub fn dequantize(&self, line: Line<MP::EI>, #[comptime] ident: InputIdent) -> Line<MP::ES> {
19 match ident {
20 InputIdent::Lhs => Line::<MP::ES>::new(self.scaling_lhs) * Line::cast_from(line),
21 InputIdent::Rhs => Line::<MP::ES>::new(self.scaling_rhs) * Line::cast_from(line),
22 }
23 }
24}