burn_cubecl/kernel/quantization/
qtensor.rs1#![allow(missing_docs)] use burn_tensor::quantization::{QuantizationMode, QuantizationScheme};
4use cubecl::prelude::*;
5
6#[derive(CubeLaunch, CubeType)]
8pub struct QParams {
9 #[cube(comptime)]
10 scheme: QuantizationScheme,
11}
12
13pub type QTensor = Array<Line<u32>>;
15
16#[cube]
17impl QParams {
18 pub fn new(scheme: QuantizationScheme) -> Self {
20 QParams { scheme }
21 }
22
23 pub fn values(&self, tensor: &QTensor) -> (f32, i32) {
25 let len = tensor.len();
26 match comptime!(self.scheme) {
27 QuantizationScheme::PerTensor(QuantizationMode::Symmetric, _) => {
29 (f32::reinterpret(tensor[len - 1][tensor.line_size() - 1]), 0)
30 }
31 }
32 }
33}