1#![allow(missing_docs)] use cubecl::features::TypeUsage;
4use cubecl::prelude::*;
5use cubecl::{
6 calculate_cube_count_elemwise,
7 ir::{ElemType, FloatKind, IntKind},
8 tensor_line_size_parallel,
9};
10
11use crate::{
12 layout::{ScalesView, scales_view},
13 scheme::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue},
14};
15use cubecl::std::tensor::{
16 View,
17 layout::linear::{LinearView, linear_view},
18};
19
20#[cube]
22pub fn dequantize_symmetric<F: Float, FS: CubePrimitive>(value: Line<F>, scale: FS) -> Line<F> {
23 Line::cast_from(scale) * value
25}
26
27#[cube]
32pub fn dequantize_symmetric_packed_values<F: Float, FS: CubePrimitive, QI: Int>(
33 position: usize,
34 values: &View<Line<QI>, usize>,
35 scales: &View<FS, usize>,
36 #[comptime] scheme: QuantScheme,
37) -> Array<Line<F>> {
38 dequantize_symmetric_packed_value_at::<F, FS, QI>(position, values[position], scales, scheme)
39}
40
41#[cube]
46pub fn dequantize_symmetric_packed_value_at<F: Float, FS: CubePrimitive, QI: Int>(
47 position: usize,
48 values: Line<QI>,
49 scales: &View<FS, usize>,
50 #[comptime] scheme: QuantScheme,
51) -> Array<Line<F>> {
52 dequantize_symmetric_packed_value::<F, FS, QI>(values, scales, position, scheme)
53}
54
55#[cube]
60pub fn dequantize_symmetric_packed_value<F: Float, FS: CubePrimitive, QS: Int>(
61 values: Line<QS>,
62 scales: &View<FS, usize>,
63 position: usize,
64 #[comptime] scheme: QuantScheme,
65) -> Array<Line<F>> {
66 let line_size_values = values.line_size();
67 let num_quants = scheme.num_quants();
68 let mut tmp = Array::lined(line_size_values, num_quants);
69
70 #[unroll]
71 for i in 0..line_size_values {
72 let floats = unpack_q::<F, QS>(values[i], scheme.value, scheme.store);
73 let scale = scales[(position * line_size_values) + i * num_quants];
74 let values = dequantize_symmetric::<F, FS>(floats, scale);
75 tmp[i] = values;
76 }
77
78 tmp
79}
80
81#[allow(clippy::explicit_counter_loop)]
85#[cube]
86fn unpack_q<F: Float, QS: Int>(
87 value: QS,
88 #[comptime] quant: QuantValue,
89 #[comptime] store: QuantStore,
90) -> Line<F> {
91 let size_quant = quant.size_bits();
92 let size_store = store.size_bits(&quant);
93 let num_quant = size_store / size_quant;
94
95 let mut output = Line::empty(num_quant);
96
97 let mask = QS::from_int((1 << size_quant) - 1);
98 let sign_bit = QS::from_int(1 << (size_quant - 1));
99 let two_pow_n = 1 << size_quant;
100
101 #[unroll]
102 for position in 0..num_quant {
103 let offset = QS::cast_from(position * size_quant);
104 let raw = (value >> offset) & mask;
105
106 let raw_i32 = i32::cast_from(raw);
109 let is_negative = i32::cast_from(raw >= sign_bit); let signed_value = raw_i32 - (is_negative * two_pow_n);
111
112 output[position] = F::cast_from(signed_value);
113 }
114
115 output
116}
117
118#[cube(launch_unchecked)]
119fn dequantize_symmetric_packed_kernel<F: Float, FS: Numeric>(
120 input: &LinearView<Line<u32>>,
121 scales: &ScalesView<FS>,
122 output: &mut LinearView<Line<F>, ReadWrite>,
123 #[comptime] scheme: QuantScheme,
124 #[define(F, FS)] _dtypes: [StorageType; 2],
125) {
126 if !input.is_in_bounds(ABSOLUTE_POS) {
127 terminate!();
128 }
129
130 let line_size_in = input.line_size();
131 let line_size_out = output.line_size();
132
133 comptime! {
134 assert_eq!(line_size_out, scheme.num_quants());
135 }
136
137 let values = input[ABSOLUTE_POS];
138 let packed_pos = ABSOLUTE_POS * scheme.num_quants();
139
140 let out = dequantize_symmetric_packed_value::<F, FS, u32>(values, scales, packed_pos, scheme);
141
142 #[unroll]
143 for i in 0..line_size_in {
144 output[ABSOLUTE_POS * line_size_in + i] = out[i];
145 }
146}
147
148#[cube(launch_unchecked)]
149fn dequantize_symmetric_native_kernel<F: Float, FS: Numeric, Q: Numeric>(
150 input: &LinearView<Line<Q>>,
151 scale: &ScalesView<FS>,
152 output: &mut LinearView<Line<F>, ReadWrite>,
153 #[define(F, FS, Q)] _dtypes: [StorageType; 3],
154) {
155 if !input.is_in_bounds(ABSOLUTE_POS) {
156 terminate!();
157 }
158
159 let native_packing = Q::packing_factor();
160 let scale = scale[ABSOLUTE_POS * input.line_size() * native_packing];
162
163 output[ABSOLUTE_POS] =
164 dequantize_symmetric::<F, FS>(Line::cast_from(input[ABSOLUTE_POS]), scale);
165}
166
167#[allow(clippy::result_large_err)]
168pub fn launch_ref<R: Runtime>(
170 client: &ComputeClient<R>,
171 values: &TensorHandleRef<R>,
172 output: &TensorHandleRef<R>,
173 params: &TensorHandleRef<'_, R>,
174 scheme: &QuantScheme,
175 input_dtype: StorageType,
176) -> Result<(), LaunchError> {
177 let dtype_scale: StorageType = ElemType::from_quant_param(scheme.param).into();
178
179 match scheme {
180 QuantScheme {
181 store: QuantStore::PackedU32(_),
182 ..
183 } => dequantize_packed(
184 client,
185 values,
186 *scheme,
187 params,
188 output,
189 input_dtype,
190 dtype_scale,
191 ),
192 QuantScheme {
193 value: QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2,
194 store: QuantStore::Native,
195 ..
196 }
197 | QuantScheme {
198 value: QuantValue::E2M1,
199 store: QuantStore::PackedNative(_),
200 ..
201 } => {
202 if !i8::supported_uses(client).contains(TypeUsage::Conversion) {
203 panic!(
204 "{:?} is not supported for native quantization",
205 scheme.value
206 );
207 }
208
209 dequantize_native(
210 client,
211 values,
212 *scheme,
213 params,
214 output,
215 input_dtype,
216 dtype_scale,
217 )
218 }
219 QuantScheme {
220 store: QuantStore::Native | QuantStore::PackedNative(_),
221 value,
222 ..
223 } => {
224 panic!("{value:?} is not supported for native quantization");
225 }
226 }
227}
228
229fn dequantize_packed<R: Runtime>(
230 client: &ComputeClient<R>,
231 input: &TensorHandleRef<R>,
232 scheme: QuantScheme,
233 scale: &TensorHandleRef<'_, R>,
234 output: &TensorHandleRef<R>,
235 input_dtype: StorageType,
236 scale_dtype: StorageType,
237) -> Result<(), LaunchError> {
238 let num_elems_input: usize = input.shape.iter().product();
239
240 let mut line_size_in = tensor_line_size_parallel(
241 client.io_optimized_line_sizes_unchecked(input.elem_size),
242 input.shape,
243 input.strides,
244 input.shape.len() - 1,
245 );
246 let num_quants = scheme.num_quants();
247 let line_size_out = num_quants;
248 let rank = output.shape.len();
249
250 if !output.shape[rank - 1].is_multiple_of(line_size_out) {
251 line_size_in = 1;
252 }
253
254 let num_elems = num_elems_input / line_size_in as usize;
255 let cube_dim = CubeDim::new(client, num_elems);
256 let cube_count = calculate_cube_count_elemwise(client, num_elems, cube_dim);
257
258 match scheme {
259 QuantScheme {
260 level: QuantLevel::Tensor | QuantLevel::Block(_),
261 store: QuantStore::PackedU32(_),
262 mode: QuantMode::Symmetric,
263 ..
264 } => unsafe {
265 dequantize_symmetric_packed_kernel::launch_unchecked(
266 client,
267 cube_count,
268 cube_dim,
269 linear_view(client, input, line_size_in),
270 scales_view(client, input, scale, 1, &scheme),
271 linear_view(client, output, line_size_out),
272 scheme,
273 [input_dtype, scale_dtype],
274 )
275 },
276 QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
277 }
278}
279
280fn dequantize_native<R: Runtime>(
281 client: &ComputeClient<R>,
282 input: &TensorHandleRef<R>,
283 scheme: QuantScheme,
284 scale: &TensorHandleRef<'_, R>,
285 output: &TensorHandleRef<R>,
286 input_dtype: StorageType,
287 scale_dtype: StorageType,
288) -> Result<(), LaunchError> {
289 let num_elems: usize = input.shape.iter().product();
290 let line_size = tensor_line_size_parallel(
291 client.io_optimized_line_sizes_unchecked(input_dtype.size()),
292 input.shape,
293 input.strides,
294 input.shape.len() - 1,
295 );
296 let working_units = num_elems / line_size as usize;
297 let cube_dim = CubeDim::new(client, working_units);
298 let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);
299
300 match scheme {
301 QuantScheme {
302 level: QuantLevel::Tensor | QuantLevel::Block(_),
303 mode: QuantMode::Symmetric,
304 value,
305 store: QuantStore::Native,
306 ..
307 } => {
308 let quant_dtype: ElemType = match value {
309 QuantValue::Q8F | QuantValue::Q8S => ElemType::Int(IntKind::I8),
310 QuantValue::E4M3 => ElemType::Float(FloatKind::E4M3),
311 QuantValue::E5M2 => ElemType::Float(FloatKind::E5M2),
312 QuantValue::E2M1 => ElemType::Float(FloatKind::E2M1),
313 other => panic!("Unsupported quantization value {other:?}"),
314 };
315
316 unsafe {
317 dequantize_symmetric_native_kernel::launch_unchecked(
318 client,
319 cube_count,
320 cube_dim,
321 linear_view(client, input, line_size),
322 scales_view(client, input, scale, 1, &scheme),
323 linear_view(client, output, line_size),
324 [input_dtype, scale_dtype, quant_dtype.into()],
325 )
326 }
327 }
328 QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
329 }
330}