cubecl_std/tensor/layout/
permuted.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::{
5 FastDivmod, FastDivmodArgs,
6 tensor::{
7 index_offset_contiguous_fastdivmod,
8 layout::{Coords1d, Layout, LayoutExpand},
9 },
10};
11
12#[derive(CubeType, CubeLaunch, Clone)]
15pub struct PermutedLayout {
16 shape: Sequence<FastDivmod>,
17 strides: Sequence<u32>,
18 len: u32,
19 #[cube(comptime)]
20 line_size: u32,
21}
22
23impl<'a, R: Runtime> PermutedLayoutLaunch<'a, R> {
24 pub fn from_shape_strides(
27 client: &ComputeClient<R::Server>,
28 shape: &[usize],
29 strides: &[usize],
30 line_size: u8,
31 ) -> Self {
32 let len = shape.iter().product::<usize>() / line_size as usize;
33
34 let shape = shape
35 .iter()
36 .map(|it| FastDivmodArgs::new(client, *it as u32))
37 .collect();
38 let strides = strides
39 .iter()
40 .map(|it| ScalarArg::new(*it as u32))
41 .collect();
42
43 Self::new(shape, strides, ScalarArg::new(len as u32), line_size as u32)
44 }
45
46 pub fn from_shapes_strides_ref(
49 client: &ComputeClient<R::Server>,
50 shape: &[usize],
51 reference_shape: &[usize],
52 strides: &[usize],
53 line_size: u8,
54 ) -> Self {
55 debug_assert!(
56 shape.len() == reference_shape.len(),
57 "Shape and reference should have the same rank"
58 );
59 debug_assert!(
60 shape
61 .iter()
62 .zip(reference_shape)
63 .all(|(s, r)| s == r || *s == 1),
64 "Shape should be equal to reference or 1 on each dimension"
65 );
66
67 let strides: Vec<usize> = strides
68 .iter()
69 .zip(shape.iter().zip(reference_shape))
70 .map(|(stride, (s, r))| if *s == *r { *stride } else { 0 })
71 .collect();
72
73 Self::from_shape_strides(client, reference_shape, &strides, line_size)
74 }
75
76 pub fn from_handles_ref(
77 client: &ComputeClient<R::Server>,
78 handle: &TensorHandleRef<'_, R>,
79 reference_handle: &TensorHandleRef<'_, R>,
80 line_size: u8,
81 ) -> Self {
82 Self::from_shapes_strides_ref(
83 client,
84 handle.shape,
85 reference_handle.shape,
86 handle.strides,
87 line_size,
88 )
89 }
90
91 pub fn from_handle(
92 client: &ComputeClient<R::Server>,
93 handle: &TensorHandleRef<'_, R>,
94 line_size: u8,
95 ) -> Self {
96 Self::from_shape_strides(client, handle.shape, handle.strides, line_size)
97 }
98}
99
100#[cube]
101impl Layout for PermutedLayout {
102 type Coordinates = Coords1d;
103 type SourceCoordinates = Coords1d;
104
105 fn to_source_pos(&self, pos: Self::Coordinates) -> u32 {
106 index_offset_contiguous_fastdivmod(
107 pos,
108 &self.shape,
109 &self.strides,
110 comptime![self.line_size],
111 )
112 }
113
114 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (u32, bool) {
115 (self.to_source_pos(pos), self.is_in_bounds(pos))
116 }
117
118 fn shape(&self) -> Self::Coordinates {
119 self.len
120 }
121
122 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
123 pos < self.len
124 }
125}