cubecl_std/quant/
dequantize.rs1use cubecl::prelude::*;
2use cubecl_common::quant::scheme::*;
3use cubecl_common::{e2m1x2, e4m3, e5m2};
4use cubecl_core as cubecl;
5
6#[cube]
9pub fn dequantize_aligned<Q: Scalar, S: CubePrimitive, F: Numeric, NQ: Size, NF: Size>(
10 value: Vector<Q, NQ>,
11 scale: S,
12 #[comptime] scheme: QuantScheme,
13) -> Vector<F, NF> {
14 let q_values = match scheme.store {
15 QuantStore::Native | QuantStore::PackedNative(_) => Vector::<F, NF>::cast_from(value),
16 QuantStore::PackedU32(_) => unpack_cast_u32::<F, NQ, NF>(Vector::cast_from(value), scheme),
17 };
18 let scale = Vector::<F, NF>::cast_from(scale);
19
20 match scheme.mode {
21 QuantMode::Symmetric => q_values * scale,
22 }
23}
24
25#[cube]
27pub fn unpack_cast_u32<F: Numeric, NQ: Size, NF: Size>(
28 value: Vector<u32, NQ>,
29 #[comptime] scheme: QuantScheme,
30) -> Vector<F, NF> {
31 let num_quants = scheme.num_quants();
32 let native_packing = scheme.native_packing();
33 let size_bits = scheme.size_bits_value();
34 let mask = comptime![packing_mask(scheme)];
35 let size!(NP) = native_packing;
36
37 let mut out = Vector::<F, NF>::empty();
38
39 #[unroll]
40 for vector_idx in 0..value.size() {
41 let packed_val = value[vector_idx];
42 let out_offset = vector_idx * num_quants;
43 #[unroll]
44 for packed_idx in range_stepped(0, num_quants, native_packing) {
45 let shift = packed_idx * size_bits;
46 let value = (packed_val >> shift as u32) & mask;
47
48 let float_value = cast_masked::<F, NP>(value, scheme);
49
50 #[unroll]
51 for native_idx in 0..native_packing {
52 let out_offset = out_offset + packed_idx + native_idx;
53 out[out_offset] = float_value[native_idx];
54 }
55 }
56 }
57
58 out
59}
60
61fn packing_mask(scheme: QuantScheme) -> u32 {
64 let bits = match scheme.value {
65 QuantValue::E2M1 => 8, other => other.size_bits(),
67 };
68 (1u32 << bits) - 1
69}
70
71#[cube]
79fn cast_masked<F: Numeric, N: Size>(value: u32, #[comptime] scheme: QuantScheme) -> Vector<F, N> {
80 match scheme.value {
81 QuantValue::E5M2 => Vector::<F, N>::cast_from(e5m2::from_bits(value as u8)),
83 QuantValue::E4M3 => Vector::<F, N>::cast_from(e4m3::from_bits(value as u8)),
84 QuantValue::E2M1 => Vector::<F, N>::cast_from(e2m1x2::from_bits(value as u8)),
85 QuantValue::Q8F
86 | QuantValue::Q4F
87 | QuantValue::Q2F
88 | QuantValue::Q8S
89 | QuantValue::Q4S
90 | QuantValue::Q2S => {
91 let size_quant = scheme.size_bits_value() as u32;
92 let sign_bit = 1u32 << (size_quant - 1);
93 let two_pow_n = 1 << size_quant;
94
95 let raw_i32 = value as i32;
98 let is_negative = (value >= sign_bit) as i32; let signed_value = raw_i32 - (is_negative * two_pow_n);
100 Vector::<F, N>::cast_from(signed_value)
101 }
102 }
103}