cubecl_core/frontend/container/
shared_memory.rs

1use core::marker::PhantomData;
2
3use crate::{
4    self as cubecl,
5    prelude::{Lined, LinedExpand},
6    unexpanded,
7};
8use cubecl_ir::{Marker, VariableKind};
9use cubecl_macros::{cube, intrinsic};
10
11use crate::{
12    frontend::{CubePrimitive, CubeType, ExpandElementTyped, IntoMut, indexation::Index},
13    ir::{Scope, Type},
14    prelude::{
15        Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign, index_unchecked,
16    },
17};
18
19pub type SharedMemoryExpand<T> = ExpandElementTyped<SharedMemory<T>>;
20pub type SharedExpand<T> = ExpandElementTyped<Shared<T>>;
21
22#[derive(Clone, Copy)]
23pub struct Shared<E: CubePrimitive> {
24    _val: PhantomData<E>,
25}
26
27#[derive(Clone, Copy)]
28pub struct SharedMemory<E: CubePrimitive> {
29    _val: PhantomData<E>,
30}
31
32impl<T: CubePrimitive> IntoMut for ExpandElementTyped<SharedMemory<T>> {
33    fn into_mut(self, _scope: &mut Scope) -> Self {
34        self
35    }
36}
37
38impl<T: CubePrimitive> CubeType for SharedMemory<T> {
39    type ExpandType = ExpandElementTyped<SharedMemory<T>>;
40}
41
42impl<T: CubePrimitive> IntoMut for ExpandElementTyped<Shared<T>> {
43    fn into_mut(self, _scope: &mut Scope) -> Self {
44        self
45    }
46}
47
48impl<T: CubePrimitive> CubeType for Shared<T> {
49    type ExpandType = ExpandElementTyped<Shared<T>>;
50}
51
52impl<T: CubePrimitive + Clone> SharedMemory<T> {
53    pub fn new<S: Index>(_size: S) -> Self {
54        SharedMemory { _val: PhantomData }
55    }
56
57    pub fn new_lined<S: Index>(_size: S, _vectorization_factor: u32) -> SharedMemory<Line<T>> {
58        SharedMemory { _val: PhantomData }
59    }
60
61    #[allow(clippy::len_without_is_empty)]
62    pub fn len(&self) -> u32 {
63        unexpanded!()
64    }
65
66    pub fn buffer_len(&self) -> u32 {
67        unexpanded!()
68    }
69
70    pub fn __expand_new_lined(
71        scope: &mut Scope,
72        size: ExpandElementTyped<u32>,
73        line_size: u32,
74    ) -> <SharedMemory<Line<T>> as CubeType>::ExpandType {
75        let size = size
76            .constant()
77            .expect("Shared memory need constant initialization value")
78            .as_u32();
79        let var =
80            scope.create_shared_array(Type::new(T::as_type(scope)).line(line_size), size, None);
81        ExpandElementTyped::new(var)
82    }
83
84    pub fn vectorized<S: Index>(_size: S, _vectorization_factor: u32) -> Self {
85        SharedMemory { _val: PhantomData }
86    }
87
88    pub fn __expand_vectorized(
89        scope: &mut Scope,
90        size: ExpandElementTyped<u32>,
91        line_size: u32,
92    ) -> <Self as CubeType>::ExpandType {
93        let size = size
94            .constant()
95            .expect("Shared memory need constant initialization value")
96            .as_u32();
97        let var =
98            scope.create_shared_array(Type::new(T::as_type(scope)).line(line_size), size, None);
99        ExpandElementTyped::new(var)
100    }
101
102    pub fn __expand_new(
103        scope: &mut Scope,
104        size: ExpandElementTyped<u32>,
105    ) -> <Self as CubeType>::ExpandType {
106        let size = size
107            .constant()
108            .expect("Shared memory need constant initialization value")
109            .as_u32();
110        let var = scope.create_shared_array(Type::new(T::as_type(scope)), size, None);
111        ExpandElementTyped::new(var)
112    }
113
114    pub fn __expand_len(
115        scope: &mut Scope,
116        this: ExpandElementTyped<Self>,
117    ) -> ExpandElementTyped<u32> {
118        this.__expand_len_method(scope)
119    }
120
121    pub fn __expand_buffer_len(
122        scope: &mut Scope,
123        this: ExpandElementTyped<Self>,
124    ) -> ExpandElementTyped<u32> {
125        this.__expand_buffer_len_method(scope)
126    }
127}
128
129#[cube]
130impl<T: CubePrimitive> Shared<T> {
131    pub fn new() -> Self {
132        intrinsic!(|scope| {
133            let var = scope.create_shared(Type::new(T::as_type(scope)));
134            ExpandElementTyped::new(var)
135        })
136    }
137}
138
139pub trait AsRefExpand<T: CubeType> {
140    /// Converts this type into a shared reference of the (usually inferred) input type.
141    fn __expand_as_ref_method(self, scope: &mut Scope) -> T::ExpandType;
142}
143impl<T: CubePrimitive> AsRefExpand<T> for ExpandElementTyped<T> {
144    fn __expand_as_ref_method(self, _scope: &mut Scope) -> ExpandElementTyped<T> {
145        self
146    }
147}
148pub trait AsMutExpand<T: CubeType> {
149    /// Converts this type into a shared reference of the (usually inferred) input type.
150    fn __expand_as_mut_method(self, scope: &mut Scope) -> T::ExpandType;
151}
152impl<T: CubePrimitive> AsMutExpand<T> for ExpandElementTyped<T> {
153    fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
154        self
155    }
156}
157
158/// Type inference won't allow things like assign to work normally, so we need to manually call
159/// `as_ref` or `as_mut` for those. Things like barrier ops should take `AsRef` so the conversion
160/// is automatic.
161impl<T: CubePrimitive> AsRef<T> for Shared<T> {
162    fn as_ref(&self) -> &T {
163        unexpanded!()
164    }
165}
166impl<T: CubePrimitive> AsRefExpand<T> for SharedExpand<T> {
167    fn __expand_as_ref_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
168        self.expand.into()
169    }
170}
171
172impl<T: CubePrimitive> AsMut<T> for Shared<T> {
173    fn as_mut(&mut self) -> &mut T {
174        unexpanded!()
175    }
176}
177impl<T: CubePrimitive> AsMutExpand<T> for SharedExpand<T> {
178    fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
179        self.expand.into()
180    }
181}
182
183impl<T: CubePrimitive> Default for Shared<T> {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188impl<T: CubePrimitive> Shared<T> {
189    pub fn __expand_default(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
190        Self::__expand_new(scope)
191    }
192}
193
194#[cube]
195impl<T: CubePrimitive> Shared<Line<T>> {
196    #[allow(unused_variables)]
197    pub fn new_lined(#[comptime] line_size: u32) -> SharedMemory<Line<T>> {
198        intrinsic!(|scope| {
199            let var = scope.create_shared(Type::new(T::as_type(scope)).line(line_size));
200            ExpandElementTyped::new(var)
201        })
202    }
203}
204
205impl<T: CubePrimitive> ExpandElementTyped<SharedMemory<T>> {
206    pub fn __expand_len_method(self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
207        len_static(&self)
208    }
209
210    pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
211        self.__expand_len_method(scope)
212    }
213}
214
215#[cube]
216impl<T: CubePrimitive + Clone> SharedMemory<T> {
217    #[allow(unused_variables)]
218    pub fn new_aligned(
219        #[comptime] size: u32,
220        #[comptime] line_size: u32,
221        #[comptime] alignment: u32,
222    ) -> SharedMemory<Line<T>> {
223        intrinsic!(|scope| {
224            let var = scope.create_shared_array(
225                Type::new(T::as_type(scope)).line(line_size),
226                size,
227                Some(alignment),
228            );
229            ExpandElementTyped::new(var)
230        })
231    }
232
233    /// Frees the shared memory for reuse, if possible on the target runtime.
234    ///
235    /// # Safety
236    /// *Must* be used in uniform control flow
237    /// *Must not* have any dangling references to this shared memory
238    pub unsafe fn free(self) {
239        intrinsic!(|scope| { scope.register(Marker::Free(*self.expand)) })
240    }
241}
242
243fn len_static<T: CubePrimitive>(
244    shared: &ExpandElementTyped<SharedMemory<T>>,
245) -> ExpandElementTyped<u32> {
246    let VariableKind::SharedArray { length, .. } = shared.expand.kind else {
247        unreachable!("Kind of shared memory is always shared memory")
248    };
249    length.into()
250}
251
252/// Module that contains the implementation details of the index functions.
253mod indexation {
254    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
255
256    use crate::ir::Instruction;
257
258    use super::*;
259
260    type SharedMemoryExpand<E> = ExpandElementTyped<SharedMemory<E>>;
261
262    #[cube]
263    impl<E: CubePrimitive> SharedMemory<E> {
264        /// Perform an unchecked index into the array
265        ///
266        /// # Safety
267        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
268        /// always in bounds
269        #[allow(unused_variables)]
270        pub unsafe fn index_unchecked(&self, i: u32) -> &E {
271            intrinsic!(|scope| {
272                let out = scope.create_local(self.expand.ty);
273                scope.register(Instruction::new(
274                    Operator::UncheckedIndex(IndexOperator {
275                        list: *self.expand,
276                        index: i.expand.consume(),
277                        line_size: 0,
278                        unroll_factor: 1,
279                    }),
280                    *out,
281                ));
282                out.into()
283            })
284        }
285
286        /// Perform an unchecked index assignment into the array
287        ///
288        /// # Safety
289        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
290        /// always in bounds
291        #[allow(unused_variables)]
292        pub unsafe fn index_assign_unchecked(&mut self, i: u32, value: E) {
293            intrinsic!(|scope| {
294                scope.register(Instruction::new(
295                    Operator::UncheckedIndexAssign(IndexAssignOperator {
296                        index: i.expand.consume(),
297                        value: value.expand.consume(),
298                        line_size: 0,
299                        unroll_factor: 1,
300                    }),
301                    *self.expand,
302                ));
303            })
304        }
305    }
306}
307
308impl<T: CubePrimitive> List<T> for SharedMemory<T> {
309    fn __expand_read(
310        scope: &mut Scope,
311        this: ExpandElementTyped<SharedMemory<T>>,
312        idx: ExpandElementTyped<u32>,
313    ) -> ExpandElementTyped<T> {
314        index::expand(scope, this, idx)
315    }
316}
317
318impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<SharedMemory<T>> {
319    fn __expand_read_method(
320        &self,
321        scope: &mut Scope,
322        idx: ExpandElementTyped<u32>,
323    ) -> ExpandElementTyped<T> {
324        index::expand(scope, self.clone(), idx)
325    }
326    fn __expand_read_unchecked_method(
327        &self,
328        scope: &mut Scope,
329        idx: ExpandElementTyped<u32>,
330    ) -> ExpandElementTyped<T> {
331        index_unchecked::expand(scope, self.clone(), idx)
332    }
333
334    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
335        Self::__expand_len_method(self.clone(), scope)
336    }
337}
338
339impl<T: CubePrimitive> Lined for SharedMemory<T> {}
340impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<SharedMemory<T>> {
341    fn line_size(&self) -> u32 {
342        self.expand.ty.line_size()
343    }
344}
345
346impl<T: CubePrimitive> ListMut<T> for SharedMemory<T> {
347    fn __expand_write(
348        scope: &mut Scope,
349        this: ExpandElementTyped<SharedMemory<T>>,
350        idx: ExpandElementTyped<u32>,
351        value: ExpandElementTyped<T>,
352    ) {
353        index_assign::expand(scope, this, idx, value);
354    }
355}
356
357impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<SharedMemory<T>> {
358    fn __expand_write_method(
359        &self,
360        scope: &mut Scope,
361        idx: ExpandElementTyped<u32>,
362        value: ExpandElementTyped<T>,
363    ) {
364        index_assign::expand(scope, self.clone(), idx, value);
365    }
366}