use crate::CubeRuntime;
use crate::{ops::empty_qtensor_optimized, tensor::CubeTensor};
use burn_backend::{TensorMetadata, quantization::QuantScheme};
pub fn quantize<R>(
tensor: CubeTensor<R>,
scheme: &QuantScheme,
scale: CubeTensor<R>,
) -> CubeTensor<R>
where
R: CubeRuntime,
{
let output = empty_qtensor_optimized(tensor.shape(), *scheme, &tensor.device);
let (out_values, out_params) = output.clone().quantized_handles().unwrap();
cubek::quantization::quantize::launch_ref(
&tensor.client,
&tensor.as_handle_ref(),
&out_values.as_handle_ref(),
&scale.as_handle_ref(),
&out_params.as_handle_ref(),
scheme,
tensor.dtype.into(),
)
.expect("Kernel to never fail");
output
}