cubecl_std/tensor/layout/
permuted.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, ir::LineSize};
3
4use crate::{
5 FastDivmod, FastDivmodArgs,
6 tensor::{
7 index_offset_contiguous_fastdivmod,
8 layout::{Coords1d, Layout, LayoutExpand},
9 },
10};
11
12#[derive(CubeType, CubeLaunch, Clone)]
15pub struct PermutedLayout {
16 shape: Sequence<FastDivmod<usize>>,
17 strides: Sequence<usize>,
18 len: usize,
19 #[cube(comptime)]
20 line_size: LineSize,
21}
22
23impl<'a, R: Runtime> PermutedLayoutLaunch<'a, R> {
24 pub fn from_shape_strides(
27 client: &ComputeClient<R>,
28 shape: &[usize],
29 strides: &[usize],
30 line_size: LineSize,
31 ) -> Self {
32 let len = shape.iter().product::<usize>() / line_size;
33
34 let shape = shape
35 .iter()
36 .map(|it| FastDivmodArgs::<usize>::new(client, *it))
37 .collect();
38 let strides = strides.iter().map(|it| ScalarArg::new(*it)).collect();
39
40 Self::new(shape, strides, ScalarArg::new(len), line_size)
41 }
42
43 pub fn from_shapes_strides_ref(
46 client: &ComputeClient<R>,
47 shape: &[usize],
48 reference_shape: &[usize],
49 strides: &[usize],
50 line_size: LineSize,
51 ) -> Self {
52 debug_assert!(
53 shape.len() == reference_shape.len(),
54 "Shape and reference should have the same rank"
55 );
56 debug_assert!(
57 shape
58 .iter()
59 .zip(reference_shape)
60 .all(|(s, r)| s == r || *s == 1),
61 "Shape should be equal to reference or 1 on each dimension"
62 );
63
64 let strides: Vec<usize> = strides
65 .iter()
66 .zip(shape.iter().zip(reference_shape))
67 .map(|(stride, (s, r))| if *s == *r { *stride } else { 0 })
68 .collect();
69
70 Self::from_shape_strides(client, reference_shape, &strides, line_size)
71 }
72
73 pub fn from_handles_ref(
74 client: &ComputeClient<R>,
75 handle: &TensorHandleRef<'_, R>,
76 reference_handle: &TensorHandleRef<'_, R>,
77 line_size: LineSize,
78 ) -> Self {
79 Self::from_shapes_strides_ref(
80 client,
81 handle.shape,
82 reference_handle.shape,
83 handle.strides,
84 line_size,
85 )
86 }
87
88 pub fn from_handle(
89 client: &ComputeClient<R>,
90 handle: &TensorHandleRef<'_, R>,
91 line_size: LineSize,
92 ) -> Self {
93 Self::from_shape_strides(client, handle.shape, handle.strides, line_size)
94 }
95}
96
97#[cube]
98impl Layout for PermutedLayout {
99 type Coordinates = Coords1d;
100 type SourceCoordinates = Coords1d;
101
102 fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
103 index_offset_contiguous_fastdivmod(pos, &self.shape, &self.strides, self.line_size)
104 }
105
106 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
107 (self.to_source_pos(pos), self.is_in_bounds(pos))
108 }
109
110 fn shape(&self) -> Self::Coordinates {
111 self.len
112 }
113
114 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
115 pos < self.len
116 }
117}