1use std::marker::PhantomData;
2
3use super::*;
4use crate::{
5 CubeOption, CubeOptionExpand,
6 tensor::{
7 View, ViewExpand, ViewOperations, ViewOperationsExpand, launch::ViewCompilationArg,
8 layout::Coordinates,
9 },
10};
11use cubecl::prelude::*;
12use cubecl_common::{
13 e2m1x2, e4m3, e5m2,
14 quant::scheme::{QuantParam, QuantScheme, QuantStore, QuantValue},
15 ue8m0,
16};
17use cubecl_core::{
18 self as cubecl,
19 ir::{ElemType, FloatKind, StorageType},
20 prelude::barrier::BarrierExpand,
21 unexpanded,
22};
23use half::{bf16, f16};
24
25#[expect(dead_code, reason = "only used in expand")]
33#[derive(CubeType, CubeLaunch, Clone, Copy)]
34pub struct QuantizedView<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static> {
35 values: View<Line<Q>, C>,
36 scales: View<S, C>,
37 #[cube(comptime)]
38 scheme: QuantScheme,
39 #[cube(comptime)]
40 _ty: PhantomData<F>,
41}
42
43#[cube]
44impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
45 QuantizedView<Q, S, F, C>
46{
47 pub fn new(
48 values: View<Line<Q>, C>,
49 scales: View<S, C>,
50 #[comptime] scheme: QuantScheme,
51 ) -> Self {
52 QuantizedView::<Q, S, F, C> {
53 values,
54 scales,
55 scheme,
56 _ty: PhantomData,
57 }
58 }
59}
60
61impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
62 QuantizedView<Q, S, F, C>
63{
64 pub fn view(self) -> View<Line<F>, C> {
65 unexpanded!()
66 }
67
68 pub fn __expand_view(
69 scope: &mut Scope,
70 this: QuantizedViewExpand<Q, S, F, C>,
71 ) -> ViewExpand<Line<F>, C, ReadOnly> {
72 this.__expand_view_method(scope)
73 }
74}
75
76impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
77 QuantizedViewExpand<Q, S, F, C>
78{
79 pub fn new(
80 values: ViewExpand<Line<Q>, C>,
81 scales: ViewExpand<S, C>,
82 scheme: QuantScheme,
83 ) -> Self {
84 QuantizedViewExpand::<Q, S, F, C> {
85 values,
86 scales,
87 scheme,
88 _ty: PhantomData,
89 }
90 }
91
92 pub fn __expand_view_method(self, _scope: &mut Scope) -> ViewExpand<Line<F>, C, ReadOnly> {
93 ViewExpand::new(self)
94 }
95}
96
97impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static> Lined
98 for QuantizedView<Q, S, F, C>
99{
100}
101impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static> LinedExpand
102 for QuantizedViewExpand<Q, S, F, C>
103{
104 fn line_size(&self) -> u32 {
105 self.values.line_size() * self.scheme.num_quants() as u32
106 }
107}
108
109impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
110 ViewOperations<Line<F>, C> for QuantizedView<Q, S, F, C>
111{
112}
113
114impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
115 ViewOperationsExpand<Line<F>, C> for QuantizedViewExpand<Q, S, F, C>
116{
117 fn __expand_read_method(
118 &self,
119 scope: &mut Scope,
120 pos: <C>::ExpandType,
121 ) -> ExpandElementTyped<Line<F>> {
122 let value = self.values.clone().__expand_read_method(scope, pos.clone());
123 let scale = self.scales.clone().__expand_read_method(scope, pos);
124
125 dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme)
126 }
127
128 fn __expand_read_checked_method(
129 &self,
130 scope: &mut Scope,
131 pos: <C>::ExpandType,
132 ) -> ExpandElementTyped<Line<F>> {
133 let value = self
134 .values
135 .clone()
136 .__expand_read_checked_method(scope, pos.clone());
137 let scale = self
138 .scales
139 .clone()
140 .__expand_read_checked_method(scope, pos.clone());
141
142 dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme)
143 }
144
145 fn __expand_read_masked_method(
146 &self,
147 scope: &mut Scope,
148 pos: <C>::ExpandType,
149 mask_value: ExpandElementTyped<Line<F>>,
150 ) -> ExpandElementTyped<Line<F>> {
151 let value = self
152 .values
153 .clone()
154 .__expand_read_checked_method(scope, pos.clone());
155 let scale = self
156 .scales
157 .clone()
158 .__expand_read_checked_method(scope, pos.clone());
159 let in_bounds = self.__expand_is_in_bounds_method(scope, pos);
160
161 let value = dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme);
162 select::expand::<Line<F>>(scope, in_bounds, value, mask_value)
163 }
164
165 fn __expand_read_unchecked_method(
166 &self,
167 scope: &mut Scope,
168 pos: <C>::ExpandType,
169 ) -> ExpandElementTyped<Line<F>> {
170 let value = self
171 .values
172 .clone()
173 .__expand_read_unchecked_method(scope, pos.clone());
174 let scale = self
175 .scales
176 .clone()
177 .__expand_read_unchecked_method(scope, pos);
178
179 dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme)
180 }
181
182 fn __expand_to_linear_slice_method(
183 &self,
184 _scope: &mut Scope,
185 _pos: <C>::ExpandType,
186 _end: <C>::ExpandType,
187 ) -> SliceExpand<Line<F>, ReadOnly> {
188 panic!("Can't create raw slice for quantized view")
189 }
190
191 fn __expand_as_tensor_map_method(
192 &self,
193 scope: &mut Scope,
194 ) -> CubeOptionExpand<TensorMap<Line<F>>> {
195 CubeOption::__expand_new_None(scope)
196 }
197
198 fn __expand_shape_method(&self, scope: &mut Scope) -> <C>::ExpandType {
199 self.values.clone().__expand_shape_method(scope)
200 }
201
202 fn __expand_is_in_bounds_method(
203 &self,
204 scope: &mut Scope,
205 pos: C::ExpandType,
206 ) -> ExpandElementTyped<bool> {
207 self.values.clone().__expand_is_in_bounds_method(scope, pos)
208 }
209
210 fn __expand_tensor_map_load_method(
211 &self,
212 _scope: &mut Scope,
213 _barrier: BarrierExpand,
214 _shared_memory: SliceExpand<Line<F>, ReadWrite>,
215 _pos: C::ExpandType,
216 ) {
217 panic!("Can't use tensor map functions on quantized view");
218 }
219}
220
221struct ExpandDynamic<'a, E: Numeric, C: Coordinates + 'static> {
222 values: &'a ViewCompilationArg<C>,
223 scales: &'a ViewCompilationArg<C>,
224 scheme: QuantScheme,
225 builder: &'a mut KernelBuilder,
226 _ty: PhantomData<E>,
227}
228
229impl<'a, E: Numeric, C: Coordinates + 'static> RunWithQuantType for ExpandDynamic<'a, E, C> {
230 type Output = ViewExpand<Line<E>, C>;
231
232 fn execute<Q: CubePrimitive, S: CubePrimitive>(self) -> Self::Output {
233 let values = View::<Line<Q>, C>::expand(self.values, self.builder);
234 let scales = View::<S, C>::expand(self.scales, self.builder);
235 let view = QuantizedViewExpand::new(values, scales, self.scheme);
236 ViewExpand::new(view)
237 }
238}
239
240pub fn run_with_quant_type<F: RunWithQuantType>(func: F, scheme: QuantScheme) -> F::Output {
243 fn run_with_q<F: RunWithQuantType, Q: CubePrimitive>(
244 func: F,
245 scheme: QuantScheme,
246 ) -> F::Output {
247 match scheme.param {
248 QuantParam::F32 => func.execute::<Q, f32>(),
249 QuantParam::F16 => func.execute::<Q, f16>(),
250 QuantParam::BF16 => func.execute::<Q, bf16>(),
251 QuantParam::UE8M0 => func.execute::<Q, ue8m0>(),
252 QuantParam::UE4M3 => func.execute::<Q, e4m3>(),
253 }
254 }
255
256 let run_q = match scheme.store {
257 QuantStore::Native => match scheme.value {
258 QuantValue::Q8F => run_with_q::<F, i8>,
259 QuantValue::Q8S => run_with_q::<F, i8>,
260 QuantValue::E5M2 => run_with_q::<F, e5m2>,
261 QuantValue::E4M3 => run_with_q::<F, e4m3>,
262 QuantValue::E2M1 => run_with_q::<F, e2m1x2>,
263 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
264 panic!("Sub-byte quantization can't be native")
265 }
266 },
267 QuantStore::U32 => run_with_q::<F, u32>,
268 };
269 run_q(func, scheme)
270}
271
272pub(crate) fn expand_dynamic<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility>(
275 values: &ViewCompilationArg<C>,
276 scales: &ViewCompilationArg<C>,
277 scheme: QuantScheme,
278 builder: &mut KernelBuilder,
279) -> ViewExpand<E, C, IO> {
280 use core::mem::transmute as t;
281
282 fn expand_dynamic_f<F: Numeric, C: Coordinates + 'static>(
284 values: &ViewCompilationArg<C>,
285 scales: &ViewCompilationArg<C>,
286 scheme: QuantScheme,
287 builder: &mut KernelBuilder,
288 ) -> ViewExpand<Line<F>, C> {
289 let func = ExpandDynamic {
290 values,
291 scales,
292 scheme,
293 builder,
294 _ty: PhantomData::<F>,
295 };
296 run_with_quant_type(func, scheme)
297 }
298
299 #[allow(clippy::missing_transmute_annotations)]
300 unsafe {
301 match E::as_type(&builder.scope) {
302 StorageType::Scalar(ElemType::Float(ty)) => match ty {
303 FloatKind::F16 => t(expand_dynamic_f::<f16, C>(values, scales, scheme, builder)),
304 FloatKind::BF16 => t(expand_dynamic_f::<bf16, C>(values, scales, scheme, builder)),
305 FloatKind::Flex32 => t(expand_dynamic_f::<flex32, C>(
306 values, scales, scheme, builder,
307 )),
308 FloatKind::F32 => t(expand_dynamic_f::<f32, C>(values, scales, scheme, builder)),
309 FloatKind::TF32 => t(expand_dynamic_f::<tf32, C>(values, scales, scheme, builder)),
310 FloatKind::F64 => t(expand_dynamic_f::<f64, C>(values, scales, scheme, builder)),
311 FloatKind::E2M1
312 | FloatKind::E2M3
313 | FloatKind::E3M2
314 | FloatKind::E4M3
315 | FloatKind::E5M2
316 | FloatKind::UE8M0 => unreachable!("Minifloats don't implement `Float` ops"),
317 },
318 _ => unreachable!("Quantized view should only be used with floats"),
319 }
320 }
321}