burn_cubecl/kernel/quantization/
dequantize.rs1use crate::CubeRuntime;
2use crate::ops::numeric::empty_device_dtype;
3use crate::tensor::CubeTensor;
4use burn_backend::DType;
5
6pub fn dequantize<R>(tensor: CubeTensor<R>, dtype: DType) -> CubeTensor<R>
8where
9 R: CubeRuntime,
10{
11 let scheme = match tensor.dtype {
12 DType::QFloat(scheme) => scheme,
13 _ => return tensor,
14 };
15
16 let output = empty_device_dtype(
17 tensor.client.clone(),
18 tensor.device.clone(),
19 tensor.shape.clone(),
20 dtype,
21 );
22 let (values, params) = tensor.quantized_handles().unwrap();
23
24 cubek::quantization::dequantize::launch_ref(
25 &values.client,
26 &values.as_handle_ref(),
27 &output.as_handle_ref(),
28 ¶ms.as_handle_ref(),
29 &scheme,
30 dtype.into(),
31 )
32 .expect("Kernel to never fail");
33
34 output
35}