cubecl_std/
reinterpret_slice.rs

1use core::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, ir::LineSize, unexpanded};
5
6/// This struct allows to take a slice of `Line<S>` and reinterpret it
7/// as a slice of `T`. Semantically, this is equivalent to reinterpreting the slice of `Line<S>`
8/// to a slice of `T`. When indexing, the index is valid in the casted list.
9///
10/// # Warning
11///
12/// Currently, this only work with `cube(launch_unchecked)` and is not supported on wgpu.
13#[derive(CubeType)]
14pub struct ReinterpretSlice<S: CubePrimitive, T: CubePrimitive> {
15    slice: Slice<Line<S>>,
16
17    #[cube(comptime)]
18    line_size: LineSize,
19
20    #[cube(comptime)]
21    load_many: Option<usize>,
22
23    #[cube(comptime)]
24    _phantom: PhantomData<T>,
25}
26
27#[cube]
28impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSlice<S, T> {
29    pub fn new(slice: Slice<Line<S>>, #[comptime] line_size: LineSize) -> ReinterpretSlice<S, T> {
30        let source_size = size_of::<S>();
31        let target_size = size_of::<T>();
32        let (optimized_line_size, load_many) =
33            comptime!(optimize_line_size(source_size, line_size, target_size));
34        match comptime!(optimized_line_size) {
35            Some(line_size) => ReinterpretSlice::<S, T> {
36                slice: slice.with_line_size(line_size),
37                line_size,
38                load_many,
39                _phantom: PhantomData,
40            },
41            None => ReinterpretSlice::<S, T> {
42                slice,
43                line_size,
44                load_many,
45                _phantom: PhantomData,
46            },
47        }
48    }
49
50    pub fn read(&self, index: usize) -> T {
51        match comptime!(self.load_many) {
52            Some(amount) => {
53                let first = index * amount;
54                let mut line = Line::<S>::empty(comptime!(amount * self.line_size));
55                #[unroll]
56                for k in 0..amount {
57                    let elem = self.slice[first + k];
58                    #[unroll]
59                    for j in 0..self.line_size {
60                        line[k * self.line_size + j] = elem[j];
61                    }
62                }
63                T::reinterpret(line)
64            }
65            None => T::reinterpret(self.slice[index]),
66        }
67    }
68}
69
70/// This struct allows to take a mutable slice of `Line<S>` and reinterpret it
71/// as a mutable slice of `T`. Semantically, this is equivalent to reinterpreting the slice of `Line<S>`
72/// to a mutable slice of `T`. When indexing, the index is valid in the casted list.
73///
74/// # Warning
75///
76/// Currently, this only work with `cube(launch_unchecked)` and is not supported on wgpu.
77#[derive(CubeType)]
78pub struct ReinterpretSliceMut<S: CubePrimitive, T: CubePrimitive> {
79    slice: SliceMut<Line<S>>,
80
81    #[cube(comptime)]
82    line_size: LineSize,
83
84    #[cube(comptime)]
85    load_many: Option<usize>,
86
87    #[cube(comptime)]
88    _phantom: PhantomData<T>,
89}
90
91#[cube]
92impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSliceMut<S, T> {
93    pub fn new(
94        slice: SliceMut<Line<S>>,
95        #[comptime] line_size: LineSize,
96    ) -> ReinterpretSliceMut<S, T> {
97        let source_size = size_of::<S>();
98        let target_size = size_of::<T>();
99        let (optimized_line_size, load_many) =
100            comptime!(optimize_line_size(source_size, line_size, target_size));
101        match comptime!(optimized_line_size) {
102            Some(line_size) => ReinterpretSliceMut::<S, T> {
103                slice: slice.with_line_size(line_size),
104                line_size,
105                load_many,
106                _phantom: PhantomData,
107            },
108            None => ReinterpretSliceMut::<S, T> {
109                slice,
110                line_size,
111                load_many,
112                _phantom: PhantomData,
113            },
114        }
115    }
116
117    pub fn read(&self, index: usize) -> T {
118        match comptime!(self.load_many) {
119            Some(amount) => {
120                let first = index * amount;
121                let mut line = Line::<S>::empty(comptime!(amount * self.line_size));
122                #[unroll]
123                for k in 0..amount {
124                    let elem = self.slice[first + k];
125                    #[unroll]
126                    for j in 0..self.line_size {
127                        line[k * self.line_size + j] = elem[j];
128                    }
129                }
130                T::reinterpret(line)
131            }
132            None => T::reinterpret(self.slice[index]),
133        }
134    }
135
136    pub fn write(&mut self, index: usize, value: T) {
137        let reinterpreted = Line::<S>::reinterpret(value);
138        match comptime!(self.load_many) {
139            Some(amount) => {
140                let first = index * amount;
141                let line_size = comptime!(reinterpreted.size() / amount);
142
143                #[unroll]
144                for k in 0..amount {
145                    let mut line = Line::empty(line_size);
146                    #[unroll]
147                    for j in 0..line_size {
148                        line[j] = reinterpreted[k * line_size + j];
149                    }
150                    self.slice[first + k] = line;
151                }
152            }
153            None => self.slice[index] = reinterpreted,
154        }
155    }
156}
157
158fn optimize_line_size(
159    source_size: u32,
160    line_size: LineSize,
161    target_size: u32,
162) -> (Option<usize>, Option<usize>) {
163    let target_size = target_size as usize;
164    let line_source_size = source_size as usize * line_size;
165    match line_source_size.cmp(&target_size) {
166        core::cmp::Ordering::Less => {
167            if !target_size.is_multiple_of(line_source_size) {
168                panic!("incompatible number of bytes");
169            }
170
171            let ratio = target_size / line_source_size;
172
173            (None, Some(ratio))
174        }
175        core::cmp::Ordering::Greater => {
176            if !line_source_size.is_multiple_of(target_size) {
177                panic!("incompatible number of bytes");
178            }
179            let ratio = line_source_size / target_size;
180
181            (Some(line_size / ratio), None)
182        }
183        core::cmp::Ordering::Equal => (None, None),
184    }
185}
186
187pub fn size_of<S: CubePrimitive>() -> u32 {
188    unexpanded!()
189}
190
191pub mod size_of {
192    use super::*;
193    #[allow(unused, clippy::all)]
194    pub fn expand<S: CubePrimitive>(context: &mut cubecl::prelude::Scope) -> u32 {
195        S::as_type(context).size() as u32
196    }
197}