burn_cubecl/kernel/quantization/
quantize.rs

1use crate::CubeRuntime;
2use crate::{ops::empty_qtensor_optimized, tensor::CubeTensor};
3use burn_backend::quantization::QuantScheme;
4
5/// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
6pub fn quantize<R>(
7    tensor: CubeTensor<R>,
8    scheme: &QuantScheme,
9    scale: CubeTensor<R>,
10) -> CubeTensor<R>
11where
12    R: CubeRuntime,
13{
14    let output = empty_qtensor_optimized(tensor.shape.clone(), *scheme, &tensor.device);
15    let (out_values, out_params) = output.clone().quantized_handles().unwrap();
16
17    cubek::quantization::quantize::launch_ref(
18        &tensor.client,
19        &tensor.as_handle_ref(),
20        &out_values.as_handle_ref(),
21        &scale.as_handle_ref(),
22        &out_params.as_handle_ref(),
23        scheme,
24        tensor.dtype.into(),
25    )
26    .expect("Kernel to never fail");
27
28    output
29}