cubecl_std/tensor/layout/
linear.rs

1use alloc::rc::Rc;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, unexpanded};
5
6use crate::tensor::{
7    View, is_contiguous, is_contiguous_pitched,
8    launch::ViewArg,
9    layout::{
10        Coords1d, Layout, LayoutExpand, VirtualLayoutOperationsExpand,
11        permuted::{PermutedLayout, PermutedLayoutLaunch},
12        plain::{PlainLayout, PlainLayoutLaunch},
13        strided::{StridedLayout, StridedLayoutLaunch},
14    },
15};
16
17/// Maps a linear index based on line count to a potentially strided tensor. Only applies the
18/// necessary level of striding, either none, only the last dim (for freshly allocated strided
19/// tensors), or all dimensions.
20///
21/// Treats indices as the line index, with the shape being adjusted for line size.
22///
23/// `Layout` version of [index_offset_contiguous]
24#[derive(CubeType, CubeLaunch, Clone)]
25pub enum LinearLayout {
26    /// Input is contiguous, no mapping
27    Plain(PlainLayout),
28    /// Strided tensor, i.e. freshly allocated but not permuted
29    Strided(StridedLayout),
30    /// Permuted layout, tracks the entire shape/strides and not just the last dim
31    Permuted(PermutedLayout),
32}
33
34impl LinearLayout {
35    fn inner(&self) -> &PlainLayout {
36        unexpanded!()
37    }
38}
39
40impl LinearLayoutExpand {
41    fn __expand_inner_method(
42        self,
43        _scope: &mut Scope,
44    ) -> Rc<dyn VirtualLayoutOperationsExpand<Coords1d, Coords1d>> {
45        match self {
46            LinearLayoutExpand::Plain(layout) => Rc::new(layout),
47            LinearLayoutExpand::Strided(layout) => Rc::new(layout),
48            LinearLayoutExpand::Permuted(layout) => Rc::new(layout),
49        }
50    }
51}
52
53impl<'a, R: Runtime> LinearLayoutArgs<'a, R> {
54    /// Construct a linear layout from shapes, strides and line size of the tensor
55    pub fn from_shape_strides(
56        client: &ComputeClient<R>,
57        shape: &[usize],
58        strides: &[usize],
59        line_size: LineSize,
60    ) -> Self {
61        if is_contiguous(shape, strides) {
62            Self::Plain(PlainLayoutLaunch::from_shape(shape, line_size))
63        } else if is_contiguous_pitched(shape, strides) {
64            Self::Strided(StridedLayoutLaunch::from_shape_strides(
65                client, shape, strides, line_size,
66            ))
67        } else {
68            Self::Permuted(PermutedLayoutLaunch::from_shape_strides(
69                client, shape, strides, line_size,
70            ))
71        }
72    }
73
74    /// Construct a possibly broadcast linear layout from shapes/strides and a reference shape
75    pub fn from_shape_strides_with_reference(
76        client: &ComputeClient<R>,
77        shape: &[usize],
78        reference_shape: &[usize],
79        strides: &[usize],
80        line_size: LineSize,
81    ) -> Self {
82        if shape != reference_shape {
83            // Broadcast layouts are always treated as permuted
84            Self::Permuted(PermutedLayoutLaunch::from_shapes_strides_ref(
85                client,
86                shape,
87                reference_shape,
88                strides,
89                line_size,
90            ))
91        } else {
92            Self::from_shape_strides(client, shape, strides, line_size)
93        }
94    }
95
96    /// Construct a linear layout from a tensor handle
97    pub fn from_handle(
98        client: &ComputeClient<R>,
99        handle: &TensorHandleRef<'a, R>,
100        line_size: LineSize,
101    ) -> Self {
102        Self::from_shape_strides(client, handle.shape, handle.strides, line_size)
103    }
104
105    /// Construct a possibly broadcast linear layout from a tensor handle and reference handle
106    pub fn from_handle_with_reference(
107        client: &ComputeClient<R>,
108        handle: &TensorHandleRef<'a, R>,
109        reference: &TensorHandleRef<'a, R>,
110        line_size: LineSize,
111    ) -> Self {
112        Self::from_shape_strides_with_reference(
113            client,
114            handle.shape,
115            reference.shape,
116            handle.strides,
117            line_size,
118        )
119    }
120}
121
122#[cube]
123impl Layout for LinearLayout {
124    type Coordinates = Coords1d;
125    type SourceCoordinates = Coords1d;
126
127    fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
128        self.inner().to_source_pos(pos)
129    }
130
131    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
132        (self.to_source_pos(pos), self.is_in_bounds(pos))
133    }
134
135    fn shape(&self) -> Self::Coordinates {
136        self.inner().shape()
137    }
138
139    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
140        self.inner().is_in_bounds(pos)
141    }
142}
143
144/// [TensorView] with a linear layout inferred from the shape/strides at launch.
145/// Useful for elementwise kernels.
146pub type LinearView<E, IO = ReadOnly> = View<E, Coords1d, IO>;
147/// Launch type for [LinearTensorView].
148pub type LinearViewLaunch<'a, R> = ViewArg<'a, Coords1d, R>;
149
150/// Create a linear tensor view from a handle and line size
151pub fn linear_view<'a, R: Runtime>(
152    client: &ComputeClient<R>,
153    handle: &'a TensorHandleRef<'a, R>,
154    line_size: LineSize,
155) -> LinearViewLaunch<'a, R> {
156    let len = handle.shape.iter().product::<usize>();
157    let layout = LinearLayoutArgs::from_handle(client, handle, line_size);
158    let buffer = unsafe {
159        ArrayArg::from_raw_parts_and_size(handle.handle, len, line_size, handle.elem_size)
160    };
161    LinearViewLaunch::new::<LinearLayout>(buffer, layout)
162}
163
164/// Create a possibly broadcast linear tensor view from a handle, reference handle and line size
165pub fn linear_view_with_reference<'a, R: Runtime>(
166    client: &ComputeClient<R>,
167    handle: &'a TensorHandleRef<'a, R>,
168    reference: &'a TensorHandleRef<'a, R>,
169    line_size: LineSize,
170) -> LinearViewLaunch<'a, R> {
171    let len = handle.shape.iter().product::<usize>();
172    let layout = LinearLayoutArgs::from_handle_with_reference(client, handle, reference, line_size);
173    let buffer = unsafe {
174        ArrayArg::from_raw_parts_and_size(handle.handle, len, line_size, handle.elem_size)
175    };
176    LinearViewLaunch::new::<LinearLayout>(buffer, layout)
177}
178
179pub fn linear_view_alias<'a, R: Runtime>(
180    client: &ComputeClient<R>,
181    handle: &'a TensorHandleRef<'a, R>,
182    line_size: LineSize,
183    pos: usize,
184) -> LinearViewLaunch<'a, R> {
185    let layout = LinearLayoutArgs::from_handle(client, handle, line_size);
186    let buffer = ArrayArg::Alias { input_pos: pos };
187    LinearViewLaunch::new::<LinearLayout>(buffer, layout)
188}