Skip to main content

cubecl_std/tensor/view/
launch.rs

1use cubecl_core::prelude::*;
2use std::{marker::PhantomData, ops::Deref, sync::Arc};
3
4use crate::tensor::{
5    View, ViewExpand, VirtualViewMutExpand,
6    layout::{Coordinates, Coords1d, Layout, VirtualLayoutExpand},
7    view::ViewType,
8};
9
10mod layout {
11    use core::{cell::RefCell, fmt::Debug, hash::Hash};
12
13    use alloc::rc::Rc;
14    use cubecl_core::{
15        self as cubecl,
16        format::DebugRaw,
17        hash::{StableHash, StableHasher},
18        prelude::*,
19        zspace::{Shape, Strides, metadata::Metadata},
20    };
21
22    use crate::tensor::layout::LayoutExpand;
23
24    use super::*;
25
26    #[allow(clippy::len_without_is_empty)]
27    pub trait BufferArg: 'static {
28        fn len(&self) -> usize;
29        fn shape(&self) -> &[usize];
30        fn strides(&self) -> &[usize];
31    }
32
33    impl<R: Runtime> BufferArg for TensorArg<R> {
34        fn len(&self) -> usize {
35            self.size()
36        }
37
38        fn shape(&self) -> &[usize] {
39            self.shape()
40        }
41
42        fn strides(&self) -> &[usize] {
43            self.strides()
44        }
45    }
46    impl<R: Runtime> BufferArg for ArrayArg<R> {
47        fn len(&self) -> usize {
48            self.size()
49        }
50
51        fn shape(&self) -> &[usize] {
52            self.shape()
53        }
54
55        fn strides(&self) -> &[usize] {
56            &[1]
57        }
58    }
59    impl<R: Runtime, K: TensorMapKind> BufferArg for TensorMapArg<R, K> {
60        fn len(&self) -> usize {
61            self.tensor.size()
62        }
63
64        fn shape(&self) -> &[usize] {
65            self.tensor.shape()
66        }
67
68        fn strides(&self) -> &[usize] {
69            self.tensor.strides()
70        }
71    }
72
73    impl BufferArg for Metadata {
74        fn len(&self) -> usize {
75            self.shape.num_elements()
76        }
77
78        fn shape(&self) -> &[usize] {
79            &self.shape
80        }
81
82        fn strides(&self) -> &[usize] {
83            &self.strides
84        }
85    }
86
87    /// Special launch arg that gets the handle and types of the view, to allow inferring launch
88    /// state based on type/handle metadata, avoiding duplication. All `LaunchArg`s also implement
89    /// this trait.
90    pub trait ViewLayoutLaunchArg: CubeType + Send + Sync + 'static {
91        /// The runtime argument for the kernel.
92        type RuntimeArg<R: Runtime>: Send + Sync;
93        /// Compilation argument.
94        type CompilationArg: CompilationArg;
95
96        fn register<R: Runtime, B: BufferArg>(
97            arg: Self::RuntimeArg<R>,
98            buffer: &B,
99            ty: Type,
100            launcher: &mut KernelLauncher<R>,
101        ) -> Self::CompilationArg;
102
103        /// Register an input variable during compilation that fill the [`KernelBuilder`].
104        fn expand(
105            arg: &Self::CompilationArg,
106            ty: Type,
107            builder: &mut KernelBuilder,
108        ) -> <Self as CubeType>::ExpandType;
109
110        /// Register an output variable during compilation that fill the [`KernelBuilder`].
111        fn expand_output(
112            arg: &Self::CompilationArg,
113            ty: Type,
114            builder: &mut KernelBuilder,
115        ) -> <Self as CubeType>::ExpandType {
116            Self::expand(arg, ty, builder)
117        }
118    }
119
120    impl<T: LaunchArg> ViewLayoutLaunchArg for T {
121        type RuntimeArg<R: Runtime> = <T as LaunchArg>::RuntimeArg<R>;
122        type CompilationArg = <T as LaunchArg>::CompilationArg;
123
124        fn register<R: Runtime, B: BufferArg>(
125            arg: Self::RuntimeArg<R>,
126            _buffer: &B,
127            _ty: Type,
128            launcher: &mut KernelLauncher<R>,
129        ) -> Self::CompilationArg {
130            <T as LaunchArg>::register(arg, launcher)
131        }
132
133        fn expand(
134            arg: &Self::CompilationArg,
135            _ty: Type,
136            builder: &mut KernelBuilder,
137        ) -> <Self as CubeType>::ExpandType {
138            <T as LaunchArg>::expand(arg, builder)
139        }
140
141        fn expand_output(
142            arg: &Self::CompilationArg,
143            _ty: Type,
144            builder: &mut KernelBuilder,
145        ) -> <Self as CubeType>::ExpandType {
146            <T as LaunchArg>::expand_output(arg, builder)
147        }
148    }
149
150    pub struct VirtualViewLayoutLaunch<C: Coordinates, S: Coordinates, B: BufferArg, R: Runtime> {
151        _ty: core::marker::PhantomData<R>,
152        #[allow(clippy::type_complexity)]
153        register: Box<
154            dyn FnOnce(&B, Type, &mut KernelLauncher<R>) -> VirtualViewLayoutCompilationArg<C, S>
155                + Send
156                + Sync,
157        >,
158    }
159
160    impl<C: Coordinates, S: Coordinates, B: BufferArg, R: Runtime> VirtualViewLayoutLaunch<C, S, B, R> {
161        pub fn new<L: Layout<Coordinates = C, SourceCoordinates = S> + ViewLayoutLaunchArg>(
162            layout: L::RuntimeArg<R>,
163        ) -> Self {
164            Self {
165                _ty: PhantomData,
166                register: Box::new(move |buffer, ty, launcher| {
167                    let comp_arg = L::register::<R, B>(layout, buffer, ty, launcher);
168                    let comp_arg_2 = comp_arg.clone();
169                    let expand = Rc::new(RefCell::new(
170                        move |ty: Type, builder: &mut KernelBuilder, is_out: bool| {
171                            let expand = match is_out {
172                                true => L::expand_output(&comp_arg_2, ty, builder),
173                                false => L::expand(&comp_arg_2, ty, builder),
174                            };
175                            VirtualLayoutExpand::new(expand)
176                        },
177                    ));
178                    VirtualViewLayoutCompilationArg::new(comp_arg, expand)
179                }),
180            }
181        }
182
183        pub fn register(
184            self,
185            buffer: &B,
186            ty: Type,
187            launcher: &mut KernelLauncher<R>,
188        ) -> VirtualViewLayoutCompilationArg<C, S> {
189            (self.register)(buffer, ty, launcher)
190        }
191    }
192
193    type ExpandFn<C, S> =
194        Rc<RefCell<dyn FnMut(Type, &mut KernelBuilder, bool) -> VirtualLayoutExpand<C, S> + Send>>;
195
196    #[derive(Clone)]
197    pub struct VirtualViewLayoutCompilationArg<C: Coordinates, S: Coordinates> {
198        type_name: String,
199        debug: Rc<dyn core::fmt::Debug>,
200        hash: StableHash,
201        expand: ExpandFn<C, S>,
202    }
203
204    // SAFETY: The struct is readonly, so `Sync` is safe to implement
205    unsafe impl<C: Coordinates, S: Coordinates> Send for VirtualViewLayoutCompilationArg<C, S> {}
206    unsafe impl<C: Coordinates, S: Coordinates> Sync for VirtualViewLayoutCompilationArg<C, S> {}
207
208    impl<C: Coordinates, S: Coordinates> VirtualViewLayoutCompilationArg<C, S> {
209        pub fn new<L: CompilationArg + 'static>(arg: L, expand: ExpandFn<C, S>) -> Self {
210            // Hash ahead of time so we don't need to store the actual data, which would be far
211            // more complex
212            let hash = StableHasher::hash_one(&arg);
213            Self {
214                type_name: core::any::type_name::<L>().to_string(),
215                debug: Rc::new(arg),
216                hash,
217                expand,
218            }
219        }
220
221        pub fn expand(&self, ty: Type, builder: &mut KernelBuilder) -> VirtualLayoutExpand<C, S> {
222            let mut expand = self.expand.borrow_mut();
223            (expand)(ty, builder, false)
224        }
225
226        pub fn expand_output(
227            &self,
228            ty: Type,
229            builder: &mut KernelBuilder,
230        ) -> VirtualLayoutExpand<C, S> {
231            let mut expand = self.expand.borrow_mut();
232            (expand)(ty, builder, true)
233        }
234    }
235
236    impl<C: Coordinates, S: Coordinates> PartialEq for VirtualViewLayoutCompilationArg<C, S> {
237        fn eq(&self, other: &Self) -> bool {
238            self.type_name == other.type_name && self.hash == other.hash
239        }
240    }
241    impl<C: Coordinates, S: Coordinates> Eq for VirtualViewLayoutCompilationArg<C, S> {}
242
243    impl<C: Coordinates, S: Coordinates> core::hash::Hash for VirtualViewLayoutCompilationArg<C, S> {
244        fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
245            self.type_name.hash(state);
246            self.hash.hash(state);
247        }
248    }
249
250    impl<C: Coordinates, S: Coordinates> core::fmt::Debug for VirtualViewLayoutCompilationArg<C, S> {
251        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252            f.debug_struct(stringify!(VirtualLayout))
253                .field("type", &DebugRaw(&self.type_name))
254                .field("value", &self.debug)
255                .finish()
256        }
257    }
258
259    #[derive(CubeType)]
260    pub struct ConcreteLayout<L: Layout + ViewLayoutLaunchArg> {
261        value: L,
262    }
263
264    #[cube]
265    impl<L: Layout + ViewLayoutLaunchArg> Layout for ConcreteLayout<L> {
266        type Coordinates = L::Coordinates;
267        type SourceCoordinates = L::SourceCoordinates;
268
269        fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
270            self.value.to_source_pos(pos)
271        }
272
273        fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
274            self.value.to_source_pos_checked(pos)
275        }
276
277        fn shape(&self) -> Self::Coordinates {
278            self.value.shape()
279        }
280
281        fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
282            self.value.is_in_bounds(pos)
283        }
284    }
285
286    impl<L: Layout + ViewLayoutLaunchArg> Deref for ConcreteLayout<L> {
287        type Target = L;
288
289        fn deref(&self) -> &Self::Target {
290            &self.value
291        }
292    }
293
294    impl<L: Layout + ViewLayoutLaunchArg> Deref for ConcreteLayoutExpand<L> {
295        type Target = <L as CubeType>::ExpandType;
296
297        fn deref(&self) -> &Self::Target {
298            &self.value
299        }
300    }
301
302    pub struct ConcreteLayoutLaunch<L: Layout + ViewLayoutLaunchArg, R: Runtime> {
303        meta: Metadata,
304        ty: Type,
305        value: L::RuntimeArg<R>,
306    }
307
308    impl<L: Layout + ViewLayoutLaunchArg, R: Runtime> ConcreteLayoutLaunch<L, R> {
309        pub fn new(meta: Metadata, ty: Type, value: L::RuntimeArg<R>) -> Self {
310            Self { meta, ty, value }
311        }
312
313        pub fn from_handle(handle: &TensorBinding<R>, ty: Type, value: L::RuntimeArg<R>) -> Self {
314            Self {
315                meta: Metadata {
316                    shape: handle.shape.clone(),
317                    strides: handle.strides.clone(),
318                },
319                ty,
320                value,
321            }
322        }
323
324        pub fn from_shape_strides(
325            shape: Shape,
326            strides: Strides,
327            ty: Type,
328            value: L::RuntimeArg<R>,
329        ) -> Self {
330            Self {
331                meta: Metadata { shape, strides },
332                ty,
333                value,
334            }
335        }
336    }
337
338    pub struct ConcreteLayoutCompilationArg<L: Layout + ViewLayoutLaunchArg> {
339        ty: Type,
340        value: L::CompilationArg,
341    }
342
343    impl<L: Layout + ViewLayoutLaunchArg> Debug for ConcreteLayoutCompilationArg<L> {
344        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345            f.debug_struct("ConcreteLayoutCompilationArg")
346                .field("ty", &self.ty)
347                .field("value", &self.value)
348                .finish()
349        }
350    }
351
352    impl<L: Layout + ViewLayoutLaunchArg> Hash for ConcreteLayoutCompilationArg<L> {
353        fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
354            self.ty.hash(state);
355            self.value.hash(state);
356        }
357    }
358
359    impl<L: Layout + ViewLayoutLaunchArg> Eq for ConcreteLayoutCompilationArg<L> {}
360    impl<L: Layout + ViewLayoutLaunchArg> PartialEq for ConcreteLayoutCompilationArg<L> {
361        fn eq(&self, other: &Self) -> bool {
362            self.ty == other.ty && self.value == other.value
363        }
364    }
365
366    impl<L: Layout + ViewLayoutLaunchArg> Clone for ConcreteLayoutCompilationArg<L> {
367        fn clone(&self) -> Self {
368            Self {
369                ty: self.ty,
370                value: self.value.clone(),
371            }
372        }
373    }
374
375    impl<L: Layout + ViewLayoutLaunchArg> LaunchArg for ConcreteLayout<L> {
376        type RuntimeArg<R: Runtime> = ConcreteLayoutLaunch<L, R>;
377        type CompilationArg = ConcreteLayoutCompilationArg<L>;
378
379        fn register<R: Runtime>(
380            arg: Self::RuntimeArg<R>,
381            launcher: &mut KernelLauncher<R>,
382        ) -> Self::CompilationArg {
383            ConcreteLayoutCompilationArg {
384                value: L::register(arg.value, &arg.meta, arg.ty, launcher),
385                ty: arg.ty,
386            }
387        }
388
389        fn expand(
390            arg: &Self::CompilationArg,
391            builder: &mut KernelBuilder,
392        ) -> <Self as CubeType>::ExpandType {
393            ConcreteLayoutExpand {
394                value: L::expand(&arg.value, arg.ty, builder),
395            }
396        }
397
398        fn expand_output(
399            arg: &Self::CompilationArg,
400            builder: &mut KernelBuilder,
401        ) -> <Self as CubeType>::ExpandType {
402            ConcreteLayoutExpand {
403                value: L::expand_output(&arg.value, arg.ty, builder),
404            }
405        }
406    }
407}
408
409pub use layout::*;
410
411mod dynamic {
412    use cubecl_common::quant::scheme::QuantScheme;
413
414    use crate::{
415        quant::{
416            self,
417            view::{RegisterDynamic, run_with_quant_type},
418        },
419        tensor::{
420            VirtualViewExpand,
421            launch::layout::{ViewLayoutLaunchArg, VirtualViewLayoutLaunch},
422            layout::as_dyn::{IntoDyn, IntoDyn2Layout, IntoDynLayout},
423        },
424    };
425
426    use super::*;
427
428    #[allow(clippy::type_complexity)]
429    pub enum ViewArg<C: Coordinates, R: Runtime> {
430        Array(
431            ArrayArg<R>,
432            VirtualViewLayoutLaunch<C, Coords1d, ArrayArg<R>, R>,
433        ),
434        Tensor(
435            TensorArg<R>,
436            VirtualViewLayoutLaunch<C, Coords1d, TensorArg<R>, R>,
437        ),
438        TensorMapTiled(
439            TensorMapArg<R, Tiled>,
440            VirtualViewLayoutLaunch<C, Sequence<i32>, TensorMapArg<R, Tiled>, R>,
441        ),
442        TensorMapIm2col(
443            TensorMapArg<R, Im2col>,
444            VirtualViewLayoutLaunch<C, (Sequence<i32>, Sequence<i32>), TensorMapArg<R, Im2col>, R>,
445        ),
446        Quantized {
447            values: Box<ViewArg<C, R>>,
448            scales: Box<ViewArg<C, R>>,
449            scheme: QuantScheme,
450        },
451    }
452
453    impl<C: Coordinates, R: Runtime> ViewArg<C, R> {
454        pub fn new_array<
455            L: Layout<Coordinates = C, SourceCoordinates = Coords1d> + ViewLayoutLaunchArg,
456        >(
457            buffer: ArrayArg<R>,
458            layout: L::RuntimeArg<R>,
459        ) -> Self {
460            let layout = VirtualViewLayoutLaunch::new::<L>(layout);
461            ViewArg::Array(buffer, layout)
462        }
463
464        pub fn new_tensor<
465            L: Layout<Coordinates = C, SourceCoordinates = Coords1d> + ViewLayoutLaunchArg,
466        >(
467            buffer: TensorArg<R>,
468            layout: L::RuntimeArg<R>,
469        ) -> Self {
470            let layout = VirtualViewLayoutLaunch::new::<L>(layout);
471            ViewArg::Tensor(buffer, layout)
472        }
473
474        pub fn new_tensor_map_tiled<
475            L: Layout<Coordinates = C, SourceCoordinates: IntoDyn> + ViewLayoutLaunchArg,
476        >(
477            buffer: TensorMapArg<R, Tiled>,
478            layout: L::RuntimeArg<R>,
479        ) -> ViewArg<C, R> {
480            let layout = VirtualViewLayoutLaunch::new::<IntoDynLayout<L>>(layout);
481            ViewArg::TensorMapTiled(buffer, layout)
482        }
483
484        pub fn new_tensor_map_im2col<
485            L: Layout<Coordinates = C, SourceCoordinates = (P, O)> + ViewLayoutLaunchArg,
486            P: IntoDyn,
487            O: IntoDyn,
488        >(
489            buffer: TensorMapArg<R, Im2col>,
490            layout: L::RuntimeArg<R>,
491        ) -> ViewArg<C, R> {
492            let layout = VirtualViewLayoutLaunch::new::<IntoDyn2Layout<L, P, O>>(layout);
493            ViewArg::TensorMapIm2col(buffer, layout)
494        }
495
496        /// Create a new view arg that dequantizes on read.
497        /// The scales layout should take values indices and map them to the corresponding scale.
498        pub fn new_quantized(values: Self, scales: Self, scheme: QuantScheme) -> Self {
499            Self::Quantized {
500                values: Box::new(values),
501                scales: Box::new(scales),
502                scheme,
503            }
504        }
505    }
506    #[derive(Clone)]
507    pub enum ViewCompilationArg<C: Coordinates> {
508        Array {
509            buffer: ArrayCompilationArg,
510            layout: VirtualViewLayoutCompilationArg<C, Coords1d>,
511        },
512        TensorMapTiled {
513            buffer: (),
514            layout: VirtualViewLayoutCompilationArg<C, Sequence<i32>>,
515        },
516        TensorMapIm2col {
517            buffer: (),
518            layout: VirtualViewLayoutCompilationArg<C, (Sequence<i32>, Sequence<i32>)>,
519        },
520        Quantized {
521            values: Box<ViewCompilationArg<C>>,
522            scales: Box<ViewCompilationArg<C>>,
523            scheme: QuantScheme,
524        },
525    }
526
527    impl<C: Coordinates> Eq for ViewCompilationArg<C> {}
528    impl<C: Coordinates> PartialEq for ViewCompilationArg<C> {
529        fn eq(&self, other: &Self) -> bool {
530            match (self, other) {
531                (
532                    ViewCompilationArg::Array { buffer, layout },
533                    ViewCompilationArg::Array {
534                        buffer: buffer_other,
535                        layout: layout_other,
536                    },
537                ) => buffer == buffer_other && layout == layout_other,
538                (
539                    ViewCompilationArg::TensorMapTiled { buffer, layout },
540                    ViewCompilationArg::TensorMapTiled {
541                        buffer: buffer_other,
542                        layout: layout_other,
543                    },
544                ) => buffer == buffer_other && layout == layout_other,
545                (
546                    ViewCompilationArg::TensorMapIm2col { buffer, layout },
547                    ViewCompilationArg::TensorMapIm2col {
548                        buffer: buffer_other,
549                        layout: layout_other,
550                    },
551                ) => buffer == buffer_other && layout == layout_other,
552                (
553                    ViewCompilationArg::Quantized {
554                        values,
555                        scales,
556                        scheme,
557                    },
558                    ViewCompilationArg::Quantized {
559                        values: values_other,
560                        scales: scales_other,
561                        scheme: scheme_other,
562                    },
563                ) => values == values_other && scales == scales_other && scheme == scheme_other,
564                _ => false,
565            }
566        }
567    }
568    impl<C: Coordinates> core::hash::Hash for ViewCompilationArg<C> {
569        fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
570            match self {
571                ViewCompilationArg::Array { buffer, layout } => {
572                    buffer.hash(ra_expand_state);
573                    layout.hash(ra_expand_state);
574                }
575                ViewCompilationArg::TensorMapTiled { buffer, layout } => {
576                    buffer.hash(ra_expand_state);
577                    layout.hash(ra_expand_state);
578                }
579                ViewCompilationArg::TensorMapIm2col { buffer, layout } => {
580                    buffer.hash(ra_expand_state);
581                    layout.hash(ra_expand_state);
582                }
583                ViewCompilationArg::Quantized {
584                    values,
585                    scales,
586                    scheme,
587                } => {
588                    values.hash(ra_expand_state);
589                    scales.hash(ra_expand_state);
590                    scheme.hash(ra_expand_state);
591                }
592            }
593        }
594    }
595    impl<C: Coordinates> core::fmt::Debug for ViewCompilationArg<C> {
596        fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
597            match self {
598                ViewCompilationArg::Array { buffer, layout } => f
599                    .debug_struct("ArrayView")
600                    .field("buffer", &buffer)
601                    .field("layout", &layout)
602                    .finish(),
603                ViewCompilationArg::TensorMapTiled { buffer, layout } => f
604                    .debug_struct("TensorMapTiledView")
605                    .field("buffer", &buffer)
606                    .field("layout", &layout)
607                    .finish(),
608                ViewCompilationArg::TensorMapIm2col { buffer, layout } => f
609                    .debug_struct("TensorMapIm2colView")
610                    .field("buffer", &buffer)
611                    .field("layout", &layout)
612                    .finish(),
613                ViewCompilationArg::Quantized {
614                    values,
615                    scales,
616                    scheme,
617                } => f
618                    .debug_struct("QuantizedView")
619                    .field("values", &values)
620                    .field("scales", &scales)
621                    .field("scheme", &scheme)
622                    .finish(),
623            }
624        }
625    }
626
627    impl<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility> LaunchArg for View<E, C, IO> {
628        type RuntimeArg<R: Runtime> = ViewArg<C, R>;
629        type CompilationArg = ViewCompilationArg<C>;
630
631        fn register<R: Runtime>(
632            arg: Self::RuntimeArg<R>,
633            launcher: &mut KernelLauncher<R>,
634        ) -> Self::CompilationArg {
635            let ty = launcher.with_scope(|scope| E::as_type(scope));
636            match arg {
637                ViewArg::Array(buffer, layout) => ViewCompilationArg::Array {
638                    layout: layout.register(&buffer, ty, launcher),
639                    buffer: <Array<E> as LaunchArg>::register(buffer, launcher),
640                },
641                ViewArg::Tensor(buffer, layout) => ViewCompilationArg::Array {
642                    layout: layout.register(&buffer, ty, launcher),
643                    buffer: <Array<E> as LaunchArg>::register(buffer.into_array_arg(), launcher),
644                },
645                ViewArg::TensorMapTiled(buffer, layout) => ViewCompilationArg::TensorMapTiled {
646                    layout: layout.register(&buffer, ty, launcher),
647                    buffer: <TensorMap<E, Tiled> as LaunchArg>::register(buffer, launcher),
648                },
649                ViewArg::TensorMapIm2col(buffer, layout) => ViewCompilationArg::TensorMapIm2col {
650                    layout: layout.register(&buffer, ty, launcher),
651                    buffer: <TensorMap<E, Im2col> as LaunchArg>::register(buffer, launcher),
652                },
653                ViewArg::Quantized {
654                    values,
655                    scales,
656                    scheme,
657                } => {
658                    let register = RegisterDynamic {
659                        values: *values,
660                        scales: *scales,
661                        scheme,
662                        launcher,
663                        _ty: PhantomData::<E>,
664                    };
665                    run_with_quant_type(register, scheme)
666                }
667            }
668        }
669        fn expand(
670            arg: &Self::CompilationArg,
671            builder: &mut KernelBuilder,
672        ) -> <Self as CubeType>::ExpandType {
673            let ty = E::as_type(&builder.scope);
674            match arg {
675                ViewCompilationArg::Array { buffer, layout } => {
676                    let layout = layout.expand(ty, builder);
677                    let buffer = <Array<E> as LaunchArg>::expand(buffer, builder);
678                    let view =
679                        VirtualViewMutExpand::<E, C, Coords1d, Array<E>>::new(buffer, layout);
680                    ViewExpand::<E, C, IO> {
681                        inner: ViewType::ReadWrite(Arc::new(view)),
682                        _io: PhantomData,
683                    }
684                }
685                ViewCompilationArg::TensorMapTiled { buffer, layout } => {
686                    let layout = layout.expand(ty, builder);
687                    let buffer = <TensorMap<E, Tiled> as LaunchArg>::expand(buffer, builder);
688                    let view =
689                        VirtualViewMutExpand::<E, C, Sequence<i32>, TensorMap<E, Tiled>>::new(
690                            buffer, layout,
691                        );
692                    ViewExpand::<E, C, IO> {
693                        inner: ViewType::ReadWrite(Arc::new(view)),
694                        _io: PhantomData,
695                    }
696                }
697                ViewCompilationArg::TensorMapIm2col { buffer, layout } => {
698                    let layout = layout.expand(ty, builder);
699                    let buffer = <TensorMap<E, Im2col> as LaunchArg>::expand(buffer, builder);
700                    let view = VirtualViewExpand::<
701                        E,
702                        C,
703                        (Sequence<i32>, Sequence<i32>),
704                        TensorMap<E, Im2col>,
705                    >::new(buffer, layout);
706                    ViewExpand::<E, C, IO> {
707                        inner: ViewType::Read(Arc::new(view)),
708                        _io: PhantomData,
709                    }
710                }
711                ViewCompilationArg::Quantized {
712                    values,
713                    scales,
714                    scheme,
715                } => quant::view::expand_dynamic(values, scales, *scheme, builder),
716            }
717        }
718        fn expand_output(
719            arg: &Self::CompilationArg,
720            builder: &mut KernelBuilder,
721        ) -> <Self as CubeType>::ExpandType {
722            let ty = E::as_type(&builder.scope);
723            match arg {
724                ViewCompilationArg::Array { buffer, layout } => {
725                    let layout = layout.expand_output(ty, builder);
726                    let buffer = <Array<E> as LaunchArg>::expand_output(buffer, builder);
727                    let view =
728                        VirtualViewMutExpand::<E, C, Coords1d, Array<E>>::new(buffer, layout);
729                    ViewExpand::<E, C, IO> {
730                        inner: ViewType::ReadWrite(Arc::new(view)),
731                        _io: PhantomData,
732                    }
733                }
734                ViewCompilationArg::TensorMapTiled { buffer, layout } => {
735                    let layout = layout.expand_output(ty, builder);
736                    let buffer = <TensorMap<E, Tiled> as LaunchArg>::expand_output(buffer, builder);
737                    let view =
738                        VirtualViewMutExpand::<E, C, Sequence<i32>, TensorMap<E, Tiled>>::new(
739                            buffer, layout,
740                        );
741                    ViewExpand::<E, C, IO> {
742                        inner: ViewType::ReadWrite(Arc::new(view)),
743                        _io: PhantomData,
744                    }
745                }
746                ViewCompilationArg::TensorMapIm2col { .. } => {
747                    unimplemented!("Im2col tensor maps can't be used as outputs");
748                }
749                ViewCompilationArg::Quantized { .. } => panic!("Quantized views must be readonly"),
750            }
751        }
752    }
753}
754
755pub use dynamic::*;