1use cubecl_core::prelude::*;
2use std::{marker::PhantomData, ops::Deref, sync::Arc};
3
4use crate::tensor::{
5 View, ViewExpand, VirtualViewMutExpand,
6 layout::{Coordinates, Coords1d, Layout, VirtualLayoutExpand},
7 view::ViewType,
8};
9
10mod layout {
11 use core::{cell::RefCell, fmt::Debug, hash::Hash};
12
13 use alloc::rc::Rc;
14 use cubecl_core::{
15 self as cubecl,
16 format::DebugRaw,
17 hash::{StableHash, StableHasher},
18 prelude::*,
19 zspace::{Shape, Strides, metadata::Metadata},
20 };
21
22 use crate::tensor::layout::LayoutExpand;
23
24 use super::*;
25
26 #[allow(clippy::len_without_is_empty)]
27 pub trait BufferArg: 'static {
28 fn len(&self) -> usize;
29 fn shape(&self) -> &[usize];
30 fn strides(&self) -> &[usize];
31 }
32
33 impl<R: Runtime> BufferArg for TensorArg<R> {
34 fn len(&self) -> usize {
35 self.size()
36 }
37
38 fn shape(&self) -> &[usize] {
39 self.shape()
40 }
41
42 fn strides(&self) -> &[usize] {
43 self.strides()
44 }
45 }
46 impl<R: Runtime> BufferArg for ArrayArg<R> {
47 fn len(&self) -> usize {
48 self.size()
49 }
50
51 fn shape(&self) -> &[usize] {
52 self.shape()
53 }
54
55 fn strides(&self) -> &[usize] {
56 &[1]
57 }
58 }
59 impl<R: Runtime, K: TensorMapKind> BufferArg for TensorMapArg<R, K> {
60 fn len(&self) -> usize {
61 self.tensor.size()
62 }
63
64 fn shape(&self) -> &[usize] {
65 self.tensor.shape()
66 }
67
68 fn strides(&self) -> &[usize] {
69 self.tensor.strides()
70 }
71 }
72
73 impl BufferArg for Metadata {
74 fn len(&self) -> usize {
75 self.shape.num_elements()
76 }
77
78 fn shape(&self) -> &[usize] {
79 &self.shape
80 }
81
82 fn strides(&self) -> &[usize] {
83 &self.strides
84 }
85 }
86
87 pub trait ViewLayoutLaunchArg: CubeType + Send + Sync + 'static {
91 type RuntimeArg<R: Runtime>: Send + Sync;
93 type CompilationArg: CompilationArg;
95
96 fn register<R: Runtime, B: BufferArg>(
97 arg: Self::RuntimeArg<R>,
98 buffer: &B,
99 ty: Type,
100 launcher: &mut KernelLauncher<R>,
101 ) -> Self::CompilationArg;
102
103 fn expand(
105 arg: &Self::CompilationArg,
106 ty: Type,
107 builder: &mut KernelBuilder,
108 ) -> <Self as CubeType>::ExpandType;
109
110 fn expand_output(
112 arg: &Self::CompilationArg,
113 ty: Type,
114 builder: &mut KernelBuilder,
115 ) -> <Self as CubeType>::ExpandType {
116 Self::expand(arg, ty, builder)
117 }
118 }
119
120 impl<T: LaunchArg> ViewLayoutLaunchArg for T {
121 type RuntimeArg<R: Runtime> = <T as LaunchArg>::RuntimeArg<R>;
122 type CompilationArg = <T as LaunchArg>::CompilationArg;
123
124 fn register<R: Runtime, B: BufferArg>(
125 arg: Self::RuntimeArg<R>,
126 _buffer: &B,
127 _ty: Type,
128 launcher: &mut KernelLauncher<R>,
129 ) -> Self::CompilationArg {
130 <T as LaunchArg>::register(arg, launcher)
131 }
132
133 fn expand(
134 arg: &Self::CompilationArg,
135 _ty: Type,
136 builder: &mut KernelBuilder,
137 ) -> <Self as CubeType>::ExpandType {
138 <T as LaunchArg>::expand(arg, builder)
139 }
140
141 fn expand_output(
142 arg: &Self::CompilationArg,
143 _ty: Type,
144 builder: &mut KernelBuilder,
145 ) -> <Self as CubeType>::ExpandType {
146 <T as LaunchArg>::expand_output(arg, builder)
147 }
148 }
149
150 pub struct VirtualViewLayoutLaunch<C: Coordinates, S: Coordinates, B: BufferArg, R: Runtime> {
151 _ty: core::marker::PhantomData<R>,
152 #[allow(clippy::type_complexity)]
153 register: Box<
154 dyn FnOnce(&B, Type, &mut KernelLauncher<R>) -> VirtualViewLayoutCompilationArg<C, S>
155 + Send
156 + Sync,
157 >,
158 }
159
160 impl<C: Coordinates, S: Coordinates, B: BufferArg, R: Runtime> VirtualViewLayoutLaunch<C, S, B, R> {
161 pub fn new<L: Layout<Coordinates = C, SourceCoordinates = S> + ViewLayoutLaunchArg>(
162 layout: L::RuntimeArg<R>,
163 ) -> Self {
164 Self {
165 _ty: PhantomData,
166 register: Box::new(move |buffer, ty, launcher| {
167 let comp_arg = L::register::<R, B>(layout, buffer, ty, launcher);
168 let comp_arg_2 = comp_arg.clone();
169 let expand = Rc::new(RefCell::new(
170 move |ty: Type, builder: &mut KernelBuilder, is_out: bool| {
171 let expand = match is_out {
172 true => L::expand_output(&comp_arg_2, ty, builder),
173 false => L::expand(&comp_arg_2, ty, builder),
174 };
175 VirtualLayoutExpand::new(expand)
176 },
177 ));
178 VirtualViewLayoutCompilationArg::new(comp_arg, expand)
179 }),
180 }
181 }
182
183 pub fn register(
184 self,
185 buffer: &B,
186 ty: Type,
187 launcher: &mut KernelLauncher<R>,
188 ) -> VirtualViewLayoutCompilationArg<C, S> {
189 (self.register)(buffer, ty, launcher)
190 }
191 }
192
193 type ExpandFn<C, S> =
194 Rc<RefCell<dyn FnMut(Type, &mut KernelBuilder, bool) -> VirtualLayoutExpand<C, S> + Send>>;
195
196 #[derive(Clone)]
197 pub struct VirtualViewLayoutCompilationArg<C: Coordinates, S: Coordinates> {
198 type_name: String,
199 debug: Rc<dyn core::fmt::Debug>,
200 hash: StableHash,
201 expand: ExpandFn<C, S>,
202 }
203
204 unsafe impl<C: Coordinates, S: Coordinates> Send for VirtualViewLayoutCompilationArg<C, S> {}
206 unsafe impl<C: Coordinates, S: Coordinates> Sync for VirtualViewLayoutCompilationArg<C, S> {}
207
208 impl<C: Coordinates, S: Coordinates> VirtualViewLayoutCompilationArg<C, S> {
209 pub fn new<L: CompilationArg + 'static>(arg: L, expand: ExpandFn<C, S>) -> Self {
210 let hash = StableHasher::hash_one(&arg);
213 Self {
214 type_name: core::any::type_name::<L>().to_string(),
215 debug: Rc::new(arg),
216 hash,
217 expand,
218 }
219 }
220
221 pub fn expand(&self, ty: Type, builder: &mut KernelBuilder) -> VirtualLayoutExpand<C, S> {
222 let mut expand = self.expand.borrow_mut();
223 (expand)(ty, builder, false)
224 }
225
226 pub fn expand_output(
227 &self,
228 ty: Type,
229 builder: &mut KernelBuilder,
230 ) -> VirtualLayoutExpand<C, S> {
231 let mut expand = self.expand.borrow_mut();
232 (expand)(ty, builder, true)
233 }
234 }
235
236 impl<C: Coordinates, S: Coordinates> PartialEq for VirtualViewLayoutCompilationArg<C, S> {
237 fn eq(&self, other: &Self) -> bool {
238 self.type_name == other.type_name && self.hash == other.hash
239 }
240 }
241 impl<C: Coordinates, S: Coordinates> Eq for VirtualViewLayoutCompilationArg<C, S> {}
242
243 impl<C: Coordinates, S: Coordinates> core::hash::Hash for VirtualViewLayoutCompilationArg<C, S> {
244 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
245 self.type_name.hash(state);
246 self.hash.hash(state);
247 }
248 }
249
250 impl<C: Coordinates, S: Coordinates> core::fmt::Debug for VirtualViewLayoutCompilationArg<C, S> {
251 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252 f.debug_struct(stringify!(VirtualLayout))
253 .field("type", &DebugRaw(&self.type_name))
254 .field("value", &self.debug)
255 .finish()
256 }
257 }
258
259 #[derive(CubeType)]
260 pub struct ConcreteLayout<L: Layout + ViewLayoutLaunchArg> {
261 value: L,
262 }
263
264 #[cube]
265 impl<L: Layout + ViewLayoutLaunchArg> Layout for ConcreteLayout<L> {
266 type Coordinates = L::Coordinates;
267 type SourceCoordinates = L::SourceCoordinates;
268
269 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
270 self.value.to_source_pos(pos)
271 }
272
273 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
274 self.value.to_source_pos_checked(pos)
275 }
276
277 fn shape(&self) -> Self::Coordinates {
278 self.value.shape()
279 }
280
281 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
282 self.value.is_in_bounds(pos)
283 }
284 }
285
286 impl<L: Layout + ViewLayoutLaunchArg> Deref for ConcreteLayout<L> {
287 type Target = L;
288
289 fn deref(&self) -> &Self::Target {
290 &self.value
291 }
292 }
293
294 impl<L: Layout + ViewLayoutLaunchArg> Deref for ConcreteLayoutExpand<L> {
295 type Target = <L as CubeType>::ExpandType;
296
297 fn deref(&self) -> &Self::Target {
298 &self.value
299 }
300 }
301
302 pub struct ConcreteLayoutLaunch<L: Layout + ViewLayoutLaunchArg, R: Runtime> {
303 meta: Metadata,
304 ty: Type,
305 value: L::RuntimeArg<R>,
306 }
307
308 impl<L: Layout + ViewLayoutLaunchArg, R: Runtime> ConcreteLayoutLaunch<L, R> {
309 pub fn new(meta: Metadata, ty: Type, value: L::RuntimeArg<R>) -> Self {
310 Self { meta, ty, value }
311 }
312
313 pub fn from_handle(handle: &TensorBinding<R>, ty: Type, value: L::RuntimeArg<R>) -> Self {
314 Self {
315 meta: Metadata {
316 shape: handle.shape.clone(),
317 strides: handle.strides.clone(),
318 },
319 ty,
320 value,
321 }
322 }
323
324 pub fn from_shape_strides(
325 shape: Shape,
326 strides: Strides,
327 ty: Type,
328 value: L::RuntimeArg<R>,
329 ) -> Self {
330 Self {
331 meta: Metadata { shape, strides },
332 ty,
333 value,
334 }
335 }
336 }
337
338 pub struct ConcreteLayoutCompilationArg<L: Layout + ViewLayoutLaunchArg> {
339 ty: Type,
340 value: L::CompilationArg,
341 }
342
343 impl<L: Layout + ViewLayoutLaunchArg> Debug for ConcreteLayoutCompilationArg<L> {
344 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345 f.debug_struct("ConcreteLayoutCompilationArg")
346 .field("ty", &self.ty)
347 .field("value", &self.value)
348 .finish()
349 }
350 }
351
352 impl<L: Layout + ViewLayoutLaunchArg> Hash for ConcreteLayoutCompilationArg<L> {
353 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
354 self.ty.hash(state);
355 self.value.hash(state);
356 }
357 }
358
359 impl<L: Layout + ViewLayoutLaunchArg> Eq for ConcreteLayoutCompilationArg<L> {}
360 impl<L: Layout + ViewLayoutLaunchArg> PartialEq for ConcreteLayoutCompilationArg<L> {
361 fn eq(&self, other: &Self) -> bool {
362 self.ty == other.ty && self.value == other.value
363 }
364 }
365
366 impl<L: Layout + ViewLayoutLaunchArg> Clone for ConcreteLayoutCompilationArg<L> {
367 fn clone(&self) -> Self {
368 Self {
369 ty: self.ty,
370 value: self.value.clone(),
371 }
372 }
373 }
374
375 impl<L: Layout + ViewLayoutLaunchArg> LaunchArg for ConcreteLayout<L> {
376 type RuntimeArg<R: Runtime> = ConcreteLayoutLaunch<L, R>;
377 type CompilationArg = ConcreteLayoutCompilationArg<L>;
378
379 fn register<R: Runtime>(
380 arg: Self::RuntimeArg<R>,
381 launcher: &mut KernelLauncher<R>,
382 ) -> Self::CompilationArg {
383 ConcreteLayoutCompilationArg {
384 value: L::register(arg.value, &arg.meta, arg.ty, launcher),
385 ty: arg.ty,
386 }
387 }
388
389 fn expand(
390 arg: &Self::CompilationArg,
391 builder: &mut KernelBuilder,
392 ) -> <Self as CubeType>::ExpandType {
393 ConcreteLayoutExpand {
394 value: L::expand(&arg.value, arg.ty, builder),
395 }
396 }
397
398 fn expand_output(
399 arg: &Self::CompilationArg,
400 builder: &mut KernelBuilder,
401 ) -> <Self as CubeType>::ExpandType {
402 ConcreteLayoutExpand {
403 value: L::expand_output(&arg.value, arg.ty, builder),
404 }
405 }
406 }
407}
408
409pub use layout::*;
410
411mod dynamic {
412 use cubecl_common::quant::scheme::QuantScheme;
413
414 use crate::{
415 quant::{
416 self,
417 view::{RegisterDynamic, run_with_quant_type},
418 },
419 tensor::{
420 VirtualViewExpand,
421 launch::layout::{ViewLayoutLaunchArg, VirtualViewLayoutLaunch},
422 layout::as_dyn::{IntoDyn, IntoDyn2Layout, IntoDynLayout},
423 },
424 };
425
426 use super::*;
427
428 #[allow(clippy::type_complexity)]
429 pub enum ViewArg<C: Coordinates, R: Runtime> {
430 Array(
431 ArrayArg<R>,
432 VirtualViewLayoutLaunch<C, Coords1d, ArrayArg<R>, R>,
433 ),
434 Tensor(
435 TensorArg<R>,
436 VirtualViewLayoutLaunch<C, Coords1d, TensorArg<R>, R>,
437 ),
438 TensorMapTiled(
439 TensorMapArg<R, Tiled>,
440 VirtualViewLayoutLaunch<C, Sequence<i32>, TensorMapArg<R, Tiled>, R>,
441 ),
442 TensorMapIm2col(
443 TensorMapArg<R, Im2col>,
444 VirtualViewLayoutLaunch<C, (Sequence<i32>, Sequence<i32>), TensorMapArg<R, Im2col>, R>,
445 ),
446 Quantized {
447 values: Box<ViewArg<C, R>>,
448 scales: Box<ViewArg<C, R>>,
449 scheme: QuantScheme,
450 },
451 }
452
453 impl<C: Coordinates, R: Runtime> ViewArg<C, R> {
454 pub fn new_array<
455 L: Layout<Coordinates = C, SourceCoordinates = Coords1d> + ViewLayoutLaunchArg,
456 >(
457 buffer: ArrayArg<R>,
458 layout: L::RuntimeArg<R>,
459 ) -> Self {
460 let layout = VirtualViewLayoutLaunch::new::<L>(layout);
461 ViewArg::Array(buffer, layout)
462 }
463
464 pub fn new_tensor<
465 L: Layout<Coordinates = C, SourceCoordinates = Coords1d> + ViewLayoutLaunchArg,
466 >(
467 buffer: TensorArg<R>,
468 layout: L::RuntimeArg<R>,
469 ) -> Self {
470 let layout = VirtualViewLayoutLaunch::new::<L>(layout);
471 ViewArg::Tensor(buffer, layout)
472 }
473
474 pub fn new_tensor_map_tiled<
475 L: Layout<Coordinates = C, SourceCoordinates: IntoDyn> + ViewLayoutLaunchArg,
476 >(
477 buffer: TensorMapArg<R, Tiled>,
478 layout: L::RuntimeArg<R>,
479 ) -> ViewArg<C, R> {
480 let layout = VirtualViewLayoutLaunch::new::<IntoDynLayout<L>>(layout);
481 ViewArg::TensorMapTiled(buffer, layout)
482 }
483
484 pub fn new_tensor_map_im2col<
485 L: Layout<Coordinates = C, SourceCoordinates = (P, O)> + ViewLayoutLaunchArg,
486 P: IntoDyn,
487 O: IntoDyn,
488 >(
489 buffer: TensorMapArg<R, Im2col>,
490 layout: L::RuntimeArg<R>,
491 ) -> ViewArg<C, R> {
492 let layout = VirtualViewLayoutLaunch::new::<IntoDyn2Layout<L, P, O>>(layout);
493 ViewArg::TensorMapIm2col(buffer, layout)
494 }
495
496 pub fn new_quantized(values: Self, scales: Self, scheme: QuantScheme) -> Self {
499 Self::Quantized {
500 values: Box::new(values),
501 scales: Box::new(scales),
502 scheme,
503 }
504 }
505 }
506 #[derive(Clone)]
507 pub enum ViewCompilationArg<C: Coordinates> {
508 Array {
509 buffer: ArrayCompilationArg,
510 layout: VirtualViewLayoutCompilationArg<C, Coords1d>,
511 },
512 TensorMapTiled {
513 buffer: (),
514 layout: VirtualViewLayoutCompilationArg<C, Sequence<i32>>,
515 },
516 TensorMapIm2col {
517 buffer: (),
518 layout: VirtualViewLayoutCompilationArg<C, (Sequence<i32>, Sequence<i32>)>,
519 },
520 Quantized {
521 values: Box<ViewCompilationArg<C>>,
522 scales: Box<ViewCompilationArg<C>>,
523 scheme: QuantScheme,
524 },
525 }
526
527 impl<C: Coordinates> Eq for ViewCompilationArg<C> {}
528 impl<C: Coordinates> PartialEq for ViewCompilationArg<C> {
529 fn eq(&self, other: &Self) -> bool {
530 match (self, other) {
531 (
532 ViewCompilationArg::Array { buffer, layout },
533 ViewCompilationArg::Array {
534 buffer: buffer_other,
535 layout: layout_other,
536 },
537 ) => buffer == buffer_other && layout == layout_other,
538 (
539 ViewCompilationArg::TensorMapTiled { buffer, layout },
540 ViewCompilationArg::TensorMapTiled {
541 buffer: buffer_other,
542 layout: layout_other,
543 },
544 ) => buffer == buffer_other && layout == layout_other,
545 (
546 ViewCompilationArg::TensorMapIm2col { buffer, layout },
547 ViewCompilationArg::TensorMapIm2col {
548 buffer: buffer_other,
549 layout: layout_other,
550 },
551 ) => buffer == buffer_other && layout == layout_other,
552 (
553 ViewCompilationArg::Quantized {
554 values,
555 scales,
556 scheme,
557 },
558 ViewCompilationArg::Quantized {
559 values: values_other,
560 scales: scales_other,
561 scheme: scheme_other,
562 },
563 ) => values == values_other && scales == scales_other && scheme == scheme_other,
564 _ => false,
565 }
566 }
567 }
568 impl<C: Coordinates> core::hash::Hash for ViewCompilationArg<C> {
569 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
570 match self {
571 ViewCompilationArg::Array { buffer, layout } => {
572 buffer.hash(ra_expand_state);
573 layout.hash(ra_expand_state);
574 }
575 ViewCompilationArg::TensorMapTiled { buffer, layout } => {
576 buffer.hash(ra_expand_state);
577 layout.hash(ra_expand_state);
578 }
579 ViewCompilationArg::TensorMapIm2col { buffer, layout } => {
580 buffer.hash(ra_expand_state);
581 layout.hash(ra_expand_state);
582 }
583 ViewCompilationArg::Quantized {
584 values,
585 scales,
586 scheme,
587 } => {
588 values.hash(ra_expand_state);
589 scales.hash(ra_expand_state);
590 scheme.hash(ra_expand_state);
591 }
592 }
593 }
594 }
595 impl<C: Coordinates> core::fmt::Debug for ViewCompilationArg<C> {
596 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
597 match self {
598 ViewCompilationArg::Array { buffer, layout } => f
599 .debug_struct("ArrayView")
600 .field("buffer", &buffer)
601 .field("layout", &layout)
602 .finish(),
603 ViewCompilationArg::TensorMapTiled { buffer, layout } => f
604 .debug_struct("TensorMapTiledView")
605 .field("buffer", &buffer)
606 .field("layout", &layout)
607 .finish(),
608 ViewCompilationArg::TensorMapIm2col { buffer, layout } => f
609 .debug_struct("TensorMapIm2colView")
610 .field("buffer", &buffer)
611 .field("layout", &layout)
612 .finish(),
613 ViewCompilationArg::Quantized {
614 values,
615 scales,
616 scheme,
617 } => f
618 .debug_struct("QuantizedView")
619 .field("values", &values)
620 .field("scales", &scales)
621 .field("scheme", &scheme)
622 .finish(),
623 }
624 }
625 }
626
627 impl<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility> LaunchArg for View<E, C, IO> {
628 type RuntimeArg<R: Runtime> = ViewArg<C, R>;
629 type CompilationArg = ViewCompilationArg<C>;
630
631 fn register<R: Runtime>(
632 arg: Self::RuntimeArg<R>,
633 launcher: &mut KernelLauncher<R>,
634 ) -> Self::CompilationArg {
635 let ty = launcher.with_scope(|scope| E::as_type(scope));
636 match arg {
637 ViewArg::Array(buffer, layout) => ViewCompilationArg::Array {
638 layout: layout.register(&buffer, ty, launcher),
639 buffer: <Array<E> as LaunchArg>::register(buffer, launcher),
640 },
641 ViewArg::Tensor(buffer, layout) => ViewCompilationArg::Array {
642 layout: layout.register(&buffer, ty, launcher),
643 buffer: <Array<E> as LaunchArg>::register(buffer.into_array_arg(), launcher),
644 },
645 ViewArg::TensorMapTiled(buffer, layout) => ViewCompilationArg::TensorMapTiled {
646 layout: layout.register(&buffer, ty, launcher),
647 buffer: <TensorMap<E, Tiled> as LaunchArg>::register(buffer, launcher),
648 },
649 ViewArg::TensorMapIm2col(buffer, layout) => ViewCompilationArg::TensorMapIm2col {
650 layout: layout.register(&buffer, ty, launcher),
651 buffer: <TensorMap<E, Im2col> as LaunchArg>::register(buffer, launcher),
652 },
653 ViewArg::Quantized {
654 values,
655 scales,
656 scheme,
657 } => {
658 let register = RegisterDynamic {
659 values: *values,
660 scales: *scales,
661 scheme,
662 launcher,
663 _ty: PhantomData::<E>,
664 };
665 run_with_quant_type(register, scheme)
666 }
667 }
668 }
669 fn expand(
670 arg: &Self::CompilationArg,
671 builder: &mut KernelBuilder,
672 ) -> <Self as CubeType>::ExpandType {
673 let ty = E::as_type(&builder.scope);
674 match arg {
675 ViewCompilationArg::Array { buffer, layout } => {
676 let layout = layout.expand(ty, builder);
677 let buffer = <Array<E> as LaunchArg>::expand(buffer, builder);
678 let view =
679 VirtualViewMutExpand::<E, C, Coords1d, Array<E>>::new(buffer, layout);
680 ViewExpand::<E, C, IO> {
681 inner: ViewType::ReadWrite(Arc::new(view)),
682 _io: PhantomData,
683 }
684 }
685 ViewCompilationArg::TensorMapTiled { buffer, layout } => {
686 let layout = layout.expand(ty, builder);
687 let buffer = <TensorMap<E, Tiled> as LaunchArg>::expand(buffer, builder);
688 let view =
689 VirtualViewMutExpand::<E, C, Sequence<i32>, TensorMap<E, Tiled>>::new(
690 buffer, layout,
691 );
692 ViewExpand::<E, C, IO> {
693 inner: ViewType::ReadWrite(Arc::new(view)),
694 _io: PhantomData,
695 }
696 }
697 ViewCompilationArg::TensorMapIm2col { buffer, layout } => {
698 let layout = layout.expand(ty, builder);
699 let buffer = <TensorMap<E, Im2col> as LaunchArg>::expand(buffer, builder);
700 let view = VirtualViewExpand::<
701 E,
702 C,
703 (Sequence<i32>, Sequence<i32>),
704 TensorMap<E, Im2col>,
705 >::new(buffer, layout);
706 ViewExpand::<E, C, IO> {
707 inner: ViewType::Read(Arc::new(view)),
708 _io: PhantomData,
709 }
710 }
711 ViewCompilationArg::Quantized {
712 values,
713 scales,
714 scheme,
715 } => quant::view::expand_dynamic(values, scales, *scheme, builder),
716 }
717 }
718 fn expand_output(
719 arg: &Self::CompilationArg,
720 builder: &mut KernelBuilder,
721 ) -> <Self as CubeType>::ExpandType {
722 let ty = E::as_type(&builder.scope);
723 match arg {
724 ViewCompilationArg::Array { buffer, layout } => {
725 let layout = layout.expand_output(ty, builder);
726 let buffer = <Array<E> as LaunchArg>::expand_output(buffer, builder);
727 let view =
728 VirtualViewMutExpand::<E, C, Coords1d, Array<E>>::new(buffer, layout);
729 ViewExpand::<E, C, IO> {
730 inner: ViewType::ReadWrite(Arc::new(view)),
731 _io: PhantomData,
732 }
733 }
734 ViewCompilationArg::TensorMapTiled { buffer, layout } => {
735 let layout = layout.expand_output(ty, builder);
736 let buffer = <TensorMap<E, Tiled> as LaunchArg>::expand_output(buffer, builder);
737 let view =
738 VirtualViewMutExpand::<E, C, Sequence<i32>, TensorMap<E, Tiled>>::new(
739 buffer, layout,
740 );
741 ViewExpand::<E, C, IO> {
742 inner: ViewType::ReadWrite(Arc::new(view)),
743 _io: PhantomData,
744 }
745 }
746 ViewCompilationArg::TensorMapIm2col { .. } => {
747 unimplemented!("Im2col tensor maps can't be used as outputs");
748 }
749 ViewCompilationArg::Quantized { .. } => panic!("Quantized views must be readonly"),
750 }
751 }
752 }
753}
754
755pub use dynamic::*;