burn_cubecl/kernel/quantization/
dequantize.rs

1use crate::CubeRuntime;
2use crate::ops::numeric::empty_device_dtype;
3use crate::tensor::CubeTensor;
4use burn_backend::DType;
5
6/// Convert the tensor back to a higher precision data type.
7pub fn dequantize<R>(tensor: CubeTensor<R>, dtype: DType) -> CubeTensor<R>
8where
9    R: CubeRuntime,
10{
11    let scheme = match tensor.dtype {
12        DType::QFloat(scheme) => scheme,
13        _ => return tensor,
14    };
15
16    let output = empty_device_dtype(
17        tensor.client.clone(),
18        tensor.device.clone(),
19        tensor.shape.clone(),
20        dtype,
21    );
22    let (values, params) = tensor.quantized_handles().unwrap();
23
24    cubek::quantization::dequantize::launch_ref(
25        &values.client,
26        &values.as_handle_ref(),
27        &output.as_handle_ref(),
28        &params.as_handle_ref(),
29        &scheme,
30        dtype.into(),
31    )
32    .expect("Kernel to never fail");
33
34    output
35}