1use cubecl::prelude::*;
2use cubecl::std::{
3 FastDivmod, FastDivmodArgs,
4 tensor::{
5 launch::{TypedView, TypedViewLaunch},
6 layout::{Coords1d, Layout, LayoutExpand},
7 },
8};
9
10use crate::scheme::{QuantLevel, QuantScheme};
11
12#[derive(CubeType, CubeLaunch)]
15pub enum ScalesLayout {
16 PerTensor(PerTensorLayout),
17 BlockScaled(BlockScaledLayout),
18}
19
20#[cube]
21impl Layout for ScalesLayout {
22 type Coordinates = Coords1d;
23 type SourceCoordinates = Coords1d;
24
25 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
26 match self {
27 ScalesLayout::PerTensor(layout) => layout.to_source_pos(pos),
28 ScalesLayout::BlockScaled(layout) => layout.to_source_pos(pos),
29 }
30 }
31
32 fn shape(&self) -> Self::Coordinates {
33 match self {
34 ScalesLayout::PerTensor(layout) => layout.shape(),
35 ScalesLayout::BlockScaled(layout) => layout.shape(),
36 }
37 }
38
39 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
40 match self {
41 ScalesLayout::PerTensor(layout) => layout.is_in_bounds(pos),
42 ScalesLayout::BlockScaled(layout) => layout.is_in_bounds(pos),
43 }
44 }
45
46 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
47 match self {
48 ScalesLayout::PerTensor(layout) => layout.to_source_pos_checked(pos),
49 ScalesLayout::BlockScaled(layout) => layout.to_source_pos_checked(pos),
50 }
51 }
52}
53
54#[cube]
55impl ScalesLayout {
56 pub fn is_block_start(&self, pos: usize) -> bool {
59 match self {
60 ScalesLayout::PerTensor(layout) => layout.is_block_start(pos),
61 ScalesLayout::BlockScaled(layout) => layout.is_block_start(pos),
62 }
63 }
64}
65
66#[derive(CubeType, CubeLaunch)]
67pub struct PerTensorLayout {
68 tensor_len: usize,
69}
70
71#[cube]
72impl PerTensorLayout {
73 pub fn new(tensor_len: usize) -> Self {
74 PerTensorLayout { tensor_len }
75 }
76}
77
78#[cube]
79impl Layout for PerTensorLayout {
80 type Coordinates = Coords1d;
81 type SourceCoordinates = Coords1d;
82
83 fn to_source_pos(&self, _pos: Self::Coordinates) -> Self::SourceCoordinates {
84 0usize.runtime()
85 }
86
87 fn shape(&self) -> Self::Coordinates {
88 self.tensor_len
89 }
90
91 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
92 pos < self.tensor_len
93 }
94
95 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
96 (self.to_source_pos(pos), self.is_in_bounds(pos))
97 }
98}
99
100#[cube]
101impl PerTensorLayout {
102 pub fn is_block_start(&self, pos: usize) -> bool {
105 pos == 0
106 }
107}
108
109#[derive(CubeType, CubeLaunch)]
110pub struct BlockScaledLayout {
111 tensor_shape: Sequence<FastDivmod<usize>>,
112 tensor_len: usize,
113 scales_strides: Sequence<usize>,
114 #[cube(comptime)]
115 block_size: Vec<u8>,
116 #[cube(comptime)]
117 scales_line_size: usize,
118}
119
120#[cube]
121impl BlockScaledLayout {
122 pub fn new(
123 tensor_shape: Sequence<FastDivmod<usize>>,
124 tensor_len: usize,
125 scales_strides: Sequence<usize>,
126 #[comptime] block_size: Vec<u8>,
127 #[comptime] scales_line_size: usize,
128 ) -> Self {
129 BlockScaledLayout {
130 tensor_shape,
131 tensor_len,
132 scales_strides,
133 block_size,
134 scales_line_size,
135 }
136 }
137}
138
139#[cube]
140impl Layout for BlockScaledLayout {
141 type Coordinates = Coords1d;
142 type SourceCoordinates = Coords1d;
143
144 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
145 let rank = self.scales_strides.len().comptime();
146 let mut offs = pos;
147 let mut scale_offs = 0;
148
149 #[unroll]
150 for i in 0..rank {
151 let dim = rank - i - 1;
152 let block_size_local = comptime![self.block_size[dim] as usize];
153 let (rem, offs_local) = self.tensor_shape[dim].div_mod(offs);
154
155 offs = rem;
156 scale_offs += (offs_local / block_size_local) * self.scales_strides[dim];
157 }
158
159 scale_offs / self.scales_line_size
160 }
161
162 fn shape(&self) -> Self::Coordinates {
163 self.tensor_len
164 }
165
166 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
167 pos < self.tensor_len
168 }
169
170 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
171 (self.to_source_pos(pos), self.is_in_bounds(pos))
172 }
173}
174
175#[cube]
176impl BlockScaledLayout {
177 pub fn is_block_start(&self, pos: usize) -> bool {
180 let rank = self.scales_strides.len().comptime();
181 let mut offs = pos;
182 let mut is_start = true;
183
184 #[unroll]
185 for i in 0..rank {
186 let dim = rank - i - 1;
187 let block_size_local = comptime![self.block_size[dim] as usize];
188 let (rem, offs_local) = self.tensor_shape[dim].div_mod(offs);
189 offs = rem;
190 is_start &= offs_local.is_multiple_of(block_size_local);
191 }
192
193 is_start
194 }
195}
196
197pub type ScalesView<E, IO = ReadOnly> = TypedView<E, ScalesLayout, IO>;
200pub type ScalesViewLaunch<'a, R> = TypedViewLaunch<'a, ScalesLayout, R>;
202
203pub fn scales_view<'a, R: Runtime>(
206 client: &ComputeClient<R>,
207 values: &'a TensorHandleRef<'a, R>,
208 scales: &'a TensorHandleRef<'a, R>,
209 scales_line_size: usize,
210 quant_scheme: &QuantScheme,
211) -> ScalesViewLaunch<'a, R> {
212 let layout = scales_layout(client, values, scales, scales_line_size, quant_scheme);
213 let len = scales.shape.iter().product::<usize>();
214 let buffer = unsafe {
215 ArrayArg::from_raw_parts_and_size(scales.handle, len, scales_line_size, scales.elem_size)
216 };
217 ScalesViewLaunch::new(buffer, layout)
218}
219
220pub fn scales_layout<'a, R: Runtime>(
221 client: &ComputeClient<R>,
222 values: &'a TensorHandleRef<'a, R>,
223 scales: &'a TensorHandleRef<'a, R>,
224 scales_line_size: usize,
225 scheme: &QuantScheme,
226) -> ScalesLayoutArgs<'a, R> {
227 let values_len = values.shape.iter().product::<usize>() * scheme.num_quants();
228 let values_len = ScalarArg::new(values_len);
229
230 match &scheme.level {
231 QuantLevel::Tensor => ScalesLayoutArgs::PerTensor(PerTensorLayoutLaunch::new(values_len)),
232 QuantLevel::Block(block_size) => {
233 let tensor_shape = shape_divmod_quant(client, values.shape, scheme.num_quants());
234 let scales_strides = strides_seq(scales.strides);
235 ScalesLayoutArgs::BlockScaled(BlockScaledLayoutLaunch::new(
236 tensor_shape,
237 values_len,
238 scales_strides,
239 block_size.to_dim_vec(values.shape.len()),
240 scales_line_size,
241 ))
242 }
243 }
244}
245
246fn shape_divmod_quant<'a, R: Runtime>(
247 client: &ComputeClient<R>,
248 shape: &'a [usize],
249 num_quants: usize,
250) -> SequenceArg<'a, R, FastDivmod<usize>> {
251 let mut out_seq = SequenceArg::new();
252 for s in &shape[..shape.len() - 1] {
253 out_seq.push(FastDivmodArgs::<usize>::new(client, *s));
254 }
255 let last = *shape.last().unwrap() * num_quants;
256 out_seq.push(FastDivmodArgs::<usize>::new(client, last));
257 out_seq
258}
259
260fn strides_seq<'a, R: Runtime>(strides: &'a [usize]) -> SequenceArg<'a, R, usize> {
261 let mut out_seq = SequenceArg::new();
262 for s in strides {
263 out_seq.push(ScalarArg::new(*s));
264 }
265 out_seq
266}