cubecl_matmul/components/global/
quantization.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::components::{InputIdent, MatmulPrecision};
5
6#[derive(CubeType, Clone, Copy)]
10pub struct Quantization<MP: MatmulPrecision> {
11 pub scaling_lhs: MP::ES,
12 pub scaling_rhs: MP::ES,
13}
14
15#[cube]
16impl<MP: MatmulPrecision> Quantization<MP> {
17 pub fn dequantize(&self, line: Line<MP::EI>, #[comptime] ident: InputIdent) -> Line<MP::ES> {
18 match ident {
19 InputIdent::Lhs => Line::<MP::ES>::new(self.scaling_lhs) * Line::cast_from(line),
20 InputIdent::Rhs => Line::<MP::ES>::new(self.scaling_rhs) * Line::cast_from(line),
21 }
22 }
23}