1#![allow(missing_docs)] use cubecl::prelude::*;
4use cubecl_common::{e2m1x2, e4m3, e5m2, ue8m0};
5use cubecl_core::{self as cubecl, calculate_cube_count_elemwise, tensor_line_size_parallel};
6use cubecl_runtime::TypeUsage;
7
8use crate::{
9 layout::{ScalesView, scales_view},
10 scheme::{QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue},
11};
12use cubecl_std::tensor::{
13 View,
14 layout::linear::{LinearView, linear_view},
15};
16use half::{bf16, f16};
17
18#[cube]
20pub fn dequantize_symmetric<F: Float, FS: CubePrimitive>(value: Line<F>, scale: FS) -> Line<F> {
21 Line::cast_from(scale) * value
23}
24
25#[cube]
30pub fn dequantize_symmetric_packed_values<F: Float, FS: CubePrimitive, QI: Int>(
31 position: u32,
32 values: &View<Line<QI>, u32>,
33 scales: &View<FS, u32>,
34 #[comptime] scheme: QuantScheme,
35) -> Array<Line<F>> {
36 dequantize_symmetric_packed_value_at::<F, FS, QI>(position, values[position], scales, scheme)
37}
38
39#[cube]
44pub fn dequantize_symmetric_packed_value_at<F: Float, FS: CubePrimitive, QI: Int>(
45 position: u32,
46 values: Line<QI>,
47 scales: &View<FS, u32>,
48 #[comptime] scheme: QuantScheme,
49) -> Array<Line<F>> {
50 dequantize_symmetric_packed_value::<F, FS, QI>(values, scales, position, scheme)
51}
52
53#[cube]
58pub fn dequantize_symmetric_packed_value<F: Float, FS: CubePrimitive, QS: Int>(
59 values: Line<QS>,
60 scales: &View<FS, u32>,
61 position: u32,
62 #[comptime] scheme: QuantScheme,
63) -> Array<Line<F>> {
64 let line_size_values = values.line_size();
65 let num_quants = comptime!(scheme.num_quants() as u32);
66 let mut tmp = Array::vectorized(line_size_values, num_quants);
67
68 #[unroll]
69 for i in 0..line_size_values {
70 let floats = unpack_q::<F, QS>(values[i], scheme.value, scheme.store);
71 let scale = scales[(position * line_size_values) + i * num_quants];
72 let values = dequantize_symmetric::<F, FS>(floats, scale);
73 tmp[i] = values;
74 }
75
76 tmp
77}
78
79#[allow(clippy::explicit_counter_loop)]
83#[cube]
84fn unpack_q<F: Float, QS: Int>(
85 value: QS,
86 #[comptime] quant: QuantValue,
87 #[comptime] store: QuantStore,
88) -> Line<F> {
89 let size_quant = comptime!(quant.size_bits() as u32);
90 let size_store = comptime!(store.size_bits(&quant) as u32);
91 let num_quant = comptime!(size_store / size_quant);
92
93 let mut output = Line::empty(num_quant);
94 let mut position = comptime!(0);
95
96 let mask = QS::cast_from(comptime!((1 << size_quant) - 1));
97 let sign_bit = QS::cast_from(comptime!(1 << (size_quant - 1)));
98 let two_pow_n = comptime!(1 << size_quant);
99
100 #[unroll]
101 for _ in 0..num_quant {
102 let offset = QS::cast_from(comptime!(position * size_quant));
103 let raw = (value >> offset) & mask;
104
105 let raw_i32 = i32::cast_from(raw);
108 let is_negative = i32::cast_from(raw >= sign_bit); let signed_value = raw_i32 - (is_negative * two_pow_n);
110
111 output[position] = F::cast_from(signed_value);
112 comptime!(position += 1);
113 }
114
115 output
116}
117
118#[cube(launch_unchecked)]
119fn dequantize_symmetric_packed_kernel<F: Float, FS: CubePrimitive>(
120 input: &LinearView<Line<u32>>,
121 scales: &ScalesView<FS>,
122 output: &mut LinearView<Line<F>, ReadWrite>,
123 #[comptime] scheme: QuantScheme,
124) {
125 if !input.is_in_bounds(ABSOLUTE_POS) {
126 terminate!();
127 }
128
129 let line_size_in = input.line_size();
130 let line_size_out = output.line_size();
131
132 comptime! {
133 assert_eq!(line_size_out, scheme.num_quants() as u32);
134 }
135
136 let values = input[ABSOLUTE_POS];
137 let packed_pos = ABSOLUTE_POS * comptime![scheme.num_quants() as u32];
138
139 let out = dequantize_symmetric_packed_value::<F, FS, u32>(values, scales, packed_pos, scheme);
140
141 #[unroll]
142 for i in 0..line_size_in {
143 output[ABSOLUTE_POS * line_size_in + i] = out[i];
144 }
145}
146
147#[cube(launch_unchecked)]
148fn dequantize_symmetric_native_kernel<F: Float, FS: CubePrimitive, Q: CubePrimitive>(
149 input: &LinearView<Line<Q>>,
150 scale: &ScalesView<FS>,
151 output: &mut LinearView<Line<F>, ReadWrite>,
152) {
153 if !input.is_in_bounds(ABSOLUTE_POS) {
154 terminate!();
155 }
156
157 let native_packing = Q::packing_factor();
158 let scale = scale[ABSOLUTE_POS * input.line_size() * native_packing];
160
161 output[ABSOLUTE_POS] =
162 dequantize_symmetric::<F, FS>(Line::cast_from(input[ABSOLUTE_POS]), scale);
163}
164
165#[allow(clippy::result_large_err)]
166pub fn launch_ref<R: Runtime, F: Float>(
168 client: &ComputeClient<R::Server>,
169 values: &TensorHandleRef<R>,
170 output: &TensorHandleRef<R>,
171 params: &TensorHandleRef<'_, R>,
172 scheme: &QuantScheme,
173) {
174 match scheme {
175 QuantScheme {
176 store: QuantStore::U32,
177 ..
178 } => match scheme.param {
179 QuantParam::F32 => {
180 dequantize_packed::<R, F, f32>(client, values, *scheme, params, output)
181 }
182 QuantParam::F16 => {
183 dequantize_packed::<R, F, f16>(client, values, *scheme, params, output)
184 }
185 QuantParam::BF16 => {
186 dequantize_packed::<R, F, bf16>(client, values, *scheme, params, output)
187 }
188 QuantParam::UE8M0 => {
189 dequantize_packed::<R, F, ue8m0>(client, values, *scheme, params, output)
190 }
191 QuantParam::UE4M3 => {
192 dequantize_packed::<R, F, e4m3>(client, values, *scheme, params, output)
193 }
194 },
195 QuantScheme {
196 value:
197 QuantValue::Q8F
198 | QuantValue::Q8S
199 | QuantValue::E4M3
200 | QuantValue::E5M2
201 | QuantValue::E2M1,
202 store: QuantStore::Native,
203 ..
204 } => {
205 if !i8::supported_uses(client).contains(TypeUsage::Conversion) {
206 panic!(
207 "{:?} is not supported for native quantization",
208 scheme.value
209 );
210 }
211
212 match scheme.param {
213 QuantParam::F32 => {
214 dequantize_native::<R, F, f32>(client, values, *scheme, params, output)
215 }
216 QuantParam::F16 => {
217 dequantize_native::<R, F, f16>(client, values, *scheme, params, output)
218 }
219 QuantParam::BF16 => {
220 dequantize_native::<R, F, bf16>(client, values, *scheme, params, output)
221 }
222 QuantParam::UE8M0 => {
223 dequantize_native::<R, F, ue8m0>(client, values, *scheme, params, output)
224 }
225 QuantParam::UE4M3 => {
226 dequantize_native::<R, F, e4m3>(client, values, *scheme, params, output)
227 }
228 }
229 }
230 QuantScheme {
231 store: QuantStore::Native,
232 value,
233 ..
234 } => {
235 panic!("{value:?} is not supported for native quantization");
236 }
237 }
238}
239
240fn dequantize_packed<R: Runtime, F: Float, FS: CubePrimitive>(
241 client: &ComputeClient<R::Server>,
242 input: &TensorHandleRef<R>,
243 scheme: QuantScheme,
244 scale: &TensorHandleRef<'_, R>,
245 output: &TensorHandleRef<R>,
246) {
247 let num_elems_input: usize = input.shape.iter().product();
248
249 let mut line_size_in = tensor_line_size_parallel(
250 R::io_optimized_line_sizes_unchecked(size_of::<F>()),
251 input.shape,
252 input.strides,
253 input.shape.len() - 1,
254 );
255 let num_quants = scheme.num_quants() as u8;
256 let line_size_out = num_quants;
257 let rank = output.shape.len();
258
259 if !output.shape[rank - 1].is_multiple_of(line_size_out as usize) {
260 line_size_in = 1;
261 }
262
263 let cube_dim = CubeDim::default();
264 let cube_count =
265 calculate_cube_count_elemwise(num_elems_input / line_size_in as usize, cube_dim);
266
267 match scheme {
268 QuantScheme {
269 level: QuantLevel::Tensor | QuantLevel::Block(_),
270 store: QuantStore::U32,
271 mode: QuantMode::Symmetric,
272 ..
273 } => {
274 unsafe {
275 dequantize_symmetric_packed_kernel::launch_unchecked::<F, FS, R>(
276 client,
277 cube_count,
278 cube_dim,
279 linear_view(client, input, line_size_in),
280 scales_view(client, input, scale, 1, &scheme),
281 linear_view(client, output, line_size_out),
282 scheme,
283 )
284 };
285 }
286 QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
287 }
288}
289
290fn dequantize_native<R: Runtime, F: Float, FS: CubePrimitive>(
291 client: &ComputeClient<R::Server>,
292 input: &TensorHandleRef<R>,
293 scheme: QuantScheme,
294 scale: &TensorHandleRef<'_, R>,
295 output: &TensorHandleRef<R>,
296) {
297 let num_elems: usize = input.shape.iter().product();
298 let line_size = tensor_line_size_parallel(
299 R::io_optimized_line_sizes_unchecked(size_of::<F>()),
300 input.shape,
301 input.strides,
302 input.shape.len() - 1,
303 );
304 let cube_dim = CubeDim::default();
305 let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);
306
307 match scheme {
308 QuantScheme {
309 level: QuantLevel::Tensor | QuantLevel::Block(_),
310 mode: QuantMode::Symmetric,
311 value,
312 store: QuantStore::Native,
313 ..
314 } => {
315 let launch = match value {
316 QuantValue::Q8F | QuantValue::Q8S => {
317 dequantize_symmetric_native_kernel::launch_unchecked::<F, FS, i8, R>
318 }
319 QuantValue::E4M3 => {
320 dequantize_symmetric_native_kernel::launch_unchecked::<F, FS, e4m3, R>
321 }
322 QuantValue::E5M2 => {
323 dequantize_symmetric_native_kernel::launch_unchecked::<F, FS, e5m2, R>
324 }
325 QuantValue::E2M1 => {
326 dequantize_symmetric_native_kernel::launch_unchecked::<F, FS, e2m1x2, R>
327 }
328 other => panic!("Unsupported quantization value {other:?}"),
329 };
330
331 unsafe {
332 launch(
333 client,
334 cube_count,
335 cube_dim,
336 linear_view(client, input, line_size),
337 scales_view(client, input, scale, 1, &scheme),
338 linear_view(client, output, line_size),
339 )
340 };
341 }
342 QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
343 }
344}