Skip to main content

cubecl_core/frontend/container/
shared_memory.rs

1use core::marker::PhantomData;
2use std::ops::{Deref, DerefMut};
3
4use crate::{
5    self as cubecl,
6    prelude::{Lined, LinedExpand},
7    unexpanded,
8};
9use cubecl_ir::{LineSize, Marker, VariableKind};
10use cubecl_macros::{cube, intrinsic};
11
12use crate::{
13    frontend::{CubePrimitive, CubeType, ExpandElementTyped, IntoMut},
14    ir::{Scope, Type},
15    prelude::{
16        Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign, index_unchecked,
17    },
18};
19
20pub type SharedMemoryExpand<T> = ExpandElementTyped<SharedMemory<T>>;
21pub type SharedExpand<T> = ExpandElementTyped<Shared<T>>;
22
23#[derive(Clone, Copy)]
24pub struct Shared<E: CubePrimitive> {
25    _val: PhantomData<E>,
26}
27
28#[derive(Clone, Copy)]
29pub struct SharedMemory<E: CubePrimitive> {
30    _val: PhantomData<E>,
31}
32
33impl<T: CubePrimitive> IntoMut for ExpandElementTyped<SharedMemory<T>> {
34    fn into_mut(self, _scope: &mut Scope) -> Self {
35        self
36    }
37}
38
39impl<T: CubePrimitive> CubeType for SharedMemory<T> {
40    type ExpandType = ExpandElementTyped<SharedMemory<T>>;
41}
42
43impl<T: CubePrimitive> IntoMut for ExpandElementTyped<Shared<T>> {
44    fn into_mut(self, _scope: &mut Scope) -> Self {
45        self
46    }
47}
48
49impl<T: CubePrimitive> CubeType for Shared<T> {
50    type ExpandType = ExpandElementTyped<Shared<T>>;
51}
52
53#[cube]
54impl<T: CubePrimitive + Clone> SharedMemory<T> {
55    #[allow(unused_variables)]
56    pub fn new(#[comptime] size: usize) -> Self {
57        intrinsic!(|scope| {
58            scope
59                .create_shared_array(Type::new(T::as_type(scope)), size, None)
60                .into()
61        })
62    }
63
64    #[allow(unused_variables)]
65    pub fn new_lined(
66        #[comptime] size: usize,
67        #[comptime] line_size: LineSize,
68    ) -> SharedMemory<Line<T>> {
69        intrinsic!(|scope| {
70            scope
71                .create_shared_array(Type::new(T::as_type(scope)).line(line_size), size, None)
72                .into()
73        })
74    }
75
76    #[allow(clippy::len_without_is_empty)]
77    pub fn len(&self) -> usize {
78        intrinsic!(|_| len_static(&self))
79    }
80
81    pub fn buffer_len(&self) -> usize {
82        self.len()
83    }
84}
85
86#[cube]
87impl<T: CubePrimitive> Shared<T> {
88    pub fn new() -> Self {
89        intrinsic!(|scope| {
90            let var = scope.create_shared(Type::new(T::as_type(scope)));
91            ExpandElementTyped::new(var)
92        })
93    }
94}
95
96pub trait AsRefExpand<T: CubeType> {
97    /// Converts this type into a shared reference of the (usually inferred) input type.
98    fn __expand_as_ref_method(self, scope: &mut Scope) -> T::ExpandType;
99}
100impl<T: CubePrimitive> AsRefExpand<T> for ExpandElementTyped<T> {
101    fn __expand_as_ref_method(self, _scope: &mut Scope) -> ExpandElementTyped<T> {
102        self
103    }
104}
105pub trait AsMutExpand<T: CubeType> {
106    /// Converts this type into a shared reference of the (usually inferred) input type.
107    fn __expand_as_mut_method(self, scope: &mut Scope) -> T::ExpandType;
108}
109impl<T: CubePrimitive> AsMutExpand<T> for ExpandElementTyped<T> {
110    fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
111        self
112    }
113}
114
115/// Type inference won't allow things like assign to work normally, so we need to manually call
116/// `as_ref` or `as_mut` for those. Things like barrier ops should take `AsRef` so the conversion
117/// is automatic.
118impl<T: CubePrimitive> AsRef<T> for Shared<T> {
119    fn as_ref(&self) -> &T {
120        unexpanded!()
121    }
122}
123impl<T: CubePrimitive> AsRefExpand<T> for SharedExpand<T> {
124    fn __expand_as_ref_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
125        self.expand.into()
126    }
127}
128
129impl<T: CubePrimitive> AsMut<T> for Shared<T> {
130    fn as_mut(&mut self) -> &mut T {
131        unexpanded!()
132    }
133}
134impl<T: CubePrimitive> AsMutExpand<T> for SharedExpand<T> {
135    fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
136        self.expand.into()
137    }
138}
139
140impl<T: CubePrimitive> Default for Shared<T> {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145impl<T: CubePrimitive> Shared<T> {
146    pub fn __expand_default(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
147        Self::__expand_new(scope)
148    }
149}
150
151#[cube]
152impl<T: CubePrimitive> Shared<Line<T>> {
153    #[allow(unused_variables)]
154    pub fn new_lined(#[comptime] line_size: LineSize) -> SharedMemory<Line<T>> {
155        intrinsic!(|scope| {
156            let var = scope.create_shared(Type::new(T::as_type(scope)).line(line_size));
157            ExpandElementTyped::new(var)
158        })
159    }
160}
161
162#[cube]
163impl<T: CubePrimitive + Clone> SharedMemory<T> {
164    #[allow(unused_variables)]
165    pub fn new_aligned(
166        #[comptime] size: usize,
167        #[comptime] line_size: LineSize,
168        #[comptime] alignment: usize,
169    ) -> SharedMemory<Line<T>> {
170        intrinsic!(|scope| {
171            let var = scope.create_shared_array(
172                Type::new(T::as_type(scope)).line(line_size),
173                size,
174                Some(alignment),
175            );
176            ExpandElementTyped::new(var)
177        })
178    }
179
180    /// Frees the shared memory for reuse, if possible on the target runtime.
181    ///
182    /// # Safety
183    /// *Must* be used in uniform control flow
184    /// *Must not* have any dangling references to this shared memory
185    pub unsafe fn free(self) {
186        intrinsic!(|scope| { scope.register(Marker::Free(*self.expand)) })
187    }
188}
189
190fn len_static<T: CubePrimitive>(
191    shared: &ExpandElementTyped<SharedMemory<T>>,
192) -> ExpandElementTyped<usize> {
193    let VariableKind::SharedArray { length, .. } = shared.expand.kind else {
194        unreachable!("Kind of shared memory is always shared memory")
195    };
196    length.into()
197}
198
199/// Module that contains the implementation details of the index functions.
200mod indexation {
201    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
202
203    use crate::ir::Instruction;
204
205    use super::*;
206
207    type SharedMemoryExpand<E> = ExpandElementTyped<SharedMemory<E>>;
208
209    #[cube]
210    impl<E: CubePrimitive> SharedMemory<E> {
211        /// Perform an unchecked index into the array
212        ///
213        /// # Safety
214        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
215        /// always in bounds
216        #[allow(unused_variables)]
217        pub unsafe fn index_unchecked(&self, i: usize) -> &E {
218            intrinsic!(|scope| {
219                let out = scope.create_local(self.expand.ty);
220                scope.register(Instruction::new(
221                    Operator::UncheckedIndex(IndexOperator {
222                        list: *self.expand,
223                        index: i.expand.consume(),
224                        line_size: 0,
225                        unroll_factor: 1,
226                    }),
227                    *out,
228                ));
229                out.into()
230            })
231        }
232
233        /// Perform an unchecked index assignment into the array
234        ///
235        /// # Safety
236        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
237        /// always in bounds
238        #[allow(unused_variables)]
239        pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E) {
240            intrinsic!(|scope| {
241                scope.register(Instruction::new(
242                    Operator::UncheckedIndexAssign(IndexAssignOperator {
243                        index: i.expand.consume(),
244                        value: value.expand.consume(),
245                        line_size: 0,
246                        unroll_factor: 1,
247                    }),
248                    *self.expand,
249                ));
250            })
251        }
252    }
253}
254
255impl<T: CubePrimitive> List<T> for SharedMemory<T> {
256    fn __expand_read(
257        scope: &mut Scope,
258        this: ExpandElementTyped<SharedMemory<T>>,
259        idx: ExpandElementTyped<usize>,
260    ) -> ExpandElementTyped<T> {
261        index::expand(scope, this, idx)
262    }
263}
264
265impl<T: CubePrimitive> Deref for SharedMemory<T> {
266    type Target = [T];
267
268    fn deref(&self) -> &Self::Target {
269        unexpanded!()
270    }
271}
272
273impl<T: CubePrimitive> DerefMut for SharedMemory<T> {
274    fn deref_mut(&mut self) -> &mut Self::Target {
275        unexpanded!()
276    }
277}
278
279impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<SharedMemory<T>> {
280    fn __expand_read_method(
281        &self,
282        scope: &mut Scope,
283        idx: ExpandElementTyped<usize>,
284    ) -> ExpandElementTyped<T> {
285        index::expand(scope, self.clone(), idx)
286    }
287    fn __expand_read_unchecked_method(
288        &self,
289        scope: &mut Scope,
290        idx: ExpandElementTyped<usize>,
291    ) -> ExpandElementTyped<T> {
292        index_unchecked::expand(scope, self.clone(), idx)
293    }
294
295    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
296        Self::__expand_len_method(self.clone(), scope)
297    }
298}
299
300impl<T: CubePrimitive> Lined for SharedMemory<T> {}
301impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<SharedMemory<T>> {
302    fn line_size(&self) -> LineSize {
303        self.expand.ty.line_size()
304    }
305}
306
307impl<T: CubePrimitive> ListMut<T> for SharedMemory<T> {
308    fn __expand_write(
309        scope: &mut Scope,
310        this: ExpandElementTyped<SharedMemory<T>>,
311        idx: ExpandElementTyped<usize>,
312        value: ExpandElementTyped<T>,
313    ) {
314        index_assign::expand(scope, this, idx, value);
315    }
316}
317
318impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<SharedMemory<T>> {
319    fn __expand_write_method(
320        &self,
321        scope: &mut Scope,
322        idx: ExpandElementTyped<usize>,
323        value: ExpandElementTyped<T>,
324    ) {
325        index_assign::expand(scope, self.clone(), idx, value);
326    }
327}