1use cubecl_core::{prelude::*, unexpanded};
2use std::{
3 marker::PhantomData,
4 ops::{Deref, DerefMut},
5 sync::Arc,
6};
7
8use crate::tensor::{
9 View, ViewExpand, ViewOperationsMut, VirtualViewMut, VirtualViewMutExpand,
10 layout::{Coordinates, Coords1d, Layout, VirtualLayoutExpand, VirtualLayoutOperationsExpand},
11 view::ViewType,
12};
13
14#[derive(Clone)]
16pub struct TypedView<E: CubePrimitive, L: LaunchLayout, IO: SliceVisibility = ReadOnly> {
17 _ty: PhantomData<(E, L, IO)>,
18}
19
20impl<E: CubePrimitive, L: LaunchLayout, IO: SliceVisibility> CubeType for TypedView<E, L, IO> {
21 type ExpandType = ViewExpand<E, L::Coordinates, IO>;
22}
23
24impl<E: CubePrimitive, L: LaunchLayout, IO: SliceVisibility> Deref for TypedView<E, L, IO> {
25 type Target = View<E, L::Coordinates, IO>;
26
27 fn deref(&self) -> &Self::Target {
28 unexpanded!()
29 }
30}
31
32impl<E: CubePrimitive, L: LaunchLayout> DerefMut for TypedView<E, L, ReadWrite> {
33 fn deref_mut(&mut self) -> &mut Self::Target {
34 unexpanded!()
35 }
36}
37
38pub struct TypedViewLaunch<'a, L: LaunchLayout<SourceCoordinates = Coords1d>, R: Runtime> {
39 buffer: ArrayArg<'a, R>,
40 layout: L::RuntimeArg<'a, R>,
41}
42impl<'a, L: LaunchLayout<SourceCoordinates = Coords1d>, R: Runtime> TypedViewLaunch<'a, L, R> {
43 #[allow(clippy::too_many_arguments)]
44 pub fn new(buffer: ArrayArg<'a, R>, layout: L::RuntimeArg<'a, R>) -> Self {
45 Self { buffer, layout }
46 }
47}
48impl<'a, L: LaunchLayout<SourceCoordinates = Coords1d>, R: Runtime> ArgSettings<R>
49 for TypedViewLaunch<'a, L, R>
50{
51 fn register(&self, launcher: &mut KernelLauncher<R>) {
52 self.buffer.register(launcher);
53 self.layout.register(launcher);
54 }
55}
56
57pub struct TypedViewCompilationArg<L: LaunchLayout<SourceCoordinates = Coords1d>> {
58 buffer: ArrayCompilationArg,
59 layout: L::CompilationArg,
60}
61impl<L: LaunchLayout<SourceCoordinates = Coords1d>> Clone for TypedViewCompilationArg<L> {
62 fn clone(&self) -> Self {
63 Self {
64 buffer: self.buffer.clone(),
65 layout: self.layout.clone(),
66 }
67 }
68}
69impl<L: LaunchLayout<SourceCoordinates = Coords1d>> CompilationArg for TypedViewCompilationArg<L> {}
70
71impl<L: LaunchLayout<SourceCoordinates = Coords1d>> core::hash::Hash
72 for TypedViewCompilationArg<L>
73{
74 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
75 self.buffer.hash(state);
76 self.layout.hash(state);
77 }
78}
79impl<L: LaunchLayout<SourceCoordinates = Coords1d>> PartialEq for TypedViewCompilationArg<L> {
80 fn eq(&self, other: &Self) -> bool {
81 self.buffer.eq(&other.buffer) && self.layout.eq(&other.layout)
82 }
83}
84impl<L: LaunchLayout<SourceCoordinates = Coords1d>> core::fmt::Debug
85 for TypedViewCompilationArg<L>
86{
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 f.debug_struct(stringify!(TensorViewTyped))
89 .field("buffer", &self.buffer)
90 .field("layout", &self.layout)
91 .finish()
92 }
93}
94impl<L: LaunchLayout<SourceCoordinates = Coords1d>> Eq for TypedViewCompilationArg<L> {}
95
96impl<E: CubePrimitive, L: LaunchLayout<SourceCoordinates = Coords1d>, IO: SliceVisibility> LaunchArg
97 for TypedView<E, L, IO>
98{
99 type RuntimeArg<'a, R: Runtime> = TypedViewLaunch<'a, L, R>;
100 type CompilationArg = TypedViewCompilationArg<L>;
101
102 fn compilation_arg<'a, R: Runtime>(
103 runtime_arg: &Self::RuntimeArg<'a, R>,
104 ) -> Self::CompilationArg {
105 TypedViewCompilationArg {
106 buffer: <Array<Line<E>> as LaunchArg>::compilation_arg(&runtime_arg.buffer),
107 layout: L::compilation_arg(&runtime_arg.layout),
108 }
109 }
110
111 fn expand(
112 arg: &Self::CompilationArg,
113 builder: &mut KernelBuilder,
114 ) -> <Self as CubeType>::ExpandType {
115 let buffer = <Array<E> as LaunchArg>::expand(&arg.buffer, builder);
116 L::apply::<E, Array<E>, IO>(L::expand(&arg.layout, builder), buffer)
117 }
118 fn expand_output(
119 arg: &Self::CompilationArg,
120 builder: &mut KernelBuilder,
121 ) -> <Self as CubeType>::ExpandType {
122 let buffer = <Array<E> as LaunchArg>::expand_output(&arg.buffer, builder);
123 L::apply::<E, Array<E>, IO>(L::expand_output(&arg.layout, builder), buffer)
124 }
125}
126
127mod seal {
128 pub trait Sealed {}
129}
130
131pub trait LaunchLayout: LaunchArg + seal::Sealed {
132 type SourceCoordinates: Coordinates;
133 type Coordinates: Coordinates;
134
135 fn apply<
136 E: CubePrimitive,
137 V: ViewOperationsMut<E, Self::SourceCoordinates> + 'static,
138 IO: SliceVisibility,
139 >(
140 value: <Self as CubeType>::ExpandType,
141 view: V::ExpandType,
142 ) -> ViewExpand<E, Self::Coordinates, IO>;
143}
144
145impl<
150 L: Layout
151 + CubeType<ExpandType: VirtualLayoutOperationsExpand<L::Coordinates, L::SourceCoordinates>>
152 + LaunchArg,
153> seal::Sealed for L
154{
155}
156impl<
157 L: Layout
158 + CubeType<ExpandType: VirtualLayoutOperationsExpand<L::Coordinates, L::SourceCoordinates>>
159 + LaunchArg,
160> LaunchLayout for L
161{
162 type SourceCoordinates = L::SourceCoordinates;
163 type Coordinates = L::Coordinates;
164
165 fn apply<
166 E: CubePrimitive,
167 V: ViewOperationsMut<E, Self::SourceCoordinates> + 'static,
168 IO: SliceVisibility,
169 >(
170 value: L::ExpandType,
171 view: V::ExpandType,
172 ) -> ViewExpand<E, Self::Coordinates, IO> {
173 let l0 = value;
174 let l0 = VirtualLayoutExpand::new::<L::ExpandType>(l0);
175 let view =
176 VirtualViewMutExpand::<E, L::Coordinates, L::SourceCoordinates, V>::new(view, l0);
177 ViewExpand::<E, L::Coordinates, IO> {
178 inner: ViewType::ReadWrite(Arc::new(view)),
179 _io: PhantomData,
180 }
181 }
182}
183
184impl<
185 L0: Layout
186 + CubeType<ExpandType: VirtualLayoutOperationsExpand<L0::Coordinates, L0::SourceCoordinates>>
187 + LaunchArg,
188 L1: Layout<SourceCoordinates = L0::Coordinates>
189 + CubeType<ExpandType: VirtualLayoutOperationsExpand<L1::Coordinates, L1::SourceCoordinates>>
190 + LaunchArg,
191> seal::Sealed for (L0, L1)
192{
193}
194impl<
195 L0: Layout
196 + CubeType<ExpandType: VirtualLayoutOperationsExpand<L0::Coordinates, L0::SourceCoordinates>>
197 + LaunchArg,
198 L1: Layout<SourceCoordinates = L0::Coordinates>
199 + CubeType<ExpandType: VirtualLayoutOperationsExpand<L1::Coordinates, L1::SourceCoordinates>>
200 + LaunchArg,
201> LaunchLayout for (L0, L1)
202{
203 type SourceCoordinates = L0::SourceCoordinates;
204 type Coordinates = L1::Coordinates;
205
206 fn apply<
207 E: CubePrimitive,
208 V: ViewOperationsMut<E, Self::SourceCoordinates> + 'static,
209 IO: SliceVisibility,
210 >(
211 value: (L0::ExpandType, L1::ExpandType),
212 view: V::ExpandType,
213 ) -> ViewExpand<E, Self::Coordinates, IO> {
214 let (l0, l1) = value;
215 let l0 = VirtualLayoutExpand::new::<L0::ExpandType>(l0);
216 let view =
217 VirtualViewMutExpand::<E, L0::Coordinates, L0::SourceCoordinates, V>::new(view, l0);
218 let l1 = VirtualLayoutExpand::new::<L1::ExpandType>(l1);
219 let view = VirtualViewMutExpand::<
220 E,
221 L1::Coordinates,
222 L1::SourceCoordinates,
223 VirtualViewMut<E, L0::Coordinates, L0::SourceCoordinates, V>,
224 >::new(view, l1);
225 ViewExpand::<E, L1::Coordinates, IO> {
226 inner: ViewType::ReadWrite(Arc::new(view)),
227 _io: PhantomData,
228 }
229 }
230}
231
232mod dynamic {
233 use cubecl_common::quant::scheme::QuantScheme;
234
235 use crate::{
236 quant,
237 tensor::{
238 VirtualViewExpand,
239 layout::{
240 VirtualLayout, VirtualLayoutCompilationArg, VirtualLayoutLaunch,
241 as_dyn::{
242 IntoDyn, IntoDyn2Layout, IntoDyn2LayoutLaunch, IntoDynLayout,
243 IntoDynLayoutLaunch,
244 },
245 },
246 },
247 };
248
249 use super::*;
250
251 pub enum ViewArg<'a, C: Coordinates, R: Runtime> {
252 Array(ArrayArg<'a, R>, VirtualLayoutLaunch<'a, C, Coords1d, R>),
253 TensorMapTiled(
254 TensorMapArg<'a, R, Tiled>,
255 VirtualLayoutLaunch<'a, C, Sequence<i32>, R>,
256 ),
257 TensorMapIm2col(
258 TensorMapArg<'a, R, Im2col>,
259 VirtualLayoutLaunch<'a, C, (Sequence<i32>, Sequence<i32>), R>,
260 ),
261 Quantized {
262 values: Box<ViewArg<'a, C, R>>,
263 scales: Box<ViewArg<'a, C, R>>,
264 scheme: QuantScheme,
265 },
266 }
267 impl<'a, C: Coordinates, R: Runtime> ViewArg<'a, C, R> {
268 pub fn new<L: Layout<Coordinates = C, SourceCoordinates = Coords1d> + LaunchArg>(
269 buffer: ArrayArg<'a, R>,
270 layout: L::RuntimeArg<'a, R>,
271 ) -> Self {
272 ViewArg::Array(buffer, VirtualLayoutLaunch::new::<L>(layout))
273 }
274
275 pub fn new_tensor_map_tiled<
276 L: Layout<Coordinates = C, SourceCoordinates: IntoDyn> + LaunchArg,
277 >(
278 buffer: TensorMapArg<'a, R, Tiled>,
279 layout: L::RuntimeArg<'a, R>,
280 ) -> Self {
281 let layout = IntoDynLayoutLaunch::new(layout);
282 ViewArg::TensorMapTiled(buffer, VirtualLayoutLaunch::new::<IntoDynLayout<L>>(layout))
283 }
284
285 pub fn new_tensor_map_im2col<
286 L: Layout<Coordinates = C, SourceCoordinates = (P, O)> + LaunchArg,
287 P: IntoDyn,
288 O: IntoDyn,
289 >(
290 buffer: TensorMapArg<'a, R, Im2col>,
291 layout: L::RuntimeArg<'a, R>,
292 ) -> Self {
293 let layout = IntoDyn2LayoutLaunch::new(layout);
294 ViewArg::TensorMapIm2col(
295 buffer,
296 VirtualLayoutLaunch::new::<IntoDyn2Layout<L, P, O>>(layout),
297 )
298 }
299
300 pub fn new_quantized(values: Self, scales: Self, scheme: QuantScheme) -> Self {
303 Self::Quantized {
304 values: Box::new(values),
305 scales: Box::new(scales),
306 scheme,
307 }
308 }
309 }
310 impl<'a, C: Coordinates, R: Runtime> ArgSettings<R> for ViewArg<'a, C, R> {
311 fn register(&self, launcher: &mut KernelLauncher<R>) {
312 match self {
313 ViewArg::Array(buffer, layout) => {
314 buffer.register(launcher);
315 layout.register(launcher);
316 }
317 ViewArg::TensorMapTiled(buffer, layout) => {
318 buffer.register(launcher);
319 layout.register(launcher);
320 }
321 ViewArg::TensorMapIm2col(buffer, layout) => {
322 buffer.register(launcher);
323 layout.register(launcher);
324 }
325 ViewArg::Quantized { values, scales, .. } => {
326 values.register(launcher);
327 scales.register(launcher);
328 }
329 }
330 }
331 }
332 #[derive(Clone)]
333 pub enum ViewCompilationArg<C: Coordinates> {
334 Array {
335 buffer: ArrayCompilationArg,
336 layout: VirtualLayoutCompilationArg<C, Coords1d>,
337 },
338 TensorMapTiled {
339 buffer: TensorMapCompilationArg,
340 layout: VirtualLayoutCompilationArg<C, Sequence<i32>>,
341 },
342 TensorMapIm2col {
343 buffer: TensorMapCompilationArg,
344 layout: VirtualLayoutCompilationArg<C, (Sequence<i32>, Sequence<i32>)>,
345 },
346 Quantized {
347 values: Box<ViewCompilationArg<C>>,
348 scales: Box<ViewCompilationArg<C>>,
349 scheme: QuantScheme,
350 },
351 }
352
353 impl<C: Coordinates + 'static> CompilationArg for ViewCompilationArg<C> {}
354 impl<C: Coordinates> Eq for ViewCompilationArg<C> {}
355 impl<C: Coordinates> PartialEq for ViewCompilationArg<C> {
356 fn eq(&self, other: &Self) -> bool {
357 match (self, other) {
358 (
359 ViewCompilationArg::Array { buffer, layout },
360 ViewCompilationArg::Array {
361 buffer: buffer_other,
362 layout: layout_other,
363 },
364 ) => buffer == buffer_other && layout == layout_other,
365 (
366 ViewCompilationArg::TensorMapTiled { buffer, layout },
367 ViewCompilationArg::TensorMapTiled {
368 buffer: buffer_other,
369 layout: layout_other,
370 },
371 ) => buffer == buffer_other && layout == layout_other,
372 (
373 ViewCompilationArg::TensorMapIm2col { buffer, layout },
374 ViewCompilationArg::TensorMapIm2col {
375 buffer: buffer_other,
376 layout: layout_other,
377 },
378 ) => buffer == buffer_other && layout == layout_other,
379 (
380 ViewCompilationArg::Quantized {
381 values,
382 scales,
383 scheme,
384 },
385 ViewCompilationArg::Quantized {
386 values: values_other,
387 scales: scales_other,
388 scheme: scheme_other,
389 },
390 ) => values == values_other && scales == scales_other && scheme == scheme_other,
391 _ => false,
392 }
393 }
394 }
395 impl<C: Coordinates> core::hash::Hash for ViewCompilationArg<C> {
396 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
397 match self {
398 ViewCompilationArg::Array { buffer, layout } => {
399 buffer.hash(ra_expand_state);
400 layout.hash(ra_expand_state);
401 }
402 ViewCompilationArg::TensorMapTiled { buffer, layout } => {
403 buffer.hash(ra_expand_state);
404 layout.hash(ra_expand_state);
405 }
406 ViewCompilationArg::TensorMapIm2col { buffer, layout } => {
407 buffer.hash(ra_expand_state);
408 layout.hash(ra_expand_state);
409 }
410 ViewCompilationArg::Quantized {
411 values,
412 scales,
413 scheme,
414 } => {
415 values.hash(ra_expand_state);
416 scales.hash(ra_expand_state);
417 scheme.hash(ra_expand_state);
418 }
419 }
420 }
421 }
422 impl<C: Coordinates> core::fmt::Debug for ViewCompilationArg<C> {
423 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
424 match self {
425 ViewCompilationArg::Array { buffer, layout } => f
426 .debug_struct("ArrayView")
427 .field("buffer", &buffer)
428 .field("layout", &layout)
429 .finish(),
430 ViewCompilationArg::TensorMapTiled { buffer, layout } => f
431 .debug_struct("TensorMapTiledView")
432 .field("buffer", &buffer)
433 .field("layout", &layout)
434 .finish(),
435 ViewCompilationArg::TensorMapIm2col { buffer, layout } => f
436 .debug_struct("TensorMapIm2colView")
437 .field("buffer", &buffer)
438 .field("layout", &layout)
439 .finish(),
440 ViewCompilationArg::Quantized {
441 values,
442 scales,
443 scheme,
444 } => f
445 .debug_struct("QuantizedView")
446 .field("values", &values)
447 .field("scales", &scales)
448 .field("scheme", &scheme)
449 .finish(),
450 }
451 }
452 }
453
454 impl<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility> LaunchArg for View<E, C, IO> {
455 type RuntimeArg<'a, R: Runtime> = ViewArg<'a, C, R>;
456 type CompilationArg = ViewCompilationArg<C>;
457
458 fn compilation_arg<'a, R: Runtime>(
459 runtime_arg: &Self::RuntimeArg<'a, R>,
460 ) -> Self::CompilationArg {
461 match runtime_arg {
462 ViewArg::Array(buffer, layout) => {
463 let buffer = Array::<E>::compilation_arg(buffer);
464 let layout = VirtualLayout::<C, Coords1d>::compilation_arg(layout);
465 ViewCompilationArg::Array { buffer, layout }
466 }
467 ViewArg::TensorMapTiled(buffer, layout) => {
468 let buffer = TensorMap::<E, Tiled>::compilation_arg(buffer);
469 let layout = VirtualLayout::<C, Sequence<i32>>::compilation_arg(layout);
470 ViewCompilationArg::TensorMapTiled { buffer, layout }
471 }
472 ViewArg::TensorMapIm2col(buffer, layout) => {
473 let buffer = TensorMap::<E, Im2col>::compilation_arg(buffer);
474 let layout =
475 VirtualLayout::<C, (Sequence<i32>, Sequence<i32>)>::compilation_arg(layout);
476 ViewCompilationArg::TensorMapIm2col { buffer, layout }
477 }
478 ViewArg::Quantized {
479 values,
480 scales,
481 scheme,
482 } => {
483 let values = View::<E, C, IO>::compilation_arg(values);
485 let scales = View::<E, C, IO>::compilation_arg(scales);
486 ViewCompilationArg::Quantized {
487 values: Box::new(values),
488 scales: Box::new(scales),
489 scheme: *scheme,
490 }
491 }
492 }
493 }
494 fn expand(
495 arg: &Self::CompilationArg,
496 builder: &mut KernelBuilder,
497 ) -> <Self as CubeType>::ExpandType {
498 match arg {
499 ViewCompilationArg::Array { buffer, layout } => {
500 let buffer = Array::<E>::expand(buffer, builder);
501 let layout = VirtualLayout::<C, Coords1d>::expand(layout, builder);
502 let view =
503 VirtualViewMutExpand::<E, C, Coords1d, Array<E>>::new(buffer, layout);
504 ViewExpand::<E, C, IO> {
505 inner: ViewType::ReadWrite(Arc::new(view)),
506 _io: PhantomData,
507 }
508 }
509 ViewCompilationArg::TensorMapTiled { buffer, layout } => {
510 let buffer = TensorMap::<E, Tiled>::expand(buffer, builder);
511 let layout = VirtualLayout::<C, Sequence<i32>>::expand(layout, builder);
512 let view =
513 VirtualViewMutExpand::<E, C, Sequence<i32>, TensorMap<E, Tiled>>::new(
514 buffer, layout,
515 );
516 ViewExpand::<E, C, IO> {
517 inner: ViewType::ReadWrite(Arc::new(view)),
518 _io: PhantomData,
519 }
520 }
521 ViewCompilationArg::TensorMapIm2col { buffer, layout } => {
522 let buffer = TensorMap::<E, Im2col>::expand(buffer, builder);
523 let layout =
524 VirtualLayout::<C, (Sequence<i32>, Sequence<i32>)>::expand(layout, builder);
525 let view = VirtualViewExpand::<
526 E,
527 C,
528 (Sequence<i32>, Sequence<i32>),
529 TensorMap<E, Im2col>,
530 >::new(buffer, layout);
531 ViewExpand::<E, C, IO> {
532 inner: ViewType::Read(Arc::new(view)),
533 _io: PhantomData,
534 }
535 }
536 ViewCompilationArg::Quantized {
537 values,
538 scales,
539 scheme,
540 } => quant::view::expand_dynamic(values, scales, *scheme, builder),
541 }
542 }
543 fn expand_output(
544 arg: &Self::CompilationArg,
545 builder: &mut KernelBuilder,
546 ) -> <Self as CubeType>::ExpandType {
547 match arg {
548 ViewCompilationArg::Array { buffer, layout } => {
549 let buffer = Array::<E>::expand_output(buffer, builder);
550 let layout = VirtualLayout::<C, Coords1d>::expand_output(layout, builder);
551 let view =
552 VirtualViewMutExpand::<E, C, Coords1d, Array<E>>::new(buffer, layout);
553 ViewExpand::<E, C, IO> {
554 inner: ViewType::ReadWrite(Arc::new(view)),
555 _io: PhantomData,
556 }
557 }
558 ViewCompilationArg::TensorMapTiled { buffer, layout } => {
559 let buffer = TensorMap::<E, Tiled>::expand_output(buffer, builder);
560 let layout = VirtualLayout::<C, Sequence<i32>>::expand_output(layout, builder);
561 let view =
562 VirtualViewMutExpand::<E, C, Sequence<i32>, TensorMap<E, Tiled>>::new(
563 buffer, layout,
564 );
565 ViewExpand::<E, C, IO> {
566 inner: ViewType::ReadWrite(Arc::new(view)),
567 _io: PhantomData,
568 }
569 }
570 ViewCompilationArg::TensorMapIm2col { .. } => {
571 unimplemented!("Im2col tensor maps can't be used as outputs");
572 }
573 ViewCompilationArg::Quantized { .. } => panic!("Quantized views must be readonly"),
574 }
575 }
576 }
577}
578
579pub use dynamic::*;