cubecl_std/tensor/layout/
linear.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, unexpanded};
3
4use crate::tensor::{
5 View, is_contiguous, is_contiguous_pitched,
6 launch::ViewArg,
7 layout::{
8 Coords1d, Layout, LayoutExpand, VirtualLayoutOperationsExpand,
9 permuted::{PermutedLayout, PermutedLayoutLaunch},
10 plain::{PlainLayout, PlainLayoutLaunch},
11 strided::{StridedLayout, StridedLayoutLaunch},
12 },
13};
14
15#[derive(CubeType, CubeLaunch, Clone)]
23pub enum LinearLayout {
24 Plain(PlainLayout),
26 Strided(StridedLayout),
28 Permuted(PermutedLayout),
30}
31
32impl LinearLayout {
33 fn inner(&self) -> &PlainLayout {
34 unexpanded!()
35 }
36}
37
38impl LinearLayoutExpand {
39 fn __expand_inner_method(
40 &self,
41 _scope: &mut Scope,
42 ) -> &dyn VirtualLayoutOperationsExpand<Coords1d, Coords1d> {
43 match self {
44 LinearLayoutExpand::Plain(layout) => layout,
45 LinearLayoutExpand::Strided(layout) => layout,
46 LinearLayoutExpand::Permuted(layout) => layout,
47 }
48 }
49}
50
51impl<'a, R: Runtime> LinearLayoutArgs<'a, R> {
52 pub fn from_shape_strides(
54 client: &ComputeClient<R::Server>,
55 shape: &[usize],
56 strides: &[usize],
57 line_size: u8,
58 ) -> Self {
59 if is_contiguous(shape, strides) {
60 Self::Plain(PlainLayoutLaunch::from_shape(shape, line_size))
61 } else if is_contiguous_pitched(shape, strides) {
62 Self::Strided(StridedLayoutLaunch::from_shape_strides(
63 client, shape, strides, line_size,
64 ))
65 } else {
66 Self::Permuted(PermutedLayoutLaunch::from_shape_strides(
67 client, shape, strides, line_size,
68 ))
69 }
70 }
71
72 pub fn from_shape_strides_with_reference(
74 client: &ComputeClient<R::Server>,
75 shape: &[usize],
76 reference_shape: &[usize],
77 strides: &[usize],
78 line_size: u8,
79 ) -> Self {
80 if shape != reference_shape {
81 Self::Permuted(PermutedLayoutLaunch::from_shapes_strides_ref(
83 client,
84 shape,
85 reference_shape,
86 strides,
87 line_size,
88 ))
89 } else {
90 Self::from_shape_strides(client, shape, strides, line_size)
91 }
92 }
93
94 pub fn from_handle(
96 client: &ComputeClient<R::Server>,
97 handle: &TensorHandleRef<'a, R>,
98 line_size: u8,
99 ) -> Self {
100 Self::from_shape_strides(client, handle.shape, handle.strides, line_size)
101 }
102
103 pub fn from_handle_with_reference(
105 client: &ComputeClient<R::Server>,
106 handle: &TensorHandleRef<'a, R>,
107 reference: &TensorHandleRef<'a, R>,
108 line_size: u8,
109 ) -> Self {
110 Self::from_shape_strides_with_reference(
111 client,
112 handle.shape,
113 reference.shape,
114 handle.strides,
115 line_size,
116 )
117 }
118}
119
120#[cube]
121impl Layout for LinearLayout {
122 type Coordinates = Coords1d;
123 type SourceCoordinates = Coords1d;
124
125 fn to_source_pos(&self, pos: Self::Coordinates) -> u32 {
126 self.inner().to_source_pos(pos)
127 }
128
129 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (u32, bool) {
130 (self.to_source_pos(pos), self.is_in_bounds(pos))
131 }
132
133 fn shape(&self) -> Self::Coordinates {
134 self.inner().shape()
135 }
136
137 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
138 self.inner().is_in_bounds(pos)
139 }
140}
141
142pub type LinearView<E, IO = ReadOnly> = View<E, Coords1d, IO>;
145pub type LinearViewLaunch<'a, R> = ViewArg<'a, Coords1d, R>;
147
148pub fn linear_view<'a, R: Runtime>(
150 client: &ComputeClient<R::Server>,
151 handle: &'a TensorHandleRef<'a, R>,
152 line_size: u8,
153) -> LinearViewLaunch<'a, R> {
154 let len = handle.shape.iter().product::<usize>();
155 let layout = LinearLayoutArgs::from_handle(client, handle, line_size);
156 let buffer = unsafe {
157 ArrayArg::from_raw_parts_and_size(handle.handle, len, line_size, handle.elem_size)
158 };
159 LinearViewLaunch::new::<LinearLayout>(buffer, layout)
160}
161
162pub fn linear_view_with_reference<'a, R: Runtime>(
164 client: &ComputeClient<R::Server>,
165 handle: &'a TensorHandleRef<'a, R>,
166 reference: &'a TensorHandleRef<'a, R>,
167 line_size: u8,
168) -> LinearViewLaunch<'a, R> {
169 let len = handle.shape.iter().product::<usize>();
170 let layout = LinearLayoutArgs::from_handle_with_reference(client, handle, reference, line_size);
171 let buffer = unsafe {
172 ArrayArg::from_raw_parts_and_size(handle.handle, len, line_size, handle.elem_size)
173 };
174 LinearViewLaunch::new::<LinearLayout>(buffer, layout)
175}
176
177pub fn linear_view_alias<'a, R: Runtime>(
178 client: &ComputeClient<R::Server>,
179 handle: &'a TensorHandleRef<'a, R>,
180 line_size: u8,
181 pos: usize,
182) -> LinearViewLaunch<'a, R> {
183 let layout = LinearLayoutArgs::from_handle(client, handle, line_size);
184 let buffer = ArrayArg::Alias { input_pos: pos };
185 LinearViewLaunch::new::<LinearLayout>(buffer, layout)
186}