cubecl_quant/layout/
scales.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_std::{
4    FastDivmod, FastDivmodArgs,
5    tensor::{
6        launch::{TypedView, TypedViewLaunch},
7        layout::{Coords1d, Layout, LayoutExpand},
8    },
9};
10
11use crate::scheme::{QuantLevel, QuantScheme};
12
13/// Layout for quantization scales, indexed by quant element index and returns the corresponding
14/// scale based on the quantization type.
15#[derive(CubeType, CubeLaunch)]
16pub enum ScalesLayout {
17    PerTensor(PerTensorLayout),
18    BlockScaled(BlockScaledLayout),
19}
20
21#[cube]
22impl Layout for ScalesLayout {
23    type Coordinates = Coords1d;
24    type SourceCoordinates = Coords1d;
25
26    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
27        match self {
28            ScalesLayout::PerTensor(layout) => layout.to_source_pos(pos),
29            ScalesLayout::BlockScaled(layout) => layout.to_source_pos(pos),
30        }
31    }
32
33    fn shape(&self) -> Self::Coordinates {
34        match self {
35            ScalesLayout::PerTensor(layout) => layout.shape(),
36            ScalesLayout::BlockScaled(layout) => layout.shape(),
37        }
38    }
39
40    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
41        match self {
42            ScalesLayout::PerTensor(layout) => layout.is_in_bounds(pos),
43            ScalesLayout::BlockScaled(layout) => layout.is_in_bounds(pos),
44        }
45    }
46
47    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
48        match self {
49            ScalesLayout::PerTensor(layout) => layout.to_source_pos_checked(pos),
50            ScalesLayout::BlockScaled(layout) => layout.to_source_pos_checked(pos),
51        }
52    }
53}
54
55#[cube]
56impl ScalesLayout {
57    /// Whether the position is at the start of a new block. Used for electing a unit to write each
58    /// scale.
59    pub fn is_block_start(&self, pos: u32) -> bool {
60        match self {
61            ScalesLayout::PerTensor(layout) => layout.is_block_start(pos),
62            ScalesLayout::BlockScaled(layout) => layout.is_block_start(pos),
63        }
64    }
65}
66
67#[derive(CubeType, CubeLaunch)]
68pub struct PerTensorLayout {
69    tensor_len: u32,
70}
71
72#[cube]
73impl PerTensorLayout {
74    pub fn new(tensor_len: u32) -> Self {
75        PerTensorLayout { tensor_len }
76    }
77}
78
79#[cube]
80impl Layout for PerTensorLayout {
81    type Coordinates = Coords1d;
82    type SourceCoordinates = Coords1d;
83
84    fn to_source_pos(&self, _pos: Self::Coordinates) -> Self::SourceCoordinates {
85        0u32.runtime()
86    }
87
88    fn shape(&self) -> Self::Coordinates {
89        self.tensor_len
90    }
91
92    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
93        pos < self.tensor_len
94    }
95
96    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
97        (self.to_source_pos(pos), self.is_in_bounds(pos))
98    }
99}
100
101#[cube]
102impl PerTensorLayout {
103    /// Whether the position is at the start of a new block. Used for electing a unit to write each
104    /// scale.
105    pub fn is_block_start(&self, pos: u32) -> bool {
106        pos == 0
107    }
108}
109
110#[derive(CubeType, CubeLaunch)]
111pub struct BlockScaledLayout {
112    tensor_shape: Sequence<FastDivmod>,
113    tensor_len: u32,
114    scales_strides: Sequence<u32>,
115    #[cube(comptime)]
116    block_size: Vec<u8>,
117    #[cube(comptime)]
118    scales_line_size: u32,
119}
120
121#[cube]
122impl BlockScaledLayout {
123    pub fn new(
124        tensor_shape: Sequence<FastDivmod>,
125        tensor_len: u32,
126        scales_strides: Sequence<u32>,
127        #[comptime] block_size: Vec<u8>,
128        #[comptime] scales_line_size: u32,
129    ) -> Self {
130        BlockScaledLayout {
131            tensor_shape,
132            tensor_len,
133            scales_strides,
134            block_size,
135            scales_line_size,
136        }
137    }
138}
139
140#[cube]
141impl Layout for BlockScaledLayout {
142    type Coordinates = Coords1d;
143    type SourceCoordinates = Coords1d;
144
145    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
146        let rank = comptime![self.scales_strides.len()];
147        let mut offs = pos;
148        let mut scale_offs = 0;
149
150        #[unroll]
151        for i in 0..rank {
152            let dim = comptime![rank - i - 1];
153            let block_size_local = comptime![self.block_size[dim as usize] as u32];
154            let (rem, offs_local) = self.tensor_shape.index(dim).div_mod(offs);
155
156            offs = rem;
157            scale_offs += (offs_local / block_size_local) * *self.scales_strides.index(dim);
158        }
159
160        scale_offs / self.scales_line_size
161    }
162
163    fn shape(&self) -> Self::Coordinates {
164        self.tensor_len
165    }
166
167    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
168        pos < self.tensor_len
169    }
170
171    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
172        (self.to_source_pos(pos), self.is_in_bounds(pos))
173    }
174}
175
176#[cube]
177impl BlockScaledLayout {
178    /// Whether the position is at the start of a new block. Used for electing a unit to write each
179    /// scale.
180    pub fn is_block_start(&self, pos: u32) -> bool {
181        let rank = comptime![self.scales_strides.len()];
182        let mut offs = pos;
183        let mut is_start = true;
184
185        #[unroll]
186        for i in 0..rank {
187            let dim = comptime![rank - i - 1];
188            let block_size_local = comptime![self.block_size[dim as usize] as u32];
189            let (rem, offs_local) = self.tensor_shape.index(dim).div_mod(offs);
190            offs = rem;
191            is_start &= offs_local.is_multiple_of(block_size_local);
192        }
193
194        is_start
195    }
196}
197
198/// [TensorView] with a linear layout inferred from the shape/strides at launch.
199/// Useful for elementwise kernels.
200pub type ScalesView<E, IO = ReadOnly> = TypedView<E, ScalesLayout, IO>;
201/// Launch type for [LinearTensorView].
202pub type ScalesViewLaunch<'a, R> = TypedViewLaunch<'a, ScalesLayout, R>;
203
204/// Create a scales view from the values and scales handle, line size and quantization scheme.
205/// `values` should be *the quantized tensor*, and will be adjusted by `num_quants`.
206pub fn scales_view<'a, R: Runtime>(
207    client: &ComputeClient<R::Server>,
208    values: &'a TensorHandleRef<'a, R>,
209    scales: &'a TensorHandleRef<'a, R>,
210    scales_line_size: u8,
211    quant_scheme: &QuantScheme,
212) -> ScalesViewLaunch<'a, R> {
213    let layout = scales_layout(client, values, scales, scales_line_size, quant_scheme);
214    let len = scales.shape.iter().product::<usize>();
215    let buffer = unsafe {
216        ArrayArg::from_raw_parts_and_size(scales.handle, len, scales_line_size, scales.elem_size)
217    };
218    ScalesViewLaunch::new(buffer, layout)
219}
220
221pub fn scales_layout<'a, R: Runtime>(
222    client: &ComputeClient<R::Server>,
223    values: &'a TensorHandleRef<'a, R>,
224    scales: &'a TensorHandleRef<'a, R>,
225    scales_line_size: u8,
226    scheme: &QuantScheme,
227) -> ScalesLayoutArgs<'a, R> {
228    let values_len = values.shape.iter().product::<usize>() * scheme.num_quants();
229    let values_len = ScalarArg::new(values_len as u32);
230
231    match &scheme.level {
232        QuantLevel::Tensor => ScalesLayoutArgs::PerTensor(PerTensorLayoutLaunch::new(values_len)),
233        QuantLevel::Block(block_size) => {
234            let tensor_shape = shape_divmod_quant(client, values.shape, scheme.num_quants());
235            let scales_strides = strides_seq(scales.strides);
236            ScalesLayoutArgs::BlockScaled(BlockScaledLayoutLaunch::new(
237                tensor_shape,
238                values_len,
239                scales_strides,
240                block_size.to_dim_vec(values.shape.len()),
241                scales_line_size as u32,
242            ))
243        }
244    }
245}
246
247fn shape_divmod_quant<'a, R: Runtime>(
248    client: &ComputeClient<R::Server>,
249    shape: &'a [usize],
250    num_quants: usize,
251) -> SequenceArg<'a, R, FastDivmod> {
252    let mut out_seq = SequenceArg::new();
253    for s in &shape[..shape.len() - 1] {
254        out_seq.push(FastDivmodArgs::new(client, *s as u32));
255    }
256    let last = *shape.last().unwrap() * num_quants;
257    out_seq.push(FastDivmodArgs::new(client, last as u32));
258    out_seq
259}
260
261fn strides_seq<'a, R: Runtime>(strides: &'a [usize]) -> SequenceArg<'a, R, u32> {
262    let mut out_seq = SequenceArg::new();
263    for s in strides {
264        out_seq.push(ScalarArg::new(*s as u32));
265    }
266    out_seq
267}