burn_jit/kernel/quantization/
qtensor.rs1#![allow(missing_docs)] use burn_tensor::quantization::QuantizationScheme;
4use cubecl::prelude::*;
5
6#[derive(CubeLaunch)]
8pub struct QParams {
9 #[cube(comptime)]
10 scheme: QuantizationScheme,
11}
12
13pub type QTensor = Array<Line<u32>>;
15
16#[cube]
17impl QParams {
18 pub fn new(scheme: QuantizationScheme) -> Self {
20 QParams { scheme }
21 }
22
23 pub fn values(&self, tensor: &QTensor) -> (f32, i32) {
25 let len = tensor.len();
26 match comptime!(self.scheme) {
27 QuantizationScheme::PerTensorAffine(_) => match comptime!(tensor.line_size()) {
28 1 => (
30 f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]),
31 i32::cast_from(tensor[len - 2][tensor.line_size() - 1]),
32 ),
33 _ => {
35 let values = tensor[len - 1];
36 (
37 f32::bitcast_from(values[tensor.line_size() - 1]),
38 i32::cast_from(values[tensor.line_size() - 2]),
39 )
40 }
41 },
42 QuantizationScheme::PerTensorSymmetric(_) => (
44 f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]),
45 0,
46 ),
47 }
48 }
49}