1use cubecl::calculate_cube_count_elemwise;
2use cubecl::prelude::*;
3use cubecl_core::ir::ElemType;
4use cubecl_core::tensor_line_size_parallel;
5use cubecl_core::{self as cubecl};
6use cubecl_runtime::TypeUsage;
7use cubecl_std::scalar::InputScalar;
8use cubecl_std::tensor::layout::linear::LinearView;
9use cubecl_std::tensor::{View, layout::linear::linear_view};
10
11use crate::{
12 layout::{ScalesLayout, scales_view},
13 utils::check_block_size_compat,
14};
15use crate::{
16 layout::{ScalesView, scales_layout},
17 scheme::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue},
18};
19
20#[cube]
21fn quantize_symmetric<F: Float, FS: CubePrimitive>(
22 value: Line<F>,
23 scale: FS,
24 range_min: F,
25 range_max: F,
26) -> Line<F> {
27 Line::clamp(
28 Line::round(value / Line::cast_from(scale)),
29 Line::new(range_min),
30 Line::new(range_max),
31 )
32}
33
34#[cube]
35fn quantize_symmetric_q<F: Float, FS: CubePrimitive, Q: CubePrimitive>(
36 value: Line<F>,
37 scale: FS,
38 range_min: F,
39 range_max: F,
40) -> Line<Q> {
41 Line::cast_from(quantize_symmetric::<F, FS>(
42 value, scale, range_min, range_max,
43 ))
44}
45
46#[cube]
47fn quantize_packed_value<F: Float, FS: CubePrimitive, QS: Int>(
48 value: Line<F>,
49 scale: FS,
50 range_min: F,
51 range_max: F,
52 #[comptime] scheme: QuantScheme,
53) -> QS {
54 let value = quantize_symmetric::<F, FS>(value, scale, range_min, range_max);
55 pack_q::<F, QS>(value, scheme.value)
56}
57
58#[allow(clippy::explicit_counter_loop)]
61#[cube]
62fn pack_q<F: Float, QS: Int>(value: Line<F>, #[comptime] quant: QuantValue) -> QS {
63 let size_quant = comptime!(quant.size_bits() as u32);
64
65 let size_store = comptime!(QS::size_bits().unwrap() as u32);
66 let num_quants = comptime!(size_store / size_quant);
67
68 let mask = i32::cast_from(comptime!((1 << size_quant) - 1));
69 let mut position = comptime!(0);
70 let mut packed = QS::cast_from(0);
71
72 #[unroll]
74 for _ in 0..num_quants {
75 let offset = QS::cast_from(comptime!(position * size_quant));
76 let shifted = QS::cast_from(i32::cast_from(value[position]) & mask) << offset;
77 packed |= shifted;
78 comptime!(position += 1);
79 }
80
81 packed
82}
83
84#[cube]
85fn write_scale<F: Float, FS: CubePrimitive>(
86 in_pos: u32,
87 scale: &View<F, u32>,
88 out_scale: &mut View<FS, u32, ReadWrite>,
89 scales_layout: ScalesLayout,
90) -> FS {
91 let scale = FS::cast_from(scale[in_pos]);
92
93 if scales_layout.is_block_start(in_pos) {
95 out_scale[in_pos] = scale;
96 }
97
98 scale
99}
100
101#[cube(launch_unchecked)]
102fn quantize_symmetric_native_kernel<F: Float, FS: Numeric, Q: Numeric>(
103 input: &LinearView<Line<F>>,
104 scale: &ScalesView<F>,
105 range_min: InputScalar,
106 range_max: InputScalar,
107 output: &mut LinearView<Line<Q>, ReadWrite>,
108 out_scale: &mut ScalesView<FS, ReadWrite>,
109 scales_layout: ScalesLayout,
110 #[define(F, FS, Q)] _dtypes: [StorageType; 3],
111) {
112 if !output.is_in_bounds(ABSOLUTE_POS) {
113 terminate!();
114 }
115
116 let native_packing = Q::packing_factor();
117 let in_pos = ABSOLUTE_POS * input.line_size() * native_packing;
118 let scale = write_scale(in_pos, scale, out_scale, scales_layout);
119
120 output[ABSOLUTE_POS] = quantize_symmetric_q::<F, FS, Q>(
121 input[ABSOLUTE_POS],
122 scale,
123 range_min.get::<F>(),
124 range_max.get::<F>(),
125 );
126 sync_cube();
127}
128
129#[cube(launch_unchecked)]
130fn quantize_symmetric_packed_kernel<F: Float, FS: Numeric>(
131 input: &LinearView<Line<F>>,
132 scale: &ScalesView<F>,
133 range_min: InputScalar,
134 range_max: InputScalar,
135 output: &mut LinearView<Line<u32>, ReadWrite>,
136 out_scale: &mut ScalesView<FS, ReadWrite>,
137 scales_layout: ScalesLayout,
138 #[comptime] scheme: QuantScheme,
139 #[define(F, FS)] _dtypes: [StorageType; 2],
140) {
141 if !output.is_in_bounds(ABSOLUTE_POS) {
142 terminate!();
143 }
144
145 let num_quants = comptime!(scheme.num_quants() as u32);
146 let packed_pos = ABSOLUTE_POS * num_quants;
147 let scale = write_scale(packed_pos, scale, out_scale, scales_layout);
148
149 if comptime!(input.line_size() == num_quants) {
150 output[ABSOLUTE_POS] = Line::cast_from(quantize_packed_value::<F, FS, u32>(
151 input[ABSOLUTE_POS],
152 scale,
153 range_min.get::<F>(),
154 range_max.get::<F>(),
155 scheme,
156 ));
157 } else {
158 let mut values = Line::<F>::empty(num_quants);
160 #[unroll]
161 for i in 0..num_quants {
162 values[i] = input[packed_pos + i][0];
163 }
164 output[ABSOLUTE_POS] = Line::cast_from(quantize_packed_value::<F, FS, u32>(
165 values,
166 scale,
167 range_min.get::<F>(),
168 range_max.get::<F>(),
169 scheme,
170 ));
171 }
172}
173
174#[allow(clippy::result_large_err)]
175pub fn launch_ref<R: Runtime>(
176 client: &ComputeClient<R>,
177 input: &TensorHandleRef<R>,
178 output: &TensorHandleRef<R>,
179 scale: &TensorHandleRef<'_, R>,
180 out_scale: &TensorHandleRef<'_, R>,
181 scheme: &QuantScheme,
182 input_elem: ElemType,
183) -> Result<(), LaunchError> {
184 let param_elem = ElemType::from_quant_param(scheme.param);
185
186 match scheme {
187 QuantScheme {
188 store: QuantStore::U32,
189 ..
190 } => quantize_packed(
191 client, input, scheme, scale, out_scale, output, input_elem, param_elem,
192 ),
193 QuantScheme {
194 value:
195 QuantValue::Q8F
196 | QuantValue::Q8S
197 | QuantValue::E4M3
198 | QuantValue::E5M2
199 | QuantValue::E2M1,
200 store: QuantStore::Native,
201 ..
202 } => {
203 if !i8::supported_uses(client).contains(TypeUsage::Conversion) {
204 panic!(
205 "{:?} is not supported for native quantization",
206 scheme.value
207 );
208 }
209
210 quantize_native(
211 client, input, scheme, scale, out_scale, output, input_elem, param_elem,
212 )
213 }
214 QuantScheme {
215 store: QuantStore::Native,
216 value,
217 ..
218 } => {
219 panic!("{value:?} is not supported for native quantization");
220 }
221 }
222}
223
224#[allow(clippy::too_many_arguments)]
225fn quantize_native<R: Runtime>(
226 client: &ComputeClient<R>,
227 input: &TensorHandleRef<R>,
228 scheme: &QuantScheme,
229 scale: &TensorHandleRef<'_, R>,
230 out_scale: &TensorHandleRef<'_, R>,
231 output: &TensorHandleRef<R>,
232 input_dtype: ElemType,
233 scale_dtype: ElemType,
234) -> Result<(), LaunchError> {
235 let num_elems: usize = input.shape.iter().product();
236 let line_size = tensor_line_size_parallel(
237 client.io_optimized_line_sizes_unchecked(input.elem_size),
238 input.shape,
239 input.strides,
240 input.shape.len() - 1,
241 );
242 let cube_dim = CubeDim::default();
243 let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);
244 let (range_min, range_max) = scheme.value.range();
245
246 match scheme {
247 QuantScheme {
248 level: QuantLevel::Tensor | QuantLevel::Block(_),
249 mode: QuantMode::Symmetric,
250 store: QuantStore::Native,
251 ..
252 } => {
253 check_block_size_compat(scheme, line_size as usize);
255 let quant_type = ElemType::from_quant_value(scheme.value);
256
257 unsafe {
258 quantize_symmetric_native_kernel::launch_unchecked(
259 client,
260 cube_count,
261 cube_dim,
262 linear_view(client, input, line_size),
263 scales_view(client, output, scale, 1, scheme),
265 InputScalar::new(range_min, input_dtype),
266 InputScalar::new(range_max, input_dtype),
267 linear_view(client, output, line_size),
268 scales_view(client, output, out_scale, 1, scheme),
269 scales_layout(client, output, scale, 1, scheme),
270 [input_dtype.into(), scale_dtype.into(), quant_type.into()],
271 )
272 }
273 }
274 _ => panic!("Unsupported quantization scheme {scheme:?}"),
275 }
276}
277
278#[allow(clippy::too_many_arguments)]
279fn quantize_packed<R: Runtime>(
280 client: &ComputeClient<R>,
281 input: &TensorHandleRef<R>,
282 scheme: &QuantScheme,
283 scale: &TensorHandleRef<'_, R>,
284 out_scale: &TensorHandleRef<'_, R>,
285 output: &TensorHandleRef<R>,
286 dtype_input: ElemType,
287 dtype_param: ElemType,
288) -> Result<(), LaunchError> {
289 let num_elems: usize = input.shape.iter().product();
290
291 let num_quants = scheme.num_quants() as u8;
292 let line_size = num_quants;
293
294 let cube_dim = CubeDim::default();
295 let cube_count =
296 calculate_cube_count_elemwise(num_elems.div_ceil(line_size as usize), cube_dim);
297 let (range_min, range_max) = scheme.value.range();
298
299 match scheme {
300 QuantScheme {
301 level: QuantLevel::Tensor | QuantLevel::Block(_),
302 mode: QuantMode::Symmetric,
303 store: QuantStore::U32,
304 ..
305 } => {
306 check_block_size_compat(scheme, num_quants as usize); unsafe {
308 quantize_symmetric_packed_kernel::launch_unchecked(
309 client,
310 cube_count,
311 cube_dim,
312 linear_view(client, input, line_size),
313 scales_view(client, output, scale, 1, scheme),
315 InputScalar::new(range_min, dtype_input),
316 InputScalar::new(range_max, dtype_input),
317 linear_view(client, output, 1),
318 scales_view(client, output, out_scale, 1, scheme),
319 scales_layout(client, output, scale, 1, scheme),
320 *scheme,
321 [dtype_input.into(), dtype_param.into()],
322 )
323 }
324 }
325 QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
326 }
327}