1#![allow(missing_docs)] use cubecl::prelude::*;
4use cubecl_core::{
5 self as cubecl, calculate_cube_count_elemwise,
6 ir::{ElemType, FloatKind, IntKind},
7 tensor_line_size_parallel,
8};
9use cubecl_runtime::TypeUsage;
10
11use crate::{
12 layout::{ScalesView, scales_view},
13 scheme::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue},
14};
15use cubecl_std::tensor::{
16 View,
17 layout::linear::{LinearView, linear_view},
18};
19
20#[cube]
22pub fn dequantize_symmetric<F: Float, FS: CubePrimitive>(value: Line<F>, scale: FS) -> Line<F> {
23 Line::cast_from(scale) * value
25}
26
27#[cube]
32pub fn dequantize_symmetric_packed_values<F: Float, FS: CubePrimitive, QI: Int>(
33 position: u32,
34 values: &View<Line<QI>, u32>,
35 scales: &View<FS, u32>,
36 #[comptime] scheme: QuantScheme,
37) -> Array<Line<F>> {
38 dequantize_symmetric_packed_value_at::<F, FS, QI>(position, values[position], scales, scheme)
39}
40
41#[cube]
46pub fn dequantize_symmetric_packed_value_at<F: Float, FS: CubePrimitive, QI: Int>(
47 position: u32,
48 values: Line<QI>,
49 scales: &View<FS, u32>,
50 #[comptime] scheme: QuantScheme,
51) -> Array<Line<F>> {
52 dequantize_symmetric_packed_value::<F, FS, QI>(values, scales, position, scheme)
53}
54
55#[cube]
60pub fn dequantize_symmetric_packed_value<F: Float, FS: CubePrimitive, QS: Int>(
61 values: Line<QS>,
62 scales: &View<FS, u32>,
63 position: u32,
64 #[comptime] scheme: QuantScheme,
65) -> Array<Line<F>> {
66 let line_size_values = values.line_size();
67 let num_quants = comptime!(scheme.num_quants() as u32);
68 let mut tmp = Array::vectorized(line_size_values, num_quants);
69
70 #[unroll]
71 for i in 0..line_size_values {
72 let floats = unpack_q::<F, QS>(values[i], scheme.value, scheme.store);
73 let scale = scales[(position * line_size_values) + i * num_quants];
74 let values = dequantize_symmetric::<F, FS>(floats, scale);
75 tmp[i] = values;
76 }
77
78 tmp
79}
80
81#[allow(clippy::explicit_counter_loop)]
85#[cube]
86fn unpack_q<F: Float, QS: Int>(
87 value: QS,
88 #[comptime] quant: QuantValue,
89 #[comptime] store: QuantStore,
90) -> Line<F> {
91 let size_quant = comptime!(quant.size_bits() as u32);
92 let size_store = comptime!(store.size_bits(&quant) as u32);
93 let num_quant = comptime!(size_store / size_quant);
94
95 let mut output = Line::empty(num_quant);
96 let mut position = comptime!(0);
97
98 let mask = QS::cast_from(comptime!((1 << size_quant) - 1));
99 let sign_bit = QS::cast_from(comptime!(1 << (size_quant - 1)));
100 let two_pow_n = comptime!(1 << size_quant);
101
102 #[unroll]
103 for _ in 0..num_quant {
104 let offset = QS::cast_from(comptime!(position * size_quant));
105 let raw = (value >> offset) & mask;
106
107 let raw_i32 = i32::cast_from(raw);
110 let is_negative = i32::cast_from(raw >= sign_bit); let signed_value = raw_i32 - (is_negative * two_pow_n);
112
113 output[position] = F::cast_from(signed_value);
114 comptime!(position += 1);
115 }
116
117 output
118}
119
120#[cube(launch_unchecked)]
121fn dequantize_symmetric_packed_kernel<F: Float, FS: Numeric>(
122 input: &LinearView<Line<u32>>,
123 scales: &ScalesView<FS>,
124 output: &mut LinearView<Line<F>, ReadWrite>,
125 #[comptime] scheme: QuantScheme,
126 #[define(F, FS)] _dtypes: [StorageType; 2],
127) {
128 if !input.is_in_bounds(ABSOLUTE_POS) {
129 terminate!();
130 }
131
132 let line_size_in = input.line_size();
133 let line_size_out = output.line_size();
134
135 comptime! {
136 assert_eq!(line_size_out, scheme.num_quants() as u32);
137 }
138
139 let values = input[ABSOLUTE_POS];
140 let packed_pos = ABSOLUTE_POS * comptime![scheme.num_quants() as u32];
141
142 let out = dequantize_symmetric_packed_value::<F, FS, u32>(values, scales, packed_pos, scheme);
143
144 #[unroll]
145 for i in 0..line_size_in {
146 output[ABSOLUTE_POS * line_size_in + i] = out[i];
147 }
148}
149
150#[cube(launch_unchecked)]
151fn dequantize_symmetric_native_kernel<F: Float, FS: Numeric, Q: Numeric>(
152 input: &LinearView<Line<Q>>,
153 scale: &ScalesView<FS>,
154 output: &mut LinearView<Line<F>, ReadWrite>,
155 #[define(F, FS, Q)] _dtypes: [StorageType; 3],
156) {
157 if !input.is_in_bounds(ABSOLUTE_POS) {
158 terminate!();
159 }
160
161 let native_packing = Q::packing_factor();
162 let scale = scale[ABSOLUTE_POS * input.line_size() * native_packing];
164
165 output[ABSOLUTE_POS] =
166 dequantize_symmetric::<F, FS>(Line::cast_from(input[ABSOLUTE_POS]), scale);
167}
168
169#[allow(clippy::result_large_err)]
170pub fn launch_ref<R: Runtime>(
172 client: &ComputeClient<R>,
173 values: &TensorHandleRef<R>,
174 output: &TensorHandleRef<R>,
175 params: &TensorHandleRef<'_, R>,
176 scheme: &QuantScheme,
177 input_dtype: StorageType,
178) -> Result<(), LaunchError> {
179 let dtype_scale: StorageType = ElemType::from_quant_param(scheme.param).into();
180
181 match scheme {
182 QuantScheme {
183 store: QuantStore::U32,
184 ..
185 } => dequantize_packed(
186 client,
187 values,
188 *scheme,
189 params,
190 output,
191 input_dtype,
192 dtype_scale,
193 ),
194 QuantScheme {
195 value:
196 QuantValue::Q8F
197 | QuantValue::Q8S
198 | QuantValue::E4M3
199 | QuantValue::E5M2
200 | QuantValue::E2M1,
201 store: QuantStore::Native,
202 ..
203 } => {
204 if !i8::supported_uses(client).contains(TypeUsage::Conversion) {
205 panic!(
206 "{:?} is not supported for native quantization",
207 scheme.value
208 );
209 }
210
211 dequantize_native(
212 client,
213 values,
214 *scheme,
215 params,
216 output,
217 input_dtype,
218 dtype_scale,
219 )
220 }
221 QuantScheme {
222 store: QuantStore::Native,
223 value,
224 ..
225 } => {
226 panic!("{value:?} is not supported for native quantization");
227 }
228 }
229}
230
231fn dequantize_packed<R: Runtime>(
232 client: &ComputeClient<R>,
233 input: &TensorHandleRef<R>,
234 scheme: QuantScheme,
235 scale: &TensorHandleRef<'_, R>,
236 output: &TensorHandleRef<R>,
237 input_dtype: StorageType,
238 scale_dtype: StorageType,
239) -> Result<(), LaunchError> {
240 let num_elems_input: usize = input.shape.iter().product();
241
242 let mut line_size_in = tensor_line_size_parallel(
243 client.io_optimized_line_sizes_unchecked(input.elem_size),
244 input.shape,
245 input.strides,
246 input.shape.len() - 1,
247 );
248 let num_quants = scheme.num_quants() as u8;
249 let line_size_out = num_quants;
250 let rank = output.shape.len();
251
252 if !output.shape[rank - 1].is_multiple_of(line_size_out as usize) {
253 line_size_in = 1;
254 }
255
256 let cube_dim = CubeDim::default();
257 let cube_count =
258 calculate_cube_count_elemwise(num_elems_input / line_size_in as usize, cube_dim);
259
260 match scheme {
261 QuantScheme {
262 level: QuantLevel::Tensor | QuantLevel::Block(_),
263 store: QuantStore::U32,
264 mode: QuantMode::Symmetric,
265 ..
266 } => unsafe {
267 dequantize_symmetric_packed_kernel::launch_unchecked(
268 client,
269 cube_count,
270 cube_dim,
271 linear_view(client, input, line_size_in),
272 scales_view(client, input, scale, 1, &scheme),
273 linear_view(client, output, line_size_out),
274 scheme,
275 [input_dtype, scale_dtype],
276 )
277 },
278 QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
279 }
280}
281
282fn dequantize_native<R: Runtime>(
283 client: &ComputeClient<R>,
284 input: &TensorHandleRef<R>,
285 scheme: QuantScheme,
286 scale: &TensorHandleRef<'_, R>,
287 output: &TensorHandleRef<R>,
288 input_dtype: StorageType,
289 scale_dtype: StorageType,
290) -> Result<(), LaunchError> {
291 let num_elems: usize = input.shape.iter().product();
292 let line_size = tensor_line_size_parallel(
293 client.io_optimized_line_sizes_unchecked(input_dtype.size()),
294 input.shape,
295 input.strides,
296 input.shape.len() - 1,
297 );
298 let cube_dim = CubeDim::default();
299 let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);
300
301 match scheme {
302 QuantScheme {
303 level: QuantLevel::Tensor | QuantLevel::Block(_),
304 mode: QuantMode::Symmetric,
305 value,
306 store: QuantStore::Native,
307 ..
308 } => {
309 let quant_dtype: ElemType = match value {
310 QuantValue::Q8F | QuantValue::Q8S => ElemType::Int(IntKind::I8),
311 QuantValue::E4M3 => ElemType::Float(FloatKind::E4M3),
312 QuantValue::E5M2 => ElemType::Float(FloatKind::E5M2),
313 QuantValue::E2M1 => ElemType::Float(FloatKind::E2M1),
314 other => panic!("Unsupported quantization value {other:?}"),
315 };
316
317 println!("{input_dtype:?} {scale_dtype:?} {quant_dtype:?}");
318 unsafe {
319 dequantize_symmetric_native_kernel::launch_unchecked(
320 client,
321 cube_count,
322 cube_dim,
323 linear_view(client, input, line_size),
324 scales_view(client, input, scale, 1, &scheme),
325 linear_view(client, output, line_size),
326 [input_dtype, scale_dtype, quant_dtype.into()],
327 )
328 }
329 }
330 QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
331 }
332}