cubecl_quant/
dequantize.rs

1#![allow(missing_docs)] // pub cube modules
2
3use cubecl::prelude::*;
4use cubecl_common::{e2m1x2, e4m3, e5m2, ue8m0};
5use cubecl_core::{self as cubecl, calculate_cube_count_elemwise, tensor_line_size_parallel};
6use cubecl_runtime::TypeUsage;
7
8use crate::{
9    layout::{ScalesView, scales_view},
10    scheme::{QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue},
11};
12use cubecl_std::tensor::{
13    View,
14    layout::linear::{LinearView, linear_view},
15};
16use half::{bf16, f16};
17
18/// Dequantize a line of values into floating-point values using the provided scale.
19#[cube]
20pub fn dequantize_symmetric<F: Float, FS: CubePrimitive>(value: Line<F>, scale: FS) -> Line<F> {
21    // x = scale * x_q
22    Line::cast_from(scale) * value
23}
24
25/// Dequantize the value at a specified position using the provided quantization scheme.
26///
27/// Returns a line of floating-point values. The number of values in the line depends on the number of packed
28/// values in the stored quantization type.
29#[cube]
30pub fn dequantize_symmetric_packed_values<F: Float, FS: CubePrimitive, QI: Int>(
31    position: u32,
32    values: &View<Line<QI>, u32>,
33    scales: &View<FS, u32>,
34    #[comptime] scheme: QuantScheme,
35) -> Array<Line<F>> {
36    dequantize_symmetric_packed_value_at::<F, FS, QI>(position, values[position], scales, scheme)
37}
38
39/// Dequantize a single value using the scale at the specified position.
40///
41/// Returns a line of floating-point values. The number of values in the line depends on the number of packed
42/// values in the stored quantization type.
43#[cube]
44pub fn dequantize_symmetric_packed_value_at<F: Float, FS: CubePrimitive, QI: Int>(
45    position: u32,
46    values: Line<QI>,
47    scales: &View<FS, u32>,
48    #[comptime] scheme: QuantScheme,
49) -> Array<Line<F>> {
50    dequantize_symmetric_packed_value::<F, FS, QI>(values, scales, position, scheme)
51}
52
53/// Dequantize a single packed value using the scale provided.
54///
55/// Returns a line of floating-point values. The number of values in the line depends on the number of packed
56/// values in the stored quantization type.
57#[cube]
58pub fn dequantize_symmetric_packed_value<F: Float, FS: CubePrimitive, QS: Int>(
59    values: Line<QS>,
60    scales: &View<FS, u32>,
61    position: u32,
62    #[comptime] scheme: QuantScheme,
63) -> Array<Line<F>> {
64    let line_size_values = values.line_size();
65    let num_quants = comptime!(scheme.num_quants() as u32);
66    let mut tmp = Array::vectorized(line_size_values, num_quants);
67
68    #[unroll]
69    for i in 0..line_size_values {
70        let floats = unpack_q::<F, QS>(values[i], scheme.value, scheme.store);
71        let scale = scales[(position * line_size_values) + i * num_quants];
72        let values = dequantize_symmetric::<F, FS>(floats, scale);
73        tmp[i] = values;
74    }
75
76    tmp
77}
78
79/// Unpack a quantized integer into a line of floating-point values, according to the specified quantization input type.
80///
81/// This handles types where multiple quantized values are packed into a single integer (the stored quantization type).
82#[allow(clippy::explicit_counter_loop)]
83#[cube]
84fn unpack_q<F: Float, QS: Int>(
85    value: QS,
86    #[comptime] quant: QuantValue,
87    #[comptime] store: QuantStore,
88) -> Line<F> {
89    let size_quant = comptime!(quant.size_bits() as u32);
90    let size_store = comptime!(store.size_bits(&quant) as u32);
91    let num_quant = comptime!(size_store / size_quant);
92
93    let mut output = Line::empty(num_quant);
94    let mut position = comptime!(0);
95
96    let mask = QS::cast_from(comptime!((1 << size_quant) - 1));
97    let sign_bit = QS::cast_from(comptime!(1 << (size_quant - 1)));
98    let two_pow_n = comptime!(1 << size_quant);
99
100    #[unroll]
101    for _ in 0..num_quant {
102        let offset = QS::cast_from(comptime!(position * size_quant));
103        let raw = (value >> offset) & mask;
104
105        // Branchless two's complement conversion
106        // If raw >= 2^(n-1), then result = raw - 2^n
107        let raw_i32 = i32::cast_from(raw);
108        let is_negative = i32::cast_from(raw >= sign_bit); // 1 if negative, 0 if positive
109        let signed_value = raw_i32 - (is_negative * two_pow_n);
110
111        output[position] = F::cast_from(signed_value);
112        comptime!(position += 1);
113    }
114
115    output
116}
117
118#[cube(launch_unchecked)]
119fn dequantize_symmetric_packed_kernel<F: Float, FS: CubePrimitive>(
120    input: &LinearView<Line<u32>>,
121    scales: &ScalesView<FS>,
122    output: &mut LinearView<Line<F>, ReadWrite>,
123    #[comptime] scheme: QuantScheme,
124) {
125    if !input.is_in_bounds(ABSOLUTE_POS) {
126        terminate!();
127    }
128
129    let line_size_in = input.line_size();
130    let line_size_out = output.line_size();
131
132    comptime! {
133        assert_eq!(line_size_out, scheme.num_quants() as u32);
134    }
135
136    let values = input[ABSOLUTE_POS];
137    let packed_pos = ABSOLUTE_POS * comptime![scheme.num_quants() as u32];
138
139    let out = dequantize_symmetric_packed_value::<F, FS, u32>(values, scales, packed_pos, scheme);
140
141    #[unroll]
142    for i in 0..line_size_in {
143        output[ABSOLUTE_POS * line_size_in + i] = out[i];
144    }
145}
146
147#[cube(launch_unchecked)]
148fn dequantize_symmetric_native_kernel<F: Float, FS: CubePrimitive, Q: CubePrimitive>(
149    input: &LinearView<Line<Q>>,
150    scale: &ScalesView<FS>,
151    output: &mut LinearView<Line<F>, ReadWrite>,
152) {
153    if !input.is_in_bounds(ABSOLUTE_POS) {
154        terminate!();
155    }
156
157    let native_packing = Q::packing_factor();
158    // Absolute pos represents the logical block (scale) used to dequantize, not layout
159    let scale = scale[ABSOLUTE_POS * input.line_size() * native_packing];
160
161    output[ABSOLUTE_POS] =
162        dequantize_symmetric::<F, FS>(Line::cast_from(input[ABSOLUTE_POS]), scale);
163}
164
165#[allow(clippy::result_large_err)]
166/// Convert the tensor back to a higher precision data type.
167pub fn launch_ref<R: Runtime, F: Float>(
168    client: &ComputeClient<R::Server>,
169    values: &TensorHandleRef<R>,
170    output: &TensorHandleRef<R>,
171    params: &TensorHandleRef<'_, R>,
172    scheme: &QuantScheme,
173) {
174    match scheme {
175        QuantScheme {
176            store: QuantStore::U32,
177            ..
178        } => match scheme.param {
179            QuantParam::F32 => {
180                dequantize_packed::<R, F, f32>(client, values, *scheme, params, output)
181            }
182            QuantParam::F16 => {
183                dequantize_packed::<R, F, f16>(client, values, *scheme, params, output)
184            }
185            QuantParam::BF16 => {
186                dequantize_packed::<R, F, bf16>(client, values, *scheme, params, output)
187            }
188            QuantParam::UE8M0 => {
189                dequantize_packed::<R, F, ue8m0>(client, values, *scheme, params, output)
190            }
191            QuantParam::UE4M3 => {
192                dequantize_packed::<R, F, e4m3>(client, values, *scheme, params, output)
193            }
194        },
195        QuantScheme {
196            value:
197                QuantValue::Q8F
198                | QuantValue::Q8S
199                | QuantValue::E4M3
200                | QuantValue::E5M2
201                | QuantValue::E2M1,
202            store: QuantStore::Native,
203            ..
204        } => {
205            if !i8::supported_uses(client).contains(TypeUsage::Conversion) {
206                panic!(
207                    "{:?} is not supported for native quantization",
208                    scheme.value
209                );
210            }
211
212            match scheme.param {
213                QuantParam::F32 => {
214                    dequantize_native::<R, F, f32>(client, values, *scheme, params, output)
215                }
216                QuantParam::F16 => {
217                    dequantize_native::<R, F, f16>(client, values, *scheme, params, output)
218                }
219                QuantParam::BF16 => {
220                    dequantize_native::<R, F, bf16>(client, values, *scheme, params, output)
221                }
222                QuantParam::UE8M0 => {
223                    dequantize_native::<R, F, ue8m0>(client, values, *scheme, params, output)
224                }
225                QuantParam::UE4M3 => {
226                    dequantize_native::<R, F, e4m3>(client, values, *scheme, params, output)
227                }
228            }
229        }
230        QuantScheme {
231            store: QuantStore::Native,
232            value,
233            ..
234        } => {
235            panic!("{value:?} is not supported for native quantization");
236        }
237    }
238}
239
240fn dequantize_packed<R: Runtime, F: Float, FS: CubePrimitive>(
241    client: &ComputeClient<R::Server>,
242    input: &TensorHandleRef<R>,
243    scheme: QuantScheme,
244    scale: &TensorHandleRef<'_, R>,
245    output: &TensorHandleRef<R>,
246) {
247    let num_elems_input: usize = input.shape.iter().product();
248
249    let mut line_size_in = tensor_line_size_parallel(
250        R::io_optimized_line_sizes_unchecked(size_of::<F>()),
251        input.shape,
252        input.strides,
253        input.shape.len() - 1,
254    );
255    let num_quants = scheme.num_quants() as u8;
256    let line_size_out = num_quants;
257    let rank = output.shape.len();
258
259    if !output.shape[rank - 1].is_multiple_of(line_size_out as usize) {
260        line_size_in = 1;
261    }
262
263    let cube_dim = CubeDim::default();
264    let cube_count =
265        calculate_cube_count_elemwise(num_elems_input / line_size_in as usize, cube_dim);
266
267    match scheme {
268        QuantScheme {
269            level: QuantLevel::Tensor | QuantLevel::Block(_),
270            store: QuantStore::U32,
271            mode: QuantMode::Symmetric,
272            ..
273        } => {
274            unsafe {
275                dequantize_symmetric_packed_kernel::launch_unchecked::<F, FS, R>(
276                    client,
277                    cube_count,
278                    cube_dim,
279                    linear_view(client, input, line_size_in),
280                    scales_view(client, input, scale, 1, &scheme),
281                    linear_view(client, output, line_size_out),
282                    scheme,
283                )
284            };
285        }
286        QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
287    }
288}
289
290fn dequantize_native<R: Runtime, F: Float, FS: CubePrimitive>(
291    client: &ComputeClient<R::Server>,
292    input: &TensorHandleRef<R>,
293    scheme: QuantScheme,
294    scale: &TensorHandleRef<'_, R>,
295    output: &TensorHandleRef<R>,
296) {
297    let num_elems: usize = input.shape.iter().product();
298    let line_size = tensor_line_size_parallel(
299        R::io_optimized_line_sizes_unchecked(size_of::<F>()),
300        input.shape,
301        input.strides,
302        input.shape.len() - 1,
303    );
304    let cube_dim = CubeDim::default();
305    let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);
306
307    match scheme {
308        QuantScheme {
309            level: QuantLevel::Tensor | QuantLevel::Block(_),
310            mode: QuantMode::Symmetric,
311            value,
312            store: QuantStore::Native,
313            ..
314        } => {
315            let launch = match value {
316                QuantValue::Q8F | QuantValue::Q8S => {
317                    dequantize_symmetric_native_kernel::launch_unchecked::<F, FS, i8, R>
318                }
319                QuantValue::E4M3 => {
320                    dequantize_symmetric_native_kernel::launch_unchecked::<F, FS, e4m3, R>
321                }
322                QuantValue::E5M2 => {
323                    dequantize_symmetric_native_kernel::launch_unchecked::<F, FS, e5m2, R>
324                }
325                QuantValue::E2M1 => {
326                    dequantize_symmetric_native_kernel::launch_unchecked::<F, FS, e2m1x2, R>
327                }
328                other => panic!("Unsupported quantization value {other:?}"),
329            };
330
331            unsafe {
332                launch(
333                    client,
334                    cube_count,
335                    cube_dim,
336                    linear_view(client, input, line_size),
337                    scales_view(client, input, scale, 1, &scheme),
338                    linear_view(client, output, line_size),
339                )
340            };
341        }
342        QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
343    }
344}