Skip to main content

cubecl_std/tensor/layout/
linear.rs

1use alloc::rc::Rc;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, ir::UIntKind, unexpanded, zspace::Shape};
5
6use crate::tensor::{
7    View, is_contiguous, is_contiguous_pitched,
8    launch::{BufferArg, ConcreteLayout, ConcreteLayoutLaunch, ViewArg, ViewLayoutLaunchArg},
9    layout::{
10        Coords1d, Layout, LayoutExpand, VirtualLayoutOperationsExpand,
11        permuted::{PermutedLayout, PermutedLayoutCompilationArg, PermutedLayoutLaunch},
12        plain::PlainLayout,
13        strided::{StridedLayout, StridedLayoutCompilationArg},
14    },
15};
16
17/// Maps a linear index based on vector count to a potentially strided tensor. Only applies the
18/// necessary level of striding, either none, only the last dim (for freshly allocated strided
19/// tensors), or all dimensions.
20///
21/// Treats indices as the vector index, with the shape being adjusted for vector size.
22///
23/// `Layout` version of [`crate::tensor::contiguous::index_offset_contiguous()`]
24#[derive(CubeType, Clone)]
25pub enum LinearViewLayout {
26    /// Input is contiguous, no mapping
27    Plain(PlainLayout),
28    /// Strided tensor, i.e. freshly allocated but not permuted
29    Strided(StridedLayout),
30    /// Permuted layout, tracks the entire shape/strides and not just the last dim
31    Permuted(PermutedLayout),
32}
33
34impl LinearViewLayout {
35    fn inner(&self) -> &PlainLayout {
36        unexpanded!()
37    }
38}
39
40impl LinearViewLayoutExpand {
41    fn __expand_inner_method(
42        self,
43        _scope: &mut Scope,
44    ) -> Rc<dyn VirtualLayoutOperationsExpand<Coords1d, Coords1d>> {
45        match self {
46            LinearViewLayoutExpand::Plain(layout) => Rc::new(layout),
47            LinearViewLayoutExpand::Strided(layout) => Rc::new(layout),
48            LinearViewLayoutExpand::Permuted(layout) => Rc::new(layout),
49        }
50    }
51}
52
53#[derive(Default)]
54pub struct LinearViewLayoutLaunch {
55    reference_shape: Option<Shape>,
56}
57
58impl ViewLayoutLaunchArg for LinearViewLayout {
59    type RuntimeArg<R: Runtime> = LinearViewLayoutLaunch;
60    type CompilationArg = LinearLayoutCompilationArg;
61
62    fn register<R: Runtime, B: BufferArg>(
63        runtime_arg: Self::RuntimeArg<R>,
64        buffer: &B,
65        ty: Type,
66        launcher: &mut KernelLauncher<R>,
67    ) -> Self::CompilationArg {
68        let shape = buffer.shape();
69        match runtime_arg.reference_shape {
70            Some(reference_shape) if reference_shape.as_slice() != shape => {
71                let arg = PermutedLayoutLaunch::from_reference_shape(reference_shape);
72                let comp_arg = PermutedLayout::register(arg, buffer, ty, launcher);
73                LinearLayoutCompilationArg::Permuted(comp_arg)
74            }
75            _ => {
76                let strides = buffer.strides();
77                if is_contiguous(shape, strides) {
78                    PlainLayout::register((), buffer, ty, launcher);
79                    LinearLayoutCompilationArg::Plain
80                } else if is_contiguous_pitched(shape, strides) {
81                    let comp_arg = StridedLayout::register((), buffer, ty, launcher);
82                    LinearLayoutCompilationArg::Strided(comp_arg)
83                } else {
84                    let comp_arg =
85                        PermutedLayout::register(Default::default(), buffer, ty, launcher);
86                    LinearLayoutCompilationArg::Permuted(comp_arg)
87                }
88            }
89        }
90    }
91    fn expand(
92        compilation_arg: &Self::CompilationArg,
93        ty: Type,
94        builder: &mut cubecl::prelude::KernelBuilder,
95    ) -> <Self as cubecl::prelude::CubeType>::ExpandType {
96        match compilation_arg {
97            LinearLayoutCompilationArg::Plain => {
98                LinearViewLayoutExpand::Plain(PlainLayout::expand(&(), ty, builder))
99            }
100            LinearLayoutCompilationArg::Strided(arg) => {
101                LinearViewLayoutExpand::Strided(StridedLayout::expand(arg, ty, builder))
102            }
103            LinearLayoutCompilationArg::Permuted(arg) => {
104                LinearViewLayoutExpand::Permuted(PermutedLayout::expand(arg, ty, builder))
105            }
106        }
107    }
108}
109
110#[derive(Debug, Hash, PartialEq, Eq, Clone)]
111pub enum LinearLayoutCompilationArg {
112    Plain,
113    Strided(StridedLayoutCompilationArg),
114    Permuted(PermutedLayoutCompilationArg),
115}
116
117impl LinearViewLayoutLaunch {
118    /// Construct a linear layout from shapes, strides and vector size of the tensor
119    pub fn new() -> Self {
120        Self::default()
121    }
122
123    /// Construct a possibly broadcast linear layout from shapes/strides and a reference shape
124    pub fn from_reference_shape(reference_shape: Shape) -> Self {
125        Self {
126            reference_shape: Some(reference_shape),
127        }
128    }
129
130    /// Construct a possibly broadcast linear layout from a tensor handle and reference handle
131    pub fn from_reference_handle<R: Runtime>(reference: TensorBinding<R>) -> Self {
132        Self::from_reference_shape(reference.shape)
133    }
134}
135
136#[cube]
137impl Layout for LinearViewLayout {
138    type Coordinates = Coords1d;
139    type SourceCoordinates = Coords1d;
140
141    fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
142        self.inner().to_source_pos(pos)
143    }
144
145    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
146        (self.to_source_pos(pos), self.is_in_bounds(pos))
147    }
148
149    fn shape(&self) -> Self::Coordinates {
150        self.inner().shape()
151    }
152
153    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
154        self.inner().is_in_bounds(pos)
155    }
156}
157
158/// Concrete version of the layout, so it can be launched on its own
159pub type LinearLayout = ConcreteLayout<LinearViewLayout>;
160pub type LinearLayoutLaunch<R> = ConcreteLayoutLaunch<LinearViewLayout, R>;
161
162/// [`View`] with a linear layout inferred from the shape/strides at launch.
163/// Useful for elementwise kernels.
164pub type LinearView<E, IO = ReadOnly> = View<E, Coords1d, IO>;
165/// Launch type for [`LinearView`].
166pub type LinearViewLaunch<R> = ViewArg<Coords1d, R>;
167
168/// Create a linear layout from a handle and vector size
169pub fn linear_layout<R: Runtime>(
170    handle: &TensorBinding<R>,
171    vector_size: VectorSize,
172) -> LinearLayoutLaunch<R> {
173    LinearLayoutLaunch::from_handle(
174        handle,
175        // Don't care about type size, only vector size
176        Type::new(UIntKind::U32.into()).with_vector_size(vector_size),
177        LinearViewLayoutLaunch::new(),
178    )
179}
180
181/// Create a linear tensor view from a handle
182pub fn linear_view<R: Runtime>(handle: TensorBinding<R>) -> LinearViewLaunch<R> {
183    let layout = LinearViewLayoutLaunch::new();
184    LinearViewLaunch::new_tensor::<LinearViewLayout>(handle.into_tensor_arg(), layout)
185}
186
187/// Create a possibly broadcast linear tensor view from a handle and reference handle
188pub fn linear_view_with_reference<R: Runtime>(
189    handle: TensorBinding<R>,
190    reference: TensorBinding<R>,
191) -> LinearViewLaunch<R> {
192    let layout = LinearViewLayoutLaunch::from_reference_handle(reference);
193    LinearViewLaunch::new_tensor::<LinearViewLayout>(handle.into_tensor_arg(), layout)
194}
195
196pub fn linear_view_alias<R: Runtime>(handle: &TensorBinding<R>, pos: usize) -> LinearViewLaunch<R> {
197    let layout = LinearViewLayoutLaunch::new();
198    LinearViewLaunch::new_tensor::<LinearViewLayout>(handle.as_alias(pos), layout)
199}