1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_std::{
4 FastDivmod, FastDivmodArgs,
5 tensor::{
6 launch::{TypedView, TypedViewLaunch},
7 layout::{Coords1d, Layout, LayoutExpand},
8 },
9};
10
11use crate::scheme::{QuantLevel, QuantScheme};
12
13#[derive(CubeType, CubeLaunch)]
16pub enum ScalesLayout {
17 PerTensor(PerTensorLayout),
18 BlockScaled(BlockScaledLayout),
19}
20
21#[cube]
22impl Layout for ScalesLayout {
23 type Coordinates = Coords1d;
24 type SourceCoordinates = Coords1d;
25
26 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
27 match self {
28 ScalesLayout::PerTensor(layout) => layout.to_source_pos(pos),
29 ScalesLayout::BlockScaled(layout) => layout.to_source_pos(pos),
30 }
31 }
32
33 fn shape(&self) -> Self::Coordinates {
34 match self {
35 ScalesLayout::PerTensor(layout) => layout.shape(),
36 ScalesLayout::BlockScaled(layout) => layout.shape(),
37 }
38 }
39
40 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
41 match self {
42 ScalesLayout::PerTensor(layout) => layout.is_in_bounds(pos),
43 ScalesLayout::BlockScaled(layout) => layout.is_in_bounds(pos),
44 }
45 }
46
47 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
48 match self {
49 ScalesLayout::PerTensor(layout) => layout.to_source_pos_checked(pos),
50 ScalesLayout::BlockScaled(layout) => layout.to_source_pos_checked(pos),
51 }
52 }
53}
54
55#[cube]
56impl ScalesLayout {
57 pub fn is_block_start(&self, pos: u32) -> bool {
60 match self {
61 ScalesLayout::PerTensor(layout) => layout.is_block_start(pos),
62 ScalesLayout::BlockScaled(layout) => layout.is_block_start(pos),
63 }
64 }
65}
66
67#[derive(CubeType, CubeLaunch)]
68pub struct PerTensorLayout {
69 tensor_len: u32,
70}
71
72#[cube]
73impl PerTensorLayout {
74 pub fn new(tensor_len: u32) -> Self {
75 PerTensorLayout { tensor_len }
76 }
77}
78
79#[cube]
80impl Layout for PerTensorLayout {
81 type Coordinates = Coords1d;
82 type SourceCoordinates = Coords1d;
83
84 fn to_source_pos(&self, _pos: Self::Coordinates) -> Self::SourceCoordinates {
85 0u32.runtime()
86 }
87
88 fn shape(&self) -> Self::Coordinates {
89 self.tensor_len
90 }
91
92 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
93 pos < self.tensor_len
94 }
95
96 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
97 (self.to_source_pos(pos), self.is_in_bounds(pos))
98 }
99}
100
101#[cube]
102impl PerTensorLayout {
103 pub fn is_block_start(&self, pos: u32) -> bool {
106 pos == 0
107 }
108}
109
110#[derive(CubeType, CubeLaunch)]
111pub struct BlockScaledLayout {
112 tensor_shape: Sequence<FastDivmod>,
113 tensor_len: u32,
114 scales_strides: Sequence<u32>,
115 #[cube(comptime)]
116 block_size: Vec<u8>,
117 #[cube(comptime)]
118 scales_line_size: u32,
119}
120
121#[cube]
122impl BlockScaledLayout {
123 pub fn new(
124 tensor_shape: Sequence<FastDivmod>,
125 tensor_len: u32,
126 scales_strides: Sequence<u32>,
127 #[comptime] block_size: Vec<u8>,
128 #[comptime] scales_line_size: u32,
129 ) -> Self {
130 BlockScaledLayout {
131 tensor_shape,
132 tensor_len,
133 scales_strides,
134 block_size,
135 scales_line_size,
136 }
137 }
138}
139
140#[cube]
141impl Layout for BlockScaledLayout {
142 type Coordinates = Coords1d;
143 type SourceCoordinates = Coords1d;
144
145 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
146 let rank = comptime![self.scales_strides.len()];
147 let mut offs = pos;
148 let mut scale_offs = 0;
149
150 #[unroll]
151 for i in 0..rank {
152 let dim = comptime![rank - i - 1];
153 let block_size_local = comptime![self.block_size[dim as usize] as u32];
154 let (rem, offs_local) = self.tensor_shape.index(dim).div_mod(offs);
155
156 offs = rem;
157 scale_offs += (offs_local / block_size_local) * *self.scales_strides.index(dim);
158 }
159
160 scale_offs / self.scales_line_size
161 }
162
163 fn shape(&self) -> Self::Coordinates {
164 self.tensor_len
165 }
166
167 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
168 pos < self.tensor_len
169 }
170
171 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
172 (self.to_source_pos(pos), self.is_in_bounds(pos))
173 }
174}
175
176#[cube]
177impl BlockScaledLayout {
178 pub fn is_block_start(&self, pos: u32) -> bool {
181 let rank = comptime![self.scales_strides.len()];
182 let mut offs = pos;
183 let mut is_start = true;
184
185 #[unroll]
186 for i in 0..rank {
187 let dim = comptime![rank - i - 1];
188 let block_size_local = comptime![self.block_size[dim as usize] as u32];
189 let (rem, offs_local) = self.tensor_shape.index(dim).div_mod(offs);
190 offs = rem;
191 is_start &= offs_local.is_multiple_of(block_size_local);
192 }
193
194 is_start
195 }
196}
197
198pub type ScalesView<E, IO = ReadOnly> = TypedView<E, ScalesLayout, IO>;
201pub type ScalesViewLaunch<'a, R> = TypedViewLaunch<'a, ScalesLayout, R>;
203
204pub fn scales_view<'a, R: Runtime>(
207 client: &ComputeClient<R::Server>,
208 values: &'a TensorHandleRef<'a, R>,
209 scales: &'a TensorHandleRef<'a, R>,
210 scales_line_size: u8,
211 quant_scheme: &QuantScheme,
212) -> ScalesViewLaunch<'a, R> {
213 let layout = scales_layout(client, values, scales, scales_line_size, quant_scheme);
214 let len = scales.shape.iter().product::<usize>();
215 let buffer = unsafe {
216 ArrayArg::from_raw_parts_and_size(scales.handle, len, scales_line_size, scales.elem_size)
217 };
218 ScalesViewLaunch::new(buffer, layout)
219}
220
221pub fn scales_layout<'a, R: Runtime>(
222 client: &ComputeClient<R::Server>,
223 values: &'a TensorHandleRef<'a, R>,
224 scales: &'a TensorHandleRef<'a, R>,
225 scales_line_size: u8,
226 scheme: &QuantScheme,
227) -> ScalesLayoutArgs<'a, R> {
228 let values_len = values.shape.iter().product::<usize>() * scheme.num_quants();
229 let values_len = ScalarArg::new(values_len as u32);
230
231 match &scheme.level {
232 QuantLevel::Tensor => ScalesLayoutArgs::PerTensor(PerTensorLayoutLaunch::new(values_len)),
233 QuantLevel::Block(block_size) => {
234 let tensor_shape = shape_divmod_quant(client, values.shape, scheme.num_quants());
235 let scales_strides = strides_seq(scales.strides);
236 ScalesLayoutArgs::BlockScaled(BlockScaledLayoutLaunch::new(
237 tensor_shape,
238 values_len,
239 scales_strides,
240 block_size.to_dim_vec(values.shape.len()),
241 scales_line_size as u32,
242 ))
243 }
244 }
245}
246
247fn shape_divmod_quant<'a, R: Runtime>(
248 client: &ComputeClient<R::Server>,
249 shape: &'a [usize],
250 num_quants: usize,
251) -> SequenceArg<'a, R, FastDivmod> {
252 let mut out_seq = SequenceArg::new();
253 for s in &shape[..shape.len() - 1] {
254 out_seq.push(FastDivmodArgs::new(client, *s as u32));
255 }
256 let last = *shape.last().unwrap() * num_quants;
257 out_seq.push(FastDivmodArgs::new(client, last as u32));
258 out_seq
259}
260
261fn strides_seq<'a, R: Runtime>(strides: &'a [usize]) -> SequenceArg<'a, R, u32> {
262 let mut out_seq = SequenceArg::new();
263 for s in strides {
264 out_seq.push(ScalarArg::new(*s as u32));
265 }
266 out_seq
267}