1use std::marker::PhantomData;
2
3use super::*;
4use crate::tensor::{
5 View, ViewExpand, ViewOperations, ViewOperationsExpand, launch::ViewCompilationArg,
6 layout::Coordinates,
7};
8use cubecl::prelude::*;
9use cubecl_common::{
10 e2m1x2, e4m3, e5m2,
11 quant::scheme::{QuantParam, QuantScheme, QuantStore, QuantValue},
12 ue8m0,
13};
14use cubecl_core::{
15 self as cubecl,
16 ir::{ElemType, FloatKind, StorageType},
17 prelude::barrier::BarrierExpand,
18 unexpanded,
19};
20use half::{bf16, f16};
21
22#[expect(dead_code, reason = "only used in expand")]
30#[derive(CubeType, CubeLaunch, Clone, Copy)]
31pub struct QuantizedView<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static> {
32 values: View<Line<Q>, C>,
33 scales: View<S, C>,
34 #[cube(comptime)]
35 scheme: QuantScheme,
36 #[cube(comptime)]
37 _ty: PhantomData<F>,
38}
39
40#[cube]
41impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
42 QuantizedView<Q, S, F, C>
43{
44 pub fn new(
45 values: View<Line<Q>, C>,
46 scales: View<S, C>,
47 #[comptime] scheme: QuantScheme,
48 ) -> Self {
49 QuantizedView::<Q, S, F, C> {
50 values,
51 scales,
52 scheme,
53 _ty: PhantomData,
54 }
55 }
56}
57
58impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
59 QuantizedView<Q, S, F, C>
60{
61 pub fn view(self) -> View<Line<F>, C> {
62 unexpanded!()
63 }
64
65 pub fn __expand_view(
66 scope: &mut Scope,
67 this: QuantizedViewExpand<Q, S, F, C>,
68 ) -> ViewExpand<Line<F>, C, ReadOnly> {
69 this.__expand_view_method(scope)
70 }
71}
72
73impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
74 QuantizedViewExpand<Q, S, F, C>
75{
76 pub fn new(
77 values: ViewExpand<Line<Q>, C>,
78 scales: ViewExpand<S, C>,
79 scheme: QuantScheme,
80 ) -> Self {
81 QuantizedViewExpand::<Q, S, F, C> {
82 values,
83 scales,
84 scheme,
85 _ty: PhantomData,
86 }
87 }
88
89 pub fn __expand_view_method(self, _scope: &mut Scope) -> ViewExpand<Line<F>, C, ReadOnly> {
90 ViewExpand::new(self)
91 }
92}
93
94impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static> Lined
95 for QuantizedView<Q, S, F, C>
96{
97}
98impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static> LinedExpand
99 for QuantizedViewExpand<Q, S, F, C>
100{
101 fn line_size(&self) -> u32 {
102 self.values.line_size() * self.scheme.num_quants() as u32
103 }
104}
105
106impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
107 ViewOperations<Line<F>, C> for QuantizedView<Q, S, F, C>
108{
109}
110
111impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
112 ViewOperationsExpand<Line<F>, C> for QuantizedViewExpand<Q, S, F, C>
113{
114 fn __expand_read_method(
115 &self,
116 scope: &mut Scope,
117 pos: <C>::ExpandType,
118 ) -> ExpandElementTyped<Line<F>> {
119 let value = self.values.clone().__expand_read_method(scope, pos.clone());
120 let scale = self.scales.clone().__expand_read_method(scope, pos);
121
122 dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme)
123 }
124
125 fn __expand_read_checked_method(
126 &self,
127 scope: &mut Scope,
128 pos: <C>::ExpandType,
129 ) -> ExpandElementTyped<Line<F>> {
130 let value = self
131 .values
132 .clone()
133 .__expand_read_checked_method(scope, pos.clone());
134 let scale = self
135 .scales
136 .clone()
137 .__expand_read_checked_method(scope, pos.clone());
138
139 dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme)
140 }
141
142 fn __expand_read_masked_method(
143 &self,
144 scope: &mut Scope,
145 pos: <C>::ExpandType,
146 mask_value: ExpandElementTyped<Line<F>>,
147 ) -> ExpandElementTyped<Line<F>> {
148 let value = self
149 .values
150 .clone()
151 .__expand_read_checked_method(scope, pos.clone());
152 let scale = self
153 .scales
154 .clone()
155 .__expand_read_checked_method(scope, pos.clone());
156 let in_bounds = self.__expand_is_in_bounds_method(scope, pos);
157
158 let value = dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme);
159 select::expand::<Line<F>>(scope, in_bounds, value, mask_value)
160 }
161
162 fn __expand_read_unchecked_method(
163 &self,
164 scope: &mut Scope,
165 pos: <C>::ExpandType,
166 ) -> ExpandElementTyped<Line<F>> {
167 let value = self
168 .values
169 .clone()
170 .__expand_read_unchecked_method(scope, pos.clone());
171 let scale = self
172 .scales
173 .clone()
174 .__expand_read_unchecked_method(scope, pos);
175
176 dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme)
177 }
178
179 fn __expand_to_linear_slice_method(
180 &self,
181 _scope: &mut Scope,
182 _pos: <C>::ExpandType,
183 _end: <C>::ExpandType,
184 ) -> SliceExpand<Line<F>, ReadOnly> {
185 panic!("Can't create raw slice for quantized view")
186 }
187
188 fn __expand_shape_method(&self, scope: &mut Scope) -> <C>::ExpandType {
189 self.values.clone().__expand_shape_method(scope)
190 }
191
192 fn __expand_is_in_bounds_method(
193 &self,
194 scope: &mut Scope,
195 pos: C::ExpandType,
196 ) -> ExpandElementTyped<bool> {
197 self.values.clone().__expand_is_in_bounds_method(scope, pos)
198 }
199
200 fn __expand_tensor_map_load_method(
201 &self,
202 _scope: &mut Scope,
203 _barrier: BarrierExpand,
204 _shared_memory: SliceExpand<Line<F>, ReadWrite>,
205 _pos: C::ExpandType,
206 ) {
207 panic!("Can't use tensor map functions on quantized view");
208 }
209}
210
211struct ExpandDynamic<'a, E: Numeric, C: Coordinates + 'static> {
212 values: &'a ViewCompilationArg<C>,
213 scales: &'a ViewCompilationArg<C>,
214 scheme: QuantScheme,
215 builder: &'a mut KernelBuilder,
216 _ty: PhantomData<E>,
217}
218
219impl<'a, E: Numeric, C: Coordinates + 'static> RunWithQuantType for ExpandDynamic<'a, E, C> {
220 type Output = ViewExpand<Line<E>, C>;
221
222 fn execute<Q: CubePrimitive, S: CubePrimitive>(self) -> Self::Output {
223 let values = View::<Line<Q>, C>::expand(self.values, self.builder);
224 let scales = View::<S, C>::expand(self.scales, self.builder);
225 let view = QuantizedViewExpand::new(values, scales, self.scheme);
226 ViewExpand::new(view)
227 }
228}
229
230pub fn run_with_quant_type<F: RunWithQuantType>(func: F, scheme: QuantScheme) -> F::Output {
233 fn run_with_q<F: RunWithQuantType, Q: CubePrimitive>(
234 func: F,
235 scheme: QuantScheme,
236 ) -> F::Output {
237 match scheme.param {
238 QuantParam::F32 => func.execute::<Q, f32>(),
239 QuantParam::F16 => func.execute::<Q, f16>(),
240 QuantParam::BF16 => func.execute::<Q, bf16>(),
241 QuantParam::UE8M0 => func.execute::<Q, ue8m0>(),
242 QuantParam::UE4M3 => func.execute::<Q, e4m3>(),
243 }
244 }
245
246 let run_q = match scheme.store {
247 QuantStore::Native => match scheme.value {
248 QuantValue::Q8F => run_with_q::<F, i8>,
249 QuantValue::Q8S => run_with_q::<F, i8>,
250 QuantValue::E5M2 => run_with_q::<F, e5m2>,
251 QuantValue::E4M3 => run_with_q::<F, e4m3>,
252 QuantValue::E2M1 => run_with_q::<F, e2m1x2>,
253 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
254 panic!("Sub-byte quantization can't be native")
255 }
256 },
257 QuantStore::U32 => run_with_q::<F, u32>,
258 };
259 run_q(func, scheme)
260}
261
262pub(crate) fn expand_dynamic<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility>(
265 values: &ViewCompilationArg<C>,
266 scales: &ViewCompilationArg<C>,
267 scheme: QuantScheme,
268 builder: &mut KernelBuilder,
269) -> ViewExpand<E, C, IO> {
270 use core::mem::transmute as t;
271
272 fn expand_dynamic_f<F: Numeric, C: Coordinates + 'static>(
274 values: &ViewCompilationArg<C>,
275 scales: &ViewCompilationArg<C>,
276 scheme: QuantScheme,
277 builder: &mut KernelBuilder,
278 ) -> ViewExpand<Line<F>, C> {
279 let func = ExpandDynamic {
280 values,
281 scales,
282 scheme,
283 builder,
284 _ty: PhantomData::<F>,
285 };
286 run_with_quant_type(func, scheme)
287 }
288
289 #[allow(clippy::missing_transmute_annotations)]
290 unsafe {
291 match E::as_type(&builder.scope) {
292 StorageType::Scalar(ElemType::Float(ty)) => match ty {
293 FloatKind::F16 => t(expand_dynamic_f::<f16, C>(values, scales, scheme, builder)),
294 FloatKind::BF16 => t(expand_dynamic_f::<bf16, C>(values, scales, scheme, builder)),
295 FloatKind::Flex32 => t(expand_dynamic_f::<flex32, C>(
296 values, scales, scheme, builder,
297 )),
298 FloatKind::F32 => t(expand_dynamic_f::<f32, C>(values, scales, scheme, builder)),
299 FloatKind::TF32 => t(expand_dynamic_f::<tf32, C>(values, scales, scheme, builder)),
300 FloatKind::F64 => t(expand_dynamic_f::<f64, C>(values, scales, scheme, builder)),
301 FloatKind::E2M1
302 | FloatKind::E2M3
303 | FloatKind::E3M2
304 | FloatKind::E4M3
305 | FloatKind::E5M2
306 | FloatKind::UE8M0 => unreachable!("Minifloats don't implement `Float` ops"),
307 },
308 _ => unreachable!("Quantized view should only be used with floats"),
309 }
310 }
311}