cubecl_std/tensor/view/
launch.rs

1use cubecl_core::{prelude::*, unexpanded};
2use std::{
3    marker::PhantomData,
4    ops::{Deref, DerefMut},
5    sync::Arc,
6};
7
8use crate::tensor::{
9    View, ViewExpand, ViewOperationsMut, VirtualViewMut, VirtualViewMutExpand,
10    layout::{Coordinates, Coords1d, Layout, VirtualLayoutExpand, VirtualLayoutOperationsExpand},
11    view::ViewType,
12};
13
14/// Launchable tensor view for ease of use.
15#[derive(Clone)]
16pub struct TypedView<E: CubePrimitive, L: LaunchLayout, IO: SliceVisibility = ReadOnly> {
17    _ty: PhantomData<(E, L, IO)>,
18}
19
20impl<E: CubePrimitive, L: LaunchLayout, IO: SliceVisibility> CubeType for TypedView<E, L, IO> {
21    type ExpandType = ViewExpand<E, L::Coordinates, IO>;
22}
23
24impl<E: CubePrimitive, L: LaunchLayout, IO: SliceVisibility> Deref for TypedView<E, L, IO> {
25    type Target = View<E, L::Coordinates, IO>;
26
27    fn deref(&self) -> &Self::Target {
28        unexpanded!()
29    }
30}
31
32impl<E: CubePrimitive, L: LaunchLayout> DerefMut for TypedView<E, L, ReadWrite> {
33    fn deref_mut(&mut self) -> &mut Self::Target {
34        unexpanded!()
35    }
36}
37
38pub struct TypedViewLaunch<'a, L: LaunchLayout<SourceCoordinates = Coords1d>, R: Runtime> {
39    buffer: ArrayArg<'a, R>,
40    layout: L::RuntimeArg<'a, R>,
41}
42impl<'a, L: LaunchLayout<SourceCoordinates = Coords1d>, R: Runtime> TypedViewLaunch<'a, L, R> {
43    #[allow(clippy::too_many_arguments)]
44    pub fn new(buffer: ArrayArg<'a, R>, layout: L::RuntimeArg<'a, R>) -> Self {
45        Self { buffer, layout }
46    }
47}
48impl<'a, L: LaunchLayout<SourceCoordinates = Coords1d>, R: Runtime> ArgSettings<R>
49    for TypedViewLaunch<'a, L, R>
50{
51    fn register(&self, launcher: &mut KernelLauncher<R>) {
52        self.buffer.register(launcher);
53        self.layout.register(launcher);
54    }
55}
56
57pub struct TypedViewCompilationArg<L: LaunchLayout<SourceCoordinates = Coords1d>> {
58    buffer: ArrayCompilationArg,
59    layout: L::CompilationArg,
60}
61impl<L: LaunchLayout<SourceCoordinates = Coords1d>> Clone for TypedViewCompilationArg<L> {
62    fn clone(&self) -> Self {
63        Self {
64            buffer: self.buffer.clone(),
65            layout: self.layout.clone(),
66        }
67    }
68}
69impl<L: LaunchLayout<SourceCoordinates = Coords1d>> CompilationArg for TypedViewCompilationArg<L> {}
70
71impl<L: LaunchLayout<SourceCoordinates = Coords1d>> core::hash::Hash
72    for TypedViewCompilationArg<L>
73{
74    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
75        self.buffer.hash(state);
76        self.layout.hash(state);
77    }
78}
79impl<L: LaunchLayout<SourceCoordinates = Coords1d>> PartialEq for TypedViewCompilationArg<L> {
80    fn eq(&self, other: &Self) -> bool {
81        self.buffer.eq(&other.buffer) && self.layout.eq(&other.layout)
82    }
83}
84impl<L: LaunchLayout<SourceCoordinates = Coords1d>> core::fmt::Debug
85    for TypedViewCompilationArg<L>
86{
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.debug_struct(stringify!(TensorViewTyped))
89            .field("buffer", &self.buffer)
90            .field("layout", &self.layout)
91            .finish()
92    }
93}
94impl<L: LaunchLayout<SourceCoordinates = Coords1d>> Eq for TypedViewCompilationArg<L> {}
95
96impl<E: CubePrimitive, L: LaunchLayout<SourceCoordinates = Coords1d>, IO: SliceVisibility> LaunchArg
97    for TypedView<E, L, IO>
98{
99    type RuntimeArg<'a, R: Runtime> = TypedViewLaunch<'a, L, R>;
100    type CompilationArg = TypedViewCompilationArg<L>;
101
102    fn compilation_arg<'a, R: Runtime>(
103        runtime_arg: &Self::RuntimeArg<'a, R>,
104    ) -> Self::CompilationArg {
105        TypedViewCompilationArg {
106            buffer: <Array<Line<E>> as LaunchArg>::compilation_arg(&runtime_arg.buffer),
107            layout: L::compilation_arg(&runtime_arg.layout),
108        }
109    }
110
111    fn expand(
112        arg: &Self::CompilationArg,
113        builder: &mut KernelBuilder,
114    ) -> <Self as CubeType>::ExpandType {
115        let buffer = <Array<E> as LaunchArg>::expand(&arg.buffer, builder);
116        L::apply::<E, Array<E>, IO>(L::expand(&arg.layout, builder), buffer)
117    }
118    fn expand_output(
119        arg: &Self::CompilationArg,
120        builder: &mut KernelBuilder,
121    ) -> <Self as CubeType>::ExpandType {
122        let buffer = <Array<E> as LaunchArg>::expand_output(&arg.buffer, builder);
123        L::apply::<E, Array<E>, IO>(L::expand_output(&arg.layout, builder), buffer)
124    }
125}
126
127mod seal {
128    pub trait Sealed {}
129}
130
131pub trait LaunchLayout: LaunchArg + seal::Sealed {
132    type SourceCoordinates: Coordinates;
133    type Coordinates: Coordinates;
134
135    fn apply<
136        E: CubePrimitive,
137        V: ViewOperationsMut<E, Self::SourceCoordinates> + 'static,
138        IO: SliceVisibility,
139    >(
140        value: <Self as CubeType>::ExpandType,
141        view: V::ExpandType,
142    ) -> ViewExpand<E, Self::Coordinates, IO>;
143}
144
145// These unfortunately need to be manually implemented due to the dependencies of each layout on
146// the coordinates of the next. Just stick with two layouts for now and add more implementations as
147// needed.
148
149impl<
150    L: Layout
151        + CubeType<ExpandType: VirtualLayoutOperationsExpand<L::Coordinates, L::SourceCoordinates>>
152        + LaunchArg,
153> seal::Sealed for L
154{
155}
156impl<
157    L: Layout
158        + CubeType<ExpandType: VirtualLayoutOperationsExpand<L::Coordinates, L::SourceCoordinates>>
159        + LaunchArg,
160> LaunchLayout for L
161{
162    type SourceCoordinates = L::SourceCoordinates;
163    type Coordinates = L::Coordinates;
164
165    fn apply<
166        E: CubePrimitive,
167        V: ViewOperationsMut<E, Self::SourceCoordinates> + 'static,
168        IO: SliceVisibility,
169    >(
170        value: L::ExpandType,
171        view: V::ExpandType,
172    ) -> ViewExpand<E, Self::Coordinates, IO> {
173        let l0 = value;
174        let l0 = VirtualLayoutExpand::new::<L::ExpandType>(l0);
175        let view =
176            VirtualViewMutExpand::<E, L::Coordinates, L::SourceCoordinates, V>::new(view, l0);
177        ViewExpand::<E, L::Coordinates, IO> {
178            inner: ViewType::ReadWrite(Arc::new(view)),
179            _io: PhantomData,
180        }
181    }
182}
183
184impl<
185    L0: Layout
186        + CubeType<ExpandType: VirtualLayoutOperationsExpand<L0::Coordinates, L0::SourceCoordinates>>
187        + LaunchArg,
188    L1: Layout<SourceCoordinates = L0::Coordinates>
189        + CubeType<ExpandType: VirtualLayoutOperationsExpand<L1::Coordinates, L1::SourceCoordinates>>
190        + LaunchArg,
191> seal::Sealed for (L0, L1)
192{
193}
194impl<
195    L0: Layout
196        + CubeType<ExpandType: VirtualLayoutOperationsExpand<L0::Coordinates, L0::SourceCoordinates>>
197        + LaunchArg,
198    L1: Layout<SourceCoordinates = L0::Coordinates>
199        + CubeType<ExpandType: VirtualLayoutOperationsExpand<L1::Coordinates, L1::SourceCoordinates>>
200        + LaunchArg,
201> LaunchLayout for (L0, L1)
202{
203    type SourceCoordinates = L0::SourceCoordinates;
204    type Coordinates = L1::Coordinates;
205
206    fn apply<
207        E: CubePrimitive,
208        V: ViewOperationsMut<E, Self::SourceCoordinates> + 'static,
209        IO: SliceVisibility,
210    >(
211        value: (L0::ExpandType, L1::ExpandType),
212        view: V::ExpandType,
213    ) -> ViewExpand<E, Self::Coordinates, IO> {
214        let (l0, l1) = value;
215        let l0 = VirtualLayoutExpand::new::<L0::ExpandType>(l0);
216        let view =
217            VirtualViewMutExpand::<E, L0::Coordinates, L0::SourceCoordinates, V>::new(view, l0);
218        let l1 = VirtualLayoutExpand::new::<L1::ExpandType>(l1);
219        let view = VirtualViewMutExpand::<
220            E,
221            L1::Coordinates,
222            L1::SourceCoordinates,
223            VirtualViewMut<E, L0::Coordinates, L0::SourceCoordinates, V>,
224        >::new(view, l1);
225        ViewExpand::<E, L1::Coordinates, IO> {
226            inner: ViewType::ReadWrite(Arc::new(view)),
227            _io: PhantomData,
228        }
229    }
230}
231
232mod dynamic {
233    use cubecl_common::quant::scheme::QuantScheme;
234
235    use crate::{
236        quant,
237        tensor::{
238            VirtualViewExpand,
239            layout::{
240                VirtualLayout, VirtualLayoutCompilationArg, VirtualLayoutLaunch,
241                as_dyn::{
242                    IntoDyn, IntoDyn2Layout, IntoDyn2LayoutLaunch, IntoDynLayout,
243                    IntoDynLayoutLaunch,
244                },
245            },
246        },
247    };
248
249    use super::*;
250
251    pub enum ViewArg<'a, C: Coordinates, R: Runtime> {
252        Array(ArrayArg<'a, R>, VirtualLayoutLaunch<'a, C, Coords1d, R>),
253        TensorMapTiled(
254            TensorMapArg<'a, R, Tiled>,
255            VirtualLayoutLaunch<'a, C, Sequence<i32>, R>,
256        ),
257        TensorMapIm2col(
258            TensorMapArg<'a, R, Im2col>,
259            VirtualLayoutLaunch<'a, C, (Sequence<i32>, Sequence<i32>), R>,
260        ),
261        Quantized {
262            values: Box<ViewArg<'a, C, R>>,
263            scales: Box<ViewArg<'a, C, R>>,
264            scheme: QuantScheme,
265        },
266    }
267    impl<'a, C: Coordinates, R: Runtime> ViewArg<'a, C, R> {
268        pub fn new<L: Layout<Coordinates = C, SourceCoordinates = Coords1d> + LaunchArg>(
269            buffer: ArrayArg<'a, R>,
270            layout: L::RuntimeArg<'a, R>,
271        ) -> Self {
272            ViewArg::Array(buffer, VirtualLayoutLaunch::new::<L>(layout))
273        }
274
275        pub fn new_tensor_map_tiled<
276            L: Layout<Coordinates = C, SourceCoordinates: IntoDyn> + LaunchArg,
277        >(
278            buffer: TensorMapArg<'a, R, Tiled>,
279            layout: L::RuntimeArg<'a, R>,
280        ) -> Self {
281            let layout = IntoDynLayoutLaunch::new(layout);
282            ViewArg::TensorMapTiled(buffer, VirtualLayoutLaunch::new::<IntoDynLayout<L>>(layout))
283        }
284
285        pub fn new_tensor_map_im2col<
286            L: Layout<Coordinates = C, SourceCoordinates = (P, O)> + LaunchArg,
287            P: IntoDyn,
288            O: IntoDyn,
289        >(
290            buffer: TensorMapArg<'a, R, Im2col>,
291            layout: L::RuntimeArg<'a, R>,
292        ) -> Self {
293            let layout = IntoDyn2LayoutLaunch::new(layout);
294            ViewArg::TensorMapIm2col(
295                buffer,
296                VirtualLayoutLaunch::new::<IntoDyn2Layout<L, P, O>>(layout),
297            )
298        }
299
300        /// Create a new view arg that dequantizes on read.
301        /// The scales layout should take values indices and map them to the corresponding scale.
302        pub fn new_quantized(values: Self, scales: Self, scheme: QuantScheme) -> Self {
303            Self::Quantized {
304                values: Box::new(values),
305                scales: Box::new(scales),
306                scheme,
307            }
308        }
309    }
310    impl<'a, C: Coordinates, R: Runtime> ArgSettings<R> for ViewArg<'a, C, R> {
311        fn register(&self, launcher: &mut KernelLauncher<R>) {
312            match self {
313                ViewArg::Array(buffer, layout) => {
314                    buffer.register(launcher);
315                    layout.register(launcher);
316                }
317                ViewArg::TensorMapTiled(buffer, layout) => {
318                    buffer.register(launcher);
319                    layout.register(launcher);
320                }
321                ViewArg::TensorMapIm2col(buffer, layout) => {
322                    buffer.register(launcher);
323                    layout.register(launcher);
324                }
325                ViewArg::Quantized { values, scales, .. } => {
326                    values.register(launcher);
327                    scales.register(launcher);
328                }
329            }
330        }
331    }
332    #[derive(Clone)]
333    pub enum ViewCompilationArg<C: Coordinates> {
334        Array {
335            buffer: ArrayCompilationArg,
336            layout: VirtualLayoutCompilationArg<C, Coords1d>,
337        },
338        TensorMapTiled {
339            buffer: TensorMapCompilationArg,
340            layout: VirtualLayoutCompilationArg<C, Sequence<i32>>,
341        },
342        TensorMapIm2col {
343            buffer: TensorMapCompilationArg,
344            layout: VirtualLayoutCompilationArg<C, (Sequence<i32>, Sequence<i32>)>,
345        },
346        Quantized {
347            values: Box<ViewCompilationArg<C>>,
348            scales: Box<ViewCompilationArg<C>>,
349            scheme: QuantScheme,
350        },
351    }
352
353    impl<C: Coordinates + 'static> CompilationArg for ViewCompilationArg<C> {}
354    impl<C: Coordinates> Eq for ViewCompilationArg<C> {}
355    impl<C: Coordinates> PartialEq for ViewCompilationArg<C> {
356        fn eq(&self, other: &Self) -> bool {
357            match (self, other) {
358                (
359                    ViewCompilationArg::Array { buffer, layout },
360                    ViewCompilationArg::Array {
361                        buffer: buffer_other,
362                        layout: layout_other,
363                    },
364                ) => buffer == buffer_other && layout == layout_other,
365                (
366                    ViewCompilationArg::TensorMapTiled { buffer, layout },
367                    ViewCompilationArg::TensorMapTiled {
368                        buffer: buffer_other,
369                        layout: layout_other,
370                    },
371                ) => buffer == buffer_other && layout == layout_other,
372                (
373                    ViewCompilationArg::TensorMapIm2col { buffer, layout },
374                    ViewCompilationArg::TensorMapIm2col {
375                        buffer: buffer_other,
376                        layout: layout_other,
377                    },
378                ) => buffer == buffer_other && layout == layout_other,
379                (
380                    ViewCompilationArg::Quantized {
381                        values,
382                        scales,
383                        scheme,
384                    },
385                    ViewCompilationArg::Quantized {
386                        values: values_other,
387                        scales: scales_other,
388                        scheme: scheme_other,
389                    },
390                ) => values == values_other && scales == scales_other && scheme == scheme_other,
391                _ => false,
392            }
393        }
394    }
395    impl<C: Coordinates> core::hash::Hash for ViewCompilationArg<C> {
396        fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
397            match self {
398                ViewCompilationArg::Array { buffer, layout } => {
399                    buffer.hash(ra_expand_state);
400                    layout.hash(ra_expand_state);
401                }
402                ViewCompilationArg::TensorMapTiled { buffer, layout } => {
403                    buffer.hash(ra_expand_state);
404                    layout.hash(ra_expand_state);
405                }
406                ViewCompilationArg::TensorMapIm2col { buffer, layout } => {
407                    buffer.hash(ra_expand_state);
408                    layout.hash(ra_expand_state);
409                }
410                ViewCompilationArg::Quantized {
411                    values,
412                    scales,
413                    scheme,
414                } => {
415                    values.hash(ra_expand_state);
416                    scales.hash(ra_expand_state);
417                    scheme.hash(ra_expand_state);
418                }
419            }
420        }
421    }
422    impl<C: Coordinates> core::fmt::Debug for ViewCompilationArg<C> {
423        fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
424            match self {
425                ViewCompilationArg::Array { buffer, layout } => f
426                    .debug_struct("ArrayView")
427                    .field("buffer", &buffer)
428                    .field("layout", &layout)
429                    .finish(),
430                ViewCompilationArg::TensorMapTiled { buffer, layout } => f
431                    .debug_struct("TensorMapTiledView")
432                    .field("buffer", &buffer)
433                    .field("layout", &layout)
434                    .finish(),
435                ViewCompilationArg::TensorMapIm2col { buffer, layout } => f
436                    .debug_struct("TensorMapIm2colView")
437                    .field("buffer", &buffer)
438                    .field("layout", &layout)
439                    .finish(),
440                ViewCompilationArg::Quantized {
441                    values,
442                    scales,
443                    scheme,
444                } => f
445                    .debug_struct("QuantizedView")
446                    .field("values", &values)
447                    .field("scales", &scales)
448                    .field("scheme", &scheme)
449                    .finish(),
450            }
451        }
452    }
453
454    impl<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility> LaunchArg for View<E, C, IO> {
455        type RuntimeArg<'a, R: Runtime> = ViewArg<'a, C, R>;
456        type CompilationArg = ViewCompilationArg<C>;
457
458        fn compilation_arg<'a, R: Runtime>(
459            runtime_arg: &Self::RuntimeArg<'a, R>,
460        ) -> Self::CompilationArg {
461            match runtime_arg {
462                ViewArg::Array(buffer, layout) => {
463                    let buffer = Array::<E>::compilation_arg(buffer);
464                    let layout = VirtualLayout::<C, Coords1d>::compilation_arg(layout);
465                    ViewCompilationArg::Array { buffer, layout }
466                }
467                ViewArg::TensorMapTiled(buffer, layout) => {
468                    let buffer = TensorMap::<E, Tiled>::compilation_arg(buffer);
469                    let layout = VirtualLayout::<C, Sequence<i32>>::compilation_arg(layout);
470                    ViewCompilationArg::TensorMapTiled { buffer, layout }
471                }
472                ViewArg::TensorMapIm2col(buffer, layout) => {
473                    let buffer = TensorMap::<E, Im2col>::compilation_arg(buffer);
474                    let layout =
475                        VirtualLayout::<C, (Sequence<i32>, Sequence<i32>)>::compilation_arg(layout);
476                    ViewCompilationArg::TensorMapIm2col { buffer, layout }
477                }
478                ViewArg::Quantized {
479                    values,
480                    scales,
481                    scheme,
482                } => {
483                    // Type isn't real, but doesn't matter for compilation arg
484                    let values = View::<E, C, IO>::compilation_arg(values);
485                    let scales = View::<E, C, IO>::compilation_arg(scales);
486                    ViewCompilationArg::Quantized {
487                        values: Box::new(values),
488                        scales: Box::new(scales),
489                        scheme: *scheme,
490                    }
491                }
492            }
493        }
494        fn expand(
495            arg: &Self::CompilationArg,
496            builder: &mut KernelBuilder,
497        ) -> <Self as CubeType>::ExpandType {
498            match arg {
499                ViewCompilationArg::Array { buffer, layout } => {
500                    let buffer = Array::<E>::expand(buffer, builder);
501                    let layout = VirtualLayout::<C, Coords1d>::expand(layout, builder);
502                    let view =
503                        VirtualViewMutExpand::<E, C, Coords1d, Array<E>>::new(buffer, layout);
504                    ViewExpand::<E, C, IO> {
505                        inner: ViewType::ReadWrite(Arc::new(view)),
506                        _io: PhantomData,
507                    }
508                }
509                ViewCompilationArg::TensorMapTiled { buffer, layout } => {
510                    let buffer = TensorMap::<E, Tiled>::expand(buffer, builder);
511                    let layout = VirtualLayout::<C, Sequence<i32>>::expand(layout, builder);
512                    let view =
513                        VirtualViewMutExpand::<E, C, Sequence<i32>, TensorMap<E, Tiled>>::new(
514                            buffer, layout,
515                        );
516                    ViewExpand::<E, C, IO> {
517                        inner: ViewType::ReadWrite(Arc::new(view)),
518                        _io: PhantomData,
519                    }
520                }
521                ViewCompilationArg::TensorMapIm2col { buffer, layout } => {
522                    let buffer = TensorMap::<E, Im2col>::expand(buffer, builder);
523                    let layout =
524                        VirtualLayout::<C, (Sequence<i32>, Sequence<i32>)>::expand(layout, builder);
525                    let view = VirtualViewExpand::<
526                        E,
527                        C,
528                        (Sequence<i32>, Sequence<i32>),
529                        TensorMap<E, Im2col>,
530                    >::new(buffer, layout);
531                    ViewExpand::<E, C, IO> {
532                        inner: ViewType::Read(Arc::new(view)),
533                        _io: PhantomData,
534                    }
535                }
536                ViewCompilationArg::Quantized {
537                    values,
538                    scales,
539                    scheme,
540                } => quant::view::expand_dynamic(values, scales, *scheme, builder),
541            }
542        }
543        fn expand_output(
544            arg: &Self::CompilationArg,
545            builder: &mut KernelBuilder,
546        ) -> <Self as CubeType>::ExpandType {
547            match arg {
548                ViewCompilationArg::Array { buffer, layout } => {
549                    let buffer = Array::<E>::expand_output(buffer, builder);
550                    let layout = VirtualLayout::<C, Coords1d>::expand_output(layout, builder);
551                    let view =
552                        VirtualViewMutExpand::<E, C, Coords1d, Array<E>>::new(buffer, layout);
553                    ViewExpand::<E, C, IO> {
554                        inner: ViewType::ReadWrite(Arc::new(view)),
555                        _io: PhantomData,
556                    }
557                }
558                ViewCompilationArg::TensorMapTiled { buffer, layout } => {
559                    let buffer = TensorMap::<E, Tiled>::expand_output(buffer, builder);
560                    let layout = VirtualLayout::<C, Sequence<i32>>::expand_output(layout, builder);
561                    let view =
562                        VirtualViewMutExpand::<E, C, Sequence<i32>, TensorMap<E, Tiled>>::new(
563                            buffer, layout,
564                        );
565                    ViewExpand::<E, C, IO> {
566                        inner: ViewType::ReadWrite(Arc::new(view)),
567                        _io: PhantomData,
568                    }
569                }
570                ViewCompilationArg::TensorMapIm2col { .. } => {
571                    unimplemented!("Im2col tensor maps can't be used as outputs");
572                }
573                ViewCompilationArg::Quantized { .. } => panic!("Quantized views must be readonly"),
574            }
575        }
576    }
577}
578
579pub use dynamic::*;