cubecl_std/
reinterpret_slice.rs1use core::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, ir::LineSize, unexpanded};
5
6#[derive(CubeType)]
14pub struct ReinterpretSlice<S: CubePrimitive, T: CubePrimitive> {
15 slice: Slice<Line<S>>,
16
17 #[cube(comptime)]
18 line_size: LineSize,
19
20 #[cube(comptime)]
21 load_many: Option<usize>,
22
23 #[cube(comptime)]
24 _phantom: PhantomData<T>,
25}
26
27#[cube]
28impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSlice<S, T> {
29 pub fn new(slice: Slice<Line<S>>, #[comptime] line_size: LineSize) -> ReinterpretSlice<S, T> {
30 let source_size = size_of::<S>();
31 let target_size = size_of::<T>();
32 let (optimized_line_size, load_many) =
33 comptime!(optimize_line_size(source_size, line_size, target_size));
34 match comptime!(optimized_line_size) {
35 Some(line_size) => ReinterpretSlice::<S, T> {
36 slice: slice.with_line_size(line_size),
37 line_size,
38 load_many,
39 _phantom: PhantomData,
40 },
41 None => ReinterpretSlice::<S, T> {
42 slice,
43 line_size,
44 load_many,
45 _phantom: PhantomData,
46 },
47 }
48 }
49
50 pub fn read(&self, index: usize) -> T {
51 match comptime!(self.load_many) {
52 Some(amount) => {
53 let first = index * amount;
54 let mut line = Line::<S>::empty(comptime!(amount * self.line_size));
55 #[unroll]
56 for k in 0..amount {
57 let elem = self.slice[first + k];
58 #[unroll]
59 for j in 0..self.line_size {
60 line[k * self.line_size + j] = elem[j];
61 }
62 }
63 T::reinterpret(line)
64 }
65 None => T::reinterpret(self.slice[index]),
66 }
67 }
68}
69
70#[derive(CubeType)]
78pub struct ReinterpretSliceMut<S: CubePrimitive, T: CubePrimitive> {
79 slice: SliceMut<Line<S>>,
80
81 #[cube(comptime)]
82 line_size: LineSize,
83
84 #[cube(comptime)]
85 load_many: Option<usize>,
86
87 #[cube(comptime)]
88 _phantom: PhantomData<T>,
89}
90
91#[cube]
92impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSliceMut<S, T> {
93 pub fn new(
94 slice: SliceMut<Line<S>>,
95 #[comptime] line_size: LineSize,
96 ) -> ReinterpretSliceMut<S, T> {
97 let source_size = size_of::<S>();
98 let target_size = size_of::<T>();
99 let (optimized_line_size, load_many) =
100 comptime!(optimize_line_size(source_size, line_size, target_size));
101 match comptime!(optimized_line_size) {
102 Some(line_size) => ReinterpretSliceMut::<S, T> {
103 slice: slice.with_line_size(line_size),
104 line_size,
105 load_many,
106 _phantom: PhantomData,
107 },
108 None => ReinterpretSliceMut::<S, T> {
109 slice,
110 line_size,
111 load_many,
112 _phantom: PhantomData,
113 },
114 }
115 }
116
117 pub fn read(&self, index: usize) -> T {
118 match comptime!(self.load_many) {
119 Some(amount) => {
120 let first = index * amount;
121 let mut line = Line::<S>::empty(comptime!(amount * self.line_size));
122 #[unroll]
123 for k in 0..amount {
124 let elem = self.slice[first + k];
125 #[unroll]
126 for j in 0..self.line_size {
127 line[k * self.line_size + j] = elem[j];
128 }
129 }
130 T::reinterpret(line)
131 }
132 None => T::reinterpret(self.slice[index]),
133 }
134 }
135
136 pub fn write(&mut self, index: usize, value: T) {
137 let reinterpreted = Line::<S>::reinterpret(value);
138 match comptime!(self.load_many) {
139 Some(amount) => {
140 let first = index * amount;
141 let line_size = comptime!(reinterpreted.size() / amount);
142
143 #[unroll]
144 for k in 0..amount {
145 let mut line = Line::empty(line_size);
146 #[unroll]
147 for j in 0..line_size {
148 line[j] = reinterpreted[k * line_size + j];
149 }
150 self.slice[first + k] = line;
151 }
152 }
153 None => self.slice[index] = reinterpreted,
154 }
155 }
156}
157
158fn optimize_line_size(
159 source_size: u32,
160 line_size: LineSize,
161 target_size: u32,
162) -> (Option<usize>, Option<usize>) {
163 let target_size = target_size as usize;
164 let line_source_size = source_size as usize * line_size;
165 match line_source_size.cmp(&target_size) {
166 core::cmp::Ordering::Less => {
167 if !target_size.is_multiple_of(line_source_size) {
168 panic!("incompatible number of bytes");
169 }
170
171 let ratio = target_size / line_source_size;
172
173 (None, Some(ratio))
174 }
175 core::cmp::Ordering::Greater => {
176 if !line_source_size.is_multiple_of(target_size) {
177 panic!("incompatible number of bytes");
178 }
179 let ratio = line_source_size / target_size;
180
181 (Some(line_size / ratio), None)
182 }
183 core::cmp::Ordering::Equal => (None, None),
184 }
185}
186
187pub fn size_of<S: CubePrimitive>() -> u32 {
188 unexpanded!()
189}
190
191pub mod size_of {
192 use super::*;
193 #[allow(unused, clippy::all)]
194 pub fn expand<S: CubePrimitive>(context: &mut cubecl::prelude::Scope) -> u32 {
195 S::as_type(context).size() as u32
196 }
197}