1use cubecl_core::{prelude::*, unexpanded};
2use std::{
3 marker::PhantomData,
4 ops::{Deref, DerefMut},
5 sync::Arc,
6};
7
8use crate::tensor::{
9 View, ViewExpand, ViewOperationsMut, VirtualViewMut, VirtualViewMutExpand,
10 layout::{Coordinates, Coords1d, Layout, VirtualLayoutExpand, VirtualLayoutOperationsExpand},
11 view::ViewType,
12};
13
14#[derive(Clone)]
16pub struct TypedView<E: CubePrimitive, L: LaunchLayout, IO: SliceVisibility = ReadOnly> {
17 _ty: PhantomData<(E, L, IO)>,
18}
19
20impl<E: CubePrimitive, L: LaunchLayout, IO: SliceVisibility> CubeType for TypedView<E, L, IO> {
21 type ExpandType = ViewExpand<E, L::Coordinates, IO>;
22}
23
24impl<E: CubePrimitive, L: LaunchLayout, IO: SliceVisibility> Deref for TypedView<E, L, IO> {
25 type Target = View<E, L::Coordinates, IO>;
26
27 fn deref(&self) -> &Self::Target {
28 unexpanded!()
29 }
30}
31
32impl<E: CubePrimitive, L: LaunchLayout> DerefMut for TypedView<E, L, ReadWrite> {
33 fn deref_mut(&mut self) -> &mut Self::Target {
34 unexpanded!()
35 }
36}
37
38pub struct TypedViewLaunch<'a, L: LaunchLayout<SourceCoordinates = Coords1d>, R: Runtime> {
39 buffer: ArrayArg<'a, R>,
40 layout: L::RuntimeArg<'a, R>,
41}
42impl<'a, L: LaunchLayout<SourceCoordinates = Coords1d>, R: Runtime> TypedViewLaunch<'a, L, R> {
43 #[allow(clippy::too_many_arguments)]
44 pub fn new(buffer: ArrayArg<'a, R>, layout: L::RuntimeArg<'a, R>) -> Self {
45 Self { buffer, layout }
46 }
47}
48impl<'a, L: LaunchLayout<SourceCoordinates = Coords1d>, R: Runtime> ArgSettings<R>
49 for TypedViewLaunch<'a, L, R>
50{
51 fn register(&self, launcher: &mut KernelLauncher<R>) {
52 self.buffer.register(launcher);
53 self.layout.register(launcher);
54 }
55}
56
57pub struct TypedViewCompilationArg<L: LaunchLayout<SourceCoordinates = Coords1d>> {
58 buffer: ArrayCompilationArg,
59 layout: L::CompilationArg,
60}
61impl<L: LaunchLayout<SourceCoordinates = Coords1d>> Clone for TypedViewCompilationArg<L> {
62 fn clone(&self) -> Self {
63 Self {
64 buffer: self.buffer.clone(),
65 layout: self.layout.clone(),
66 }
67 }
68}
69impl<L: LaunchLayout<SourceCoordinates = Coords1d>> CompilationArg for TypedViewCompilationArg<L> {}
70
71impl<L: LaunchLayout<SourceCoordinates = Coords1d>> core::hash::Hash
72 for TypedViewCompilationArg<L>
73{
74 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
75 self.buffer.hash(state);
76 self.layout.hash(state);
77 }
78}
79impl<L: LaunchLayout<SourceCoordinates = Coords1d>> PartialEq for TypedViewCompilationArg<L> {
80 fn eq(&self, other: &Self) -> bool {
81 self.buffer.eq(&other.buffer) && self.layout.eq(&other.layout)
82 }
83}
84impl<L: LaunchLayout<SourceCoordinates = Coords1d>> core::fmt::Debug
85 for TypedViewCompilationArg<L>
86{
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 f.debug_struct(stringify!(TensorViewTyped))
89 .field("buffer", &self.buffer)
90 .field("layout", &self.layout)
91 .finish()
92 }
93}
94impl<L: LaunchLayout<SourceCoordinates = Coords1d>> Eq for TypedViewCompilationArg<L> {}
95
96impl<E: CubePrimitive, L: LaunchLayout<SourceCoordinates = Coords1d>, IO: SliceVisibility> LaunchArg
97 for TypedView<E, L, IO>
98{
99 type RuntimeArg<'a, R: Runtime> = TypedViewLaunch<'a, L, R>;
100 type CompilationArg = TypedViewCompilationArg<L>;
101
102 fn compilation_arg<'a, R: Runtime>(
103 runtime_arg: &Self::RuntimeArg<'a, R>,
104 ) -> Self::CompilationArg {
105 TypedViewCompilationArg {
106 buffer: <Array<Line<E>> as LaunchArg>::compilation_arg(&runtime_arg.buffer),
107 layout: L::compilation_arg(&runtime_arg.layout),
108 }
109 }
110
111 fn expand(
112 arg: &Self::CompilationArg,
113 builder: &mut KernelBuilder,
114 ) -> <Self as CubeType>::ExpandType {
115 let buffer = <Array<E> as LaunchArg>::expand(&arg.buffer, builder);
116 L::apply::<E, Array<E>, IO>(L::expand(&arg.layout, builder), buffer)
117 }
118 fn expand_output(
119 arg: &Self::CompilationArg,
120 builder: &mut KernelBuilder,
121 ) -> <Self as CubeType>::ExpandType {
122 let buffer = <Array<E> as LaunchArg>::expand_output(&arg.buffer, builder);
123 L::apply::<E, Array<E>, IO>(L::expand_output(&arg.layout, builder), buffer)
124 }
125}
126
127mod seal {
128 pub trait Sealed {}
129}
130
131pub trait LaunchLayout: LaunchArg + seal::Sealed {
132 type SourceCoordinates: Coordinates;
133 type Coordinates: Coordinates;
134
135 fn apply<
136 E: CubePrimitive,
137 V: ViewOperationsMut<E, Self::SourceCoordinates> + 'static,
138 IO: SliceVisibility,
139 >(
140 value: <Self as CubeType>::ExpandType,
141 view: V::ExpandType,
142 ) -> ViewExpand<E, Self::Coordinates, IO>;
143}
144
145impl<
150 L: Layout
151 + CubeType<ExpandType: VirtualLayoutOperationsExpand<L::Coordinates, L::SourceCoordinates>>
152 + LaunchArg,
153> seal::Sealed for L
154{
155}
156impl<
157 L: Layout
158 + CubeType<ExpandType: VirtualLayoutOperationsExpand<L::Coordinates, L::SourceCoordinates>>
159 + LaunchArg,
160> LaunchLayout for L
161{
162 type SourceCoordinates = L::SourceCoordinates;
163 type Coordinates = L::Coordinates;
164
165 fn apply<
166 E: CubePrimitive,
167 V: ViewOperationsMut<E, Self::SourceCoordinates> + 'static,
168 IO: SliceVisibility,
169 >(
170 value: L::ExpandType,
171 view: V::ExpandType,
172 ) -> ViewExpand<E, Self::Coordinates, IO> {
173 let l0 = value;
174 let l0 = VirtualLayoutExpand::new::<L::ExpandType>(l0);
175 let view =
176 VirtualViewMutExpand::<E, L::Coordinates, L::SourceCoordinates, V>::new(view, l0);
177 ViewExpand::<E, L::Coordinates, IO> {
178 inner: ViewType::ReadWrite(Arc::new(view)),
179 _io: PhantomData,
180 }
181 }
182}
183
184impl<
185 L0: Layout
186 + CubeType<ExpandType: VirtualLayoutOperationsExpand<L0::Coordinates, L0::SourceCoordinates>>
187 + LaunchArg,
188 L1: Layout<SourceCoordinates = L0::Coordinates>
189 + CubeType<ExpandType: VirtualLayoutOperationsExpand<L1::Coordinates, L1::SourceCoordinates>>
190 + LaunchArg,
191> seal::Sealed for (L0, L1)
192{
193}
194impl<
195 L0: Layout
196 + CubeType<ExpandType: VirtualLayoutOperationsExpand<L0::Coordinates, L0::SourceCoordinates>>
197 + LaunchArg,
198 L1: Layout<SourceCoordinates = L0::Coordinates>
199 + CubeType<ExpandType: VirtualLayoutOperationsExpand<L1::Coordinates, L1::SourceCoordinates>>
200 + LaunchArg,
201> LaunchLayout for (L0, L1)
202{
203 type SourceCoordinates = L0::SourceCoordinates;
204 type Coordinates = L1::Coordinates;
205
206 fn apply<
207 E: CubePrimitive,
208 V: ViewOperationsMut<E, Self::SourceCoordinates> + 'static,
209 IO: SliceVisibility,
210 >(
211 value: (L0::ExpandType, L1::ExpandType),
212 view: V::ExpandType,
213 ) -> ViewExpand<E, Self::Coordinates, IO> {
214 let (l0, l1) = value;
215 let l0 = VirtualLayoutExpand::new::<L0::ExpandType>(l0);
216 let view =
217 VirtualViewMutExpand::<E, L0::Coordinates, L0::SourceCoordinates, V>::new(view, l0);
218 let l1 = VirtualLayoutExpand::new::<L1::ExpandType>(l1);
219 let view = VirtualViewMutExpand::<
220 E,
221 L1::Coordinates,
222 L1::SourceCoordinates,
223 VirtualViewMut<E, L0::Coordinates, L0::SourceCoordinates, V>,
224 >::new(view, l1);
225 ViewExpand::<E, L1::Coordinates, IO> {
226 inner: ViewType::ReadWrite(Arc::new(view)),
227 _io: PhantomData,
228 }
229 }
230}
231
232mod dynamic {
233 use cubecl_common::quant::scheme::QuantScheme;
234
235 use crate::{
236 quant,
237 tensor::layout::{
238 VirtualLayout, VirtualLayoutCompilationArg, VirtualLayoutLaunch,
239 as_dyn::{IntoDyn, IntoDynLayout, IntoDynLayoutLaunch},
240 },
241 };
242
243 use super::*;
244
245 pub enum ViewArg<'a, C: Coordinates, R: Runtime> {
246 Array(ArrayArg<'a, R>, VirtualLayoutLaunch<'a, C, Coords1d, R>),
247 TensorMap(
248 TensorMapArg<'a, R>,
249 VirtualLayoutLaunch<'a, C, Sequence<i32>, R>,
250 ),
251 Quantized {
252 values: Box<ViewArg<'a, C, R>>,
253 scales: Box<ViewArg<'a, C, R>>,
254 scheme: QuantScheme,
255 },
256 }
257 impl<'a, C: Coordinates, R: Runtime> ViewArg<'a, C, R> {
258 pub fn new<L: Layout<Coordinates = C, SourceCoordinates = Coords1d> + LaunchArg>(
259 buffer: ArrayArg<'a, R>,
260 layout: L::RuntimeArg<'a, R>,
261 ) -> Self {
262 ViewArg::Array(buffer, VirtualLayoutLaunch::new::<L>(layout))
263 }
264
265 pub fn new_tensor_map<
266 L: Layout<Coordinates = C, SourceCoordinates: IntoDyn> + LaunchArg,
267 >(
268 buffer: TensorMapArg<'a, R>,
269 layout: L::RuntimeArg<'a, R>,
270 ) -> Self {
271 let layout = IntoDynLayoutLaunch::new(layout);
272 ViewArg::TensorMap(buffer, VirtualLayoutLaunch::new::<IntoDynLayout<L>>(layout))
273 }
274
275 pub fn new_quantized(values: Self, scales: Self, scheme: QuantScheme) -> Self {
278 Self::Quantized {
279 values: Box::new(values),
280 scales: Box::new(scales),
281 scheme,
282 }
283 }
284 }
285 impl<'a, C: Coordinates, R: Runtime> ArgSettings<R> for ViewArg<'a, C, R> {
286 fn register(&self, launcher: &mut KernelLauncher<R>) {
287 match self {
288 ViewArg::Array(buffer, layout) => {
289 buffer.register(launcher);
290 layout.register(launcher);
291 }
292 ViewArg::TensorMap(buffer, layout) => {
293 buffer.register(launcher);
294 layout.register(launcher);
295 }
296 ViewArg::Quantized { values, scales, .. } => {
297 values.register(launcher);
298 scales.register(launcher);
299 }
300 }
301 }
302 }
303 #[derive(Clone)]
304 pub enum ViewCompilationArg<C: Coordinates> {
305 Array {
306 buffer: ArrayCompilationArg,
307 layout: VirtualLayoutCompilationArg<C, Coords1d>,
308 },
309 TensorMap {
310 buffer: TensorMapCompilationArg,
311 layout: VirtualLayoutCompilationArg<C, Sequence<i32>>,
312 },
313 Quantized {
314 values: Box<ViewCompilationArg<C>>,
315 scales: Box<ViewCompilationArg<C>>,
316 scheme: QuantScheme,
317 },
318 }
319
320 impl<C: Coordinates + 'static> CompilationArg for ViewCompilationArg<C> {}
321 impl<C: Coordinates> Eq for ViewCompilationArg<C> {}
322 impl<C: Coordinates> PartialEq for ViewCompilationArg<C> {
323 fn eq(&self, other: &Self) -> bool {
324 match (self, other) {
325 (
326 ViewCompilationArg::Array { buffer, layout },
327 ViewCompilationArg::Array {
328 buffer: buffer_other,
329 layout: layout_other,
330 },
331 ) => buffer == buffer_other && layout == layout_other,
332 (
333 ViewCompilationArg::TensorMap { buffer, layout },
334 ViewCompilationArg::TensorMap {
335 buffer: buffer_other,
336 layout: layout_other,
337 },
338 ) => buffer == buffer_other && layout == layout_other,
339 (
340 ViewCompilationArg::Quantized {
341 values,
342 scales,
343 scheme,
344 },
345 ViewCompilationArg::Quantized {
346 values: values_other,
347 scales: scales_other,
348 scheme: scheme_other,
349 },
350 ) => values == values_other && scales == scales_other && scheme == scheme_other,
351 _ => false,
352 }
353 }
354 }
355 impl<C: Coordinates> core::hash::Hash for ViewCompilationArg<C> {
356 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
357 match self {
358 ViewCompilationArg::Array { buffer, layout } => {
359 buffer.hash(ra_expand_state);
360 layout.hash(ra_expand_state);
361 }
362 ViewCompilationArg::TensorMap { buffer, layout } => {
363 buffer.hash(ra_expand_state);
364 layout.hash(ra_expand_state);
365 }
366 ViewCompilationArg::Quantized {
367 values,
368 scales,
369 scheme,
370 } => {
371 values.hash(ra_expand_state);
372 scales.hash(ra_expand_state);
373 scheme.hash(ra_expand_state);
374 }
375 }
376 }
377 }
378 impl<C: Coordinates> core::fmt::Debug for ViewCompilationArg<C> {
379 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
380 match self {
381 ViewCompilationArg::Array { buffer, layout } => f
382 .debug_struct("ArrayView")
383 .field("buffer", &buffer)
384 .field("layout", &layout)
385 .finish(),
386 ViewCompilationArg::TensorMap { buffer, layout } => f
387 .debug_struct("TensorMapView")
388 .field("buffer", &buffer)
389 .field("layout", &layout)
390 .finish(),
391 ViewCompilationArg::Quantized {
392 values,
393 scales,
394 scheme,
395 } => f
396 .debug_struct("QuantizedView")
397 .field("values", &values)
398 .field("scales", &scales)
399 .field("scheme", &scheme)
400 .finish(),
401 }
402 }
403 }
404
405 impl<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility> LaunchArg for View<E, C, IO> {
406 type RuntimeArg<'a, R: Runtime> = ViewArg<'a, C, R>;
407 type CompilationArg = ViewCompilationArg<C>;
408
409 fn compilation_arg<'a, R: Runtime>(
410 runtime_arg: &Self::RuntimeArg<'a, R>,
411 ) -> Self::CompilationArg {
412 match runtime_arg {
413 ViewArg::Array(buffer, layout) => {
414 let buffer = Array::<E>::compilation_arg(buffer);
415 let layout = VirtualLayout::<C, Coords1d>::compilation_arg(layout);
416 ViewCompilationArg::Array { buffer, layout }
417 }
418 ViewArg::TensorMap(buffer, layout) => {
419 let buffer = TensorMap::<E>::compilation_arg(buffer);
420 let layout = VirtualLayout::<C, Sequence<i32>>::compilation_arg(layout);
421 ViewCompilationArg::TensorMap { buffer, layout }
422 }
423 ViewArg::Quantized {
424 values,
425 scales,
426 scheme,
427 } => {
428 let values = View::<E, C, IO>::compilation_arg(values);
430 let scales = View::<E, C, IO>::compilation_arg(scales);
431 ViewCompilationArg::Quantized {
432 values: Box::new(values),
433 scales: Box::new(scales),
434 scheme: *scheme,
435 }
436 }
437 }
438 }
439 fn expand(
440 arg: &Self::CompilationArg,
441 builder: &mut KernelBuilder,
442 ) -> <Self as CubeType>::ExpandType {
443 match arg {
444 ViewCompilationArg::Array { buffer, layout } => {
445 let buffer = Array::<E>::expand(buffer, builder);
446 let layout = VirtualLayout::<C, Coords1d>::expand(layout, builder);
447 let view =
448 VirtualViewMutExpand::<E, C, Coords1d, Array<E>>::new(buffer, layout);
449 ViewExpand::<E, C, IO> {
450 inner: ViewType::ReadWrite(Arc::new(view)),
451 _io: PhantomData,
452 }
453 }
454 ViewCompilationArg::TensorMap { buffer, layout } => {
455 let buffer = TensorMap::<E>::expand(buffer, builder);
456 let layout = VirtualLayout::<C, Sequence<i32>>::expand(layout, builder);
457 let view = VirtualViewMutExpand::<E, C, Sequence<i32>, TensorMap<E>>::new(
458 buffer, layout,
459 );
460 ViewExpand::<E, C, IO> {
461 inner: ViewType::ReadWrite(Arc::new(view)),
462 _io: PhantomData,
463 }
464 }
465 ViewCompilationArg::Quantized {
466 values,
467 scales,
468 scheme,
469 } => quant::view::expand_dynamic(values, scales, *scheme, builder),
470 }
471 }
472 fn expand_output(
473 arg: &Self::CompilationArg,
474 builder: &mut KernelBuilder,
475 ) -> <Self as CubeType>::ExpandType {
476 match arg {
477 ViewCompilationArg::Array { buffer, layout } => {
478 let buffer = Array::<E>::expand_output(buffer, builder);
479 let layout = VirtualLayout::<C, Coords1d>::expand_output(layout, builder);
480 let view =
481 VirtualViewMutExpand::<E, C, Coords1d, Array<E>>::new(buffer, layout);
482 ViewExpand::<E, C, IO> {
483 inner: ViewType::ReadWrite(Arc::new(view)),
484 _io: PhantomData,
485 }
486 }
487 ViewCompilationArg::TensorMap { buffer, layout } => {
488 let buffer = TensorMap::<E>::expand_output(buffer, builder);
489 let layout = VirtualLayout::<C, Sequence<i32>>::expand_output(layout, builder);
490 let view = VirtualViewMutExpand::<E, C, Sequence<i32>, TensorMap<E>>::new(
491 buffer, layout,
492 );
493 ViewExpand::<E, C, IO> {
494 inner: ViewType::ReadWrite(Arc::new(view)),
495 _io: PhantomData,
496 }
497 }
498 ViewCompilationArg::Quantized { .. } => panic!("Quantized views must be readonly"),
499 }
500 }
501 }
502}
503
504pub use dynamic::*;