cubek_quant/
dequantize.rs

1#![allow(missing_docs)] // pub cube modules
2
3use cubecl::features::TypeUsage;
4use cubecl::prelude::*;
5use cubecl::{
6    calculate_cube_count_elemwise,
7    ir::{ElemType, FloatKind, IntKind},
8    tensor_line_size_parallel,
9};
10
11use crate::{
12    layout::{ScalesView, scales_view},
13    scheme::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue},
14};
15use cubecl::std::tensor::{
16    View,
17    layout::linear::{LinearView, linear_view},
18};
19
20/// Dequantize a line of values into floating-point values using the provided scale.
21#[cube]
22pub fn dequantize_symmetric<F: Float, FS: CubePrimitive>(value: Line<F>, scale: FS) -> Line<F> {
23    // x = scale * x_q
24    Line::cast_from(scale) * value
25}
26
27/// Dequantize the value at a specified position using the provided quantization scheme.
28///
29/// Returns a line of floating-point values. The number of values in the line depends on the number of packed
30/// values in the stored quantization type.
31#[cube]
32pub fn dequantize_symmetric_packed_values<F: Float, FS: CubePrimitive, QI: Int>(
33    position: usize,
34    values: &View<Line<QI>, usize>,
35    scales: &View<FS, usize>,
36    #[comptime] scheme: QuantScheme,
37) -> Array<Line<F>> {
38    dequantize_symmetric_packed_value_at::<F, FS, QI>(position, values[position], scales, scheme)
39}
40
41/// Dequantize a single value using the scale at the specified position.
42///
43/// Returns a line of floating-point values. The number of values in the line depends on the number of packed
44/// values in the stored quantization type.
45#[cube]
46pub fn dequantize_symmetric_packed_value_at<F: Float, FS: CubePrimitive, QI: Int>(
47    position: usize,
48    values: Line<QI>,
49    scales: &View<FS, usize>,
50    #[comptime] scheme: QuantScheme,
51) -> Array<Line<F>> {
52    dequantize_symmetric_packed_value::<F, FS, QI>(values, scales, position, scheme)
53}
54
55/// Dequantize a single packed value using the scale provided.
56///
57/// Returns a line of floating-point values. The number of values in the line depends on the number of packed
58/// values in the stored quantization type.
59#[cube]
60pub fn dequantize_symmetric_packed_value<F: Float, FS: CubePrimitive, QS: Int>(
61    values: Line<QS>,
62    scales: &View<FS, usize>,
63    position: usize,
64    #[comptime] scheme: QuantScheme,
65) -> Array<Line<F>> {
66    let line_size_values = values.line_size();
67    let num_quants = scheme.num_quants();
68    let mut tmp = Array::lined(line_size_values, num_quants);
69
70    #[unroll]
71    for i in 0..line_size_values {
72        let floats = unpack_q::<F, QS>(values[i], scheme.value, scheme.store);
73        let scale = scales[(position * line_size_values) + i * num_quants];
74        let values = dequantize_symmetric::<F, FS>(floats, scale);
75        tmp[i] = values;
76    }
77
78    tmp
79}
80
81/// Unpack a quantized integer into a line of floating-point values, according to the specified quantization input type.
82///
83/// This handles types where multiple quantized values are packed into a single integer (the stored quantization type).
84#[allow(clippy::explicit_counter_loop)]
85#[cube]
86fn unpack_q<F: Float, QS: Int>(
87    value: QS,
88    #[comptime] quant: QuantValue,
89    #[comptime] store: QuantStore,
90) -> Line<F> {
91    let size_quant = quant.size_bits();
92    let size_store = store.size_bits(&quant);
93    let num_quant = size_store / size_quant;
94
95    let mut output = Line::empty(num_quant);
96
97    let mask = QS::from_int((1 << size_quant) - 1);
98    let sign_bit = QS::from_int(1 << (size_quant - 1));
99    let two_pow_n = 1 << size_quant;
100
101    #[unroll]
102    for position in 0..num_quant {
103        let offset = QS::cast_from(position * size_quant);
104        let raw = (value >> offset) & mask;
105
106        // Branchless two's complement conversion
107        // If raw >= 2^(n-1), then result = raw - 2^n
108        let raw_i32 = i32::cast_from(raw);
109        let is_negative = i32::cast_from(raw >= sign_bit); // 1 if negative, 0 if positive
110        let signed_value = raw_i32 - (is_negative * two_pow_n);
111
112        output[position] = F::cast_from(signed_value);
113    }
114
115    output
116}
117
118#[cube(launch_unchecked)]
119fn dequantize_symmetric_packed_kernel<F: Float, FS: Numeric>(
120    input: &LinearView<Line<u32>>,
121    scales: &ScalesView<FS>,
122    output: &mut LinearView<Line<F>, ReadWrite>,
123    #[comptime] scheme: QuantScheme,
124    #[define(F, FS)] _dtypes: [StorageType; 2],
125) {
126    if !input.is_in_bounds(ABSOLUTE_POS) {
127        terminate!();
128    }
129
130    let line_size_in = input.line_size();
131    let line_size_out = output.line_size();
132
133    comptime! {
134        assert_eq!(line_size_out, scheme.num_quants());
135    }
136
137    let values = input[ABSOLUTE_POS];
138    let packed_pos = ABSOLUTE_POS * scheme.num_quants();
139
140    let out = dequantize_symmetric_packed_value::<F, FS, u32>(values, scales, packed_pos, scheme);
141
142    #[unroll]
143    for i in 0..line_size_in {
144        output[ABSOLUTE_POS * line_size_in + i] = out[i];
145    }
146}
147
148#[cube(launch_unchecked)]
149fn dequantize_symmetric_native_kernel<F: Float, FS: Numeric, Q: Numeric>(
150    input: &LinearView<Line<Q>>,
151    scale: &ScalesView<FS>,
152    output: &mut LinearView<Line<F>, ReadWrite>,
153    #[define(F, FS, Q)] _dtypes: [StorageType; 3],
154) {
155    if !input.is_in_bounds(ABSOLUTE_POS) {
156        terminate!();
157    }
158
159    let native_packing = Q::packing_factor();
160    // Absolute pos represents the logical block (scale) used to dequantize, not layout
161    let scale = scale[ABSOLUTE_POS * input.line_size() * native_packing];
162
163    output[ABSOLUTE_POS] =
164        dequantize_symmetric::<F, FS>(Line::cast_from(input[ABSOLUTE_POS]), scale);
165}
166
167#[allow(clippy::result_large_err)]
168/// Convert the tensor back to a higher precision data type.
169pub fn launch_ref<R: Runtime>(
170    client: &ComputeClient<R>,
171    values: &TensorHandleRef<R>,
172    output: &TensorHandleRef<R>,
173    params: &TensorHandleRef<'_, R>,
174    scheme: &QuantScheme,
175    input_dtype: StorageType,
176) -> Result<(), LaunchError> {
177    let dtype_scale: StorageType = ElemType::from_quant_param(scheme.param).into();
178
179    match scheme {
180        QuantScheme {
181            store: QuantStore::PackedU32(_),
182            ..
183        } => dequantize_packed(
184            client,
185            values,
186            *scheme,
187            params,
188            output,
189            input_dtype,
190            dtype_scale,
191        ),
192        QuantScheme {
193            value: QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2,
194            store: QuantStore::Native,
195            ..
196        }
197        | QuantScheme {
198            value: QuantValue::E2M1,
199            store: QuantStore::PackedNative(_),
200            ..
201        } => {
202            if !i8::supported_uses(client).contains(TypeUsage::Conversion) {
203                panic!(
204                    "{:?} is not supported for native quantization",
205                    scheme.value
206                );
207            }
208
209            dequantize_native(
210                client,
211                values,
212                *scheme,
213                params,
214                output,
215                input_dtype,
216                dtype_scale,
217            )
218        }
219        QuantScheme {
220            store: QuantStore::Native | QuantStore::PackedNative(_),
221            value,
222            ..
223        } => {
224            panic!("{value:?} is not supported for native quantization");
225        }
226    }
227}
228
229fn dequantize_packed<R: Runtime>(
230    client: &ComputeClient<R>,
231    input: &TensorHandleRef<R>,
232    scheme: QuantScheme,
233    scale: &TensorHandleRef<'_, R>,
234    output: &TensorHandleRef<R>,
235    input_dtype: StorageType,
236    scale_dtype: StorageType,
237) -> Result<(), LaunchError> {
238    let num_elems_input: usize = input.shape.iter().product();
239
240    let mut line_size_in = tensor_line_size_parallel(
241        client.io_optimized_line_sizes_unchecked(input.elem_size),
242        input.shape,
243        input.strides,
244        input.shape.len() - 1,
245    );
246    let num_quants = scheme.num_quants();
247    let line_size_out = num_quants;
248    let rank = output.shape.len();
249
250    if !output.shape[rank - 1].is_multiple_of(line_size_out) {
251        line_size_in = 1;
252    }
253
254    let num_elems = num_elems_input / line_size_in as usize;
255    let cube_dim = CubeDim::new(client, num_elems);
256    let cube_count = calculate_cube_count_elemwise(client, num_elems, cube_dim);
257
258    match scheme {
259        QuantScheme {
260            level: QuantLevel::Tensor | QuantLevel::Block(_),
261            store: QuantStore::PackedU32(_),
262            mode: QuantMode::Symmetric,
263            ..
264        } => unsafe {
265            dequantize_symmetric_packed_kernel::launch_unchecked(
266                client,
267                cube_count,
268                cube_dim,
269                linear_view(client, input, line_size_in),
270                scales_view(client, input, scale, 1, &scheme),
271                linear_view(client, output, line_size_out),
272                scheme,
273                [input_dtype, scale_dtype],
274            )
275        },
276        QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
277    }
278}
279
280fn dequantize_native<R: Runtime>(
281    client: &ComputeClient<R>,
282    input: &TensorHandleRef<R>,
283    scheme: QuantScheme,
284    scale: &TensorHandleRef<'_, R>,
285    output: &TensorHandleRef<R>,
286    input_dtype: StorageType,
287    scale_dtype: StorageType,
288) -> Result<(), LaunchError> {
289    let num_elems: usize = input.shape.iter().product();
290    let line_size = tensor_line_size_parallel(
291        client.io_optimized_line_sizes_unchecked(input_dtype.size()),
292        input.shape,
293        input.strides,
294        input.shape.len() - 1,
295    );
296    let working_units = num_elems / line_size as usize;
297    let cube_dim = CubeDim::new(client, working_units);
298    let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);
299
300    match scheme {
301        QuantScheme {
302            level: QuantLevel::Tensor | QuantLevel::Block(_),
303            mode: QuantMode::Symmetric,
304            value,
305            store: QuantStore::Native,
306            ..
307        } => {
308            let quant_dtype: ElemType = match value {
309                QuantValue::Q8F | QuantValue::Q8S => ElemType::Int(IntKind::I8),
310                QuantValue::E4M3 => ElemType::Float(FloatKind::E4M3),
311                QuantValue::E5M2 => ElemType::Float(FloatKind::E5M2),
312                QuantValue::E2M1 => ElemType::Float(FloatKind::E2M1),
313                other => panic!("Unsupported quantization value {other:?}"),
314            };
315
316            unsafe {
317                dequantize_symmetric_native_kernel::launch_unchecked(
318                    client,
319                    cube_count,
320                    cube_dim,
321                    linear_view(client, input, line_size),
322                    scales_view(client, input, scale, 1, &scheme),
323                    linear_view(client, output, line_size),
324                    [input_dtype, scale_dtype, quant_dtype.into()],
325                )
326            }
327        }
328        QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
329    }
330}