Skip to main content

burn_cubecl/kernel/quantization/
dequantize.rs

1use crate::tensor::CubeTensor;
2use crate::{CubeRuntime, ops::numeric::empty_device_dtype};
3use burn_backend::{DType, TensorMetadata};
4
5/// Convert the tensor back to a higher precision data type.
6pub fn dequantize<R>(tensor: CubeTensor<R>, dtype: DType) -> CubeTensor<R>
7where
8    R: CubeRuntime,
9{
10    let scheme = match tensor.dtype {
11        DType::QFloat(scheme) => scheme,
12        _ => return tensor,
13    };
14
15    let output = empty_device_dtype(
16        tensor.client.clone(),
17        tensor.device.clone(),
18        tensor.shape(),
19        dtype,
20    );
21    let (values, params) = tensor.quantized_handles().unwrap();
22
23    cubek::quantization::dequantize::launch_ref(
24        &output.client,
25        values.binding(),
26        output.clone().binding(),
27        params.binding(),
28        &scheme,
29        dtype.into(),
30    )
31    .expect("Kernel to never fail");
32
33    output
34}