cubecl_std/tensor/layout/
linear.rs1use alloc::rc::Rc;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, ir::UIntKind, unexpanded, zspace::Shape};
5
6use crate::tensor::{
7 View, is_contiguous, is_contiguous_pitched,
8 launch::{BufferArg, ConcreteLayout, ConcreteLayoutLaunch, ViewArg, ViewLayoutLaunchArg},
9 layout::{
10 Coords1d, Layout, LayoutExpand, VirtualLayoutOperationsExpand,
11 permuted::{PermutedLayout, PermutedLayoutCompilationArg, PermutedLayoutLaunch},
12 plain::PlainLayout,
13 strided::{StridedLayout, StridedLayoutCompilationArg},
14 },
15};
16
17#[derive(CubeType, Clone)]
25pub enum LinearViewLayout {
26 Plain(PlainLayout),
28 Strided(StridedLayout),
30 Permuted(PermutedLayout),
32}
33
34impl LinearViewLayout {
35 fn inner(&self) -> &PlainLayout {
36 unexpanded!()
37 }
38}
39
40impl LinearViewLayoutExpand {
41 fn __expand_inner_method(
42 self,
43 _scope: &mut Scope,
44 ) -> Rc<dyn VirtualLayoutOperationsExpand<Coords1d, Coords1d>> {
45 match self {
46 LinearViewLayoutExpand::Plain(layout) => Rc::new(layout),
47 LinearViewLayoutExpand::Strided(layout) => Rc::new(layout),
48 LinearViewLayoutExpand::Permuted(layout) => Rc::new(layout),
49 }
50 }
51}
52
53#[derive(Default)]
54pub struct LinearViewLayoutLaunch {
55 reference_shape: Option<Shape>,
56}
57
58impl ViewLayoutLaunchArg for LinearViewLayout {
59 type RuntimeArg<R: Runtime> = LinearViewLayoutLaunch;
60 type CompilationArg = LinearLayoutCompilationArg;
61
62 fn register<R: Runtime, B: BufferArg>(
63 runtime_arg: Self::RuntimeArg<R>,
64 buffer: &B,
65 ty: Type,
66 launcher: &mut KernelLauncher<R>,
67 ) -> Self::CompilationArg {
68 let shape = buffer.shape();
69 match runtime_arg.reference_shape {
70 Some(reference_shape) if reference_shape.as_slice() != shape => {
71 let arg = PermutedLayoutLaunch::from_reference_shape(reference_shape);
72 let comp_arg = PermutedLayout::register(arg, buffer, ty, launcher);
73 LinearLayoutCompilationArg::Permuted(comp_arg)
74 }
75 _ => {
76 let strides = buffer.strides();
77 if is_contiguous(shape, strides) {
78 PlainLayout::register((), buffer, ty, launcher);
79 LinearLayoutCompilationArg::Plain
80 } else if is_contiguous_pitched(shape, strides) {
81 let comp_arg = StridedLayout::register((), buffer, ty, launcher);
82 LinearLayoutCompilationArg::Strided(comp_arg)
83 } else {
84 let comp_arg =
85 PermutedLayout::register(Default::default(), buffer, ty, launcher);
86 LinearLayoutCompilationArg::Permuted(comp_arg)
87 }
88 }
89 }
90 }
91 fn expand(
92 compilation_arg: &Self::CompilationArg,
93 ty: Type,
94 builder: &mut cubecl::prelude::KernelBuilder,
95 ) -> <Self as cubecl::prelude::CubeType>::ExpandType {
96 match compilation_arg {
97 LinearLayoutCompilationArg::Plain => {
98 LinearViewLayoutExpand::Plain(PlainLayout::expand(&(), ty, builder))
99 }
100 LinearLayoutCompilationArg::Strided(arg) => {
101 LinearViewLayoutExpand::Strided(StridedLayout::expand(arg, ty, builder))
102 }
103 LinearLayoutCompilationArg::Permuted(arg) => {
104 LinearViewLayoutExpand::Permuted(PermutedLayout::expand(arg, ty, builder))
105 }
106 }
107 }
108}
109
110#[derive(Debug, Hash, PartialEq, Eq, Clone)]
111pub enum LinearLayoutCompilationArg {
112 Plain,
113 Strided(StridedLayoutCompilationArg),
114 Permuted(PermutedLayoutCompilationArg),
115}
116
117impl LinearViewLayoutLaunch {
118 pub fn new() -> Self {
120 Self::default()
121 }
122
123 pub fn from_reference_shape(reference_shape: Shape) -> Self {
125 Self {
126 reference_shape: Some(reference_shape),
127 }
128 }
129
130 pub fn from_reference_handle<R: Runtime>(reference: TensorBinding<R>) -> Self {
132 Self::from_reference_shape(reference.shape)
133 }
134}
135
136#[cube]
137impl Layout for LinearViewLayout {
138 type Coordinates = Coords1d;
139 type SourceCoordinates = Coords1d;
140
141 fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
142 self.inner().to_source_pos(pos)
143 }
144
145 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
146 (self.to_source_pos(pos), self.is_in_bounds(pos))
147 }
148
149 fn shape(&self) -> Self::Coordinates {
150 self.inner().shape()
151 }
152
153 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
154 self.inner().is_in_bounds(pos)
155 }
156}
157
158pub type LinearLayout = ConcreteLayout<LinearViewLayout>;
160pub type LinearLayoutLaunch<R> = ConcreteLayoutLaunch<LinearViewLayout, R>;
161
162pub type LinearView<E, IO = ReadOnly> = View<E, Coords1d, IO>;
165pub type LinearViewLaunch<R> = ViewArg<Coords1d, R>;
167
168pub fn linear_layout<R: Runtime>(
170 handle: &TensorBinding<R>,
171 vector_size: VectorSize,
172) -> LinearLayoutLaunch<R> {
173 LinearLayoutLaunch::from_handle(
174 handle,
175 Type::new(UIntKind::U32.into()).with_vector_size(vector_size),
177 LinearViewLayoutLaunch::new(),
178 )
179}
180
181pub fn linear_view<R: Runtime>(handle: TensorBinding<R>) -> LinearViewLaunch<R> {
183 let layout = LinearViewLayoutLaunch::new();
184 LinearViewLaunch::new_tensor::<LinearViewLayout>(handle.into_tensor_arg(), layout)
185}
186
187pub fn linear_view_with_reference<R: Runtime>(
189 handle: TensorBinding<R>,
190 reference: TensorBinding<R>,
191) -> LinearViewLaunch<R> {
192 let layout = LinearViewLayoutLaunch::from_reference_handle(reference);
193 LinearViewLaunch::new_tensor::<LinearViewLayout>(handle.into_tensor_arg(), layout)
194}
195
196pub fn linear_view_alias<R: Runtime>(handle: &TensorBinding<R>, pos: usize) -> LinearViewLaunch<R> {
197 let layout = LinearViewLayoutLaunch::new();
198 LinearViewLaunch::new_tensor::<LinearViewLayout>(handle.as_alias(pos), layout)
199}