burn_cubecl/kernel/quantization/
dequantize.rs1use crate::tensor::CubeTensor;
2use crate::{CubeRuntime, ops::numeric::empty_device_dtype};
3use burn_backend::{DType, TensorMetadata};
4
5pub fn dequantize<R>(tensor: CubeTensor<R>, dtype: DType) -> CubeTensor<R>
7where
8 R: CubeRuntime,
9{
10 let scheme = match tensor.dtype {
11 DType::QFloat(scheme) => scheme,
12 _ => return tensor,
13 };
14
15 let output = empty_device_dtype(
16 tensor.client.clone(),
17 tensor.device.clone(),
18 tensor.shape(),
19 dtype,
20 );
21 let (values, params) = tensor.quantized_handles().unwrap();
22
23 cubek::quantization::dequantize::launch_ref(
24 &output.client,
25 values.binding(),
26 output.clone().binding(),
27 params.binding(),
28 &scheme,
29 dtype.into(),
30 )
31 .expect("Kernel to never fail");
32
33 output
34}