Skip to main content

cubecl_std/tensor/layout/
plain.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3
4use crate::tensor::{
5    launch::{BufferArg, ViewLayoutLaunchArg},
6    layout::{Coords1d, Layout, LayoutExpand},
7};
8
9/// Layout for contiguous tensors.
10#[derive(CubeType, Clone)]
11pub struct PlainLayout {
12    len: usize,
13}
14
15#[cube]
16impl PlainLayout {
17    pub fn new(len: usize) -> Self {
18        PlainLayout { len }
19    }
20}
21
22impl ViewLayoutLaunchArg for PlainLayout {
23    type RuntimeArg<R: Runtime> = ();
24    type CompilationArg = ();
25
26    fn register<R: Runtime, B: BufferArg>(
27        _: Self::RuntimeArg<R>,
28        buffer: &B,
29        ty: Type,
30        launcher: &mut KernelLauncher<R>,
31    ) {
32        <usize as LaunchArg>::register(buffer.len() / ty.vector_size(), launcher);
33    }
34
35    fn expand(
36        _: &Self::CompilationArg,
37        _: Type,
38        builder: &mut KernelBuilder,
39    ) -> <Self as CubeType>::ExpandType {
40        let len = <usize as LaunchArg>::expand(&(), builder);
41        PlainLayout::__expand_new(&mut builder.scope, len)
42    }
43}
44
45#[cube]
46impl Layout for PlainLayout {
47    type Coordinates = Coords1d;
48    type SourceCoordinates = Coords1d;
49
50    fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
51        pos
52    }
53
54    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
55        (self.to_source_pos(pos), self.is_in_bounds(pos))
56    }
57
58    fn shape(&self) -> Self::Coordinates {
59        self.len
60    }
61
62    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
63        pos < self.len
64    }
65}