cubecl_std/
reinterpret_slice.rs

1use core::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, unexpanded};
5
6/// This struct allows to take a slice of `Line<S>` and reinterpret it
7/// as a slice of `T`. Semantically, this is equivalent to reinterpreting the slice of `Line<S>`
8/// to a slice of `T`. When indexing, the index is valid in the casted list.
9///
10/// # Warning
11///
12/// Currently, this only work with `cube(launch_unchecked)` and is not supported on wgpu.
13#[derive(CubeType)]
14pub struct ReinterpretSlice<S: CubePrimitive, T: CubePrimitive> {
15    slice: Slice<Line<S>>,
16
17    #[cube(comptime)]
18    line_size: u32,
19
20    #[cube(comptime)]
21    load_many: Option<u32>,
22
23    #[cube(comptime)]
24    _phantom: PhantomData<T>,
25}
26
27#[cube]
28impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSlice<S, T> {
29    pub fn new(slice: Slice<Line<S>>, #[comptime] line_size: u32) -> ReinterpretSlice<S, T> {
30        let source_size = size_of::<S>();
31        let target_size = size_of::<T>();
32        let (optimized_line_size, load_many) =
33            comptime!(optimize_line_size(source_size, line_size, target_size));
34        match comptime!(optimized_line_size) {
35            Some(line_size) => ReinterpretSlice::<S, T> {
36                slice: slice.with_line_size(line_size),
37                line_size,
38                load_many,
39                _phantom: PhantomData,
40            },
41            None => ReinterpretSlice::<S, T> {
42                slice,
43                line_size,
44                load_many,
45                _phantom: PhantomData,
46            },
47        }
48    }
49
50    pub fn read(&self, index: u32) -> T {
51        match comptime!(self.load_many) {
52            Some(amount) => {
53                let first = index * amount;
54                let mut line = Line::<S>::empty(comptime!(amount * self.line_size));
55                #[unroll]
56                for k in 0..amount {
57                    let elem = self.slice[first + k];
58                    #[unroll]
59                    for j in 0..self.line_size {
60                        line[k * self.line_size + j] = elem[j];
61                    }
62                }
63                T::reinterpret(line)
64            }
65            None => T::reinterpret(self.slice[index]),
66        }
67    }
68}
69
70/// This struct allows to take a mutable slice of `Line<S>` and reinterpret it
71/// as a mutable slice of `T`. Semantically, this is equivalent to reinterpreting the slice of `Line<S>`
72/// to a mutable slice of `T`. When indexing, the index is valid in the casted list.
73///
74/// # Warning
75///
76/// Currently, this only work with `cube(launch_unchecked)` and is not supported on wgpu.
77#[derive(CubeType)]
78pub struct ReinterpretSliceMut<S: CubePrimitive, T: CubePrimitive> {
79    slice: SliceMut<Line<S>>,
80
81    #[cube(comptime)]
82    line_size: u32,
83
84    #[cube(comptime)]
85    load_many: Option<u32>,
86
87    #[cube(comptime)]
88    _phantom: PhantomData<T>,
89}
90
91#[cube]
92impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSliceMut<S, T> {
93    pub fn new(slice: SliceMut<Line<S>>, #[comptime] line_size: u32) -> ReinterpretSliceMut<S, T> {
94        let source_size = size_of::<S>();
95        let target_size = size_of::<T>();
96        let (optimized_line_size, load_many) =
97            comptime!(optimize_line_size(source_size, line_size, target_size));
98        match comptime!(optimized_line_size) {
99            Some(line_size) => ReinterpretSliceMut::<S, T> {
100                slice: slice.with_line_size(line_size),
101                line_size,
102                load_many,
103                _phantom: PhantomData,
104            },
105            None => ReinterpretSliceMut::<S, T> {
106                slice,
107                line_size,
108                load_many,
109                _phantom: PhantomData,
110            },
111        }
112    }
113
114    pub fn read(&self, index: u32) -> T {
115        match comptime!(self.load_many) {
116            Some(amount) => {
117                let first = index * amount;
118                let mut line = Line::<S>::empty(comptime!(amount * self.line_size));
119                #[unroll]
120                for k in 0..amount {
121                    let elem = self.slice[first + k];
122                    #[unroll]
123                    for j in 0..self.line_size {
124                        line[k * self.line_size + j] = elem[j];
125                    }
126                }
127                T::reinterpret(line)
128            }
129            None => T::reinterpret(self.slice[index]),
130        }
131    }
132
133    pub fn write(&mut self, index: u32, value: T) {
134        let reinterpreted = Line::<S>::reinterpret(value);
135        match comptime!(self.load_many) {
136            Some(amount) => {
137                let first = index * amount;
138                let line_size = comptime!(reinterpreted.size() / amount);
139
140                #[unroll]
141                for k in 0..amount {
142                    let mut line = Line::empty(line_size);
143                    #[unroll]
144                    for j in 0..line_size {
145                        line[j] = reinterpreted[k * line_size + j];
146                    }
147                    self.slice[first + k] = line;
148                }
149            }
150            None => self.slice[index] = reinterpreted,
151        }
152    }
153}
154
155fn optimize_line_size(
156    source_size: u32,
157    line_size: u32,
158    target_size: u32,
159) -> (Option<u32>, Option<u32>) {
160    let line_source_size = source_size * line_size;
161    match line_source_size.cmp(&target_size) {
162        core::cmp::Ordering::Less => {
163            if target_size % line_source_size != 0 {
164                panic!("incompatible number of bytes");
165            }
166
167            let ratio = target_size / line_source_size;
168
169            (None, Some(ratio))
170        }
171        core::cmp::Ordering::Greater => {
172            if line_source_size % target_size != 0 {
173                panic!("incompatible number of bytes");
174            }
175            let ratio = line_source_size / target_size;
176
177            (Some(line_size / ratio), None)
178        }
179        core::cmp::Ordering::Equal => (None, None),
180    }
181}
182
183pub fn size_of<S: CubePrimitive>() -> u32 {
184    unexpanded!()
185}
186
187pub mod size_of {
188    use super::*;
189    #[allow(unused, clippy::all)]
190    pub fn expand<S: CubePrimitive>(context: &mut cubecl::prelude::Scope) -> u32 {
191        S::as_elem(context).size() as u32
192    }
193}