cubecl_std/
reinterpret_slice.rs1use core::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, unexpanded};
5
6#[derive(CubeType)]
14pub struct ReinterpretSlice<S: CubePrimitive, T: CubePrimitive> {
15 slice: Slice<Line<S>>,
16
17 #[cube(comptime)]
18 line_size: u32,
19
20 #[cube(comptime)]
21 load_many: Option<u32>,
22
23 #[cube(comptime)]
24 _phantom: PhantomData<T>,
25}
26
27#[cube]
28impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSlice<S, T> {
29 pub fn new(slice: Slice<Line<S>>, #[comptime] line_size: u32) -> ReinterpretSlice<S, T> {
30 let source_size = size_of::<S>();
31 let target_size = size_of::<T>();
32 let (optimized_line_size, load_many) =
33 comptime!(optimize_line_size(source_size, line_size, target_size));
34 match comptime!(optimized_line_size) {
35 Some(line_size) => ReinterpretSlice::<S, T> {
36 slice: slice.with_line_size(line_size),
37 line_size,
38 load_many,
39 _phantom: PhantomData,
40 },
41 None => ReinterpretSlice::<S, T> {
42 slice,
43 line_size,
44 load_many,
45 _phantom: PhantomData,
46 },
47 }
48 }
49
50 pub fn read(&self, index: u32) -> T {
51 match comptime!(self.load_many) {
52 Some(amount) => {
53 let first = index * amount;
54 let mut line = Line::<S>::empty(comptime!(amount * self.line_size));
55 #[unroll]
56 for k in 0..amount {
57 let elem = self.slice[first + k];
58 #[unroll]
59 for j in 0..self.line_size {
60 line[k * self.line_size + j] = elem[j];
61 }
62 }
63 T::reinterpret(line)
64 }
65 None => T::reinterpret(self.slice[index]),
66 }
67 }
68}
69
70#[derive(CubeType)]
78pub struct ReinterpretSliceMut<S: CubePrimitive, T: CubePrimitive> {
79 slice: SliceMut<Line<S>>,
80
81 #[cube(comptime)]
82 line_size: u32,
83
84 #[cube(comptime)]
85 load_many: Option<u32>,
86
87 #[cube(comptime)]
88 _phantom: PhantomData<T>,
89}
90
91#[cube]
92impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSliceMut<S, T> {
93 pub fn new(slice: SliceMut<Line<S>>, #[comptime] line_size: u32) -> ReinterpretSliceMut<S, T> {
94 let source_size = size_of::<S>();
95 let target_size = size_of::<T>();
96 let (optimized_line_size, load_many) =
97 comptime!(optimize_line_size(source_size, line_size, target_size));
98 match comptime!(optimized_line_size) {
99 Some(line_size) => ReinterpretSliceMut::<S, T> {
100 slice: slice.with_line_size(line_size),
101 line_size,
102 load_many,
103 _phantom: PhantomData,
104 },
105 None => ReinterpretSliceMut::<S, T> {
106 slice,
107 line_size,
108 load_many,
109 _phantom: PhantomData,
110 },
111 }
112 }
113
114 pub fn read(&self, index: u32) -> T {
115 match comptime!(self.load_many) {
116 Some(amount) => {
117 let first = index * amount;
118 let mut line = Line::<S>::empty(comptime!(amount * self.line_size));
119 #[unroll]
120 for k in 0..amount {
121 let elem = self.slice[first + k];
122 #[unroll]
123 for j in 0..self.line_size {
124 line[k * self.line_size + j] = elem[j];
125 }
126 }
127 T::reinterpret(line)
128 }
129 None => T::reinterpret(self.slice[index]),
130 }
131 }
132
133 pub fn write(&mut self, index: u32, value: T) {
134 let reinterpreted = Line::<S>::reinterpret(value);
135 match comptime!(self.load_many) {
136 Some(amount) => {
137 let first = index * amount;
138 let line_size = comptime!(reinterpreted.size() / amount);
139
140 #[unroll]
141 for k in 0..amount {
142 let mut line = Line::empty(line_size);
143 #[unroll]
144 for j in 0..line_size {
145 line[j] = reinterpreted[k * line_size + j];
146 }
147 self.slice[first + k] = line;
148 }
149 }
150 None => self.slice[index] = reinterpreted,
151 }
152 }
153}
154
155fn optimize_line_size(
156 source_size: u32,
157 line_size: u32,
158 target_size: u32,
159) -> (Option<u32>, Option<u32>) {
160 let line_source_size = source_size * line_size;
161 match line_source_size.cmp(&target_size) {
162 core::cmp::Ordering::Less => {
163 if target_size % line_source_size != 0 {
164 panic!("incompatible number of bytes");
165 }
166
167 let ratio = target_size / line_source_size;
168
169 (None, Some(ratio))
170 }
171 core::cmp::Ordering::Greater => {
172 if line_source_size % target_size != 0 {
173 panic!("incompatible number of bytes");
174 }
175 let ratio = line_source_size / target_size;
176
177 (Some(line_size / ratio), None)
178 }
179 core::cmp::Ordering::Equal => (None, None),
180 }
181}
182
183pub fn size_of<S: CubePrimitive>() -> u32 {
184 unexpanded!()
185}
186
187pub mod size_of {
188 use super::*;
189 #[allow(unused, clippy::all)]
190 pub fn expand<S: CubePrimitive>(context: &mut cubecl::prelude::Scope) -> u32 {
191 S::as_elem(context).size() as u32
192 }
193}