cubecl_core/frontend/container/
shared_memory.rs

1use core::marker::PhantomData;
2
3use crate::{
4    self as cubecl,
5    prelude::{Lined, LinedExpand},
6    unexpanded,
7};
8use cubecl_ir::{LineSize, Marker, VariableKind};
9use cubecl_macros::{cube, intrinsic};
10
11use crate::{
12    frontend::{CubePrimitive, CubeType, ExpandElementTyped, IntoMut},
13    ir::{Scope, Type},
14    prelude::{
15        Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign, index_unchecked,
16    },
17};
18
19pub type SharedMemoryExpand<T> = ExpandElementTyped<SharedMemory<T>>;
20pub type SharedExpand<T> = ExpandElementTyped<Shared<T>>;
21
22#[derive(Clone, Copy)]
23pub struct Shared<E: CubePrimitive> {
24    _val: PhantomData<E>,
25}
26
27#[derive(Clone, Copy)]
28pub struct SharedMemory<E: CubePrimitive> {
29    _val: PhantomData<E>,
30}
31
32impl<T: CubePrimitive> IntoMut for ExpandElementTyped<SharedMemory<T>> {
33    fn into_mut(self, _scope: &mut Scope) -> Self {
34        self
35    }
36}
37
38impl<T: CubePrimitive> CubeType for SharedMemory<T> {
39    type ExpandType = ExpandElementTyped<SharedMemory<T>>;
40}
41
42impl<T: CubePrimitive> IntoMut for ExpandElementTyped<Shared<T>> {
43    fn into_mut(self, _scope: &mut Scope) -> Self {
44        self
45    }
46}
47
48impl<T: CubePrimitive> CubeType for Shared<T> {
49    type ExpandType = ExpandElementTyped<Shared<T>>;
50}
51
52#[cube]
53impl<T: CubePrimitive + Clone> SharedMemory<T> {
54    #[allow(unused_variables)]
55    pub fn new(#[comptime] size: usize) -> Self {
56        intrinsic!(|scope| {
57            scope
58                .create_shared_array(Type::new(T::as_type(scope)), size, None)
59                .into()
60        })
61    }
62
63    #[allow(unused_variables)]
64    pub fn new_lined(
65        #[comptime] size: usize,
66        #[comptime] line_size: LineSize,
67    ) -> SharedMemory<Line<T>> {
68        intrinsic!(|scope| {
69            scope
70                .create_shared_array(Type::new(T::as_type(scope)).line(line_size), size, None)
71                .into()
72        })
73    }
74
75    #[allow(clippy::len_without_is_empty)]
76    pub fn len(&self) -> usize {
77        intrinsic!(|_| len_static(&self))
78    }
79
80    pub fn buffer_len(&self) -> usize {
81        self.len()
82    }
83}
84
85#[cube]
86impl<T: CubePrimitive> Shared<T> {
87    pub fn new() -> Self {
88        intrinsic!(|scope| {
89            let var = scope.create_shared(Type::new(T::as_type(scope)));
90            ExpandElementTyped::new(var)
91        })
92    }
93}
94
95pub trait AsRefExpand<T: CubeType> {
96    /// Converts this type into a shared reference of the (usually inferred) input type.
97    fn __expand_as_ref_method(self, scope: &mut Scope) -> T::ExpandType;
98}
99impl<T: CubePrimitive> AsRefExpand<T> for ExpandElementTyped<T> {
100    fn __expand_as_ref_method(self, _scope: &mut Scope) -> ExpandElementTyped<T> {
101        self
102    }
103}
104pub trait AsMutExpand<T: CubeType> {
105    /// Converts this type into a shared reference of the (usually inferred) input type.
106    fn __expand_as_mut_method(self, scope: &mut Scope) -> T::ExpandType;
107}
108impl<T: CubePrimitive> AsMutExpand<T> for ExpandElementTyped<T> {
109    fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
110        self
111    }
112}
113
114/// Type inference won't allow things like assign to work normally, so we need to manually call
115/// `as_ref` or `as_mut` for those. Things like barrier ops should take `AsRef` so the conversion
116/// is automatic.
117impl<T: CubePrimitive> AsRef<T> for Shared<T> {
118    fn as_ref(&self) -> &T {
119        unexpanded!()
120    }
121}
122impl<T: CubePrimitive> AsRefExpand<T> for SharedExpand<T> {
123    fn __expand_as_ref_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
124        self.expand.into()
125    }
126}
127
128impl<T: CubePrimitive> AsMut<T> for Shared<T> {
129    fn as_mut(&mut self) -> &mut T {
130        unexpanded!()
131    }
132}
133impl<T: CubePrimitive> AsMutExpand<T> for SharedExpand<T> {
134    fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
135        self.expand.into()
136    }
137}
138
139impl<T: CubePrimitive> Default for Shared<T> {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144impl<T: CubePrimitive> Shared<T> {
145    pub fn __expand_default(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
146        Self::__expand_new(scope)
147    }
148}
149
150#[cube]
151impl<T: CubePrimitive> Shared<Line<T>> {
152    #[allow(unused_variables)]
153    pub fn new_lined(#[comptime] line_size: LineSize) -> SharedMemory<Line<T>> {
154        intrinsic!(|scope| {
155            let var = scope.create_shared(Type::new(T::as_type(scope)).line(line_size));
156            ExpandElementTyped::new(var)
157        })
158    }
159}
160
161#[cube]
162impl<T: CubePrimitive + Clone> SharedMemory<T> {
163    #[allow(unused_variables)]
164    pub fn new_aligned(
165        #[comptime] size: usize,
166        #[comptime] line_size: LineSize,
167        #[comptime] alignment: usize,
168    ) -> SharedMemory<Line<T>> {
169        intrinsic!(|scope| {
170            let var = scope.create_shared_array(
171                Type::new(T::as_type(scope)).line(line_size),
172                size,
173                Some(alignment),
174            );
175            ExpandElementTyped::new(var)
176        })
177    }
178
179    /// Frees the shared memory for reuse, if possible on the target runtime.
180    ///
181    /// # Safety
182    /// *Must* be used in uniform control flow
183    /// *Must not* have any dangling references to this shared memory
184    pub unsafe fn free(self) {
185        intrinsic!(|scope| { scope.register(Marker::Free(*self.expand)) })
186    }
187}
188
189fn len_static<T: CubePrimitive>(
190    shared: &ExpandElementTyped<SharedMemory<T>>,
191) -> ExpandElementTyped<usize> {
192    let VariableKind::SharedArray { length, .. } = shared.expand.kind else {
193        unreachable!("Kind of shared memory is always shared memory")
194    };
195    length.into()
196}
197
198/// Module that contains the implementation details of the index functions.
199mod indexation {
200    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
201
202    use crate::ir::Instruction;
203
204    use super::*;
205
206    type SharedMemoryExpand<E> = ExpandElementTyped<SharedMemory<E>>;
207
208    #[cube]
209    impl<E: CubePrimitive> SharedMemory<E> {
210        /// Perform an unchecked index into the array
211        ///
212        /// # Safety
213        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
214        /// always in bounds
215        #[allow(unused_variables)]
216        pub unsafe fn index_unchecked(&self, i: usize) -> &E {
217            intrinsic!(|scope| {
218                let out = scope.create_local(self.expand.ty);
219                scope.register(Instruction::new(
220                    Operator::UncheckedIndex(IndexOperator {
221                        list: *self.expand,
222                        index: i.expand.consume(),
223                        line_size: 0,
224                        unroll_factor: 1,
225                    }),
226                    *out,
227                ));
228                out.into()
229            })
230        }
231
232        /// Perform an unchecked index assignment into the array
233        ///
234        /// # Safety
235        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
236        /// always in bounds
237        #[allow(unused_variables)]
238        pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E) {
239            intrinsic!(|scope| {
240                scope.register(Instruction::new(
241                    Operator::UncheckedIndexAssign(IndexAssignOperator {
242                        index: i.expand.consume(),
243                        value: value.expand.consume(),
244                        line_size: 0,
245                        unroll_factor: 1,
246                    }),
247                    *self.expand,
248                ));
249            })
250        }
251    }
252}
253
254impl<T: CubePrimitive> List<T> for SharedMemory<T> {
255    fn __expand_read(
256        scope: &mut Scope,
257        this: ExpandElementTyped<SharedMemory<T>>,
258        idx: ExpandElementTyped<usize>,
259    ) -> ExpandElementTyped<T> {
260        index::expand(scope, this, idx)
261    }
262}
263
264impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<SharedMemory<T>> {
265    fn __expand_read_method(
266        &self,
267        scope: &mut Scope,
268        idx: ExpandElementTyped<usize>,
269    ) -> ExpandElementTyped<T> {
270        index::expand(scope, self.clone(), idx)
271    }
272    fn __expand_read_unchecked_method(
273        &self,
274        scope: &mut Scope,
275        idx: ExpandElementTyped<usize>,
276    ) -> ExpandElementTyped<T> {
277        index_unchecked::expand(scope, self.clone(), idx)
278    }
279
280    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
281        Self::__expand_len_method(self.clone(), scope)
282    }
283}
284
285impl<T: CubePrimitive> Lined for SharedMemory<T> {}
286impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<SharedMemory<T>> {
287    fn line_size(&self) -> LineSize {
288        self.expand.ty.line_size()
289    }
290}
291
292impl<T: CubePrimitive> ListMut<T> for SharedMemory<T> {
293    fn __expand_write(
294        scope: &mut Scope,
295        this: ExpandElementTyped<SharedMemory<T>>,
296        idx: ExpandElementTyped<usize>,
297        value: ExpandElementTyped<T>,
298    ) {
299        index_assign::expand(scope, this, idx, value);
300    }
301}
302
303impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<SharedMemory<T>> {
304    fn __expand_write_method(
305        &self,
306        scope: &mut Scope,
307        idx: ExpandElementTyped<usize>,
308        value: ExpandElementTyped<T>,
309    ) {
310        index_assign::expand(scope, self.clone(), idx, value);
311    }
312}