1use cubecl::calculate_cube_count_elemwise;
2use cubecl::features::TypeUsage;
3use cubecl::ir::ElemType;
4use cubecl::prelude::*;
5use cubecl::std::tensor::layout::linear::LinearView;
6use cubecl::std::tensor::{View, layout::linear::linear_view};
7use cubecl::tensor_line_size_parallel;
8
9use crate::{
10 layout::{ScalesLayout, scales_view},
11 utils::check_block_size_compat,
12};
13use crate::{
14 layout::{ScalesView, scales_layout},
15 scheme::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue},
16};
17
18#[cube]
19fn quantize_symmetric<F: Float, FS: CubePrimitive>(
20 value: Line<F>,
21 scale: FS,
22 range_min: F,
23 range_max: F,
24) -> Line<F> {
25 clamp(
26 Line::round(value / Line::cast_from(scale)),
27 Line::new(range_min),
28 Line::new(range_max),
29 )
30}
31
32#[cube]
33fn quantize_symmetric_q<F: Float, FS: CubePrimitive, Q: CubePrimitive>(
34 value: Line<F>,
35 scale: FS,
36 range_min: F,
37 range_max: F,
38) -> Line<Q> {
39 Line::cast_from(quantize_symmetric::<F, FS>(
40 value, scale, range_min, range_max,
41 ))
42}
43
44#[cube]
45fn quantize_packed_value<F: Float, FS: CubePrimitive, QS: Int>(
46 value: Line<F>,
47 scale: FS,
48 range_min: F,
49 range_max: F,
50 #[comptime] scheme: QuantScheme,
51) -> QS {
52 let value = quantize_symmetric::<F, FS>(value, scale, range_min, range_max);
53 pack_q::<F, QS>(value, scheme.value)
54}
55
56#[allow(clippy::explicit_counter_loop)]
59#[cube]
60fn pack_q<F: Float, QS: Int>(value: Line<F>, #[comptime] quant: QuantValue) -> QS {
61 let size_quant = quant.size_bits();
62
63 let size_store = QS::type_size_bits().comptime();
64 let num_quants = size_store / size_quant;
65
66 let mask = (1 << size_quant) - 1;
67 let mut packed = QS::from_int(0);
68
69 #[unroll]
71 for position in 0..num_quants {
72 let offset = QS::cast_from(position * size_quant);
73 let shifted = QS::cast_from(i32::cast_from(value[position]) & mask) << offset;
74 packed |= shifted;
75 }
76
77 packed
78}
79
80#[cube]
81fn write_scale<F: Float, FS: CubePrimitive>(
82 in_pos: usize,
83 scale: &View<F, usize>,
84 out_scale: &mut View<FS, usize, ReadWrite>,
85 scales_layout: ScalesLayout,
86) -> FS {
87 let scale = FS::cast_from(scale[in_pos]);
88
89 if scales_layout.is_block_start(in_pos) {
91 out_scale[in_pos] = scale;
92 }
93
94 scale
95}
96
97#[cube(launch_unchecked)]
98fn quantize_symmetric_native_kernel<F: Float, FS: Numeric, Q: Numeric>(
99 input: &LinearView<Line<F>>,
100 scale: &ScalesView<F>,
101 range_min: InputScalar,
102 range_max: InputScalar,
103 output: &mut LinearView<Line<Q>, ReadWrite>,
104 out_scale: &mut ScalesView<FS, ReadWrite>,
105 scales_layout: ScalesLayout,
106 #[define(F, FS, Q)] _dtypes: [StorageType; 3],
107) {
108 if !output.is_in_bounds(ABSOLUTE_POS) {
109 terminate!();
110 }
111
112 let native_packing = Q::packing_factor();
113 let in_pos = ABSOLUTE_POS * input.line_size() * native_packing;
114 let scale = write_scale(in_pos, scale, out_scale, scales_layout);
115
116 output[ABSOLUTE_POS] = quantize_symmetric_q::<F, FS, Q>(
117 input[ABSOLUTE_POS],
118 scale,
119 range_min.get::<F>(),
120 range_max.get::<F>(),
121 );
122 sync_cube();
123}
124
125#[cube(launch_unchecked)]
126fn quantize_symmetric_packed_kernel<F: Float, FS: Numeric>(
127 input: &LinearView<Line<F>>,
128 scale: &ScalesView<F>,
129 range_min: InputScalar,
130 range_max: InputScalar,
131 output: &mut LinearView<Line<u32>, ReadWrite>,
132 out_scale: &mut ScalesView<FS, ReadWrite>,
133 scales_layout: ScalesLayout,
134 #[comptime] scheme: QuantScheme,
135 #[define(F, FS)] _dtypes: [StorageType; 2],
136) {
137 if !output.is_in_bounds(ABSOLUTE_POS) {
138 terminate!();
139 }
140
141 let num_quants = scheme.num_quants();
142 let packed_pos = ABSOLUTE_POS * num_quants;
143 let scale = write_scale(packed_pos, scale, out_scale, scales_layout);
144
145 if input.line_size().comptime() == num_quants {
146 output[ABSOLUTE_POS] = Line::cast_from(quantize_packed_value::<F, FS, u32>(
147 input[ABSOLUTE_POS],
148 scale,
149 range_min.get::<F>(),
150 range_max.get::<F>(),
151 scheme,
152 ));
153 } else {
154 let mut values = Line::<F>::empty(num_quants);
156 #[unroll]
157 for i in 0..num_quants {
158 values[i] = input[packed_pos + i][0];
159 }
160 output[ABSOLUTE_POS] = Line::cast_from(quantize_packed_value::<F, FS, u32>(
161 values,
162 scale,
163 range_min.get::<F>(),
164 range_max.get::<F>(),
165 scheme,
166 ));
167 }
168}
169
170#[allow(clippy::result_large_err)]
171pub fn launch_ref<R: Runtime>(
172 client: &ComputeClient<R>,
173 input: &TensorHandleRef<R>,
174 output: &TensorHandleRef<R>,
175 scale: &TensorHandleRef<'_, R>,
176 out_scale: &TensorHandleRef<'_, R>,
177 scheme: &QuantScheme,
178 input_elem: ElemType,
179) -> Result<(), LaunchError> {
180 let param_elem = ElemType::from_quant_param(scheme.param);
181
182 match scheme {
183 QuantScheme {
184 store: QuantStore::PackedU32(_),
185 ..
186 } => quantize_packed(
187 client, input, scheme, scale, out_scale, output, input_elem, param_elem,
188 ),
189 QuantScheme {
190 value: QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2,
191 store: QuantStore::Native,
192 ..
193 }
194 | QuantScheme {
195 value: QuantValue::E2M1,
196 store: QuantStore::PackedNative(_),
197 ..
198 } => {
199 if !i8::supported_uses(client).contains(TypeUsage::Conversion) {
200 panic!(
201 "{:?} is not supported for native quantization",
202 scheme.value
203 );
204 }
205
206 quantize_native(
207 client, input, scheme, scale, out_scale, output, input_elem, param_elem,
208 )
209 }
210 QuantScheme {
211 store: QuantStore::Native | QuantStore::PackedNative(_),
212 value,
213 ..
214 } => {
215 panic!("{value:?} is not supported for native quantization");
216 }
217 }
218}
219
220#[allow(clippy::too_many_arguments)]
221fn quantize_native<R: Runtime>(
222 client: &ComputeClient<R>,
223 input: &TensorHandleRef<R>,
224 scheme: &QuantScheme,
225 scale: &TensorHandleRef<'_, R>,
226 out_scale: &TensorHandleRef<'_, R>,
227 output: &TensorHandleRef<R>,
228 input_dtype: ElemType,
229 scale_dtype: ElemType,
230) -> Result<(), LaunchError> {
231 let num_elems: usize = input.shape.iter().product();
232 let line_size = tensor_line_size_parallel(
233 client.io_optimized_line_sizes_unchecked(input.elem_size),
234 input.shape,
235 input.strides,
236 input.shape.len() - 1,
237 );
238 let working_units = num_elems / line_size as usize;
239 let cube_dim = CubeDim::new(client, working_units);
240 let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);
241 let (range_min, range_max) = scheme.value.range();
242
243 match scheme {
244 QuantScheme {
245 level: QuantLevel::Tensor | QuantLevel::Block(_),
246 mode: QuantMode::Symmetric,
247 store: QuantStore::Native,
248 ..
249 } => {
250 check_block_size_compat(scheme, line_size as usize);
252 let quant_type = ElemType::from_quant_value(scheme.value);
253
254 unsafe {
255 quantize_symmetric_native_kernel::launch_unchecked(
256 client,
257 cube_count,
258 cube_dim,
259 linear_view(client, input, line_size),
260 scales_view(client, output, scale, 1, scheme),
262 InputScalar::new(range_min, input_dtype),
263 InputScalar::new(range_max, input_dtype),
264 linear_view(client, output, line_size),
265 scales_view(client, output, out_scale, 1, scheme),
266 scales_layout(client, output, scale, 1, scheme),
267 [input_dtype.into(), scale_dtype.into(), quant_type.into()],
268 )
269 }
270 }
271 _ => panic!("Unsupported quantization scheme {scheme:?}"),
272 }
273}
274
275#[allow(clippy::too_many_arguments)]
276fn quantize_packed<R: Runtime>(
277 client: &ComputeClient<R>,
278 input: &TensorHandleRef<R>,
279 scheme: &QuantScheme,
280 scale: &TensorHandleRef<'_, R>,
281 out_scale: &TensorHandleRef<'_, R>,
282 output: &TensorHandleRef<R>,
283 dtype_input: ElemType,
284 dtype_param: ElemType,
285) -> Result<(), LaunchError> {
286 let num_elems: usize = input.shape.iter().product();
287
288 let num_quants = scheme.num_quants();
289 let line_size = num_quants;
290
291 let working_units = num_elems.div_ceil(line_size);
292 let cube_dim = CubeDim::new(client, working_units);
293 let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);
294 let (range_min, range_max) = scheme.value.range();
295
296 match scheme {
297 QuantScheme {
298 level: QuantLevel::Tensor | QuantLevel::Block(_),
299 mode: QuantMode::Symmetric,
300 store: QuantStore::PackedU32(_),
301 ..
302 } => {
303 check_block_size_compat(scheme, num_quants); unsafe {
305 quantize_symmetric_packed_kernel::launch_unchecked(
306 client,
307 cube_count,
308 cube_dim,
309 linear_view(client, input, line_size),
310 scales_view(client, output, scale, 1, scheme),
312 InputScalar::new(range_min, dtype_input),
313 InputScalar::new(range_max, dtype_input),
314 linear_view(client, output, 1),
315 scales_view(client, output, out_scale, 1, scheme),
316 scales_layout(client, output, scale, 1, scheme),
317 *scheme,
318 [dtype_input.into(), dtype_param.into()],
319 )
320 }
321 }
322 QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
323 }
324}