burn_cubecl/kernel/quantization/
quantize.rs1use crate::CubeRuntime;
2use crate::{ops::empty_qtensor_optimized, tensor::CubeTensor};
3use burn_backend::{TensorMetadata, quantization::QuantScheme};
4
5pub fn quantize<R>(
7 tensor: CubeTensor<R>,
8 scheme: &QuantScheme,
9 scale: CubeTensor<R>,
10) -> CubeTensor<R>
11where
12 R: CubeRuntime,
13{
14 let output = empty_qtensor_optimized(tensor.shape(), *scheme, &tensor.device);
15 let (out_values, out_params) = output.clone().quantized_handles().unwrap();
16 let dtype = tensor.dtype;
17
18 cubek::quantization::quantize::launch_ref(
19 &output.client,
20 tensor.binding(),
21 out_values.binding(),
22 scale.binding(),
23 out_params.binding(),
24 scheme,
25 dtype.into(),
26 )
27 .expect("Kernel to never fail");
28
29 output
30}