1use cubecl::calculate_cube_count_elemwise;
2use cubecl::prelude::*;
3use cubecl_common::{e2m1x2, e4m3, e5m2, ue8m0};
4use cubecl_core::tensor_line_size_parallel;
5use cubecl_core::{self as cubecl};
6use cubecl_runtime::TypeUsage;
7use cubecl_std::tensor::layout::linear::LinearView;
8use cubecl_std::tensor::{View, layout::linear::linear_view};
9
10use crate::{
11 layout::{ScalesLayout, scales_view},
12 utils::check_block_size_compat,
13};
14use crate::{
15 layout::{ScalesView, scales_layout},
16 scheme::{QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue},
17};
18use half::{bf16, f16};
19
20#[cube]
21fn quantize_symmetric<F: Float, FS: CubePrimitive>(
22 value: Line<F>,
23 scale: FS,
24 range_min: F,
25 range_max: F,
26) -> Line<F> {
27 Line::clamp(
28 Line::round(value / Line::cast_from(scale)),
29 Line::new(range_min),
30 Line::new(range_max),
31 )
32}
33
34#[cube]
35fn quantize_symmetric_q<F: Float, FS: CubePrimitive, Q: CubePrimitive>(
36 value: Line<F>,
37 scale: FS,
38 range_min: F,
39 range_max: F,
40) -> Line<Q> {
41 Line::cast_from(quantize_symmetric::<F, FS>(
42 value, scale, range_min, range_max,
43 ))
44}
45
46#[cube]
47fn quantize_packed_value<F: Float, FS: CubePrimitive, QS: Int>(
48 value: Line<F>,
49 scale: FS,
50 range_min: F,
51 range_max: F,
52 #[comptime] scheme: QuantScheme,
53) -> QS {
54 let value = quantize_symmetric::<F, FS>(value, scale, range_min, range_max);
55 pack_q::<F, QS>(value, scheme.value)
56}
57
58#[allow(clippy::explicit_counter_loop)]
61#[cube]
62fn pack_q<F: Float, QS: Int>(value: Line<F>, #[comptime] quant: QuantValue) -> QS {
63 let size_quant = comptime!(quant.size_bits() as u32);
64
65 let size_store = comptime!(QS::size_bits().unwrap() as u32);
66 let num_quants = comptime!(size_store / size_quant);
67
68 let mask = i32::cast_from(comptime!((1 << size_quant) - 1));
69 let mut position = comptime!(0);
70 let mut packed = QS::cast_from(0);
71
72 #[unroll]
74 for _ in 0..num_quants {
75 let offset = QS::cast_from(comptime!(position * size_quant));
76 let shifted = QS::cast_from(i32::cast_from(value[position]) & mask) << offset;
77 packed |= shifted;
78 comptime!(position += 1);
79 }
80
81 packed
82}
83
84#[cube]
85fn write_scale<F: Float, FS: CubePrimitive>(
86 in_pos: u32,
87 scale: &View<F, u32>,
88 out_scale: &mut View<FS, u32, ReadWrite>,
89 scales_layout: ScalesLayout,
90) -> FS {
91 let scale = FS::cast_from(scale[in_pos]);
92
93 if scales_layout.is_block_start(in_pos) {
95 out_scale[in_pos] = scale;
96 }
97
98 scale
99}
100
101#[cube(launch_unchecked)]
102fn quantize_symmetric_native_kernel<F: Float, FS: CubePrimitive, Q: CubePrimitive>(
103 input: &LinearView<Line<F>>,
104 scale: &ScalesView<F>,
105 range_min: F,
106 range_max: F,
107 output: &mut LinearView<Line<Q>, ReadWrite>,
108 out_scale: &mut ScalesView<FS, ReadWrite>,
109 scales_layout: ScalesLayout,
110) {
111 if !output.is_in_bounds(ABSOLUTE_POS) {
112 terminate!();
113 }
114
115 let native_packing = Q::packing_factor();
116 let in_pos = ABSOLUTE_POS * input.line_size() * native_packing;
117 let scale = write_scale(in_pos, scale, out_scale, scales_layout);
118
119 output[ABSOLUTE_POS] =
120 quantize_symmetric_q::<F, FS, Q>(input[ABSOLUTE_POS], scale, range_min, range_max);
121 sync_cube();
122}
123
124#[cube(launch_unchecked)]
125fn quantize_symmetric_packed_kernel<F: Float, FS: CubePrimitive>(
126 input: &LinearView<Line<F>>,
127 scale: &ScalesView<F>,
128 range_min: F,
129 range_max: F,
130 output: &mut LinearView<Line<u32>, ReadWrite>,
131 out_scale: &mut ScalesView<FS, ReadWrite>,
132 scales_layout: ScalesLayout,
133 #[comptime] scheme: QuantScheme,
134) {
135 if !output.is_in_bounds(ABSOLUTE_POS) {
136 terminate!();
137 }
138
139 let num_quants = comptime!(scheme.num_quants() as u32);
140 let packed_pos = ABSOLUTE_POS * num_quants;
141 let scale = write_scale(packed_pos, scale, out_scale, scales_layout);
142
143 if comptime!(input.line_size() == num_quants) {
144 output[ABSOLUTE_POS] = Line::cast_from(quantize_packed_value::<F, FS, u32>(
145 input[ABSOLUTE_POS],
146 scale,
147 range_min,
148 range_max,
149 scheme,
150 ));
151 } else {
152 let mut values = Line::<F>::empty(num_quants);
154 #[unroll]
155 for i in 0..num_quants {
156 values[i] = input[packed_pos + i][0];
157 }
158 output[ABSOLUTE_POS] = Line::cast_from(quantize_packed_value::<F, FS, u32>(
159 values, scale, range_min, range_max, scheme,
160 ));
161 }
162}
163
164#[allow(clippy::result_large_err)]
165pub fn launch_ref<R: Runtime, F: Float + CubeElement>(
166 client: &ComputeClient<R::Server>,
167 input: &TensorHandleRef<R>,
168 output: &TensorHandleRef<R>,
169 scale: &TensorHandleRef<'_, R>,
170 out_scale: &TensorHandleRef<'_, R>,
171 scheme: &QuantScheme,
172) {
173 match scheme {
174 QuantScheme {
175 store: QuantStore::U32,
176 ..
177 } => match scheme.param {
178 QuantParam::F32 => {
179 quantize_packed::<R, F, f32>(client, input, scheme, scale, out_scale, output)
180 }
181 QuantParam::F16 => {
182 quantize_packed::<R, F, f16>(client, input, scheme, scale, out_scale, output)
183 }
184 QuantParam::BF16 => {
185 quantize_packed::<R, F, bf16>(client, input, scheme, scale, out_scale, output)
186 }
187 QuantParam::UE8M0 => {
188 quantize_packed::<R, F, ue8m0>(client, input, scheme, scale, out_scale, output)
189 }
190 QuantParam::UE4M3 => {
191 quantize_packed::<R, F, e4m3>(client, input, scheme, scale, out_scale, output)
192 }
193 },
194 QuantScheme {
195 value:
196 QuantValue::Q8F
197 | QuantValue::Q8S
198 | QuantValue::E4M3
199 | QuantValue::E5M2
200 | QuantValue::E2M1,
201 store: QuantStore::Native,
202 ..
203 } => {
204 if !i8::supported_uses(client).contains(TypeUsage::Conversion) {
205 panic!(
206 "{:?} is not supported for native quantization",
207 scheme.value
208 );
209 }
210
211 match scheme.param {
212 QuantParam::F32 => {
213 quantize_native::<R, F, f32>(client, input, scheme, scale, out_scale, output)
214 }
215 QuantParam::F16 => {
216 quantize_native::<R, F, f16>(client, input, scheme, scale, out_scale, output)
217 }
218 QuantParam::BF16 => {
219 quantize_native::<R, F, bf16>(client, input, scheme, scale, out_scale, output)
220 }
221 QuantParam::UE8M0 => {
222 quantize_native::<R, F, ue8m0>(client, input, scheme, scale, out_scale, output)
223 }
224 QuantParam::UE4M3 => {
225 quantize_native::<R, F, e4m3>(client, input, scheme, scale, out_scale, output)
226 }
227 }
228 }
229 QuantScheme {
230 store: QuantStore::Native,
231 value,
232 ..
233 } => {
234 panic!("{value:?} is not supported for native quantization");
235 }
236 }
237}
238
239fn quantize_native<R: Runtime, F: Float, FS: CubePrimitive>(
240 client: &ComputeClient<R::Server>,
241 input: &TensorHandleRef<R>,
242 scheme: &QuantScheme,
243 scale: &TensorHandleRef<'_, R>,
244 out_scale: &TensorHandleRef<'_, R>,
245 output: &TensorHandleRef<R>,
246) {
247 let num_elems: usize = input.shape.iter().product();
248 let line_size = tensor_line_size_parallel(
249 R::io_optimized_line_sizes_unchecked(size_of::<F>()),
250 input.shape,
251 input.strides,
252 input.shape.len() - 1,
253 );
254 let cube_dim = CubeDim::default();
255 let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);
256 let (range_min, range_max) = scheme.value.range();
257
258 match scheme {
259 QuantScheme {
260 level: QuantLevel::Tensor | QuantLevel::Block(_),
261 mode: QuantMode::Symmetric,
262 value,
263 store: QuantStore::Native,
264 ..
265 } => {
266 check_block_size_compat(scheme, line_size as usize);
268
269 let launch = match value {
270 QuantValue::Q8F | QuantValue::Q8S => {
271 quantize_symmetric_native_kernel::launch_unchecked::<F, FS, i8, R>
272 }
273 QuantValue::E4M3 => {
274 quantize_symmetric_native_kernel::launch_unchecked::<F, FS, e4m3, R>
275 }
276 QuantValue::E5M2 => {
277 quantize_symmetric_native_kernel::launch_unchecked::<F, FS, e5m2, R>
278 }
279 QuantValue::E2M1 => {
280 quantize_symmetric_native_kernel::launch_unchecked::<F, FS, e2m1x2, R>
281 }
282 other => panic!("Unsupported quant value {other:?}"),
283 };
284
285 unsafe {
286 launch(
287 client,
288 cube_count,
289 cube_dim,
290 linear_view(client, input, line_size),
291 scales_view(client, output, scale, 1, scheme),
293 ScalarArg::new(F::from_int(range_min as i64)),
294 ScalarArg::new(F::from_int(range_max as i64)),
295 linear_view(client, output, line_size),
296 scales_view(client, output, out_scale, 1, scheme),
297 scales_layout(client, output, scale, 1, scheme),
298 )
299 };
300 }
301 _ => panic!("Unsupported quantization scheme {scheme:?}"),
302 }
303}
304
305fn quantize_packed<R: Runtime, F: Float, FS: CubePrimitive>(
306 client: &ComputeClient<R::Server>,
307 input: &TensorHandleRef<R>,
308 scheme: &QuantScheme,
309 scale: &TensorHandleRef<'_, R>,
310 out_scale: &TensorHandleRef<'_, R>,
311 output: &TensorHandleRef<R>,
312) {
313 let num_elems: usize = input.shape.iter().product();
314
315 let num_quants = scheme.num_quants() as u8;
316 let line_size = num_quants;
317
318 let cube_dim = CubeDim::default();
319 let cube_count =
320 calculate_cube_count_elemwise(num_elems.div_ceil(line_size as usize), cube_dim);
321 let (range_min, range_max) = scheme.value.range();
322
323 match scheme {
324 QuantScheme {
325 level: QuantLevel::Tensor | QuantLevel::Block(_),
326 mode: QuantMode::Symmetric,
327 store: QuantStore::U32,
328 ..
329 } => {
330 check_block_size_compat(scheme, num_quants as usize); unsafe {
332 quantize_symmetric_packed_kernel::launch_unchecked::<F, FS, R>(
333 client,
334 cube_count,
335 cube_dim,
336 linear_view(client, input, line_size),
337 scales_view(client, output, scale, 1, scheme),
339 ScalarArg::new(F::from_int(range_min as i64)),
340 ScalarArg::new(F::from_int(range_max as i64)),
341 linear_view(client, output, 1),
342 scales_view(client, output, out_scale, 1, scheme),
343 scales_layout(client, output, scale, 1, scheme),
344 *scheme,
345 )
346 };
347 }
348 QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
349 }
350}