Skip to main content

burn_cubecl/kernel/quantization/
quantize.rs

1use crate::CubeRuntime;
2use crate::{ops::empty_qtensor_optimized, tensor::CubeTensor};
3use burn_backend::{TensorMetadata, quantization::QuantScheme};
4
5/// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
6pub fn quantize<R>(
7    tensor: CubeTensor<R>,
8    scheme: &QuantScheme,
9    scale: CubeTensor<R>,
10) -> CubeTensor<R>
11where
12    R: CubeRuntime,
13{
14    let output = empty_qtensor_optimized(tensor.shape(), *scheme, &tensor.device);
15    let (out_values, out_params) = output.clone().quantized_handles().unwrap();
16    let dtype = tensor.dtype;
17
18    cubek::quantization::quantize::launch_ref(
19        &output.client,
20        tensor.binding(),
21        out_values.binding(),
22        scale.binding(),
23        out_params.binding(),
24        scheme,
25        dtype.into(),
26    )
27    .expect("Kernel to never fail");
28
29    output
30}