use cubecl::prelude::*;
use cubecl_common::quant::scheme::*;
use cubecl_common::{e2m1x2, e4m3, e5m2};
use cubecl_core as cubecl;
#[cube]
pub fn dequantize_aligned<Q: Scalar, S: CubePrimitive, F: Numeric, NQ: Size, NF: Size>(
value: Vector<Q, NQ>,
scale: S,
#[comptime] scheme: QuantScheme,
) -> Vector<F, NF> {
let q_values = match scheme.store {
QuantStore::Native | QuantStore::PackedNative(_) => Vector::<F, NF>::cast_from(value),
QuantStore::PackedU32(_) => unpack_cast_u32::<F, NQ, NF>(Vector::cast_from(value), scheme),
};
let scale = Vector::<F, NF>::cast_from(scale);
match scheme.mode {
QuantMode::Symmetric => q_values * scale,
}
}
#[cube]
pub fn unpack_cast_u32<F: Numeric, NQ: Size, NF: Size>(
value: Vector<u32, NQ>,
#[comptime] scheme: QuantScheme,
) -> Vector<F, NF> {
let num_quants = scheme.num_quants();
let native_packing = scheme.native_packing();
let size_bits = scheme.size_bits_value();
let mask = comptime![packing_mask(scheme)];
let size!(NP) = native_packing;
let mut out = Vector::<F, NF>::empty();
#[unroll]
for vector_idx in 0..value.size() {
let packed_val = value[vector_idx];
let out_offset = vector_idx * num_quants;
#[unroll]
for packed_idx in range_stepped(0, num_quants, native_packing) {
let shift = packed_idx * size_bits;
let value = (packed_val >> shift as u32) & mask;
let float_value = cast_masked::<F, NP>(value, scheme);
#[unroll]
for native_idx in 0..native_packing {
let out_offset = out_offset + packed_idx + native_idx;
out[out_offset] = float_value[native_idx];
}
}
}
out
}
fn packing_mask(scheme: QuantScheme) -> u32 {
let bits = match scheme.value {
QuantValue::E2M1 => 8, other => other.size_bits(),
};
(1u32 << bits) - 1
}
#[cube]
fn cast_masked<F: Numeric, N: Size>(value: u32, #[comptime] scheme: QuantScheme) -> Vector<F, N> {
match scheme.value {
QuantValue::E5M2 => Vector::<F, N>::cast_from(e5m2::from_bits(value as u8)),
QuantValue::E4M3 => Vector::<F, N>::cast_from(e4m3::from_bits(value as u8)),
QuantValue::E2M1 => Vector::<F, N>::cast_from(e2m1x2::from_bits(value as u8)),
QuantValue::Q8F
| QuantValue::Q4F
| QuantValue::Q2F
| QuantValue::Q8S
| QuantValue::Q4S
| QuantValue::Q2S => {
let size_quant = scheme.size_bits_value() as u32;
let sign_bit = 1u32 << (size_quant - 1);
let two_pow_n = 1 << size_quant;
let raw_i32 = value as i32;
let is_negative = (value >= sign_bit) as i32; let signed_value = raw_i32 - (is_negative * two_pow_n);
Vector::<F, N>::cast_from(signed_value)
}
}
}