cubecl_std/tensor/layout/
plain.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3
4use crate::tensor::{
5 launch::{BufferArg, ViewLayoutLaunchArg},
6 layout::{Coords1d, Layout, LayoutExpand},
7};
8
9#[derive(CubeType, Clone)]
11pub struct PlainLayout {
12 len: usize,
13}
14
15#[cube]
16impl PlainLayout {
17 pub fn new(len: usize) -> Self {
18 PlainLayout { len }
19 }
20}
21
22impl ViewLayoutLaunchArg for PlainLayout {
23 type RuntimeArg<R: Runtime> = ();
24 type CompilationArg = ();
25
26 fn register<R: Runtime, B: BufferArg>(
27 _: Self::RuntimeArg<R>,
28 buffer: &B,
29 ty: Type,
30 launcher: &mut KernelLauncher<R>,
31 ) {
32 <usize as LaunchArg>::register(buffer.len() / ty.vector_size(), launcher);
33 }
34
35 fn expand(
36 _: &Self::CompilationArg,
37 _: Type,
38 builder: &mut KernelBuilder,
39 ) -> <Self as CubeType>::ExpandType {
40 let len = <usize as LaunchArg>::expand(&(), builder);
41 PlainLayout::__expand_new(&mut builder.scope, len)
42 }
43}
44
45#[cube]
46impl Layout for PlainLayout {
47 type Coordinates = Coords1d;
48 type SourceCoordinates = Coords1d;
49
50 fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
51 pos
52 }
53
54 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
55 (self.to_source_pos(pos), self.is_in_bounds(pos))
56 }
57
58 fn shape(&self) -> Self::Coordinates {
59 self.len
60 }
61
62 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
63 pos < self.len
64 }
65}