cubecl_std/tensor/layout/
linear.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, unexpanded};
3
4use crate::tensor::{
5    View, is_contiguous, is_contiguous_pitched,
6    launch::ViewArg,
7    layout::{
8        Coords1d, Layout, LayoutExpand, VirtualLayoutOperationsExpand,
9        permuted::{PermutedLayout, PermutedLayoutLaunch},
10        plain::{PlainLayout, PlainLayoutLaunch},
11        strided::{StridedLayout, StridedLayoutLaunch},
12    },
13};
14
15/// Maps a linear index based on line count to a potentially strided tensor. Only applies the
16/// necessary level of striding, either none, only the last dim (for freshly allocated strided
17/// tensors), or all dimensions.
18///
19/// Treats indices as the line index, with the shape being adjusted for line size.
20///
21/// `Layout` version of [index_offset_contiguous]
22#[derive(CubeType, CubeLaunch, Clone)]
23pub enum LinearLayout {
24    /// Input is contiguous, no mapping
25    Plain(PlainLayout),
26    /// Strided tensor, i.e. freshly allocated but not permuted
27    Strided(StridedLayout),
28    /// Permuted layout, tracks the entire shape/strides and not just the last dim
29    Permuted(PermutedLayout),
30}
31
32impl LinearLayout {
33    fn inner(&self) -> &PlainLayout {
34        unexpanded!()
35    }
36}
37
38impl LinearLayoutExpand {
39    fn __expand_inner_method(
40        &self,
41        _scope: &mut Scope,
42    ) -> &dyn VirtualLayoutOperationsExpand<Coords1d, Coords1d> {
43        match self {
44            LinearLayoutExpand::Plain(layout) => layout,
45            LinearLayoutExpand::Strided(layout) => layout,
46            LinearLayoutExpand::Permuted(layout) => layout,
47        }
48    }
49}
50
51impl<'a, R: Runtime> LinearLayoutArgs<'a, R> {
52    /// Construct a linear layout from shapes, strides and line size of the tensor
53    pub fn from_shape_strides(
54        client: &ComputeClient<R::Server>,
55        shape: &[usize],
56        strides: &[usize],
57        line_size: u8,
58    ) -> Self {
59        if is_contiguous(shape, strides) {
60            Self::Plain(PlainLayoutLaunch::from_shape(shape, line_size))
61        } else if is_contiguous_pitched(shape, strides) {
62            Self::Strided(StridedLayoutLaunch::from_shape_strides(
63                client, shape, strides, line_size,
64            ))
65        } else {
66            Self::Permuted(PermutedLayoutLaunch::from_shape_strides(
67                client, shape, strides, line_size,
68            ))
69        }
70    }
71
72    /// Construct a possibly broadcast linear layout from shapes/strides and a reference shape
73    pub fn from_shape_strides_with_reference(
74        client: &ComputeClient<R::Server>,
75        shape: &[usize],
76        reference_shape: &[usize],
77        strides: &[usize],
78        line_size: u8,
79    ) -> Self {
80        if shape != reference_shape {
81            // Broadcast layouts are always treated as permuted
82            Self::Permuted(PermutedLayoutLaunch::from_shapes_strides_ref(
83                client,
84                shape,
85                reference_shape,
86                strides,
87                line_size,
88            ))
89        } else {
90            Self::from_shape_strides(client, shape, strides, line_size)
91        }
92    }
93
94    /// Construct a linear layout from a tensor handle
95    pub fn from_handle(
96        client: &ComputeClient<R::Server>,
97        handle: &TensorHandleRef<'a, R>,
98        line_size: u8,
99    ) -> Self {
100        Self::from_shape_strides(client, handle.shape, handle.strides, line_size)
101    }
102
103    /// Construct a possibly broadcast linear layout from a tensor handle and reference handle
104    pub fn from_handle_with_reference(
105        client: &ComputeClient<R::Server>,
106        handle: &TensorHandleRef<'a, R>,
107        reference: &TensorHandleRef<'a, R>,
108        line_size: u8,
109    ) -> Self {
110        Self::from_shape_strides_with_reference(
111            client,
112            handle.shape,
113            reference.shape,
114            handle.strides,
115            line_size,
116        )
117    }
118}
119
120#[cube]
121impl Layout for LinearLayout {
122    type Coordinates = Coords1d;
123    type SourceCoordinates = Coords1d;
124
125    fn to_source_pos(&self, pos: Self::Coordinates) -> u32 {
126        self.inner().to_source_pos(pos)
127    }
128
129    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (u32, bool) {
130        (self.to_source_pos(pos), self.is_in_bounds(pos))
131    }
132
133    fn shape(&self) -> Self::Coordinates {
134        self.inner().shape()
135    }
136
137    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
138        self.inner().is_in_bounds(pos)
139    }
140}
141
142/// [TensorView] with a linear layout inferred from the shape/strides at launch.
143/// Useful for elementwise kernels.
144pub type LinearView<E, IO = ReadOnly> = View<E, Coords1d, IO>;
145/// Launch type for [LinearTensorView].
146pub type LinearViewLaunch<'a, R> = ViewArg<'a, Coords1d, R>;
147
148/// Create a linear tensor view from a handle and line size
149pub fn linear_view<'a, R: Runtime>(
150    client: &ComputeClient<R::Server>,
151    handle: &'a TensorHandleRef<'a, R>,
152    line_size: u8,
153) -> LinearViewLaunch<'a, R> {
154    let len = handle.shape.iter().product::<usize>();
155    let layout = LinearLayoutArgs::from_handle(client, handle, line_size);
156    let buffer = unsafe {
157        ArrayArg::from_raw_parts_and_size(handle.handle, len, line_size, handle.elem_size)
158    };
159    LinearViewLaunch::new::<LinearLayout>(buffer, layout)
160}
161
162/// Create a possibly broadcast linear tensor view from a handle, reference handle and line size
163pub fn linear_view_with_reference<'a, R: Runtime>(
164    client: &ComputeClient<R::Server>,
165    handle: &'a TensorHandleRef<'a, R>,
166    reference: &'a TensorHandleRef<'a, R>,
167    line_size: u8,
168) -> LinearViewLaunch<'a, R> {
169    let len = handle.shape.iter().product::<usize>();
170    let layout = LinearLayoutArgs::from_handle_with_reference(client, handle, reference, line_size);
171    let buffer = unsafe {
172        ArrayArg::from_raw_parts_and_size(handle.handle, len, line_size, handle.elem_size)
173    };
174    LinearViewLaunch::new::<LinearLayout>(buffer, layout)
175}
176
177pub fn linear_view_alias<'a, R: Runtime>(
178    client: &ComputeClient<R::Server>,
179    handle: &'a TensorHandleRef<'a, R>,
180    line_size: u8,
181    pos: usize,
182) -> LinearViewLaunch<'a, R> {
183    let layout = LinearLayoutArgs::from_handle(client, handle, line_size);
184    let buffer = ArrayArg::Alias { input_pos: pos };
185    LinearViewLaunch::new::<LinearLayout>(buffer, layout)
186}