cubecl_std/tensor/layout/
permuted.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, ir::LineSize};
3
4use crate::{
5 FastDivmod, FastDivmodArgs,
6 tensor::{
7 index_offset_contiguous_fastdivmod,
8 layout::{Coords1d, Layout, LayoutExpand},
9 },
10};
11
12#[derive(CubeType, CubeLaunch, Clone)]
15pub struct PermutedLayout {
16 shape: Sequence<FastDivmod<usize>>,
17 strides: Sequence<usize>,
18 len: usize,
19 #[cube(comptime)]
20 line_size: LineSize,
21}
22
23#[cube]
24impl PermutedLayout {
25 pub fn new(
26 shape: Sequence<FastDivmod<usize>>,
27 strides: Sequence<usize>,
28 len: usize,
29 #[comptime] line_size: LineSize,
30 ) -> Self {
31 PermutedLayout {
32 shape,
33 strides,
34 len,
35 line_size,
36 }
37 }
38}
39
40impl<'a, R: Runtime> PermutedLayoutLaunch<'a, R> {
41 pub fn from_shape_strides(
44 client: &ComputeClient<R>,
45 shape: &[usize],
46 strides: &[usize],
47 line_size: LineSize,
48 ) -> Self {
49 let len = shape.iter().product::<usize>() / line_size;
50
51 let shape = shape
52 .iter()
53 .map(|it| FastDivmodArgs::<usize>::new(client, *it))
54 .collect();
55 let strides = strides.iter().map(|it| ScalarArg::new(*it)).collect();
56
57 Self::new(shape, strides, ScalarArg::new(len), line_size)
58 }
59
60 pub fn from_shapes_strides_ref(
63 client: &ComputeClient<R>,
64 shape: &[usize],
65 reference_shape: &[usize],
66 strides: &[usize],
67 line_size: LineSize,
68 ) -> Self {
69 debug_assert!(
70 shape.len() == reference_shape.len(),
71 "Shape and reference should have the same rank"
72 );
73 debug_assert!(
74 shape
75 .iter()
76 .zip(reference_shape)
77 .all(|(s, r)| s == r || *s == 1),
78 "Shape should be equal to reference or 1 on each dimension"
79 );
80
81 let strides: Vec<usize> = strides
82 .iter()
83 .zip(shape.iter().zip(reference_shape))
84 .map(|(stride, (s, r))| if *s == *r { *stride } else { 0 })
85 .collect();
86
87 Self::from_shape_strides(client, reference_shape, &strides, line_size)
88 }
89
90 pub fn from_handles_ref(
91 client: &ComputeClient<R>,
92 handle: &TensorHandleRef<'_, R>,
93 reference_handle: &TensorHandleRef<'_, R>,
94 line_size: LineSize,
95 ) -> Self {
96 Self::from_shapes_strides_ref(
97 client,
98 handle.shape,
99 reference_handle.shape,
100 handle.strides,
101 line_size,
102 )
103 }
104
105 pub fn from_handle(
106 client: &ComputeClient<R>,
107 handle: &TensorHandleRef<'_, R>,
108 line_size: LineSize,
109 ) -> Self {
110 Self::from_shape_strides(client, handle.shape, handle.strides, line_size)
111 }
112}
113
114#[cube]
115impl Layout for PermutedLayout {
116 type Coordinates = Coords1d;
117 type SourceCoordinates = Coords1d;
118
119 fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
120 index_offset_contiguous_fastdivmod(pos, &self.shape, &self.strides, self.line_size)
121 }
122
123 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
124 (self.to_source_pos(pos), self.is_in_bounds(pos))
125 }
126
127 fn shape(&self) -> Self::Coordinates {
128 self.len
129 }
130
131 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
132 pos < self.len
133 }
134}