cubecl_quant/
dequantize.rs

1#![allow(missing_docs)] // pub cube modules
2
3use cubecl::prelude::*;
4use cubecl_core::{
5    self as cubecl, calculate_cube_count_elemwise,
6    ir::{ElemType, FloatKind, IntKind},
7    tensor_line_size_parallel,
8};
9use cubecl_runtime::TypeUsage;
10
11use crate::{
12    layout::{ScalesView, scales_view},
13    scheme::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue},
14};
15use cubecl_std::tensor::{
16    View,
17    layout::linear::{LinearView, linear_view},
18};
19
20/// Dequantize a line of values into floating-point values using the provided scale.
21#[cube]
22pub fn dequantize_symmetric<F: Float, FS: CubePrimitive>(value: Line<F>, scale: FS) -> Line<F> {
23    // x = scale * x_q
24    Line::cast_from(scale) * value
25}
26
27/// Dequantize the value at a specified position using the provided quantization scheme.
28///
29/// Returns a line of floating-point values. The number of values in the line depends on the number of packed
30/// values in the stored quantization type.
31#[cube]
32pub fn dequantize_symmetric_packed_values<F: Float, FS: CubePrimitive, QI: Int>(
33    position: u32,
34    values: &View<Line<QI>, u32>,
35    scales: &View<FS, u32>,
36    #[comptime] scheme: QuantScheme,
37) -> Array<Line<F>> {
38    dequantize_symmetric_packed_value_at::<F, FS, QI>(position, values[position], scales, scheme)
39}
40
41/// Dequantize a single value using the scale at the specified position.
42///
43/// Returns a line of floating-point values. The number of values in the line depends on the number of packed
44/// values in the stored quantization type.
45#[cube]
46pub fn dequantize_symmetric_packed_value_at<F: Float, FS: CubePrimitive, QI: Int>(
47    position: u32,
48    values: Line<QI>,
49    scales: &View<FS, u32>,
50    #[comptime] scheme: QuantScheme,
51) -> Array<Line<F>> {
52    dequantize_symmetric_packed_value::<F, FS, QI>(values, scales, position, scheme)
53}
54
55/// Dequantize a single packed value using the scale provided.
56///
57/// Returns a line of floating-point values. The number of values in the line depends on the number of packed
58/// values in the stored quantization type.
59#[cube]
60pub fn dequantize_symmetric_packed_value<F: Float, FS: CubePrimitive, QS: Int>(
61    values: Line<QS>,
62    scales: &View<FS, u32>,
63    position: u32,
64    #[comptime] scheme: QuantScheme,
65) -> Array<Line<F>> {
66    let line_size_values = values.line_size();
67    let num_quants = comptime!(scheme.num_quants() as u32);
68    let mut tmp = Array::vectorized(line_size_values, num_quants);
69
70    #[unroll]
71    for i in 0..line_size_values {
72        let floats = unpack_q::<F, QS>(values[i], scheme.value, scheme.store);
73        let scale = scales[(position * line_size_values) + i * num_quants];
74        let values = dequantize_symmetric::<F, FS>(floats, scale);
75        tmp[i] = values;
76    }
77
78    tmp
79}
80
81/// Unpack a quantized integer into a line of floating-point values, according to the specified quantization input type.
82///
83/// This handles types where multiple quantized values are packed into a single integer (the stored quantization type).
84#[allow(clippy::explicit_counter_loop)]
85#[cube]
86fn unpack_q<F: Float, QS: Int>(
87    value: QS,
88    #[comptime] quant: QuantValue,
89    #[comptime] store: QuantStore,
90) -> Line<F> {
91    let size_quant = comptime!(quant.size_bits() as u32);
92    let size_store = comptime!(store.size_bits(&quant) as u32);
93    let num_quant = comptime!(size_store / size_quant);
94
95    let mut output = Line::empty(num_quant);
96    let mut position = comptime!(0);
97
98    let mask = QS::cast_from(comptime!((1 << size_quant) - 1));
99    let sign_bit = QS::cast_from(comptime!(1 << (size_quant - 1)));
100    let two_pow_n = comptime!(1 << size_quant);
101
102    #[unroll]
103    for _ in 0..num_quant {
104        let offset = QS::cast_from(comptime!(position * size_quant));
105        let raw = (value >> offset) & mask;
106
107        // Branchless two's complement conversion
108        // If raw >= 2^(n-1), then result = raw - 2^n
109        let raw_i32 = i32::cast_from(raw);
110        let is_negative = i32::cast_from(raw >= sign_bit); // 1 if negative, 0 if positive
111        let signed_value = raw_i32 - (is_negative * two_pow_n);
112
113        output[position] = F::cast_from(signed_value);
114        comptime!(position += 1);
115    }
116
117    output
118}
119
120#[cube(launch_unchecked)]
121fn dequantize_symmetric_packed_kernel<F: Float, FS: Numeric>(
122    input: &LinearView<Line<u32>>,
123    scales: &ScalesView<FS>,
124    output: &mut LinearView<Line<F>, ReadWrite>,
125    #[comptime] scheme: QuantScheme,
126    #[define(F, FS)] _dtypes: [StorageType; 2],
127) {
128    if !input.is_in_bounds(ABSOLUTE_POS) {
129        terminate!();
130    }
131
132    let line_size_in = input.line_size();
133    let line_size_out = output.line_size();
134
135    comptime! {
136        assert_eq!(line_size_out, scheme.num_quants() as u32);
137    }
138
139    let values = input[ABSOLUTE_POS];
140    let packed_pos = ABSOLUTE_POS * comptime![scheme.num_quants() as u32];
141
142    let out = dequantize_symmetric_packed_value::<F, FS, u32>(values, scales, packed_pos, scheme);
143
144    #[unroll]
145    for i in 0..line_size_in {
146        output[ABSOLUTE_POS * line_size_in + i] = out[i];
147    }
148}
149
150#[cube(launch_unchecked)]
151fn dequantize_symmetric_native_kernel<F: Float, FS: Numeric, Q: Numeric>(
152    input: &LinearView<Line<Q>>,
153    scale: &ScalesView<FS>,
154    output: &mut LinearView<Line<F>, ReadWrite>,
155    #[define(F, FS, Q)] _dtypes: [StorageType; 3],
156) {
157    if !input.is_in_bounds(ABSOLUTE_POS) {
158        terminate!();
159    }
160
161    let native_packing = Q::packing_factor();
162    // Absolute pos represents the logical block (scale) used to dequantize, not layout
163    let scale = scale[ABSOLUTE_POS * input.line_size() * native_packing];
164
165    output[ABSOLUTE_POS] =
166        dequantize_symmetric::<F, FS>(Line::cast_from(input[ABSOLUTE_POS]), scale);
167}
168
169#[allow(clippy::result_large_err)]
170/// Convert the tensor back to a higher precision data type.
171pub fn launch_ref<R: Runtime>(
172    client: &ComputeClient<R>,
173    values: &TensorHandleRef<R>,
174    output: &TensorHandleRef<R>,
175    params: &TensorHandleRef<'_, R>,
176    scheme: &QuantScheme,
177    input_dtype: StorageType,
178) -> Result<(), LaunchError> {
179    let dtype_scale: StorageType = ElemType::from_quant_param(scheme.param).into();
180
181    match scheme {
182        QuantScheme {
183            store: QuantStore::U32,
184            ..
185        } => dequantize_packed(
186            client,
187            values,
188            *scheme,
189            params,
190            output,
191            input_dtype,
192            dtype_scale,
193        ),
194        QuantScheme {
195            value:
196                QuantValue::Q8F
197                | QuantValue::Q8S
198                | QuantValue::E4M3
199                | QuantValue::E5M2
200                | QuantValue::E2M1,
201            store: QuantStore::Native,
202            ..
203        } => {
204            if !i8::supported_uses(client).contains(TypeUsage::Conversion) {
205                panic!(
206                    "{:?} is not supported for native quantization",
207                    scheme.value
208                );
209            }
210
211            dequantize_native(
212                client,
213                values,
214                *scheme,
215                params,
216                output,
217                input_dtype,
218                dtype_scale,
219            )
220        }
221        QuantScheme {
222            store: QuantStore::Native,
223            value,
224            ..
225        } => {
226            panic!("{value:?} is not supported for native quantization");
227        }
228    }
229}
230
231fn dequantize_packed<R: Runtime>(
232    client: &ComputeClient<R>,
233    input: &TensorHandleRef<R>,
234    scheme: QuantScheme,
235    scale: &TensorHandleRef<'_, R>,
236    output: &TensorHandleRef<R>,
237    input_dtype: StorageType,
238    scale_dtype: StorageType,
239) -> Result<(), LaunchError> {
240    let num_elems_input: usize = input.shape.iter().product();
241
242    let mut line_size_in = tensor_line_size_parallel(
243        client.io_optimized_line_sizes_unchecked(input.elem_size),
244        input.shape,
245        input.strides,
246        input.shape.len() - 1,
247    );
248    let num_quants = scheme.num_quants() as u8;
249    let line_size_out = num_quants;
250    let rank = output.shape.len();
251
252    if !output.shape[rank - 1].is_multiple_of(line_size_out as usize) {
253        line_size_in = 1;
254    }
255
256    let cube_dim = CubeDim::default();
257    let cube_count =
258        calculate_cube_count_elemwise(num_elems_input / line_size_in as usize, cube_dim);
259
260    match scheme {
261        QuantScheme {
262            level: QuantLevel::Tensor | QuantLevel::Block(_),
263            store: QuantStore::U32,
264            mode: QuantMode::Symmetric,
265            ..
266        } => unsafe {
267            dequantize_symmetric_packed_kernel::launch_unchecked(
268                client,
269                cube_count,
270                cube_dim,
271                linear_view(client, input, line_size_in),
272                scales_view(client, input, scale, 1, &scheme),
273                linear_view(client, output, line_size_out),
274                scheme,
275                [input_dtype, scale_dtype],
276            )
277        },
278        QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
279    }
280}
281
282fn dequantize_native<R: Runtime>(
283    client: &ComputeClient<R>,
284    input: &TensorHandleRef<R>,
285    scheme: QuantScheme,
286    scale: &TensorHandleRef<'_, R>,
287    output: &TensorHandleRef<R>,
288    input_dtype: StorageType,
289    scale_dtype: StorageType,
290) -> Result<(), LaunchError> {
291    let num_elems: usize = input.shape.iter().product();
292    let line_size = tensor_line_size_parallel(
293        client.io_optimized_line_sizes_unchecked(input_dtype.size()),
294        input.shape,
295        input.strides,
296        input.shape.len() - 1,
297    );
298    let cube_dim = CubeDim::default();
299    let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);
300
301    match scheme {
302        QuantScheme {
303            level: QuantLevel::Tensor | QuantLevel::Block(_),
304            mode: QuantMode::Symmetric,
305            value,
306            store: QuantStore::Native,
307            ..
308        } => {
309            let quant_dtype: ElemType = match value {
310                QuantValue::Q8F | QuantValue::Q8S => ElemType::Int(IntKind::I8),
311                QuantValue::E4M3 => ElemType::Float(FloatKind::E4M3),
312                QuantValue::E5M2 => ElemType::Float(FloatKind::E5M2),
313                QuantValue::E2M1 => ElemType::Float(FloatKind::E2M1),
314                other => panic!("Unsupported quantization value {other:?}"),
315            };
316
317            println!("{input_dtype:?} {scale_dtype:?} {quant_dtype:?}");
318            unsafe {
319                dequantize_symmetric_native_kernel::launch_unchecked(
320                    client,
321                    cube_count,
322                    cube_dim,
323                    linear_view(client, input, line_size),
324                    scales_view(client, input, scale, 1, &scheme),
325                    linear_view(client, output, line_size),
326                    [input_dtype, scale_dtype, quant_dtype.into()],
327                )
328            }
329        }
330        QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
331    }
332}