cubecl_std/tensor/view/
launch.rs

1use cubecl_core::{prelude::*, unexpanded};
2use std::{
3    marker::PhantomData,
4    ops::{Deref, DerefMut},
5    sync::Arc,
6};
7
8use crate::tensor::{
9    View, ViewExpand, ViewOperationsMut, VirtualViewMut, VirtualViewMutExpand,
10    layout::{Coordinates, Coords1d, Layout, VirtualLayoutExpand, VirtualLayoutOperationsExpand},
11    view::ViewType,
12};
13
14/// Launchable tensor view for ease of use.
15#[derive(Clone)]
16pub struct TypedView<E: CubePrimitive, L: LaunchLayout, IO: SliceVisibility = ReadOnly> {
17    _ty: PhantomData<(E, L, IO)>,
18}
19
20impl<E: CubePrimitive, L: LaunchLayout, IO: SliceVisibility> CubeType for TypedView<E, L, IO> {
21    type ExpandType = ViewExpand<E, L::Coordinates, IO>;
22}
23
24impl<E: CubePrimitive, L: LaunchLayout, IO: SliceVisibility> Deref for TypedView<E, L, IO> {
25    type Target = View<E, L::Coordinates, IO>;
26
27    fn deref(&self) -> &Self::Target {
28        unexpanded!()
29    }
30}
31
32impl<E: CubePrimitive, L: LaunchLayout> DerefMut for TypedView<E, L, ReadWrite> {
33    fn deref_mut(&mut self) -> &mut Self::Target {
34        unexpanded!()
35    }
36}
37
38pub struct TypedViewLaunch<'a, L: LaunchLayout<SourceCoordinates = Coords1d>, R: Runtime> {
39    buffer: ArrayArg<'a, R>,
40    layout: L::RuntimeArg<'a, R>,
41}
42impl<'a, L: LaunchLayout<SourceCoordinates = Coords1d>, R: Runtime> TypedViewLaunch<'a, L, R> {
43    #[allow(clippy::too_many_arguments)]
44    pub fn new(buffer: ArrayArg<'a, R>, layout: L::RuntimeArg<'a, R>) -> Self {
45        Self { buffer, layout }
46    }
47}
48impl<'a, L: LaunchLayout<SourceCoordinates = Coords1d>, R: Runtime> ArgSettings<R>
49    for TypedViewLaunch<'a, L, R>
50{
51    fn register(&self, launcher: &mut KernelLauncher<R>) {
52        self.buffer.register(launcher);
53        self.layout.register(launcher);
54    }
55}
56
57pub struct TypedViewCompilationArg<L: LaunchLayout<SourceCoordinates = Coords1d>> {
58    buffer: ArrayCompilationArg,
59    layout: L::CompilationArg,
60}
61impl<L: LaunchLayout<SourceCoordinates = Coords1d>> Clone for TypedViewCompilationArg<L> {
62    fn clone(&self) -> Self {
63        Self {
64            buffer: self.buffer.clone(),
65            layout: self.layout.clone(),
66        }
67    }
68}
69impl<L: LaunchLayout<SourceCoordinates = Coords1d>> CompilationArg for TypedViewCompilationArg<L> {}
70
71impl<L: LaunchLayout<SourceCoordinates = Coords1d>> core::hash::Hash
72    for TypedViewCompilationArg<L>
73{
74    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
75        self.buffer.hash(state);
76        self.layout.hash(state);
77    }
78}
79impl<L: LaunchLayout<SourceCoordinates = Coords1d>> PartialEq for TypedViewCompilationArg<L> {
80    fn eq(&self, other: &Self) -> bool {
81        self.buffer.eq(&other.buffer) && self.layout.eq(&other.layout)
82    }
83}
84impl<L: LaunchLayout<SourceCoordinates = Coords1d>> core::fmt::Debug
85    for TypedViewCompilationArg<L>
86{
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.debug_struct(stringify!(TensorViewTyped))
89            .field("buffer", &self.buffer)
90            .field("layout", &self.layout)
91            .finish()
92    }
93}
94impl<L: LaunchLayout<SourceCoordinates = Coords1d>> Eq for TypedViewCompilationArg<L> {}
95
96impl<E: CubePrimitive, L: LaunchLayout<SourceCoordinates = Coords1d>, IO: SliceVisibility> LaunchArg
97    for TypedView<E, L, IO>
98{
99    type RuntimeArg<'a, R: Runtime> = TypedViewLaunch<'a, L, R>;
100    type CompilationArg = TypedViewCompilationArg<L>;
101
102    fn compilation_arg<'a, R: Runtime>(
103        runtime_arg: &Self::RuntimeArg<'a, R>,
104    ) -> Self::CompilationArg {
105        TypedViewCompilationArg {
106            buffer: <Array<Line<E>> as LaunchArg>::compilation_arg(&runtime_arg.buffer),
107            layout: L::compilation_arg(&runtime_arg.layout),
108        }
109    }
110
111    fn expand(
112        arg: &Self::CompilationArg,
113        builder: &mut KernelBuilder,
114    ) -> <Self as CubeType>::ExpandType {
115        let buffer = <Array<E> as LaunchArg>::expand(&arg.buffer, builder);
116        L::apply::<E, Array<E>, IO>(L::expand(&arg.layout, builder), buffer)
117    }
118    fn expand_output(
119        arg: &Self::CompilationArg,
120        builder: &mut KernelBuilder,
121    ) -> <Self as CubeType>::ExpandType {
122        let buffer = <Array<E> as LaunchArg>::expand_output(&arg.buffer, builder);
123        L::apply::<E, Array<E>, IO>(L::expand_output(&arg.layout, builder), buffer)
124    }
125}
126
127mod seal {
128    pub trait Sealed {}
129}
130
131pub trait LaunchLayout: LaunchArg + seal::Sealed {
132    type SourceCoordinates: Coordinates;
133    type Coordinates: Coordinates;
134
135    fn apply<
136        E: CubePrimitive,
137        V: ViewOperationsMut<E, Self::SourceCoordinates> + 'static,
138        IO: SliceVisibility,
139    >(
140        value: <Self as CubeType>::ExpandType,
141        view: V::ExpandType,
142    ) -> ViewExpand<E, Self::Coordinates, IO>;
143}
144
145// These unfortunately need to be manually implemented due to the dependencies of each layout on
146// the coordinates of the next. Just stick with two layouts for now and add more implementations as
147// needed.
148
149impl<
150    L: Layout
151        + CubeType<ExpandType: VirtualLayoutOperationsExpand<L::Coordinates, L::SourceCoordinates>>
152        + LaunchArg,
153> seal::Sealed for L
154{
155}
156impl<
157    L: Layout
158        + CubeType<ExpandType: VirtualLayoutOperationsExpand<L::Coordinates, L::SourceCoordinates>>
159        + LaunchArg,
160> LaunchLayout for L
161{
162    type SourceCoordinates = L::SourceCoordinates;
163    type Coordinates = L::Coordinates;
164
165    fn apply<
166        E: CubePrimitive,
167        V: ViewOperationsMut<E, Self::SourceCoordinates> + 'static,
168        IO: SliceVisibility,
169    >(
170        value: L::ExpandType,
171        view: V::ExpandType,
172    ) -> ViewExpand<E, Self::Coordinates, IO> {
173        let l0 = value;
174        let l0 = VirtualLayoutExpand::new::<L::ExpandType>(l0);
175        let view =
176            VirtualViewMutExpand::<E, L::Coordinates, L::SourceCoordinates, V>::new(view, l0);
177        ViewExpand::<E, L::Coordinates, IO> {
178            inner: ViewType::ReadWrite(Arc::new(view)),
179            _io: PhantomData,
180        }
181    }
182}
183
184impl<
185    L0: Layout
186        + CubeType<ExpandType: VirtualLayoutOperationsExpand<L0::Coordinates, L0::SourceCoordinates>>
187        + LaunchArg,
188    L1: Layout<SourceCoordinates = L0::Coordinates>
189        + CubeType<ExpandType: VirtualLayoutOperationsExpand<L1::Coordinates, L1::SourceCoordinates>>
190        + LaunchArg,
191> seal::Sealed for (L0, L1)
192{
193}
194impl<
195    L0: Layout
196        + CubeType<ExpandType: VirtualLayoutOperationsExpand<L0::Coordinates, L0::SourceCoordinates>>
197        + LaunchArg,
198    L1: Layout<SourceCoordinates = L0::Coordinates>
199        + CubeType<ExpandType: VirtualLayoutOperationsExpand<L1::Coordinates, L1::SourceCoordinates>>
200        + LaunchArg,
201> LaunchLayout for (L0, L1)
202{
203    type SourceCoordinates = L0::SourceCoordinates;
204    type Coordinates = L1::Coordinates;
205
206    fn apply<
207        E: CubePrimitive,
208        V: ViewOperationsMut<E, Self::SourceCoordinates> + 'static,
209        IO: SliceVisibility,
210    >(
211        value: (L0::ExpandType, L1::ExpandType),
212        view: V::ExpandType,
213    ) -> ViewExpand<E, Self::Coordinates, IO> {
214        let (l0, l1) = value;
215        let l0 = VirtualLayoutExpand::new::<L0::ExpandType>(l0);
216        let view =
217            VirtualViewMutExpand::<E, L0::Coordinates, L0::SourceCoordinates, V>::new(view, l0);
218        let l1 = VirtualLayoutExpand::new::<L1::ExpandType>(l1);
219        let view = VirtualViewMutExpand::<
220            E,
221            L1::Coordinates,
222            L1::SourceCoordinates,
223            VirtualViewMut<E, L0::Coordinates, L0::SourceCoordinates, V>,
224        >::new(view, l1);
225        ViewExpand::<E, L1::Coordinates, IO> {
226            inner: ViewType::ReadWrite(Arc::new(view)),
227            _io: PhantomData,
228        }
229    }
230}
231
232mod dynamic {
233    use cubecl_common::quant::scheme::QuantScheme;
234
235    use crate::{
236        quant,
237        tensor::layout::{
238            VirtualLayout, VirtualLayoutCompilationArg, VirtualLayoutLaunch,
239            as_dyn::{IntoDyn, IntoDynLayout, IntoDynLayoutLaunch},
240        },
241    };
242
243    use super::*;
244
245    pub enum ViewArg<'a, C: Coordinates, R: Runtime> {
246        Array(ArrayArg<'a, R>, VirtualLayoutLaunch<'a, C, Coords1d, R>),
247        TensorMap(
248            TensorMapArg<'a, R>,
249            VirtualLayoutLaunch<'a, C, Sequence<i32>, R>,
250        ),
251        Quantized {
252            values: Box<ViewArg<'a, C, R>>,
253            scales: Box<ViewArg<'a, C, R>>,
254            scheme: QuantScheme,
255        },
256    }
257    impl<'a, C: Coordinates, R: Runtime> ViewArg<'a, C, R> {
258        pub fn new<L: Layout<Coordinates = C, SourceCoordinates = Coords1d> + LaunchArg>(
259            buffer: ArrayArg<'a, R>,
260            layout: L::RuntimeArg<'a, R>,
261        ) -> Self {
262            ViewArg::Array(buffer, VirtualLayoutLaunch::new::<L>(layout))
263        }
264
265        pub fn new_tensor_map<
266            L: Layout<Coordinates = C, SourceCoordinates: IntoDyn> + LaunchArg,
267        >(
268            buffer: TensorMapArg<'a, R>,
269            layout: L::RuntimeArg<'a, R>,
270        ) -> Self {
271            let layout = IntoDynLayoutLaunch::new(layout);
272            ViewArg::TensorMap(buffer, VirtualLayoutLaunch::new::<IntoDynLayout<L>>(layout))
273        }
274
275        /// Create a new view arg that dequantizes on read.
276        /// The scales layout should take values indices and map them to the corresponding scale.
277        pub fn new_quantized(values: Self, scales: Self, scheme: QuantScheme) -> Self {
278            Self::Quantized {
279                values: Box::new(values),
280                scales: Box::new(scales),
281                scheme,
282            }
283        }
284    }
285    impl<'a, C: Coordinates, R: Runtime> ArgSettings<R> for ViewArg<'a, C, R> {
286        fn register(&self, launcher: &mut KernelLauncher<R>) {
287            match self {
288                ViewArg::Array(buffer, layout) => {
289                    buffer.register(launcher);
290                    layout.register(launcher);
291                }
292                ViewArg::TensorMap(buffer, layout) => {
293                    buffer.register(launcher);
294                    layout.register(launcher);
295                }
296                ViewArg::Quantized { values, scales, .. } => {
297                    values.register(launcher);
298                    scales.register(launcher);
299                }
300            }
301        }
302    }
303    #[derive(Clone)]
304    pub enum ViewCompilationArg<C: Coordinates> {
305        Array {
306            buffer: ArrayCompilationArg,
307            layout: VirtualLayoutCompilationArg<C, Coords1d>,
308        },
309        TensorMap {
310            buffer: TensorMapCompilationArg,
311            layout: VirtualLayoutCompilationArg<C, Sequence<i32>>,
312        },
313        Quantized {
314            values: Box<ViewCompilationArg<C>>,
315            scales: Box<ViewCompilationArg<C>>,
316            scheme: QuantScheme,
317        },
318    }
319
320    impl<C: Coordinates + 'static> CompilationArg for ViewCompilationArg<C> {}
321    impl<C: Coordinates> Eq for ViewCompilationArg<C> {}
322    impl<C: Coordinates> PartialEq for ViewCompilationArg<C> {
323        fn eq(&self, other: &Self) -> bool {
324            match (self, other) {
325                (
326                    ViewCompilationArg::Array { buffer, layout },
327                    ViewCompilationArg::Array {
328                        buffer: buffer_other,
329                        layout: layout_other,
330                    },
331                ) => buffer == buffer_other && layout == layout_other,
332                (
333                    ViewCompilationArg::TensorMap { buffer, layout },
334                    ViewCompilationArg::TensorMap {
335                        buffer: buffer_other,
336                        layout: layout_other,
337                    },
338                ) => buffer == buffer_other && layout == layout_other,
339                (
340                    ViewCompilationArg::Quantized {
341                        values,
342                        scales,
343                        scheme,
344                    },
345                    ViewCompilationArg::Quantized {
346                        values: values_other,
347                        scales: scales_other,
348                        scheme: scheme_other,
349                    },
350                ) => values == values_other && scales == scales_other && scheme == scheme_other,
351                _ => false,
352            }
353        }
354    }
355    impl<C: Coordinates> core::hash::Hash for ViewCompilationArg<C> {
356        fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
357            match self {
358                ViewCompilationArg::Array { buffer, layout } => {
359                    buffer.hash(ra_expand_state);
360                    layout.hash(ra_expand_state);
361                }
362                ViewCompilationArg::TensorMap { buffer, layout } => {
363                    buffer.hash(ra_expand_state);
364                    layout.hash(ra_expand_state);
365                }
366                ViewCompilationArg::Quantized {
367                    values,
368                    scales,
369                    scheme,
370                } => {
371                    values.hash(ra_expand_state);
372                    scales.hash(ra_expand_state);
373                    scheme.hash(ra_expand_state);
374                }
375            }
376        }
377    }
378    impl<C: Coordinates> core::fmt::Debug for ViewCompilationArg<C> {
379        fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
380            match self {
381                ViewCompilationArg::Array { buffer, layout } => f
382                    .debug_struct("ArrayView")
383                    .field("buffer", &buffer)
384                    .field("layout", &layout)
385                    .finish(),
386                ViewCompilationArg::TensorMap { buffer, layout } => f
387                    .debug_struct("TensorMapView")
388                    .field("buffer", &buffer)
389                    .field("layout", &layout)
390                    .finish(),
391                ViewCompilationArg::Quantized {
392                    values,
393                    scales,
394                    scheme,
395                } => f
396                    .debug_struct("QuantizedView")
397                    .field("values", &values)
398                    .field("scales", &scales)
399                    .field("scheme", &scheme)
400                    .finish(),
401            }
402        }
403    }
404
405    impl<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility> LaunchArg for View<E, C, IO> {
406        type RuntimeArg<'a, R: Runtime> = ViewArg<'a, C, R>;
407        type CompilationArg = ViewCompilationArg<C>;
408
409        fn compilation_arg<'a, R: Runtime>(
410            runtime_arg: &Self::RuntimeArg<'a, R>,
411        ) -> Self::CompilationArg {
412            match runtime_arg {
413                ViewArg::Array(buffer, layout) => {
414                    let buffer = Array::<E>::compilation_arg(buffer);
415                    let layout = VirtualLayout::<C, Coords1d>::compilation_arg(layout);
416                    ViewCompilationArg::Array { buffer, layout }
417                }
418                ViewArg::TensorMap(buffer, layout) => {
419                    let buffer = TensorMap::<E>::compilation_arg(buffer);
420                    let layout = VirtualLayout::<C, Sequence<i32>>::compilation_arg(layout);
421                    ViewCompilationArg::TensorMap { buffer, layout }
422                }
423                ViewArg::Quantized {
424                    values,
425                    scales,
426                    scheme,
427                } => {
428                    // Type isn't real, but doesn't matter for compilation arg
429                    let values = View::<E, C, IO>::compilation_arg(values);
430                    let scales = View::<E, C, IO>::compilation_arg(scales);
431                    ViewCompilationArg::Quantized {
432                        values: Box::new(values),
433                        scales: Box::new(scales),
434                        scheme: *scheme,
435                    }
436                }
437            }
438        }
439        fn expand(
440            arg: &Self::CompilationArg,
441            builder: &mut KernelBuilder,
442        ) -> <Self as CubeType>::ExpandType {
443            match arg {
444                ViewCompilationArg::Array { buffer, layout } => {
445                    let buffer = Array::<E>::expand(buffer, builder);
446                    let layout = VirtualLayout::<C, Coords1d>::expand(layout, builder);
447                    let view =
448                        VirtualViewMutExpand::<E, C, Coords1d, Array<E>>::new(buffer, layout);
449                    ViewExpand::<E, C, IO> {
450                        inner: ViewType::ReadWrite(Arc::new(view)),
451                        _io: PhantomData,
452                    }
453                }
454                ViewCompilationArg::TensorMap { buffer, layout } => {
455                    let buffer = TensorMap::<E>::expand(buffer, builder);
456                    let layout = VirtualLayout::<C, Sequence<i32>>::expand(layout, builder);
457                    let view = VirtualViewMutExpand::<E, C, Sequence<i32>, TensorMap<E>>::new(
458                        buffer, layout,
459                    );
460                    ViewExpand::<E, C, IO> {
461                        inner: ViewType::ReadWrite(Arc::new(view)),
462                        _io: PhantomData,
463                    }
464                }
465                ViewCompilationArg::Quantized {
466                    values,
467                    scales,
468                    scheme,
469                } => quant::view::expand_dynamic(values, scales, *scheme, builder),
470            }
471        }
472        fn expand_output(
473            arg: &Self::CompilationArg,
474            builder: &mut KernelBuilder,
475        ) -> <Self as CubeType>::ExpandType {
476            match arg {
477                ViewCompilationArg::Array { buffer, layout } => {
478                    let buffer = Array::<E>::expand_output(buffer, builder);
479                    let layout = VirtualLayout::<C, Coords1d>::expand_output(layout, builder);
480                    let view =
481                        VirtualViewMutExpand::<E, C, Coords1d, Array<E>>::new(buffer, layout);
482                    ViewExpand::<E, C, IO> {
483                        inner: ViewType::ReadWrite(Arc::new(view)),
484                        _io: PhantomData,
485                    }
486                }
487                ViewCompilationArg::TensorMap { buffer, layout } => {
488                    let buffer = TensorMap::<E>::expand_output(buffer, builder);
489                    let layout = VirtualLayout::<C, Sequence<i32>>::expand_output(layout, builder);
490                    let view = VirtualViewMutExpand::<E, C, Sequence<i32>, TensorMap<E>>::new(
491                        buffer, layout,
492                    );
493                    ViewExpand::<E, C, IO> {
494                        inner: ViewType::ReadWrite(Arc::new(view)),
495                        _io: PhantomData,
496                    }
497                }
498                ViewCompilationArg::Quantized { .. } => panic!("Quantized views must be readonly"),
499            }
500        }
501    }
502}
503
504pub use dynamic::*;