burn_cubecl/kernel/quantization/
quantize.rs1use crate::CubeRuntime;
2use crate::{ops::empty_qtensor_optimized, tensor::CubeTensor};
3use burn_backend::quantization::QuantScheme;
4
5pub fn quantize<R>(
7 tensor: CubeTensor<R>,
8 scheme: &QuantScheme,
9 scale: CubeTensor<R>,
10) -> CubeTensor<R>
11where
12 R: CubeRuntime,
13{
14 let output = empty_qtensor_optimized(tensor.shape.clone(), *scheme, &tensor.device);
15 let (out_values, out_params) = output.clone().quantized_handles().unwrap();
16
17 cubek::quantization::quantize::launch_ref(
18 &tensor.client,
19 &tensor.as_handle_ref(),
20 &out_values.as_handle_ref(),
21 &scale.as_handle_ref(),
22 &out_params.as_handle_ref(),
23 scheme,
24 tensor.dtype.into(),
25 )
26 .expect("Kernel to never fail");
27
28 output
29}