cubecl_std/quant/
dequantize.rs

1use cubecl::prelude::*;
2use cubecl_common::quant::scheme::*;
3use cubecl_common::{e2m1x2, e4m3, e5m2};
4use cubecl_core as cubecl;
5
6/// Dequantize a line of values, where `line_size * num_quants` is a power of two.
7/// Unaligned values can't be dequantized in place.
8#[cube]
9pub fn dequantize_aligned<Q: CubePrimitive, S: CubePrimitive, F: Numeric>(
10    value: Line<Q>,
11    scale: S,
12    #[comptime] scheme: QuantScheme,
13) -> Line<F> {
14    let q_values = match scheme.store {
15        QuantStore::Native => Line::<F>::cast_from(value),
16        QuantStore::U32 => unpack_cast_u32::<F>(Line::cast_from(value), scheme),
17    };
18    let scale = Line::<F>::cast_from(scale);
19
20    match scheme.mode {
21        QuantMode::Symmetric => q_values * scale,
22    }
23}
24
25/// Unpack a set of values from u32, and convert to the specified floating point format.
26#[cube]
27pub fn unpack_cast_u32<F: Numeric>(value: Line<u32>, #[comptime] scheme: QuantScheme) -> Line<F> {
28    let num_quants = comptime![scheme.num_quants() as u32];
29    let native_packing = comptime![scheme.native_packing() as u32];
30    let out_line_size = comptime![value.line_size() * num_quants];
31    let size_bits = comptime![scheme.size_bits_value() as u32];
32    let mask = comptime![packing_mask(scheme)];
33
34    let mut out = Line::<F>::empty(out_line_size);
35
36    #[unroll]
37    for line_idx in 0..value.line_size() {
38        let packed_val = value[line_idx];
39        let out_offset = comptime![line_idx * num_quants];
40        #[unroll]
41        for packed_idx in range_stepped(0, num_quants, native_packing) {
42            let shift = packed_idx * size_bits;
43            let value = (packed_val >> shift) & mask;
44
45            let float_value = cast_masked::<F>(value, scheme);
46
47            #[unroll]
48            for native_idx in 0..native_packing {
49                let out_offset = comptime![out_offset + packed_idx + native_idx];
50                out[out_offset] = float_value[native_idx];
51            }
52        }
53    }
54
55    out
56}
57
58/// The mask required for each packed value, taking into account the native packing required for
59/// `e2m1`.
60fn packing_mask(scheme: QuantScheme) -> u32 {
61    let bits = match scheme.value {
62        QuantValue::E2M1 => 8, // Packed conversion
63        other => other.size_bits(),
64    };
65    (1u32 << bits) - 1
66}
67
68/// Cast a masked-out value in the low `n` bits of a `u32` to the specified float type.
69/// Applies sign conversion for integer quantization before casting to the float type,
70/// while minifloats are simply truncated to `u8`, reinterpreted and then cast.
71/// For `e2m1`, casting is done on the packed `e2m1x2` representation.
72///
73/// # Returns
74/// Two floating point numbers for `e2m1`, one for all other formats.
75#[cube]
76fn cast_masked<F: Numeric>(value: u32, #[comptime] scheme: QuantScheme) -> Line<F> {
77    match scheme.value {
78        // For minifloat we can assume if they're supported then u8 is supported
79        QuantValue::E5M2 => Line::<F>::cast_from(e5m2::reinterpret(value as u8)),
80        QuantValue::E4M3 => Line::<F>::cast_from(e4m3::reinterpret(value as u8)),
81        QuantValue::E2M1 => Line::<F>::cast_from(e2m1x2::reinterpret(value as u8)),
82        QuantValue::Q8F
83        | QuantValue::Q4F
84        | QuantValue::Q2F
85        | QuantValue::Q8S
86        | QuantValue::Q4S
87        | QuantValue::Q2S => {
88            let size_quant = comptime!(scheme.size_bits_value() as u32);
89            let sign_bit = comptime!(1u32 << (size_quant - 1));
90            let two_pow_n = comptime!(1 << size_quant);
91
92            // Branchless two's complement conversion
93            // If raw >= 2^(n-1), then result = raw - 2^n
94            let raw_i32 = value as i32;
95            let is_negative = (value >= sign_bit) as i32; // 1 if negative, 0 if positive
96            let signed_value = raw_i32 - (is_negative * two_pow_n);
97            Line::<F>::cast_from(signed_value)
98        }
99    }
100}