cubecl_std/quant/
dequantize.rs1use cubecl::prelude::*;
2use cubecl_common::quant::scheme::*;
3use cubecl_common::{e2m1x2, e4m3, e5m2};
4use cubecl_core as cubecl;
5
6#[cube]
9pub fn dequantize_aligned<Q: CubePrimitive, S: CubePrimitive, F: Numeric>(
10 value: Line<Q>,
11 scale: S,
12 #[comptime] scheme: QuantScheme,
13) -> Line<F> {
14 let q_values = match scheme.store {
15 QuantStore::Native => Line::<F>::cast_from(value),
16 QuantStore::U32 => unpack_cast_u32::<F>(Line::cast_from(value), scheme),
17 };
18 let scale = Line::<F>::cast_from(scale);
19
20 match scheme.mode {
21 QuantMode::Symmetric => q_values * scale,
22 }
23}
24
25#[cube]
27pub fn unpack_cast_u32<F: Numeric>(value: Line<u32>, #[comptime] scheme: QuantScheme) -> Line<F> {
28 let num_quants = comptime![scheme.num_quants() as u32];
29 let native_packing = comptime![scheme.native_packing() as u32];
30 let out_line_size = comptime![value.line_size() * num_quants];
31 let size_bits = comptime![scheme.size_bits_value() as u32];
32 let mask = comptime![packing_mask(scheme)];
33
34 let mut out = Line::<F>::empty(out_line_size);
35
36 #[unroll]
37 for line_idx in 0..value.line_size() {
38 let packed_val = value[line_idx];
39 let out_offset = comptime![line_idx * num_quants];
40 #[unroll]
41 for packed_idx in range_stepped(0, num_quants, native_packing) {
42 let shift = packed_idx * size_bits;
43 let value = (packed_val >> shift) & mask;
44
45 let float_value = cast_masked::<F>(value, scheme);
46
47 #[unroll]
48 for native_idx in 0..native_packing {
49 let out_offset = comptime![out_offset + packed_idx + native_idx];
50 out[out_offset] = float_value[native_idx];
51 }
52 }
53 }
54
55 out
56}
57
58fn packing_mask(scheme: QuantScheme) -> u32 {
61 let bits = match scheme.value {
62 QuantValue::E2M1 => 8, other => other.size_bits(),
64 };
65 (1u32 << bits) - 1
66}
67
68#[cube]
76fn cast_masked<F: Numeric>(value: u32, #[comptime] scheme: QuantScheme) -> Line<F> {
77 match scheme.value {
78 QuantValue::E5M2 => Line::<F>::cast_from(e5m2::reinterpret(value as u8)),
80 QuantValue::E4M3 => Line::<F>::cast_from(e4m3::reinterpret(value as u8)),
81 QuantValue::E2M1 => Line::<F>::cast_from(e2m1x2::reinterpret(value as u8)),
82 QuantValue::Q8F
83 | QuantValue::Q4F
84 | QuantValue::Q2F
85 | QuantValue::Q8S
86 | QuantValue::Q4S
87 | QuantValue::Q2S => {
88 let size_quant = comptime!(scheme.size_bits_value() as u32);
89 let sign_bit = comptime!(1u32 << (size_quant - 1));
90 let two_pow_n = comptime!(1 << size_quant);
91
92 let raw_i32 = value as i32;
95 let is_negative = (value >= sign_bit) as i32; let signed_value = raw_i32 - (is_negative * two_pow_n);
97 Line::<F>::cast_from(signed_value)
98 }
99 }
100}