use crate::tensor::CubeTensor;
use crate::{CubeRuntime, ops::numeric::empty_device_dtype};
use burn_backend::{DType, TensorMetadata};
pub fn dequantize<R>(tensor: CubeTensor<R>, dtype: DType) -> CubeTensor<R>
where
R: CubeRuntime,
{
let scheme = match tensor.dtype {
DType::QFloat(scheme) => scheme,
_ => return tensor,
};
let output = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape(),
dtype,
);
let (values, params) = tensor.quantized_handles().unwrap();
cubek::quantization::dequantize::launch_ref(
&values.client,
&values.as_handle_ref(),
&output.as_handle_ref(),
¶ms.as_handle_ref(),
&scheme,
dtype.into(),
)
.expect("Kernel to never fail");
output
}