Skip to main content

cubecl_std/
reinterpret_slice.rs

1use core::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, ir::VectorSize, unexpanded};
5
6/// This struct allows to take a slice of `Vector<S>` and reinterpret it
7/// as a slice of `T`. Semantically, this is equivalent to reinterpreting the slice of `Vector<S>`
8/// to a slice of `T`. When indexing, the index is valid in the casted list.
9///
10/// # Warning
11///
12/// Currently, this only work with `cube(launch_unchecked)` and is not supported on wgpu.
13#[derive(CubeType)]
14pub struct ReinterpretSlice<S: CubePrimitive, T: CubePrimitive> {
15    // Dummy vector size for downcasting later
16    slice: Slice<S>,
17
18    #[cube(comptime)]
19    vector_size: VectorSize,
20
21    #[cube(comptime)]
22    load_many: Option<usize>,
23
24    #[cube(comptime)]
25    _phantom: PhantomData<T>,
26}
27
28#[cube]
29impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSlice<S, T> {
30    pub fn new(slice: Slice<S>) -> ReinterpretSlice<S, T> {
31        let in_vector_size = slice.vector_size();
32        let source_size = S::Scalar::type_size();
33        let target_size = T::Scalar::type_size();
34        let (optimized_vector_size, load_many) = comptime!(optimize_vector_size(
35            source_size,
36            in_vector_size,
37            target_size
38        ));
39        match comptime!(optimized_vector_size) {
40            Some(vector_size) => {
41                let size!(N2) = vector_size;
42                let slice = slice.into_vectorized().with_vector_size::<N2>();
43
44                ReinterpretSlice::<S, T> {
45                    slice: unsafe { slice.downcast_unchecked() },
46                    vector_size,
47                    load_many,
48                    _phantom: PhantomData,
49                }
50            }
51            None => ReinterpretSlice::<S, T> {
52                slice,
53                vector_size: in_vector_size,
54                load_many,
55                _phantom: PhantomData,
56            },
57        }
58    }
59
60    pub fn read(&self, index: usize) -> T {
61        let size!(N) = self.vector_size;
62        let slice = self.slice.into_vectorized().with_vector_size::<N>();
63        match comptime!(self.load_many) {
64            Some(amount) => {
65                let first = index * amount;
66                let size!(N2) = comptime!(amount * self.vector_size);
67                let mut vector = Vector::<S::Scalar, N2>::empty();
68                #[unroll]
69                for k in 0..amount {
70                    let elem = slice[first + k];
71                    #[unroll]
72                    for j in 0..self.vector_size {
73                        vector[k * self.vector_size + j] = elem[j];
74                    }
75                }
76                T::reinterpret(vector)
77            }
78            None => T::reinterpret(slice[index]),
79        }
80    }
81}
82
83/// This struct allows to take a mutable slice of `Vector<S>` and reinterpret it
84/// as a mutable slice of `T`. Semantically, this is equivalent to reinterpreting the slice of `Vector<S>`
85/// to a mutable slice of `T`. When indexing, the index is valid in the casted list.
86///
87/// # Warning
88///
89/// Currently, this only work with `cube(launch_unchecked)` and is not supported on wgpu.
90#[derive(CubeType)]
91pub struct ReinterpretSliceMut<S: CubePrimitive, T: CubePrimitive> {
92    slice: SliceMut<S>,
93
94    #[cube(comptime)]
95    vector_size: VectorSize,
96
97    #[cube(comptime)]
98    load_many: Option<usize>,
99
100    #[cube(comptime)]
101    _phantom: PhantomData<T>,
102}
103
104#[cube]
105impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSliceMut<S, T> {
106    pub fn new(slice: SliceMut<S>) -> ReinterpretSliceMut<S, T> {
107        let in_vector_size = slice.vector_size();
108        let source_size = S::Scalar::type_size();
109        let target_size = T::Scalar::type_size();
110        let (optimized_vector_size, load_many) = comptime!(optimize_vector_size(
111            source_size,
112            in_vector_size,
113            target_size
114        ));
115        match comptime!(optimized_vector_size) {
116            Some(vector_size) => {
117                let size!(N2) = vector_size;
118                let slice = slice.into_vectorized().with_vector_size::<N2>();
119
120                ReinterpretSliceMut::<S, T> {
121                    slice: unsafe { slice.downcast_unchecked() },
122                    vector_size,
123                    load_many,
124                    _phantom: PhantomData,
125                }
126            }
127            None => ReinterpretSliceMut::<S, T> {
128                slice,
129                vector_size: in_vector_size,
130                load_many,
131                _phantom: PhantomData,
132            },
133        }
134    }
135
136    pub fn read(&self, index: usize) -> T {
137        let size!(N) = self.vector_size;
138        let slice = self.slice.into_vectorized().with_vector_size::<N>();
139        match comptime!(self.load_many) {
140            Some(amount) => {
141                let first = index * amount;
142                let size!(N2) = comptime!(amount * self.vector_size);
143                let mut vector = Vector::<S::Scalar, N2>::empty();
144                #[unroll]
145                for k in 0..amount {
146                    let elem = slice[first + k];
147                    #[unroll]
148                    for j in 0..self.vector_size {
149                        vector[k * self.vector_size + j] = elem[j];
150                    }
151                }
152                T::reinterpret(vector)
153            }
154            None => T::reinterpret(slice[index]),
155        }
156    }
157
158    pub fn write(&mut self, index: usize, value: T) {
159        let size!(N) = self.vector_size;
160        let mut slice = self.slice.into_vectorized().with_vector_size::<N>();
161        let size!(N1) = S::reinterpret_vectorization::<T>();
162        let reinterpreted = Vector::<S::Scalar, N1>::reinterpret(value);
163        match comptime!(self.load_many) {
164            Some(amount) => {
165                let first = index * amount;
166                let vector_size = comptime!(reinterpreted.size() / amount);
167
168                #[unroll]
169                for k in 0..amount {
170                    let mut vector = Vector::empty();
171                    #[unroll]
172                    for j in 0..vector_size {
173                        vector[j] = reinterpreted[k * vector_size + j];
174                    }
175                    slice[first + k] = vector;
176                }
177            }
178            None => slice[index] = Vector::cast_from(reinterpreted),
179        }
180    }
181}
182
183fn optimize_vector_size(
184    source_size: usize,
185    vector_size: VectorSize,
186    target_size: usize,
187) -> (Option<usize>, Option<usize>) {
188    let vector_source_size = source_size * vector_size;
189    match vector_source_size.cmp(&target_size) {
190        core::cmp::Ordering::Less => {
191            if !target_size.is_multiple_of(vector_source_size) {
192                panic!("incompatible number of bytes");
193            }
194
195            let ratio = target_size / vector_source_size;
196
197            (None, Some(ratio))
198        }
199        core::cmp::Ordering::Greater => {
200            if !vector_source_size.is_multiple_of(target_size) {
201                panic!("incompatible number of bytes");
202            }
203            let ratio = vector_source_size / target_size;
204
205            (Some(vector_size / ratio), None)
206        }
207        core::cmp::Ordering::Equal => (None, None),
208    }
209}
210
211pub fn size_of<S: CubePrimitive>() -> u32 {
212    unexpanded!()
213}
214
215pub mod size_of {
216    use super::*;
217    #[allow(unused, clippy::all)]
218    pub fn expand<S: CubePrimitive>(context: &mut cubecl::prelude::Scope) -> u32 {
219        S::as_type(context).size() as u32
220    }
221}