Skip to main content

cubecl_std/quant/
dequantize.rs

1use cubecl::prelude::*;
2use cubecl_common::quant::scheme::*;
3use cubecl_common::{e2m1x2, e4m3, e5m2};
4use cubecl_core as cubecl;
5
6/// Dequantize a vector of values, where `vector_size * num_quants` is a power of two.
7/// Unaligned values can't be dequantized in place.
8#[cube]
9pub fn dequantize_aligned<Q: Scalar, S: CubePrimitive, F: Numeric, NQ: Size, NF: Size>(
10    value: Vector<Q, NQ>,
11    scale: S,
12    #[comptime] scheme: QuantScheme,
13) -> Vector<F, NF> {
14    let q_values = match scheme.store {
15        QuantStore::Native | QuantStore::PackedNative(_) => Vector::<F, NF>::cast_from(value),
16        QuantStore::PackedU32(_) => unpack_cast_u32::<F, NQ, NF>(Vector::cast_from(value), scheme),
17    };
18    let scale = Vector::<F, NF>::cast_from(scale);
19
20    match scheme.mode {
21        QuantMode::Symmetric => q_values * scale,
22    }
23}
24
25/// Unpack a set of values from u32, and convert to the specified floating point format.
26#[cube]
27pub fn unpack_cast_u32<F: Numeric, NQ: Size, NF: Size>(
28    value: Vector<u32, NQ>,
29    #[comptime] scheme: QuantScheme,
30) -> Vector<F, NF> {
31    let num_quants = scheme.num_quants();
32    let native_packing = scheme.native_packing();
33    let size_bits = scheme.size_bits_value();
34    let mask = comptime![packing_mask(scheme)];
35    let size!(NP) = native_packing;
36
37    let mut out = Vector::<F, NF>::empty();
38
39    #[unroll]
40    for vector_idx in 0..value.size() {
41        let packed_val = value[vector_idx];
42        let out_offset = vector_idx * num_quants;
43        #[unroll]
44        for packed_idx in range_stepped(0, num_quants, native_packing) {
45            let shift = packed_idx * size_bits;
46            let value = (packed_val >> shift as u32) & mask;
47
48            let float_value = cast_masked::<F, NP>(value, scheme);
49
50            #[unroll]
51            for native_idx in 0..native_packing {
52                let out_offset = out_offset + packed_idx + native_idx;
53                out[out_offset] = float_value[native_idx];
54            }
55        }
56    }
57
58    out
59}
60
61/// The mask required for each packed value, taking into account the native packing required for
62/// `e2m1`.
63fn packing_mask(scheme: QuantScheme) -> u32 {
64    let bits = match scheme.value {
65        QuantValue::E2M1 => 8, // Packed conversion
66        other => other.size_bits(),
67    };
68    (1u32 << bits) - 1
69}
70
71/// Cast a masked-out value in the low `n` bits of a `u32` to the specified float type.
72/// Applies sign conversion for integer quantization before casting to the float type,
73/// while minifloats are simply truncated to `u8`, reinterpreted and then cast.
74/// For `e2m1`, casting is done on the packed `e2m1x2` representation.
75///
76/// # Returns
77/// Two floating point numbers for `e2m1`, one for all other formats.
78#[cube]
79fn cast_masked<F: Numeric, N: Size>(value: u32, #[comptime] scheme: QuantScheme) -> Vector<F, N> {
80    match scheme.value {
81        // For minifloat we can assume if they're supported then u8 is supported
82        QuantValue::E5M2 => Vector::<F, N>::cast_from(e5m2::from_bits(value as u8)),
83        QuantValue::E4M3 => Vector::<F, N>::cast_from(e4m3::from_bits(value as u8)),
84        QuantValue::E2M1 => Vector::<F, N>::cast_from(e2m1x2::from_bits(value as u8)),
85        QuantValue::Q8F
86        | QuantValue::Q4F
87        | QuantValue::Q2F
88        | QuantValue::Q8S
89        | QuantValue::Q4S
90        | QuantValue::Q2S => {
91            let size_quant = scheme.size_bits_value() as u32;
92            let sign_bit = 1u32 << (size_quant - 1);
93            let two_pow_n = 1 << size_quant;
94
95            // Branchless two's complement conversion
96            // If raw >= 2^(n-1), then result = raw - 2^n
97            let raw_i32 = value as i32;
98            let is_negative = (value >= sign_bit) as i32; // 1 if negative, 0 if positive
99            let signed_value = raw_i32 - (is_negative * two_pow_n);
100            Vector::<F, N>::cast_from(signed_value)
101        }
102    }
103}