cubecl_std/tensor/layout/
linear.rs1use alloc::rc::Rc;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, unexpanded};
5
6use crate::tensor::{
7 View, is_contiguous, is_contiguous_pitched,
8 launch::ViewArg,
9 layout::{
10 Coords1d, Layout, LayoutExpand, VirtualLayoutOperationsExpand,
11 permuted::{PermutedLayout, PermutedLayoutLaunch},
12 plain::{PlainLayout, PlainLayoutLaunch},
13 strided::{StridedLayout, StridedLayoutLaunch},
14 },
15};
16
17#[derive(CubeType, CubeLaunch, Clone)]
25pub enum LinearLayout {
26 Plain(PlainLayout),
28 Strided(StridedLayout),
30 Permuted(PermutedLayout),
32}
33
34impl LinearLayout {
35 fn inner(&self) -> &PlainLayout {
36 unexpanded!()
37 }
38}
39
40impl LinearLayoutExpand {
41 fn __expand_inner_method(
42 self,
43 _scope: &mut Scope,
44 ) -> Rc<dyn VirtualLayoutOperationsExpand<Coords1d, Coords1d>> {
45 match self {
46 LinearLayoutExpand::Plain(layout) => Rc::new(layout),
47 LinearLayoutExpand::Strided(layout) => Rc::new(layout),
48 LinearLayoutExpand::Permuted(layout) => Rc::new(layout),
49 }
50 }
51}
52
53impl<'a, R: Runtime> LinearLayoutArgs<'a, R> {
54 pub fn from_shape_strides(
56 client: &ComputeClient<R>,
57 shape: &[usize],
58 strides: &[usize],
59 line_size: LineSize,
60 ) -> Self {
61 if is_contiguous(shape, strides) {
62 Self::Plain(PlainLayoutLaunch::from_shape(shape, line_size))
63 } else if is_contiguous_pitched(shape, strides) {
64 Self::Strided(StridedLayoutLaunch::from_shape_strides(
65 client, shape, strides, line_size,
66 ))
67 } else {
68 Self::Permuted(PermutedLayoutLaunch::from_shape_strides(
69 client, shape, strides, line_size,
70 ))
71 }
72 }
73
74 pub fn from_shape_strides_with_reference(
76 client: &ComputeClient<R>,
77 shape: &[usize],
78 reference_shape: &[usize],
79 strides: &[usize],
80 line_size: LineSize,
81 ) -> Self {
82 if shape != reference_shape {
83 Self::Permuted(PermutedLayoutLaunch::from_shapes_strides_ref(
85 client,
86 shape,
87 reference_shape,
88 strides,
89 line_size,
90 ))
91 } else {
92 Self::from_shape_strides(client, shape, strides, line_size)
93 }
94 }
95
96 pub fn from_handle(
98 client: &ComputeClient<R>,
99 handle: &TensorHandleRef<'a, R>,
100 line_size: LineSize,
101 ) -> Self {
102 Self::from_shape_strides(client, handle.shape, handle.strides, line_size)
103 }
104
105 pub fn from_handle_with_reference(
107 client: &ComputeClient<R>,
108 handle: &TensorHandleRef<'a, R>,
109 reference: &TensorHandleRef<'a, R>,
110 line_size: LineSize,
111 ) -> Self {
112 Self::from_shape_strides_with_reference(
113 client,
114 handle.shape,
115 reference.shape,
116 handle.strides,
117 line_size,
118 )
119 }
120}
121
122#[cube]
123impl Layout for LinearLayout {
124 type Coordinates = Coords1d;
125 type SourceCoordinates = Coords1d;
126
127 fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
128 self.inner().to_source_pos(pos)
129 }
130
131 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
132 (self.to_source_pos(pos), self.is_in_bounds(pos))
133 }
134
135 fn shape(&self) -> Self::Coordinates {
136 self.inner().shape()
137 }
138
139 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
140 self.inner().is_in_bounds(pos)
141 }
142}
143
144pub type LinearView<E, IO = ReadOnly> = View<E, Coords1d, IO>;
147pub type LinearViewLaunch<'a, R> = ViewArg<'a, Coords1d, R>;
149
150pub fn linear_view<'a, R: Runtime>(
152 client: &ComputeClient<R>,
153 handle: &'a TensorHandleRef<'a, R>,
154 line_size: LineSize,
155) -> LinearViewLaunch<'a, R> {
156 let len = handle.shape.iter().product::<usize>();
157 let layout = LinearLayoutArgs::from_handle(client, handle, line_size);
158 let buffer = unsafe {
159 ArrayArg::from_raw_parts_and_size(handle.handle, len, line_size, handle.elem_size)
160 };
161 LinearViewLaunch::new::<LinearLayout>(buffer, layout)
162}
163
164pub fn linear_view_with_reference<'a, R: Runtime>(
166 client: &ComputeClient<R>,
167 handle: &'a TensorHandleRef<'a, R>,
168 reference: &'a TensorHandleRef<'a, R>,
169 line_size: LineSize,
170) -> LinearViewLaunch<'a, R> {
171 let len = handle.shape.iter().product::<usize>();
172 let layout = LinearLayoutArgs::from_handle_with_reference(client, handle, reference, line_size);
173 let buffer = unsafe {
174 ArrayArg::from_raw_parts_and_size(handle.handle, len, line_size, handle.elem_size)
175 };
176 LinearViewLaunch::new::<LinearLayout>(buffer, layout)
177}
178
179pub fn linear_view_alias<'a, R: Runtime>(
180 client: &ComputeClient<R>,
181 handle: &'a TensorHandleRef<'a, R>,
182 line_size: LineSize,
183 pos: usize,
184) -> LinearViewLaunch<'a, R> {
185 let layout = LinearLayoutArgs::from_handle(client, handle, line_size);
186 let buffer = ArrayArg::Alias { input_pos: pos };
187 LinearViewLaunch::new::<LinearLayout>(buffer, layout)
188}