cubek_quant/layout/
scales.rs

1use cubecl::prelude::*;
2use cubecl::std::{
3    FastDivmod, FastDivmodArgs,
4    tensor::{
5        launch::{TypedView, TypedViewLaunch},
6        layout::{Coords1d, Layout, LayoutExpand},
7    },
8};
9
10use crate::scheme::{QuantLevel, QuantScheme};
11
12/// Layout for quantization scales, indexed by quant element index and returns the corresponding
13/// scale based on the quantization type.
14#[derive(CubeType, CubeLaunch)]
15pub enum ScalesLayout {
16    PerTensor(PerTensorLayout),
17    BlockScaled(BlockScaledLayout),
18}
19
20#[cube]
21impl Layout for ScalesLayout {
22    type Coordinates = Coords1d;
23    type SourceCoordinates = Coords1d;
24
25    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
26        match self {
27            ScalesLayout::PerTensor(layout) => layout.to_source_pos(pos),
28            ScalesLayout::BlockScaled(layout) => layout.to_source_pos(pos),
29        }
30    }
31
32    fn shape(&self) -> Self::Coordinates {
33        match self {
34            ScalesLayout::PerTensor(layout) => layout.shape(),
35            ScalesLayout::BlockScaled(layout) => layout.shape(),
36        }
37    }
38
39    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
40        match self {
41            ScalesLayout::PerTensor(layout) => layout.is_in_bounds(pos),
42            ScalesLayout::BlockScaled(layout) => layout.is_in_bounds(pos),
43        }
44    }
45
46    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
47        match self {
48            ScalesLayout::PerTensor(layout) => layout.to_source_pos_checked(pos),
49            ScalesLayout::BlockScaled(layout) => layout.to_source_pos_checked(pos),
50        }
51    }
52}
53
54#[cube]
55impl ScalesLayout {
56    /// Whether the position is at the start of a new block. Used for electing a unit to write each
57    /// scale.
58    pub fn is_block_start(&self, pos: usize) -> bool {
59        match self {
60            ScalesLayout::PerTensor(layout) => layout.is_block_start(pos),
61            ScalesLayout::BlockScaled(layout) => layout.is_block_start(pos),
62        }
63    }
64}
65
66#[derive(CubeType, CubeLaunch)]
67pub struct PerTensorLayout {
68    tensor_len: usize,
69}
70
71#[cube]
72impl PerTensorLayout {
73    pub fn new(tensor_len: usize) -> Self {
74        PerTensorLayout { tensor_len }
75    }
76}
77
78#[cube]
79impl Layout for PerTensorLayout {
80    type Coordinates = Coords1d;
81    type SourceCoordinates = Coords1d;
82
83    fn to_source_pos(&self, _pos: Self::Coordinates) -> Self::SourceCoordinates {
84        0usize.runtime()
85    }
86
87    fn shape(&self) -> Self::Coordinates {
88        self.tensor_len
89    }
90
91    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
92        pos < self.tensor_len
93    }
94
95    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
96        (self.to_source_pos(pos), self.is_in_bounds(pos))
97    }
98}
99
100#[cube]
101impl PerTensorLayout {
102    /// Whether the position is at the start of a new block. Used for electing a unit to write each
103    /// scale.
104    pub fn is_block_start(&self, pos: usize) -> bool {
105        pos == 0
106    }
107}
108
109#[derive(CubeType, CubeLaunch)]
110pub struct BlockScaledLayout {
111    tensor_shape: Sequence<FastDivmod<usize>>,
112    tensor_len: usize,
113    scales_strides: Sequence<usize>,
114    #[cube(comptime)]
115    block_size: Vec<u8>,
116    #[cube(comptime)]
117    scales_line_size: usize,
118}
119
120#[cube]
121impl BlockScaledLayout {
122    pub fn new(
123        tensor_shape: Sequence<FastDivmod<usize>>,
124        tensor_len: usize,
125        scales_strides: Sequence<usize>,
126        #[comptime] block_size: Vec<u8>,
127        #[comptime] scales_line_size: usize,
128    ) -> Self {
129        BlockScaledLayout {
130            tensor_shape,
131            tensor_len,
132            scales_strides,
133            block_size,
134            scales_line_size,
135        }
136    }
137}
138
139#[cube]
140impl Layout for BlockScaledLayout {
141    type Coordinates = Coords1d;
142    type SourceCoordinates = Coords1d;
143
144    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
145        let rank = self.scales_strides.len().comptime();
146        let mut offs = pos;
147        let mut scale_offs = 0;
148
149        #[unroll]
150        for i in 0..rank {
151            let dim = rank - i - 1;
152            let block_size_local = comptime![self.block_size[dim] as usize];
153            let (rem, offs_local) = self.tensor_shape[dim].div_mod(offs);
154
155            offs = rem;
156            scale_offs += (offs_local / block_size_local) * self.scales_strides[dim];
157        }
158
159        scale_offs / self.scales_line_size
160    }
161
162    fn shape(&self) -> Self::Coordinates {
163        self.tensor_len
164    }
165
166    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
167        pos < self.tensor_len
168    }
169
170    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
171        (self.to_source_pos(pos), self.is_in_bounds(pos))
172    }
173}
174
175#[cube]
176impl BlockScaledLayout {
177    /// Whether the position is at the start of a new block. Used for electing a unit to write each
178    /// scale.
179    pub fn is_block_start(&self, pos: usize) -> bool {
180        let rank = self.scales_strides.len().comptime();
181        let mut offs = pos;
182        let mut is_start = true;
183
184        #[unroll]
185        for i in 0..rank {
186            let dim = rank - i - 1;
187            let block_size_local = comptime![self.block_size[dim] as usize];
188            let (rem, offs_local) = self.tensor_shape[dim].div_mod(offs);
189            offs = rem;
190            is_start &= offs_local.is_multiple_of(block_size_local);
191        }
192
193        is_start
194    }
195}
196
197/// TensorView with a linear layout inferred from the shape/strides at launch.
198/// Useful for elementwise kernels.
199pub type ScalesView<E, IO = ReadOnly> = TypedView<E, ScalesLayout, IO>;
200/// Launch type for LinearTensorView.
201pub type ScalesViewLaunch<'a, R> = TypedViewLaunch<'a, ScalesLayout, R>;
202
203/// Create a scales view from the values and scales handle, line size and quantization scheme.
204/// `values` should be *the quantized tensor*, and will be adjusted by `num_quants`.
205pub fn scales_view<'a, R: Runtime>(
206    client: &ComputeClient<R>,
207    values: &'a TensorHandleRef<'a, R>,
208    scales: &'a TensorHandleRef<'a, R>,
209    scales_line_size: usize,
210    quant_scheme: &QuantScheme,
211) -> ScalesViewLaunch<'a, R> {
212    let layout = scales_layout(client, values, scales, scales_line_size, quant_scheme);
213    let len = scales.shape.iter().product::<usize>();
214    let buffer = unsafe {
215        ArrayArg::from_raw_parts_and_size(scales.handle, len, scales_line_size, scales.elem_size)
216    };
217    ScalesViewLaunch::new(buffer, layout)
218}
219
220pub fn scales_layout<'a, R: Runtime>(
221    client: &ComputeClient<R>,
222    values: &'a TensorHandleRef<'a, R>,
223    scales: &'a TensorHandleRef<'a, R>,
224    scales_line_size: usize,
225    scheme: &QuantScheme,
226) -> ScalesLayoutArgs<'a, R> {
227    let values_len = values.shape.iter().product::<usize>() * scheme.num_quants();
228    let values_len = ScalarArg::new(values_len);
229
230    match &scheme.level {
231        QuantLevel::Tensor => ScalesLayoutArgs::PerTensor(PerTensorLayoutLaunch::new(values_len)),
232        QuantLevel::Block(block_size) => {
233            let tensor_shape = shape_divmod_quant(client, values.shape, scheme.num_quants());
234            let scales_strides = strides_seq(scales.strides);
235            ScalesLayoutArgs::BlockScaled(BlockScaledLayoutLaunch::new(
236                tensor_shape,
237                values_len,
238                scales_strides,
239                block_size.to_dim_vec(values.shape.len()),
240                scales_line_size,
241            ))
242        }
243    }
244}
245
246fn shape_divmod_quant<'a, R: Runtime>(
247    client: &ComputeClient<R>,
248    shape: &'a [usize],
249    num_quants: usize,
250) -> SequenceArg<'a, R, FastDivmod<usize>> {
251    let mut out_seq = SequenceArg::new();
252    for s in &shape[..shape.len() - 1] {
253        out_seq.push(FastDivmodArgs::<usize>::new(client, *s));
254    }
255    let last = *shape.last().unwrap() * num_quants;
256    out_seq.push(FastDivmodArgs::<usize>::new(client, last));
257    out_seq
258}
259
260fn strides_seq<'a, R: Runtime>(strides: &'a [usize]) -> SequenceArg<'a, R, usize> {
261    let mut out_seq = SequenceArg::new();
262    for s in strides {
263        out_seq.push(ScalarArg::new(*s));
264    }
265    out_seq
266}